Python과 FastAPI로 지속 성장 가능한 웹서비스 개발하기 - 2. MySQL & SQLAlchemy

이전 글에 이어서 작성된 글입니다.

MySQL

프로덕션으로는 Amazon Aurora MySQL을 사용하기로 했다. 순수하게 python으로 구현된 MySQL 드라이버 PyMySQL의 대부분을 차용하여 만들어진 aiomysql 비동기 드라이버로 데이터베이스를 사용한다. 순수 파이썬으로 구현된 드라이버는 aws lambda에도 쉽게 올릴 수 있는 장점을 갖고 있다. (나중에 모노레포 형식으로 aws lambda를 사용할 수도 있지 않을까?)

ORM 라이브러리로는 SQLAlchemy를 사용했다. 최근에는 asyncio 지원도 꽤 성숙하게 올라와 있어 아주 만족스럽게 사용하고 있다.

With asyncio scoped session (async_scoped_session)

기본적으로 SQLAlchemy가 제공해주는 api를 충실하게 사용하여 설정했다.

from asyncio import current_task

from sqlalchemy.ext.asyncio import (
    async_scoped_session,
    async_sessionmaker,
    AsyncEngine,
    AsyncSession,
    create_async_engine,
)


class AsyncDatabase:
    engines: dict[str, AsyncEngine] = {}

    def __init__(self, writer_url: str, reader_url: str, echo: bool = False) -> None:
        self.engines = {
            "writer": create_async_engine(writer_url, pool_recycle=3600, echo=echo),
            "reader": create_async_engine(reader_url, pool_recycle=3600, echo=echo),
        }
        self._session = async_scoped_session(
            scopefunc=current_task,
            session_factory=async_sessionmaker(
                class_=AsyncSession,
                sync_session_class=self.get_routing_session(),
            ),  # Here!
        )

sqlalchemy는 async_scoped_session 클래스를 통해 asyncio Task별로 격리된 session을 사용할 수 있게 한다. scopefuncasyncio.current_task()를 사용하였다. 즉 asyncio.Task별로 1개의 격리된 session을 이용하겠다는 설정이다.

FastAPI의 Request 관점에서는 요청마다 격리된 세션을 사용하게 된다. 요청한 코드 안에서 새로운 태스크를 추가한다면, 그 새로운 태스크는 격리된 세션을 사용할 수 있다.

트랜잭션 처리

스프링의 @Transactional, django의 @transaction.atomic과 같은 기능

서비스를 구현하면서 데이터베이스와의 통신을 생각할 때, 꼭 필요한 개념 중 하나인 트랜잭션 처리를 아래와 같이 구현할 수 있다. 이미 바로 밑의 데코레이터 코드는 여기저기서 쉽게 볼 수 있다. 이미 많은 분들이 synchronous 엔진 기반의 SQLAlchemy를 사용하는 동안 이렇게 쓰고 있었다.

class Transactional:
    commit: bool

    def __init__(self, commit=False):
        self.commit = commit

    @inject
    def _get_session(self, session=Provide["infra.rdb.provided.session"]):
        return session

    def __call__(self, func):
        @wraps(func)
        async def _transactional(*args, **kwargs):
            session = self._get_session()
            try:
                result = await func(*args, **kwargs)
            except Exception as e:
                await session.rollback()
                raise e
            else:
                if self.commit:
                    await session.commit()

            return result

        return _transactional

물론 Task 1개를 핸들링할 때만 사용되어야 한다. 한 코루틴 안에 여러 database 테스크가 물려있는 경우 롤백 처리 등이 지원되지 않는다. (어쩌면 당연한 이야기.)

마침 python 3.11에는 nursery 개념이 포함된 TaskGroup 이 포함되었다. 여러 개의 Task들이 실행되고 모두 종료되었을 때 실패한 Task가 1개라도 있으면 모든 Task의 모든 세션은 커밋되지 않은 모든 동작을 롤백하도록 하였다. 간단한 상속과 래핑으로 구현할 수 있었다.

from asyncio import TaskGroup

from dependency_injector.wiring import Provide, inject
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session


class TransactionalGroupTask(TaskGroup):
    commit: bool
    sessions = set()

    def __init__(self, commit=False, **kwargs):
        self.commit = commit
        super().__init__(**kwargs)

    @inject
    def _get_session(self, session: async_scoped_session = Provide["infra.rdb.provided.session"]):
        return session

    def _get_session_by_task(self) -> AsyncSession | None:
        proxy = self._get_session()
        scopefunc = proxy.registry.scopefunc
        return proxy.registry.registry.get(scopefunc())

    async def __aenter__(self):
        return await super().__aenter__()

    async def __aexit__(self, exc_type, exc, tb):
        try:
            ret = await super().__aexit__(exc_type, exc, tb)
        except BaseExceptionGroup as e:
            await self.rollback_sessions()
            raise e
        else:
            if self.commit:
                await self.commit_sessions()
            return ret
        finally:
            await self.remove_sessions()

    async def rollback_sessions(self):
        for session in self.sessions:
            await session.rollback()

    async def commit_sessions(self):
        for session in self.sessions:
            await session.commit()

    async def remove_sessions(self):
        for session in self.sessions:
            await session.remove()

    def create_task(self, coro, *, name=None, context=None):
        async def wrap():
            try:
                ret = await coro
            except Exception as e:
                raise e
            else:
                return ret
            finally:
                if _task_local_session := self._get_session_by_task():
                    self.sessions.add(_task_local_session)

        return super().create_task(wrap(), name=name, context=context)

예외 처리를 무시하는 로직이던, 필요에 맞게 변형해 사용하시면 된다. 아래와 같이 이렇게 사용할 수 있다.

async def batch():
    try:
        async with TransactionalGroupTask(commit=True) as ttg:
            ttg.create_task(container.some_service().do_some_write_process(1))
            ttg.create_task(container.some_service().do_some_write_process(2))
            ttg.create_task(container.some_service().do_some_update_process(3))
        pass
    except* BadRequestException as e:
        print(e)

How we write query

SQLAlchemy에서 공식적으로 권장하는 “2.0 style”로 엔티티를 구성하고 쿼리를 작성하고 있다.

stmt = (
    select(ReviewEntity)
    .where(
        ReviewEntity.user_external_id == user_external_id,
        ReviewEntity.is_deleted.is_(False),
    )
    .order_by(ReviewEntity.created_at.desc())
    .limit(limit)
    .offset(offset)
    .options(selectinload(ReviewEntity.images))
)

발생하는 쿼리는 아래와 같다. 개발자가 의도하는 쿼리가 거의 일치하는 모습이다.

--- query 1
SELECT review.id, ...(생략)
FROM review
WHERE review.user_external_id = %s AND review.is_deleted IS false
ORDER BY review.created_at DESC
LIMIT %s, %s
-- (208232, 0, 10)

--- query 2
SELECT review_image.review_id AS review_image_review_id, ...(생략)
FROM review_image
WHERE review_image.review_id IN (%s, %s, %s, %s, %s, %s)
--  (2127, 2126, 1775, 1214, 1051, 827)

Repository Pattern

지난 글에서 아래의 원칙을 소개해드렸다.

  1. 레이어는 위에서 아래로 순방향으로만 참조되어야 한다.
  2. 역류 참조를 절대 금지한다.
  3. 건너뛰는 참조도 금지한다.

Repository Layer도 마찬가지로 위의 원칙을 지켜야 한다. 프로젝트에 Repository Layer를 도입하여…

이를 통해 결합도가 낮은 유연한 구조를 만들 수 있었다. 아래는 간단하게 Repository Class를 위한 abstract Class를 구현한 것이다.

from typing import Generic, Type, TypeVar

from sqlalchemy import func, Label, Result, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import async_scoped_session

from infra.rdb.base.base_entity import BaseEntity
from infra.rdb.base.mysql_error_exception import MySQLErrorException
from infra.rdb.base.results import ResultWithTotalCount

T = TypeVar("T", bound=BaseEntity)


class BaseRepository(Generic[T]):
    entity: Type[T]
    session: async_scoped_session

    def __init__(self, entity: Type[T], session: async_scoped_session):
        self.entity = entity
        self.session = session

    @property
    def table(self):
        return BaseEntity.metadata.tables[self.entity.__tablename__]

    async def create(self, **kwargs) -> T:
        self.session.add(entity := self.entity(**kwargs))
        try:
            await self.session.flush()
        except IntegrityError as e:
            await self.session.rollback()
            raise MySQLErrorException(e)
        return entity

    async def find_by_id(self, id: int) -> T | None:
        stmt = select(self.entity).where(
            self.entity.id == id,
            self.entity.is_deleted.is_(False),
        )
        result = await self.session.execute(stmt)
        return result.scalars().first()

그렇다면 API endpoints는 이렇게 구현될 것이다. 어떤 상품을 product_id로 조회하는 API를 예시로 들겠다.

@router.get("/{product_id:int}")
@inject
async def get_product_by_id(
    product_id: int,
    product_service: ProductService = Depends(Provide['product_service']),
):
    product = await product_service.get_by_id(product_id)
    return ProductResponse.from_model(product)

엔드포인트 구현은 product_service를 주입받아 사용만 했을 뿐이다. 상위 레이어인 표현 계층에서 ProductService 클래스를 생성하지 않는다.

class ProductService:
    def __init__(product_repository: ProductRepository):
        self.product_repository = product_repository

    async def get_by_id(product_id: int):
        entity = await self.product_repository.find_by_id(product_id)
        if entity is None:
            raise NotFoundException("상품이 존재하지 않습니다")
        return Product.from_entity(entity)

서비스 구현도 쿼리에 대한 이야기는 전혀 없다.단지 Repository 인스턴스에서 인터페이스에 의존해 product를 가져오는 코드일 뿐이다.

메모리 릭

Using current_task() for the “key” in the scope requires that the async_scoped_session.remove() method is called from within the outermost awaitable, to ensure the key is removed from the registry when the task is complete, otherwise the task handle as well as the AsyncSession will remain in memory, essentially creating a memory leak. See the following example which illustrates the correct use of async_scoped_session.remove().

문서 발췌 : Asynchronous I/O (asyncio)

내부적으로 ScopedRegistry가 미리 선언한 scopefunc의 결과를 키로 AsyncSession을 Dictionary 형태로 저장해두고 있다. 많은 Task를 실행할 수록 레지스트리 딕셔너리가 커질 수 밖에 없다. scopefunc가 어떤 형태로 지정되었는지 프레임워크는 알 수 있는 형태가 아니므로, scope의 생명주기에 따라 AsyncSession을 정리해야 한다.

아까 위에서 언급한 TransactionalGroupTask의 구현에 remove_sessions()메소드가 포함된 것도 이와 관련된 일이다. 서브태스크들은 이미 종료되었기에 session을 삭제해주는 로직이 포함되어 있다.

이외 요청에 의해 생성된 Task들은 Middleware를 통해 정리할 수 있다. 간단한 미들웨어 구현은 아래와 같다.

class AsyncSessionMiddleware:
    def __init__(self, app: ASGIApp):
        self.app = app

    @inject
    def session(self, session=Provide["infra.rdb.provided.session"]):
        return session

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        session = self.session()
        try:
            await self.app(scope, receive, send)
        except Exception as e:
            raise e
        else:
            await session.commit()
        finally:
            await session.remove()

다음 글(작성 중)로 이어집니다.