feat: Ariadne message sender support

This commit is contained in:
LWR
2022-10-31 01:40:43 +08:00
parent 7090ba5686
commit f92f02f8b2
4 changed files with 242 additions and 13 deletions

View File

@@ -1,5 +1,8 @@
import sys
from creart import create
from graia.ariadne import Ariadne
from graia.broadcast import Broadcast
from loguru import logger
from .datasource import DataSource
@@ -34,21 +37,11 @@ class StarBot:
"""
self.__datasource = datasource
async def run(self):
async def __main(self):
"""
启动 StarBot
StarBot 入口
"""
# 设置日志格式
logger_format = (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{line}</cyan> | "
"<level>{message}</level>"
)
logger.remove()
logger.add(sys.stderr, format=logger_format, level="INFO")
logger.opt(colors=True, raw=True).info(f"<yellow>{self.STARBOT_ASCII_LOGO}</>")
logger.info("开始启动 StarBot")
@@ -65,3 +58,34 @@ class StarBot:
except RedisException as ex:
logger.error(ex.msg)
return
# 启动 Bot
logger.info("开始启动 Ariadne 消息推送模块")
Ariadne.options["default_account"] = 1499887988
try:
Ariadne.launch_blocking()
except RuntimeError as ex:
if "This event loop is already running" in str(ex):
pass
else:
logger.error(ex)
return
def run(self):
"""
启动 StarBot
"""
logger_format = (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name}</cyan>:<cyan>{line}</cyan> | "
"<level>{message}</level>"
)
logger.remove()
logger.add(sys.stderr, format=logger_format, level="INFO")
bcc = create(Broadcast)
loop = bcc.loop
loop.create_task(self.__main())
loop.run_forever()

View File

@@ -2,13 +2,22 @@
Bot 配置相关类
"""
import asyncio
import time
from asyncio import AbstractEventLoop
from enum import Enum
from typing import List, Optional, Any, Union, Dict
from graia.ariadne import Ariadne
from graia.ariadne.connection.config import config as AriadneConfig, HttpClientConfig, WebsocketClientConfig
from graia.ariadne.event.lifecycle import ApplicationLaunched
from graia.ariadne.message.chain import MessageChain
from graia.ariadne.message.element import Plain, At, AtAll, Image
from graia.ariadne.model import LogConfig, MemberPerm
from loguru import logger
from pydantic import BaseModel, PrivateAttr
from .live import LiveDanmaku
from ..utils import config
from ..utils.AsyncEvent import AsyncEvent
@@ -236,6 +245,94 @@ class Up(BaseModel):
return hash(self.uid)
class Message(BaseModel):
"""
消息封装类
"""
id: int
"""目标 QQ 号或目标群号"""
content: str
"""原始消息内容,可包含 {next}{atall} 等占位符"""
type: Optional[PushType] = PushType.Group
"""发送目标类型PushType.Friend 为私聊消息PushType.Group 为群聊消息。默认PushType.Group"""
__time: Optional[int] = PrivateAttr()
"""消息创建时间戳"""
__chains: Optional[List[MessageChain]] = PrivateAttr()
"""原始消息内容自动处理后的消息链列表"""
def __init__(self, **data: Any):
super().__init__(**data)
self.__time = int(time.time())
self.__chains = Message.gen_message_chains(self.content)
def get_time(self) -> int:
return self.__time
def get_message_chains(self) -> List[MessageChain]:
return self.__chains
@classmethod
def gen_message_chains(cls, raw_msg: str) -> List[MessageChain]:
"""
转换 {next}{atall}{at}{pic_url}{pic_path} 元素,将原始消息内容转换为可发送的消息链
Args:
raw_msg: 原始消息文本
Returns:
可直接发送的消息链列表
"""
chains = []
raw_msgs = raw_msg.split("{next}")
for msg in raw_msgs:
chain = MessageChain([])
next_code = msg.find("{")
while msg != "":
if next_code == -1:
chain.append(Plain(msg))
msg = ""
elif next_code != 0:
chain.append(Plain(msg[:next_code]))
msg = msg[next_code:]
next_code = msg.find("{")
else:
code_end = msg.find("}")
if code_end == -1:
chain.append(Plain(msg))
msg = ""
else:
if msg[1:3] == "at":
at_target = msg[3:code_end]
if at_target == "all":
chain.append(AtAll())
elif at_target.isdigit():
chain.append((At(int(at_target))))
chain.append(Plain(" "))
else:
chain.append(Plain("[无效的@参数]"))
elif msg[1:7] == "urlpic":
pic_url = msg[8:code_end]
if pic_url != "":
chain.append(Image(url=pic_url))
elif msg[1:8] == "pathpic":
pic_path = msg[9:code_end]
if pic_path != "":
chain.append(Image(path=pic_path))
else:
chain.append(Plain(msg[:code_end + 1]))
msg = msg[code_end + 1:]
next_code = msg.find("{")
chains.append(chain)
return chains
class Bot(BaseModel, AsyncEvent):
"""
Bot 类,每个实例为一个 QQ 号,可用于配置多 Bot 推送
@@ -247,6 +344,100 @@ class Bot(BaseModel, AsyncEvent):
ups: List[Up]
"""Bot 账号下运行的 UP 主列表"""
__loop: Optional[AbstractEventLoop] = PrivateAttr()
"""asyncio 事件循环"""
__bot: Optional[Ariadne] = PrivateAttr()
"""Ariadne 实例"""
__queue: Optional[List[Message]] = PrivateAttr()
"""待发送消息队列"""
def __init__(self, **data: Any):
super().__init__(**data)
self.__loop = asyncio.get_event_loop()
self.__bot = Ariadne(
connection=AriadneConfig(
self.qq,
"StarBot",
HttpClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
WebsocketClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
),
log_config=LogConfig(log_level="DEBUG")
)
self.__queue = []
@self.on("SEND_MESSAGE")
async def send_message(msg: Message):
self.__queue.append(msg)
@self.__bot.broadcast.receiver(ApplicationLaunched)
async def start_sender():
logger.success(f"Bot [{self.qq}] 已启动")
self.__loop.create_task(self.__sender())
async def __sender(self):
"""
消息发送模块
"""
while True:
if self.__queue:
msg = self.__queue[0]
if msg.type == PushType.Friend:
for message in msg.get_message_chains():
logger.info(f"{self.qq} -> 好友[{msg.id}] : {message}")
await self.__bot.send_friend_message(msg.id, message)
else:
for message in await self.group_message_filter(msg):
logger.info(f"{self.qq} -> 群[{msg.id}] : {message}")
await self.__bot.send_group_message(msg.id, message)
self.__queue.pop(0)
else:
await asyncio.sleep(0.1)
async def group_message_filter(self, message: Message) -> List[MessageChain]:
"""
过滤群消息中的非法元素
Args:
message: 源消息链
Returns:
处理后的消息链
"""
if message.type == PushType.Friend:
return message.get_message_chains()
new_chains = []
# 过滤 Bot 不在群内的消息
group = await self.__bot.get_group(message.id)
if group is None:
return new_chains
for chain in message.get_message_chains():
if AtAll in chain:
# 过滤 Bot 不是群管理员时的 @全体成员 消息
bot_info = await self.__bot.get_member(self.qq, message.id)
if bot_info.permission < MemberPerm.Administrator:
chain = chain.exclude(AtAll)
# 过滤多余的 @全体成员 消息
if chain.count(AtAll) > 1:
elements = [e for e in chain.exclude(AtAll)]
elements.insert(chain.index(AtAll), AtAll())
chain = MessageChain(elements)
if At in message:
# 过滤已不在群内的群成员的 @ 消息
member_list = [member.id for member in await self.__bot.get_member_list(message.id)]
elements = [e for e in chain if (not isinstance(e, At)) or (e.target in member_list)]
chain = MessageChain(elements)
new_chains.append(chain)
return new_chains
def __eq__(self, other):
if isinstance(other, Bot):
return self.qq == other.qq

View File

@@ -3,7 +3,7 @@
"""
import asyncio
from typing import Any, Coroutine
from typing import Any, Coroutine, Optional, Dict
class AsyncEvent:
@@ -13,6 +13,8 @@ class AsyncEvent:
特殊事件__ALL__ 所有事件均触发
"""
__handlers: Optional[Dict] = {}
def __init__(self):
self.__handlers = {}

View File

@@ -30,6 +30,9 @@ SIMPLE_CONFIG = {
# MySQL 数据库名
"MYSQL_DB": "starbot",
# Mirai HTTP 及 Websocket 端口
"MIRAI_PORT": 7827,
# 登录 B 站账号所需 Cookie 数据 ( 不登录账号将有部分功能不可用 ) 各字段获取方式查看https://bili.moyu.moe/#/get-credential.md
"SESSDATA": None,
"BILI_JCT": None,
@@ -89,6 +92,9 @@ FULL_CONFIG = {
# MySQL 数据库名
"MYSQL_DB": "starbot",
# Mirai HTTP 及 Websocket 端口
"MIRAI_PORT": 7827,
# 登录 B 站账号所需 Cookie 数据 ( 不登录账号将有部分功能不可用 ) 各字段获取方式查看https://bili.moyu.moe/#/get-credential.md
"SESSDATA": None,
"BILI_JCT": None,
@@ -141,6 +147,9 @@ def use_simple_config():
自动检测最新版本
使用 Redis 默认连接配置 (host: "localhost", port: 6379, db: 0, username: "", password: "")
使用 MySQL 默认连接配置 (host: "localhost", port: 3306, db: "starbot", username: "root", password: "123456")
Mirai 连接端口 7827
未设置登录 B 站账号所需 Cookie 数据
未设置 Bot 主人 QQ
不使用 HTTP 代理
不开启 HTTP API 推送
无命令触发前缀
@@ -158,6 +167,9 @@ def use_full_config():
自动检测最新版本
使用 Redis 默认连接配置 (host: "localhost", port: 6379, db: 0, username: "", password: "")
使用 MySQL 默认连接配置 (host: "localhost", port: 3306, db: "starbot", username: "root", password: "123456")
Mirai 连接端口 7827
未设置登录 B 站账号所需 Cookie 数据
未设置 Bot 主人 QQ
不使用 HTTP 代理
开启 HTTP API 推送
无命令触发前缀