refactor: Refactor model
This commit is contained in:
@@ -7,7 +7,9 @@ import pymysql
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .model import LiveOn, LiveOff, LiveReport, DynamicUpdate, PushTarget, Up, Bot
|
||||
from .model import LiveOn, LiveOff, LiveReport, DynamicUpdate, PushTarget
|
||||
from .room import Up
|
||||
from .sender import Bot
|
||||
from ..exception.DataSourceException import DataSourceException
|
||||
from ..utils import config
|
||||
|
||||
|
||||
@@ -1,25 +1,11 @@
|
||||
"""
|
||||
Bot 配置相关类
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from asyncio import AbstractEventLoop
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Any, Union, Dict
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from graia.ariadne import Ariadne
|
||||
from graia.ariadne.connection.config import config as AriadneConfig, HttpClientConfig, WebsocketClientConfig
|
||||
from graia.ariadne.event.lifecycle import ApplicationLaunched
|
||||
from graia.ariadne.message.chain import MessageChain
|
||||
from graia.ariadne.message.element import Plain, At, AtAll, Image
|
||||
from graia.ariadne.model import LogConfig, MemberPerm
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from .live import LiveDanmaku
|
||||
from ..utils import config
|
||||
from ..utils.AsyncEvent import AsyncEvent
|
||||
|
||||
|
||||
class LiveOn(BaseModel):
|
||||
"""
|
||||
@@ -200,51 +186,6 @@ class PushTarget(BaseModel):
|
||||
return hash(self.id)
|
||||
|
||||
|
||||
class Up(BaseModel):
|
||||
"""
|
||||
主播类
|
||||
"""
|
||||
|
||||
uid: int
|
||||
"""主播 UID"""
|
||||
|
||||
targets: Union[List[PushTarget], Dict[int, PushTarget]]
|
||||
"""主播所需推送的所有好友或群"""
|
||||
|
||||
uname: Optional[str] = None
|
||||
"""主播昵称,无需手动传入,会自动获取"""
|
||||
|
||||
room_id: Optional[int] = None
|
||||
"""主播直播间房间号,无需手动传入,会自动获取"""
|
||||
|
||||
__room: Optional[LiveDanmaku] = PrivateAttr()
|
||||
"""直播间连接实例"""
|
||||
|
||||
__is_reconnect: Optional[bool] = PrivateAttr()
|
||||
"""是否为断线重连"""
|
||||
|
||||
__loop: Optional[AbstractEventLoop] = PrivateAttr()
|
||||
"""asyncio 事件循环"""
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
if isinstance(self.targets, list):
|
||||
self.targets = dict(zip(map(lambda t: t.id, self.targets), self.targets))
|
||||
self.__room = None
|
||||
self.__is_reconnect = False
|
||||
self.__loop = asyncio.get_event_loop()
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Up):
|
||||
return self.uid == other.uid
|
||||
elif isinstance(other, int):
|
||||
return self.uid == other
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.uid)
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""
|
||||
消息封装类
|
||||
@@ -331,119 +272,3 @@ class Message(BaseModel):
|
||||
chains.append(chain)
|
||||
|
||||
return chains
|
||||
|
||||
|
||||
class Bot(BaseModel, AsyncEvent):
|
||||
"""
|
||||
Bot 类,每个实例为一个 QQ 号,可用于配置多 Bot 推送
|
||||
"""
|
||||
|
||||
qq: int
|
||||
"""Bot 的 QQ 号"""
|
||||
|
||||
ups: List[Up]
|
||||
"""Bot 账号下运行的 UP 主列表"""
|
||||
|
||||
__loop: Optional[AbstractEventLoop] = PrivateAttr()
|
||||
"""asyncio 事件循环"""
|
||||
|
||||
__bot: Optional[Ariadne] = PrivateAttr()
|
||||
"""Ariadne 实例"""
|
||||
|
||||
__queue: Optional[List[Message]] = PrivateAttr()
|
||||
"""待发送消息队列"""
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
self.__loop = asyncio.get_event_loop()
|
||||
self.__bot = Ariadne(
|
||||
connection=AriadneConfig(
|
||||
self.qq,
|
||||
"StarBot",
|
||||
HttpClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
|
||||
WebsocketClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
|
||||
),
|
||||
log_config=LogConfig(log_level="DEBUG")
|
||||
)
|
||||
self.__queue = []
|
||||
|
||||
@self.on("SEND_MESSAGE")
|
||||
async def send_message(msg: Message):
|
||||
self.__queue.append(msg)
|
||||
|
||||
@self.__bot.broadcast.receiver(ApplicationLaunched)
|
||||
async def start_sender():
|
||||
logger.success(f"Bot [{self.qq}] 已启动")
|
||||
self.__loop.create_task(self.__sender())
|
||||
|
||||
async def __sender(self):
|
||||
"""
|
||||
消息发送模块
|
||||
"""
|
||||
while True:
|
||||
if self.__queue:
|
||||
msg = self.__queue[0]
|
||||
if msg.type == PushType.Friend:
|
||||
for message in msg.get_message_chains():
|
||||
logger.info(f"{self.qq} -> 好友[{msg.id}] : {message}")
|
||||
await self.__bot.send_friend_message(msg.id, message)
|
||||
else:
|
||||
for message in await self.group_message_filter(msg):
|
||||
logger.info(f"{self.qq} -> 群[{msg.id}] : {message}")
|
||||
await self.__bot.send_group_message(msg.id, message)
|
||||
self.__queue.pop(0)
|
||||
else:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def group_message_filter(self, message: Message) -> List[MessageChain]:
|
||||
"""
|
||||
过滤群消息中的非法元素
|
||||
|
||||
Args:
|
||||
message: 源消息链
|
||||
|
||||
Returns:
|
||||
处理后的消息链
|
||||
"""
|
||||
if message.type == PushType.Friend:
|
||||
return message.get_message_chains()
|
||||
|
||||
new_chains = []
|
||||
|
||||
# 过滤 Bot 不在群内的消息
|
||||
group = await self.__bot.get_group(message.id)
|
||||
if group is None:
|
||||
return new_chains
|
||||
|
||||
for chain in message.get_message_chains():
|
||||
if AtAll in chain:
|
||||
# 过滤 Bot 不是群管理员时的 @全体成员 消息
|
||||
bot_info = await self.__bot.get_member(self.qq, message.id)
|
||||
if bot_info.permission < MemberPerm.Administrator:
|
||||
chain = chain.exclude(AtAll)
|
||||
|
||||
# 过滤多余的 @全体成员 消息
|
||||
if chain.count(AtAll) > 1:
|
||||
elements = [e for e in chain.exclude(AtAll)]
|
||||
elements.insert(chain.index(AtAll), AtAll())
|
||||
chain = MessageChain(elements)
|
||||
|
||||
if At in message:
|
||||
# 过滤已不在群内的群成员的 @ 消息
|
||||
member_list = [member.id for member in await self.__bot.get_member_list(message.id)]
|
||||
elements = [e for e in chain if (not isinstance(e, At)) or (e.target in member_list)]
|
||||
chain = MessageChain(elements)
|
||||
|
||||
new_chains.append(chain)
|
||||
|
||||
return new_chains
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Bot):
|
||||
return self.qq == other.qq
|
||||
elif isinstance(other, int):
|
||||
return self.qq == other
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.qq)
|
||||
|
||||
53
starbot/core/room.py
Normal file
53
starbot/core/room.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
from typing import Optional, Union, List, Dict, Any
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from .live import LiveDanmaku
|
||||
from .model import PushTarget
|
||||
|
||||
|
||||
class Up(BaseModel):
|
||||
"""
|
||||
主播类
|
||||
"""
|
||||
|
||||
uid: int
|
||||
"""主播 UID"""
|
||||
|
||||
targets: Union[List[PushTarget], Dict[int, PushTarget]]
|
||||
"""主播所需推送的所有好友或群"""
|
||||
|
||||
uname: Optional[str] = None
|
||||
"""主播昵称,无需手动传入,会自动获取"""
|
||||
|
||||
room_id: Optional[int] = None
|
||||
"""主播直播间房间号,无需手动传入,会自动获取"""
|
||||
|
||||
__room: Optional[LiveDanmaku] = PrivateAttr()
|
||||
"""直播间连接实例"""
|
||||
|
||||
__is_reconnect: Optional[bool] = PrivateAttr()
|
||||
"""是否为断线重连"""
|
||||
|
||||
__loop: Optional[AbstractEventLoop] = PrivateAttr()
|
||||
"""asyncio 事件循环"""
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
if isinstance(self.targets, list):
|
||||
self.targets = dict(zip(map(lambda t: t.id, self.targets), self.targets))
|
||||
self.__room = None
|
||||
self.__is_reconnect = False
|
||||
self.__loop = asyncio.get_event_loop()
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Up):
|
||||
return self.uid == other.uid
|
||||
elif isinstance(other, int):
|
||||
return self.uid == other
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.uid)
|
||||
133
starbot/core/sender.py
Normal file
133
starbot/core/sender.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from graia.ariadne import Ariadne
|
||||
from graia.ariadne.connection.config import config as AriadneConfig, HttpClientConfig, WebsocketClientConfig
|
||||
from graia.ariadne.event.lifecycle import ApplicationLaunched
|
||||
from graia.ariadne.message.chain import MessageChain
|
||||
from graia.ariadne.message.element import At, AtAll
|
||||
from graia.ariadne.model import LogConfig, MemberPerm
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from .model import PushType, Message
|
||||
from .room import Up
|
||||
from ..utils import config
|
||||
from ..utils.AsyncEvent import AsyncEvent
|
||||
|
||||
|
||||
class Bot(BaseModel, AsyncEvent):
|
||||
"""
|
||||
Bot 类,每个实例为一个 QQ 号,可用于配置多 Bot 推送
|
||||
"""
|
||||
|
||||
qq: int
|
||||
"""Bot 的 QQ 号"""
|
||||
|
||||
ups: List[Up]
|
||||
"""Bot 账号下运行的 UP 主列表"""
|
||||
|
||||
__loop: Optional[AbstractEventLoop] = PrivateAttr()
|
||||
"""asyncio 事件循环"""
|
||||
|
||||
__bot: Optional[Ariadne] = PrivateAttr()
|
||||
"""Ariadne 实例"""
|
||||
|
||||
__queue: Optional[List[Message]] = PrivateAttr()
|
||||
"""待发送消息队列"""
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
self.__loop = asyncio.get_event_loop()
|
||||
self.__bot = Ariadne(
|
||||
connection=AriadneConfig(
|
||||
self.qq,
|
||||
"StarBot",
|
||||
HttpClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
|
||||
WebsocketClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
|
||||
),
|
||||
log_config=LogConfig(log_level="DEBUG")
|
||||
)
|
||||
self.__queue = []
|
||||
|
||||
@self.on("SEND_MESSAGE")
|
||||
async def send_message(msg: Message):
|
||||
self.__queue.append(msg)
|
||||
|
||||
@self.__bot.broadcast.receiver(ApplicationLaunched)
|
||||
async def start_sender():
|
||||
logger.success(f"Bot [{self.qq}] 已启动")
|
||||
self.__loop.create_task(self.__sender())
|
||||
|
||||
async def __sender(self):
|
||||
"""
|
||||
消息发送模块
|
||||
"""
|
||||
while True:
|
||||
if self.__queue:
|
||||
msg = self.__queue[0]
|
||||
if msg.type == PushType.Friend:
|
||||
for message in msg.get_message_chains():
|
||||
logger.info(f"{self.qq} -> 好友[{msg.id}] : {message}")
|
||||
await self.__bot.send_friend_message(msg.id, message)
|
||||
else:
|
||||
for message in await self.group_message_filter(msg):
|
||||
logger.info(f"{self.qq} -> 群[{msg.id}] : {message}")
|
||||
await self.__bot.send_group_message(msg.id, message)
|
||||
self.__queue.pop(0)
|
||||
else:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def group_message_filter(self, message: Message) -> List[MessageChain]:
|
||||
"""
|
||||
过滤群消息中的非法元素
|
||||
|
||||
Args:
|
||||
message: 源消息链
|
||||
|
||||
Returns:
|
||||
处理后的消息链
|
||||
"""
|
||||
if message.type == PushType.Friend:
|
||||
return message.get_message_chains()
|
||||
|
||||
new_chains = []
|
||||
|
||||
# 过滤 Bot 不在群内的消息
|
||||
group = await self.__bot.get_group(message.id)
|
||||
if group is None:
|
||||
return new_chains
|
||||
|
||||
for chain in message.get_message_chains():
|
||||
if AtAll in chain:
|
||||
# 过滤 Bot 不是群管理员时的 @全体成员 消息
|
||||
bot_info = await self.__bot.get_member(self.qq, message.id)
|
||||
if bot_info.permission < MemberPerm.Administrator:
|
||||
chain = chain.exclude(AtAll)
|
||||
|
||||
# 过滤多余的 @全体成员 消息
|
||||
if chain.count(AtAll) > 1:
|
||||
elements = [e for e in chain.exclude(AtAll)]
|
||||
elements.insert(chain.index(AtAll), AtAll())
|
||||
chain = MessageChain(elements)
|
||||
|
||||
if At in message:
|
||||
# 过滤已不在群内的群成员的 @ 消息
|
||||
member_list = [member.id for member in await self.__bot.get_member_list(message.id)]
|
||||
elements = [e for e in chain if (not isinstance(e, At)) or (e.target in member_list)]
|
||||
chain = MessageChain(elements)
|
||||
|
||||
new_chains.append(chain)
|
||||
|
||||
return new_chains
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Bot):
|
||||
return self.qq == other.qq
|
||||
elif isinstance(other, int):
|
||||
return self.qq == other
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.qq)
|
||||
Reference in New Issue
Block a user