🐛 修复 pydantic model 不能被正确反序列化的bug

This commit is contained in:
2023-11-29 11:43:00 +08:00
parent 546369241a
commit 88c2915251
2 changed files with 19 additions and 8 deletions

View File

@@ -45,8 +45,8 @@ def upgrade(name: str = '') -> None:
sa.Column('game_platform', sa.String(length=32), nullable=False), sa.Column('game_platform', sa.String(length=32), nullable=False),
sa.Column('command_type', sa.String(length=16), nullable=False), sa.Column('command_type', sa.String(length=16), nullable=False),
sa.Column('command_args', sa.JSON(), nullable=False), sa.Column('command_args', sa.JSON(), nullable=False),
sa.Column('game_user', PydanticType(), nullable=False), sa.Column('game_user', PydanticType(list), nullable=False),
sa.Column('processed_data', PydanticType(), nullable=False), sa.Column('processed_data', PydanticType(list), nullable=False),
sa.Column('finish_time', sa.DateTime(), nullable=False), sa.Column('finish_time', sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint('id', name=op.f('pk_nonebot_plugin_tetris_stats_historicaldata')), sa.PrimaryKeyConstraint('id', name=op.f('pk_nonebot_plugin_tetris_stats_historicaldata')),
) )

View File

@@ -1,9 +1,10 @@
from collections.abc import Callable, Sequence
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from nonebot.adapters import Message from nonebot.adapters import Message
from nonebot_plugin_orm import Model from nonebot_plugin_orm import Model
from pydantic import BaseModel from pydantic import BaseModel, ValidationError
from sqlalchemy import JSON, DateTime, Dialect, PickleType, String, TypeDecorator from sqlalchemy import JSON, DateTime, Dialect, PickleType, String, TypeDecorator
from sqlalchemy.orm import Mapped, MappedAsDataclass, mapped_column from sqlalchemy.orm import Mapped, MappedAsDataclass, mapped_column
@@ -14,16 +15,24 @@ from ..utils.typing import CommandType, GameType
class PydanticType(TypeDecorator): class PydanticType(TypeDecorator):
impl = JSON impl = JSON
def __init__(self, get_model: Callable[[], Sequence[type[BaseModel]]], *args: Any, **kwargs: Any): # noqa: ANN401
self.get_model = get_model
super().__init__(*args, **kwargs)
def process_bind_param(self, value: Any | None, dialect: Dialect) -> str: # noqa: ANN401 def process_bind_param(self, value: Any | None, dialect: Dialect) -> str: # noqa: ANN401
# 将 Pydantic 模型实例转换为 JSON # 将 Pydantic 模型实例转换为 JSON
if isinstance(value, BaseModel): if isinstance(value, tuple(self.get_model())):
return value.json() return value.json() # type: ignore[union-attr]
raise TypeError raise TypeError
def process_result_value(self, value: Any | None, dialect: Dialect) -> BaseModel: # noqa: ANN401 def process_result_value(self, value: Any | None, dialect: Dialect) -> BaseModel: # noqa: ANN401
# 将 JSON 转换回 Pydantic 模型实例 # 将 JSON 转换回 Pydantic 模型实例
if isinstance(value, str | bytes): if isinstance(value, str | bytes):
return BaseModel.parse_raw(value) for i in self.get_model():
try:
return i.parse_raw(value)
except ValidationError: # noqa: PERF203
...
raise TypeError raise TypeError
@@ -46,6 +55,8 @@ class HistoricalData(MappedAsDataclass, Model):
game_platform: Mapped[GameType] = mapped_column(String(32), index=True, init=False) game_platform: Mapped[GameType] = mapped_column(String(32), index=True, init=False)
command_type: Mapped[CommandType] = mapped_column(String(16), index=True, init=False) command_type: Mapped[CommandType] = mapped_column(String(16), index=True, init=False)
command_args: Mapped[list[str]] = mapped_column(JSON, init=False) command_args: Mapped[list[str]] = mapped_column(JSON, init=False)
game_user: Mapped[BaseUser] = mapped_column(PydanticType, init=False) game_user: Mapped[BaseUser] = mapped_column(PydanticType(get_model=BaseUser.__subclasses__), init=False)
processed_data: Mapped[BaseProcessedData] = mapped_column(PydanticType, init=False) processed_data: Mapped[BaseProcessedData] = mapped_column(
PydanticType(get_model=BaseProcessedData.__subclasses__), init=False
)
finish_time: Mapped[datetime] = mapped_column(DateTime, init=False) finish_time: Mapped[datetime] = mapped_column(DateTime, init=False)