Idempotent Endpoints 구현

Idempotent?

멱등성이란, 연산을 여러 번 적용하더라도 결과가 달라지지 않는 성질을 의미한다.

멱등성을 보장하는 API Endpoints는 같은 요청이 여러 번 일어나더라도 항상 첫 번째 요청과 같은 결과를 받게 된다. 단지 응답값만 같은 것이 아니라 side-effect도 한 번만 일어나야 한다. 이런 성질이 보장되는 건 특정 상황에서 매우 중요할 수 있다.

HTTP 통신은 다양한 이유로 재시도될 수 있고, 이런 재시도가 참사를 만들 수 있다. 결제 API라면 중복 결제가 일어날 수도 있기 때문이다. 이런 상황을 방지하기 위해 멱등성을 보장하는 구현을 통해 안전한 어플리케이션을 구현할 수 있다.

멱등성을 위한 키 (Idempotency Key)

HTTP 요청에서 다양한 이유로 재시도되더라도 서버 어플리케이션에서 같은 요청인지 다른 요청인지 판단할 수 있도록 요청에 키값을 포함시킨다. 바로 개발자들이 곧 잘 사용하는 aws에도 멱등성 확보를 위해 이와 같은 구현이 되어있다.

Retries-safe API Implementation

많은 구현체들이 멱등성 확보를 위해 이와 같은 흐름을 가지고 있다. 본인은 이를 위해 Redis를 데이터베이스로 잘 사용하는데, 그 이유는 속도가 빠르고 락을 아주 간단하게 사용할 수 있기 때문이다.

Redis: SET with NX

SET if not exists : https://redis.io/docs/latest/commands/setnx/

redis의 setnx 를 이용한 락 구현으로 멱등성을 보장할 수 있다. 최신 버전의 레디스는 setnx 명령어는 Deprecated되었고, set 명령의 nx 키워드를 사용한다.

두 가지 과정을 atomic하게 시간복잡도 O(1)으로 수행한다.

  1. 요청에 포함되어 있는 Client Request Id를 키로하여 락이 존재하는지 확인하고, 존재하지 않으면 해당 요청이 처음으로 들어온 것으로 간주하여 실행한다.
  2. 만약 응답이 나가기 전에 같은 값의 Client Request Id를 가진 요청이 또 들어온다면, 락의 존재를 확인 후 같은 요청으로 간주하고 409 Conflict 예외를 발생시킨다.
  3. 1의 요청이 완료되면 응답을 Client Request Id를 키로 캐싱해둔다.
  4. 1의 응답이 처리된 이후에 같은 값의 Client Request Id를 가진 요청이 또 들어온다면, 캐싱된 Equivalent한 응답을 내보내며 실제 구현이 동작하지 않는다.

Implementation

스프링 프레임워크에서는 고수준 API인 RedisTemplate가 setIfAbsent()메서드를 제공한다. 특정 Controller의 메서드에 적용시키기 위해서 아래와 같은 어노테이션 클래스를 구현한다.

import io.swagger.v3.oas.annotations.media.Content
import io.swagger.v3.oas.annotations.media.Schema
import io.swagger.v3.oas.annotations.responses.ApiResponse
import io.binaryflavor.example.base.schemas.ErrorResponse
import kotlin.reflect.KClass

@Target(AnnotationTarget.FUNCTION)
@Retention(AnnotationRetention.RUNTIME)
annotation class Idempotent(
    val lifetime: Long = 3600,
    val responseClass: KClass<*>,
    val value: Array<ApiResponse> = [
        ApiResponse(
            responseCode = "409",
            description = "중복된 요청이 처리 중일 때 (`IDEMPOTENT_REQUEST_PROCESSING`)",
            content = [Content(schema = Schema(implementation = ErrorResponse::class))],
        ),
        ApiResponse(
            responseCode = "422",
            description = """
            | error code | description |
            | --- | --- |
            | `CLIENT_REQUEST_ID_REQUIRED` | 클라이언트 요청 ID가 필요합니다. |
            | `CLIENT_REQUEST_ID_MAX_LENGTH_LIMIT_EXCEEDED` | 클라이언트 요청 ID가 최대 길이를 초과했습니다. (129자 이상인 경우~) |
            | `INVALID_CLIENT_REQUEST_ID` | 부적절한 키 값입니다. (Alphanumeric & Dash 외의 char 포함 경우) |""",
            content = [Content(schema = Schema(implementation = ErrorResponse::class))],
        ),
    ],
)

어노테이션을 처리할 Aspect 구현은 아래와 같다.

import com.google.gson.Gson
import jakarta.servlet.http.HttpServletRequest
import io.binaryflavor.example.core.exception.idempotent.IdempotentRequestProcessing
import io.binaryflavor.example.core.exception.idempotent.InvalidClientRequestId
import io.binaryflavor.example.storages.rediscore.repositories.IdempotentInterceptorStorageRepository
import org.aspectj.lang.ProceedingJoinPoint
import org.aspectj.lang.annotation.Around
import org.aspectj.lang.annotation.Aspect
import org.springframework.stereotype.Component
import org.springframework.web.servlet.HandlerInterceptor
import kotlin.reflect.KClass

@Aspect
@Component
class IdempotentAspect(
    private val idempotentInterceptorStorageRepository: IdempotentInterceptorStorageRepository,
) : HandlerInterceptor {
    private fun getRequest(proceedingJoinPoint: ProceedingJoinPoint): HttpServletRequest =
        proceedingJoinPoint.args.firstOrNull {
            it is HttpServletRequest
        } as HttpServletRequest

    private fun getRequestId(request: HttpServletRequest) = request.getHeader("X-Client-Request-Id").also {
        validateClientRequestId(it)
    }

    private fun buildCacheKey(signature: String, requestId: String, postfix: String) =
        "idempotent:$signature:$requestId:$postfix"

    private fun validateClientRequestId(requestId: String) {
        if (requestId.length > 128) {
            throw InvalidClientRequestId(
                code = "CLIENT_REQUEST_ID_MAX_LENGTH_LIMIT_EXCEEDED",
                message = "Client request id max length limit exceeded",
            )
        }
        if (!requestId.matches(Regex("^[a-zA-Z0-9-]*$"))) {
            throw InvalidClientRequestId(
                code = "INVALID_CLIENT_REQUEST_ID",
                message = "Invalid client request id",
            )
        }
    }

    private fun lockIfAbsent(signature: String, requestId: String, duration: Long): Boolean {
        return idempotentInterceptorStorageRepository.setNx(
            key = buildCacheKey(signature, requestId, "lock"),
            value = "1",
            expireSeconds = duration,
        )
    }

    private fun getCachedResponse(signature: String, requestId: String, responseClass: KClass<*>) =
        idempotentInterceptorStorageRepository.get(buildCacheKey(signature, requestId, "response"))?.let {
            Gson().fromJson(it, responseClass.java)
        }

    private fun cacheResponse(
        signature: String,
        requestId: String,
        response: Any,
        responseClass: KClass<*>,
        duration: Long,
    ) = idempotentInterceptorStorageRepository.set(
        key = buildCacheKey(signature, requestId, "response"),
        value = Gson().toJson(response, responseClass.java),
        expireSeconds = duration,
    )

    private fun unlock(signature: String, requestId: String) {
        idempotentInterceptorStorageRepository.delete(buildCacheKey(signature, requestId, "lock"))
    }

    private fun getSignature(proceedingJoinPoint: ProceedingJoinPoint): String {
        val methodName = proceedingJoinPoint.signature.name
        val args = proceedingJoinPoint.args.joinToString(",") { it.toString() }
        return "$methodName:$args"
    }

    @Around("@annotation(idempotent)")
    @Throws(Throwable::class)
    fun around(proceedingJoinPoint: ProceedingJoinPoint, idempotent: Idempotent): Any {
        val request = getRequest(proceedingJoinPoint)
        val requestId = getRequestId(request) ?: return proceedingJoinPoint.proceed()
        val signature = getSignature(proceedingJoinPoint)
        if (!lockIfAbsent(signature, requestId, idempotent.lifetime)) {
            return getCachedResponse(signature, requestId, idempotent.responseClass)
                ?: throw IdempotentRequestProcessing(message = "Request is being processed")
        } else {
            val result = try {
                proceedingJoinPoint.proceed()
            } catch (e: Throwable) {
                unlock(signature, requestId)
                throw e
            }
            cacheResponse(signature, requestId, result, idempotent.responseClass, idempotent.lifetime)
            return result
        }
    }
    
}