refactor: Refactor model

This commit is contained in:
LWR
2022-10-31 19:06:51 +08:00
parent f92f02f8b2
commit 5e32504b94
4 changed files with 190 additions and 177 deletions

View File

@@ -7,7 +7,9 @@ import pymysql
from loguru import logger
from pydantic import ValidationError
from .model import LiveOn, LiveOff, LiveReport, DynamicUpdate, PushTarget, Up, Bot
from .model import LiveOn, LiveOff, LiveReport, DynamicUpdate, PushTarget
from .room import Up
from .sender import Bot
from ..exception.DataSourceException import DataSourceException
from ..utils import config

View File

@@ -1,25 +1,11 @@
"""
Bot 配置相关类
"""
import asyncio
import time
from asyncio import AbstractEventLoop
from enum import Enum
from typing import List, Optional, Any, Union, Dict
from typing import List, Optional, Any
from graia.ariadne import Ariadne
from graia.ariadne.connection.config import config as AriadneConfig, HttpClientConfig, WebsocketClientConfig
from graia.ariadne.event.lifecycle import ApplicationLaunched
from graia.ariadne.message.chain import MessageChain
from graia.ariadne.message.element import Plain, At, AtAll, Image
from graia.ariadne.model import LogConfig, MemberPerm
from loguru import logger
from pydantic import BaseModel, PrivateAttr
from .live import LiveDanmaku
from ..utils import config
from ..utils.AsyncEvent import AsyncEvent
class LiveOn(BaseModel):
"""
@@ -200,51 +186,6 @@ class PushTarget(BaseModel):
return hash(self.id)
class Up(BaseModel):
"""
主播类
"""
uid: int
"""主播 UID"""
targets: Union[List[PushTarget], Dict[int, PushTarget]]
"""主播所需推送的所有好友或群"""
uname: Optional[str] = None
"""主播昵称,无需手动传入,会自动获取"""
room_id: Optional[int] = None
"""主播直播间房间号,无需手动传入,会自动获取"""
__room: Optional[LiveDanmaku] = PrivateAttr()
"""直播间连接实例"""
__is_reconnect: Optional[bool] = PrivateAttr()
"""是否为断线重连"""
__loop: Optional[AbstractEventLoop] = PrivateAttr()
"""asyncio 事件循环"""
def __init__(self, **data: Any):
super().__init__(**data)
if isinstance(self.targets, list):
self.targets = dict(zip(map(lambda t: t.id, self.targets), self.targets))
self.__room = None
self.__is_reconnect = False
self.__loop = asyncio.get_event_loop()
def __eq__(self, other):
if isinstance(other, Up):
return self.uid == other.uid
elif isinstance(other, int):
return self.uid == other
return False
def __hash__(self):
return hash(self.uid)
class Message(BaseModel):
"""
消息封装类
@@ -331,119 +272,3 @@ class Message(BaseModel):
chains.append(chain)
return chains
class Bot(BaseModel, AsyncEvent):
"""
Bot 类,每个实例为一个 QQ 号,可用于配置多 Bot 推送
"""
qq: int
"""Bot 的 QQ 号"""
ups: List[Up]
"""Bot 账号下运行的 UP 主列表"""
__loop: Optional[AbstractEventLoop] = PrivateAttr()
"""asyncio 事件循环"""
__bot: Optional[Ariadne] = PrivateAttr()
"""Ariadne 实例"""
__queue: Optional[List[Message]] = PrivateAttr()
"""待发送消息队列"""
def __init__(self, **data: Any):
super().__init__(**data)
self.__loop = asyncio.get_event_loop()
self.__bot = Ariadne(
connection=AriadneConfig(
self.qq,
"StarBot",
HttpClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
WebsocketClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
),
log_config=LogConfig(log_level="DEBUG")
)
self.__queue = []
@self.on("SEND_MESSAGE")
async def send_message(msg: Message):
self.__queue.append(msg)
@self.__bot.broadcast.receiver(ApplicationLaunched)
async def start_sender():
logger.success(f"Bot [{self.qq}] 已启动")
self.__loop.create_task(self.__sender())
async def __sender(self):
"""
消息发送模块
"""
while True:
if self.__queue:
msg = self.__queue[0]
if msg.type == PushType.Friend:
for message in msg.get_message_chains():
logger.info(f"{self.qq} -> 好友[{msg.id}] : {message}")
await self.__bot.send_friend_message(msg.id, message)
else:
for message in await self.group_message_filter(msg):
logger.info(f"{self.qq} -> 群[{msg.id}] : {message}")
await self.__bot.send_group_message(msg.id, message)
self.__queue.pop(0)
else:
await asyncio.sleep(0.1)
async def group_message_filter(self, message: Message) -> List[MessageChain]:
"""
过滤群消息中的非法元素
Args:
message: 源消息链
Returns:
处理后的消息链
"""
if message.type == PushType.Friend:
return message.get_message_chains()
new_chains = []
# 过滤 Bot 不在群内的消息
group = await self.__bot.get_group(message.id)
if group is None:
return new_chains
for chain in message.get_message_chains():
if AtAll in chain:
# 过滤 Bot 不是群管理员时的 @全体成员 消息
bot_info = await self.__bot.get_member(self.qq, message.id)
if bot_info.permission < MemberPerm.Administrator:
chain = chain.exclude(AtAll)
# 过滤多余的 @全体成员 消息
if chain.count(AtAll) > 1:
elements = [e for e in chain.exclude(AtAll)]
elements.insert(chain.index(AtAll), AtAll())
chain = MessageChain(elements)
if At in message:
# 过滤已不在群内的群成员的 @ 消息
member_list = [member.id for member in await self.__bot.get_member_list(message.id)]
elements = [e for e in chain if (not isinstance(e, At)) or (e.target in member_list)]
chain = MessageChain(elements)
new_chains.append(chain)
return new_chains
def __eq__(self, other):
if isinstance(other, Bot):
return self.qq == other.qq
elif isinstance(other, int):
return self.qq == other
return False
def __hash__(self):
return hash(self.qq)

53
starbot/core/room.py Normal file
View File

@@ -0,0 +1,53 @@
import asyncio
from asyncio import AbstractEventLoop
from typing import Optional, Union, List, Dict, Any
from pydantic import BaseModel, PrivateAttr
from .live import LiveDanmaku
from .model import PushTarget
class Up(BaseModel):
"""
主播类
"""
uid: int
"""主播 UID"""
targets: Union[List[PushTarget], Dict[int, PushTarget]]
"""主播所需推送的所有好友或群"""
uname: Optional[str] = None
"""主播昵称,无需手动传入,会自动获取"""
room_id: Optional[int] = None
"""主播直播间房间号,无需手动传入,会自动获取"""
__room: Optional[LiveDanmaku] = PrivateAttr()
"""直播间连接实例"""
__is_reconnect: Optional[bool] = PrivateAttr()
"""是否为断线重连"""
__loop: Optional[AbstractEventLoop] = PrivateAttr()
"""asyncio 事件循环"""
def __init__(self, **data: Any):
super().__init__(**data)
if isinstance(self.targets, list):
self.targets = dict(zip(map(lambda t: t.id, self.targets), self.targets))
self.__room = None
self.__is_reconnect = False
self.__loop = asyncio.get_event_loop()
def __eq__(self, other):
if isinstance(other, Up):
return self.uid == other.uid
elif isinstance(other, int):
return self.uid == other
return False
def __hash__(self):
return hash(self.uid)

133
starbot/core/sender.py Normal file
View File

@@ -0,0 +1,133 @@
import asyncio
from asyncio import AbstractEventLoop
from typing import List, Optional, Any
from graia.ariadne import Ariadne
from graia.ariadne.connection.config import config as AriadneConfig, HttpClientConfig, WebsocketClientConfig
from graia.ariadne.event.lifecycle import ApplicationLaunched
from graia.ariadne.message.chain import MessageChain
from graia.ariadne.message.element import At, AtAll
from graia.ariadne.model import LogConfig, MemberPerm
from loguru import logger
from pydantic import BaseModel, PrivateAttr
from .model import PushType, Message
from .room import Up
from ..utils import config
from ..utils.AsyncEvent import AsyncEvent
class Bot(BaseModel, AsyncEvent):
"""
Bot 类,每个实例为一个 QQ 号,可用于配置多 Bot 推送
"""
qq: int
"""Bot 的 QQ 号"""
ups: List[Up]
"""Bot 账号下运行的 UP 主列表"""
__loop: Optional[AbstractEventLoop] = PrivateAttr()
"""asyncio 事件循环"""
__bot: Optional[Ariadne] = PrivateAttr()
"""Ariadne 实例"""
__queue: Optional[List[Message]] = PrivateAttr()
"""待发送消息队列"""
def __init__(self, **data: Any):
super().__init__(**data)
self.__loop = asyncio.get_event_loop()
self.__bot = Ariadne(
connection=AriadneConfig(
self.qq,
"StarBot",
HttpClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
WebsocketClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
),
log_config=LogConfig(log_level="DEBUG")
)
self.__queue = []
@self.on("SEND_MESSAGE")
async def send_message(msg: Message):
self.__queue.append(msg)
@self.__bot.broadcast.receiver(ApplicationLaunched)
async def start_sender():
logger.success(f"Bot [{self.qq}] 已启动")
self.__loop.create_task(self.__sender())
async def __sender(self):
"""
消息发送模块
"""
while True:
if self.__queue:
msg = self.__queue[0]
if msg.type == PushType.Friend:
for message in msg.get_message_chains():
logger.info(f"{self.qq} -> 好友[{msg.id}] : {message}")
await self.__bot.send_friend_message(msg.id, message)
else:
for message in await self.group_message_filter(msg):
logger.info(f"{self.qq} -> 群[{msg.id}] : {message}")
await self.__bot.send_group_message(msg.id, message)
self.__queue.pop(0)
else:
await asyncio.sleep(0.1)
async def group_message_filter(self, message: Message) -> List[MessageChain]:
"""
过滤群消息中的非法元素
Args:
message: 源消息链
Returns:
处理后的消息链
"""
if message.type == PushType.Friend:
return message.get_message_chains()
new_chains = []
# 过滤 Bot 不在群内的消息
group = await self.__bot.get_group(message.id)
if group is None:
return new_chains
for chain in message.get_message_chains():
if AtAll in chain:
# 过滤 Bot 不是群管理员时的 @全体成员 消息
bot_info = await self.__bot.get_member(self.qq, message.id)
if bot_info.permission < MemberPerm.Administrator:
chain = chain.exclude(AtAll)
# 过滤多余的 @全体成员 消息
if chain.count(AtAll) > 1:
elements = [e for e in chain.exclude(AtAll)]
elements.insert(chain.index(AtAll), AtAll())
chain = MessageChain(elements)
if At in message:
# 过滤已不在群内的群成员的 @ 消息
member_list = [member.id for member in await self.__bot.get_member_list(message.id)]
elements = [e for e in chain if (not isinstance(e, At)) or (e.target in member_list)]
chain = MessageChain(elements)
new_chains.append(chain)
return new_chains
def __eq__(self, other):
if isinstance(other, Bot):
return self.qq == other.qq
elif isinstance(other, int):
return self.qq == other
return False
def __hash__(self):
return hash(self.qq)