完善 retry 装饰器

This commit is contained in:
2024-06-08 11:57:45 +08:00
parent fa05b80e61
commit ce95d8f977

View File

@@ -1,9 +1,9 @@
from asyncio import sleep
from collections.abc import Awaitable, Callable
from collections.abc import Callable, Coroutine
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from typing import ParamSpec, TypeVar, cast
from typing import Any, ParamSpec, TypeVar
from nonebot.log import logger
from nonebot_plugin_alconna.uniseg import SerializeFailed, UniMessage
@@ -17,24 +17,28 @@ def retry(
exception_type: type[BaseException] | tuple[type[BaseException], ...] = Exception,
delay: timedelta | None = None,
reply: str | UniMessage | None = None,
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
) -> Callable[[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]]:
def decorator(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Coroutine[Any, Any, T]]:
@wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
for i in range(1, max_attempts + 1):
try:
return await func(*args, **kwargs)
except exception_type as e: # noqa: PERF203
logger.exception(e)
if delay is not None:
await sleep(delay.total_seconds())
for i in range(max_attempts + 1):
if i > 0:
message = f'Retrying: {func.__name__} ({i}/{max_attempts})'
logger.debug(message)
with suppress(SerializeFailed):
await UniMessage(reply or message).send()
msg = 'Unexpectedly reached the end of the retry loop'
raise RuntimeError(msg)
if i == max_attempts:
break
try:
return await func(*args, **kwargs)
except exception_type as e:
if i == max_attempts:
raise
logger.exception(e)
if delay is not None:
await sleep(delay.total_seconds())
return await func(*args, **kwargs)
return cast(Callable[P, Awaitable[T]], wrapper)
return wrapper
return decorator