完善 retry 装饰器的类型,并添加消息提示功能

This commit is contained in:
2024-06-04 20:14:19 +08:00
parent b6f6eb1170
commit b7b152d84d

View File

@@ -1,38 +1,40 @@
from asyncio import sleep from asyncio import sleep
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from contextlib import suppress
from datetime import timedelta from datetime import timedelta
from functools import wraps from functools import wraps
from typing import TypeVar, cast from typing import ParamSpec, TypeVar, cast
from nonebot.log import logger from nonebot.log import logger
from nonebot_plugin_alconna.uniseg import SerializeFailed, UniMessage
T = TypeVar('T') T = TypeVar('T')
P = ParamSpec('P')
def retry( def retry(
max_attempts: int = 3, max_attempts: int = 3,
exception_type: type[BaseException] | tuple[type[BaseException], ...] = Exception, exception_type: type[BaseException] | tuple[type[BaseException], ...] = Exception,
delay: timedelta | None = None, delay: timedelta | None = None,
) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]: reply: str | UniMessage | None = None,
def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]: ) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs) -> T: # noqa: ANN002, ANN003 async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
attempts = 0 for i in range(1, max_attempts + 1):
while attempts < max_attempts + 1:
try: try:
return await func(*args, **kwargs) return await func(*args, **kwargs)
except exception_type as e: # noqa: PERF203 except exception_type as e: # noqa: PERF203
logger.exception(e) logger.exception(e)
attempts += 1
if attempts <= max_attempts:
if delay is not None: if delay is not None:
await sleep(delay.total_seconds()) await sleep(delay.total_seconds())
logger.debug(f'Retrying: {func.__name__} ({attempts}/{max_attempts})') message = f'Retrying: {func.__name__} ({i}/{max_attempts})'
continue logger.debug(message)
raise with suppress(SerializeFailed):
await UniMessage(reply or message).send()
msg = 'Unexpectedly reached the end of the retry loop' msg = 'Unexpectedly reached the end of the retry loop'
raise RuntimeError(msg) raise RuntimeError(msg)
return cast(Callable[..., Awaitable[T]], wrapper) return cast(Callable[P, Awaitable[T]], wrapper)
return decorator return decorator