refactor: Refactor model
This commit is contained in:
@@ -7,7 +7,9 @@ import pymysql
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from .model import LiveOn, LiveOff, LiveReport, DynamicUpdate, PushTarget, Up, Bot
|
from .model import LiveOn, LiveOff, LiveReport, DynamicUpdate, PushTarget
|
||||||
|
from .room import Up
|
||||||
|
from .sender import Bot
|
||||||
from ..exception.DataSourceException import DataSourceException
|
from ..exception.DataSourceException import DataSourceException
|
||||||
from ..utils import config
|
from ..utils import config
|
||||||
|
|
||||||
|
|||||||
+1
-176
@@ -1,25 +1,11 @@
|
|||||||
"""
|
|
||||||
Bot 配置相关类
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
import time
|
import time
|
||||||
from asyncio import AbstractEventLoop
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Any, Union, Dict
|
from typing import List, Optional, Any
|
||||||
|
|
||||||
from graia.ariadne import Ariadne
|
|
||||||
from graia.ariadne.connection.config import config as AriadneConfig, HttpClientConfig, WebsocketClientConfig
|
|
||||||
from graia.ariadne.event.lifecycle import ApplicationLaunched
|
|
||||||
from graia.ariadne.message.chain import MessageChain
|
from graia.ariadne.message.chain import MessageChain
|
||||||
from graia.ariadne.message.element import Plain, At, AtAll, Image
|
from graia.ariadne.message.element import Plain, At, AtAll, Image
|
||||||
from graia.ariadne.model import LogConfig, MemberPerm
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import BaseModel, PrivateAttr
|
from pydantic import BaseModel, PrivateAttr
|
||||||
|
|
||||||
from .live import LiveDanmaku
|
|
||||||
from ..utils import config
|
|
||||||
from ..utils.AsyncEvent import AsyncEvent
|
|
||||||
|
|
||||||
|
|
||||||
class LiveOn(BaseModel):
|
class LiveOn(BaseModel):
|
||||||
"""
|
"""
|
||||||
@@ -200,51 +186,6 @@ class PushTarget(BaseModel):
|
|||||||
return hash(self.id)
|
return hash(self.id)
|
||||||
|
|
||||||
|
|
||||||
class Up(BaseModel):
|
|
||||||
"""
|
|
||||||
主播类
|
|
||||||
"""
|
|
||||||
|
|
||||||
uid: int
|
|
||||||
"""主播 UID"""
|
|
||||||
|
|
||||||
targets: Union[List[PushTarget], Dict[int, PushTarget]]
|
|
||||||
"""主播所需推送的所有好友或群"""
|
|
||||||
|
|
||||||
uname: Optional[str] = None
|
|
||||||
"""主播昵称,无需手动传入,会自动获取"""
|
|
||||||
|
|
||||||
room_id: Optional[int] = None
|
|
||||||
"""主播直播间房间号,无需手动传入,会自动获取"""
|
|
||||||
|
|
||||||
__room: Optional[LiveDanmaku] = PrivateAttr()
|
|
||||||
"""直播间连接实例"""
|
|
||||||
|
|
||||||
__is_reconnect: Optional[bool] = PrivateAttr()
|
|
||||||
"""是否为断线重连"""
|
|
||||||
|
|
||||||
__loop: Optional[AbstractEventLoop] = PrivateAttr()
|
|
||||||
"""asyncio 事件循环"""
|
|
||||||
|
|
||||||
def __init__(self, **data: Any):
|
|
||||||
super().__init__(**data)
|
|
||||||
if isinstance(self.targets, list):
|
|
||||||
self.targets = dict(zip(map(lambda t: t.id, self.targets), self.targets))
|
|
||||||
self.__room = None
|
|
||||||
self.__is_reconnect = False
|
|
||||||
self.__loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
if isinstance(other, Up):
|
|
||||||
return self.uid == other.uid
|
|
||||||
elif isinstance(other, int):
|
|
||||||
return self.uid == other
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash(self.uid)
|
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
"""
|
"""
|
||||||
消息封装类
|
消息封装类
|
||||||
@@ -331,119 +272,3 @@ class Message(BaseModel):
|
|||||||
chains.append(chain)
|
chains.append(chain)
|
||||||
|
|
||||||
return chains
|
return chains
|
||||||
|
|
||||||
|
|
||||||
class Bot(BaseModel, AsyncEvent):
|
|
||||||
"""
|
|
||||||
Bot 类,每个实例为一个 QQ 号,可用于配置多 Bot 推送
|
|
||||||
"""
|
|
||||||
|
|
||||||
qq: int
|
|
||||||
"""Bot 的 QQ 号"""
|
|
||||||
|
|
||||||
ups: List[Up]
|
|
||||||
"""Bot 账号下运行的 UP 主列表"""
|
|
||||||
|
|
||||||
__loop: Optional[AbstractEventLoop] = PrivateAttr()
|
|
||||||
"""asyncio 事件循环"""
|
|
||||||
|
|
||||||
__bot: Optional[Ariadne] = PrivateAttr()
|
|
||||||
"""Ariadne 实例"""
|
|
||||||
|
|
||||||
__queue: Optional[List[Message]] = PrivateAttr()
|
|
||||||
"""待发送消息队列"""
|
|
||||||
|
|
||||||
def __init__(self, **data: Any):
|
|
||||||
super().__init__(**data)
|
|
||||||
self.__loop = asyncio.get_event_loop()
|
|
||||||
self.__bot = Ariadne(
|
|
||||||
connection=AriadneConfig(
|
|
||||||
self.qq,
|
|
||||||
"StarBot",
|
|
||||||
HttpClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
|
|
||||||
WebsocketClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
|
|
||||||
),
|
|
||||||
log_config=LogConfig(log_level="DEBUG")
|
|
||||||
)
|
|
||||||
self.__queue = []
|
|
||||||
|
|
||||||
@self.on("SEND_MESSAGE")
|
|
||||||
async def send_message(msg: Message):
|
|
||||||
self.__queue.append(msg)
|
|
||||||
|
|
||||||
@self.__bot.broadcast.receiver(ApplicationLaunched)
|
|
||||||
async def start_sender():
|
|
||||||
logger.success(f"Bot [{self.qq}] 已启动")
|
|
||||||
self.__loop.create_task(self.__sender())
|
|
||||||
|
|
||||||
async def __sender(self):
|
|
||||||
"""
|
|
||||||
消息发送模块
|
|
||||||
"""
|
|
||||||
while True:
|
|
||||||
if self.__queue:
|
|
||||||
msg = self.__queue[0]
|
|
||||||
if msg.type == PushType.Friend:
|
|
||||||
for message in msg.get_message_chains():
|
|
||||||
logger.info(f"{self.qq} -> 好友[{msg.id}] : {message}")
|
|
||||||
await self.__bot.send_friend_message(msg.id, message)
|
|
||||||
else:
|
|
||||||
for message in await self.group_message_filter(msg):
|
|
||||||
logger.info(f"{self.qq} -> 群[{msg.id}] : {message}")
|
|
||||||
await self.__bot.send_group_message(msg.id, message)
|
|
||||||
self.__queue.pop(0)
|
|
||||||
else:
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
async def group_message_filter(self, message: Message) -> List[MessageChain]:
|
|
||||||
"""
|
|
||||||
过滤群消息中的非法元素
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 源消息链
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
处理后的消息链
|
|
||||||
"""
|
|
||||||
if message.type == PushType.Friend:
|
|
||||||
return message.get_message_chains()
|
|
||||||
|
|
||||||
new_chains = []
|
|
||||||
|
|
||||||
# 过滤 Bot 不在群内的消息
|
|
||||||
group = await self.__bot.get_group(message.id)
|
|
||||||
if group is None:
|
|
||||||
return new_chains
|
|
||||||
|
|
||||||
for chain in message.get_message_chains():
|
|
||||||
if AtAll in chain:
|
|
||||||
# 过滤 Bot 不是群管理员时的 @全体成员 消息
|
|
||||||
bot_info = await self.__bot.get_member(self.qq, message.id)
|
|
||||||
if bot_info.permission < MemberPerm.Administrator:
|
|
||||||
chain = chain.exclude(AtAll)
|
|
||||||
|
|
||||||
# 过滤多余的 @全体成员 消息
|
|
||||||
if chain.count(AtAll) > 1:
|
|
||||||
elements = [e for e in chain.exclude(AtAll)]
|
|
||||||
elements.insert(chain.index(AtAll), AtAll())
|
|
||||||
chain = MessageChain(elements)
|
|
||||||
|
|
||||||
if At in message:
|
|
||||||
# 过滤已不在群内的群成员的 @ 消息
|
|
||||||
member_list = [member.id for member in await self.__bot.get_member_list(message.id)]
|
|
||||||
elements = [e for e in chain if (not isinstance(e, At)) or (e.target in member_list)]
|
|
||||||
chain = MessageChain(elements)
|
|
||||||
|
|
||||||
new_chains.append(chain)
|
|
||||||
|
|
||||||
return new_chains
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
if isinstance(other, Bot):
|
|
||||||
return self.qq == other.qq
|
|
||||||
elif isinstance(other, int):
|
|
||||||
return self.qq == other
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash(self.qq)
|
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
import asyncio
|
||||||
|
from asyncio import AbstractEventLoop
|
||||||
|
from typing import Optional, Union, List, Dict, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, PrivateAttr
|
||||||
|
|
||||||
|
from .live import LiveDanmaku
|
||||||
|
from .model import PushTarget
|
||||||
|
|
||||||
|
|
||||||
|
class Up(BaseModel):
|
||||||
|
"""
|
||||||
|
主播类
|
||||||
|
"""
|
||||||
|
|
||||||
|
uid: int
|
||||||
|
"""主播 UID"""
|
||||||
|
|
||||||
|
targets: Union[List[PushTarget], Dict[int, PushTarget]]
|
||||||
|
"""主播所需推送的所有好友或群"""
|
||||||
|
|
||||||
|
uname: Optional[str] = None
|
||||||
|
"""主播昵称,无需手动传入,会自动获取"""
|
||||||
|
|
||||||
|
room_id: Optional[int] = None
|
||||||
|
"""主播直播间房间号,无需手动传入,会自动获取"""
|
||||||
|
|
||||||
|
__room: Optional[LiveDanmaku] = PrivateAttr()
|
||||||
|
"""直播间连接实例"""
|
||||||
|
|
||||||
|
__is_reconnect: Optional[bool] = PrivateAttr()
|
||||||
|
"""是否为断线重连"""
|
||||||
|
|
||||||
|
__loop: Optional[AbstractEventLoop] = PrivateAttr()
|
||||||
|
"""asyncio 事件循环"""
|
||||||
|
|
||||||
|
def __init__(self, **data: Any):
|
||||||
|
super().__init__(**data)
|
||||||
|
if isinstance(self.targets, list):
|
||||||
|
self.targets = dict(zip(map(lambda t: t.id, self.targets), self.targets))
|
||||||
|
self.__room = None
|
||||||
|
self.__is_reconnect = False
|
||||||
|
self.__loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, Up):
|
||||||
|
return self.uid == other.uid
|
||||||
|
elif isinstance(other, int):
|
||||||
|
return self.uid == other
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.uid)
|
||||||
@@ -0,0 +1,133 @@
|
|||||||
|
import asyncio
|
||||||
|
from asyncio import AbstractEventLoop
|
||||||
|
from typing import List, Optional, Any
|
||||||
|
|
||||||
|
from graia.ariadne import Ariadne
|
||||||
|
from graia.ariadne.connection.config import config as AriadneConfig, HttpClientConfig, WebsocketClientConfig
|
||||||
|
from graia.ariadne.event.lifecycle import ApplicationLaunched
|
||||||
|
from graia.ariadne.message.chain import MessageChain
|
||||||
|
from graia.ariadne.message.element import At, AtAll
|
||||||
|
from graia.ariadne.model import LogConfig, MemberPerm
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel, PrivateAttr
|
||||||
|
|
||||||
|
from .model import PushType, Message
|
||||||
|
from .room import Up
|
||||||
|
from ..utils import config
|
||||||
|
from ..utils.AsyncEvent import AsyncEvent
|
||||||
|
|
||||||
|
|
||||||
|
class Bot(BaseModel, AsyncEvent):
|
||||||
|
"""
|
||||||
|
Bot 类,每个实例为一个 QQ 号,可用于配置多 Bot 推送
|
||||||
|
"""
|
||||||
|
|
||||||
|
qq: int
|
||||||
|
"""Bot 的 QQ 号"""
|
||||||
|
|
||||||
|
ups: List[Up]
|
||||||
|
"""Bot 账号下运行的 UP 主列表"""
|
||||||
|
|
||||||
|
__loop: Optional[AbstractEventLoop] = PrivateAttr()
|
||||||
|
"""asyncio 事件循环"""
|
||||||
|
|
||||||
|
__bot: Optional[Ariadne] = PrivateAttr()
|
||||||
|
"""Ariadne 实例"""
|
||||||
|
|
||||||
|
__queue: Optional[List[Message]] = PrivateAttr()
|
||||||
|
"""待发送消息队列"""
|
||||||
|
|
||||||
|
def __init__(self, **data: Any):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.__loop = asyncio.get_event_loop()
|
||||||
|
self.__bot = Ariadne(
|
||||||
|
connection=AriadneConfig(
|
||||||
|
self.qq,
|
||||||
|
"StarBot",
|
||||||
|
HttpClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
|
||||||
|
WebsocketClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
|
||||||
|
),
|
||||||
|
log_config=LogConfig(log_level="DEBUG")
|
||||||
|
)
|
||||||
|
self.__queue = []
|
||||||
|
|
||||||
|
@self.on("SEND_MESSAGE")
|
||||||
|
async def send_message(msg: Message):
|
||||||
|
self.__queue.append(msg)
|
||||||
|
|
||||||
|
@self.__bot.broadcast.receiver(ApplicationLaunched)
|
||||||
|
async def start_sender():
|
||||||
|
logger.success(f"Bot [{self.qq}] 已启动")
|
||||||
|
self.__loop.create_task(self.__sender())
|
||||||
|
|
||||||
|
async def __sender(self):
|
||||||
|
"""
|
||||||
|
消息发送模块
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
if self.__queue:
|
||||||
|
msg = self.__queue[0]
|
||||||
|
if msg.type == PushType.Friend:
|
||||||
|
for message in msg.get_message_chains():
|
||||||
|
logger.info(f"{self.qq} -> 好友[{msg.id}] : {message}")
|
||||||
|
await self.__bot.send_friend_message(msg.id, message)
|
||||||
|
else:
|
||||||
|
for message in await self.group_message_filter(msg):
|
||||||
|
logger.info(f"{self.qq} -> 群[{msg.id}] : {message}")
|
||||||
|
await self.__bot.send_group_message(msg.id, message)
|
||||||
|
self.__queue.pop(0)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
async def group_message_filter(self, message: Message) -> List[MessageChain]:
|
||||||
|
"""
|
||||||
|
过滤群消息中的非法元素
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 源消息链
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的消息链
|
||||||
|
"""
|
||||||
|
if message.type == PushType.Friend:
|
||||||
|
return message.get_message_chains()
|
||||||
|
|
||||||
|
new_chains = []
|
||||||
|
|
||||||
|
# 过滤 Bot 不在群内的消息
|
||||||
|
group = await self.__bot.get_group(message.id)
|
||||||
|
if group is None:
|
||||||
|
return new_chains
|
||||||
|
|
||||||
|
for chain in message.get_message_chains():
|
||||||
|
if AtAll in chain:
|
||||||
|
# 过滤 Bot 不是群管理员时的 @全体成员 消息
|
||||||
|
bot_info = await self.__bot.get_member(self.qq, message.id)
|
||||||
|
if bot_info.permission < MemberPerm.Administrator:
|
||||||
|
chain = chain.exclude(AtAll)
|
||||||
|
|
||||||
|
# 过滤多余的 @全体成员 消息
|
||||||
|
if chain.count(AtAll) > 1:
|
||||||
|
elements = [e for e in chain.exclude(AtAll)]
|
||||||
|
elements.insert(chain.index(AtAll), AtAll())
|
||||||
|
chain = MessageChain(elements)
|
||||||
|
|
||||||
|
if At in message:
|
||||||
|
# 过滤已不在群内的群成员的 @ 消息
|
||||||
|
member_list = [member.id for member in await self.__bot.get_member_list(message.id)]
|
||||||
|
elements = [e for e in chain if (not isinstance(e, At)) or (e.target in member_list)]
|
||||||
|
chain = MessageChain(elements)
|
||||||
|
|
||||||
|
new_chains.append(chain)
|
||||||
|
|
||||||
|
return new_chains
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, Bot):
|
||||||
|
return self.qq == other.qq
|
||||||
|
elif isinstance(other, int):
|
||||||
|
return self.qq == other
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.qq)
|
||||||
Reference in New Issue
Block a user