feat: Ariadne message sender support
This commit is contained in:
+36
-12
@@ -1,5 +1,8 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from creart import create
|
||||||
|
from graia.ariadne import Ariadne
|
||||||
|
from graia.broadcast import Broadcast
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .datasource import DataSource
|
from .datasource import DataSource
|
||||||
@@ -34,21 +37,11 @@ class StarBot:
|
|||||||
"""
|
"""
|
||||||
self.__datasource = datasource
|
self.__datasource = datasource
|
||||||
|
|
||||||
async def run(self):
|
async def __main(self):
|
||||||
"""
|
"""
|
||||||
启动 StarBot
|
StarBot 入口
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 设置日志格式
|
|
||||||
logger_format = (
|
|
||||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
|
||||||
"<level>{level: <8}</level> | "
|
|
||||||
"<cyan>{name}</cyan>:<cyan>{line}</cyan> | "
|
|
||||||
"<level>{message}</level>"
|
|
||||||
)
|
|
||||||
logger.remove()
|
|
||||||
logger.add(sys.stderr, format=logger_format, level="INFO")
|
|
||||||
|
|
||||||
logger.opt(colors=True, raw=True).info(f"<yellow>{self.STARBOT_ASCII_LOGO}</>")
|
logger.opt(colors=True, raw=True).info(f"<yellow>{self.STARBOT_ASCII_LOGO}</>")
|
||||||
logger.info("开始启动 StarBot")
|
logger.info("开始启动 StarBot")
|
||||||
|
|
||||||
@@ -65,3 +58,34 @@ class StarBot:
|
|||||||
except RedisException as ex:
|
except RedisException as ex:
|
||||||
logger.error(ex.msg)
|
logger.error(ex.msg)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 启动 Bot
|
||||||
|
logger.info("开始启动 Ariadne 消息推送模块")
|
||||||
|
Ariadne.options["default_account"] = 1499887988
|
||||||
|
try:
|
||||||
|
Ariadne.launch_blocking()
|
||||||
|
except RuntimeError as ex:
|
||||||
|
if "This event loop is already running" in str(ex):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.error(ex)
|
||||||
|
return
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
启动 StarBot
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger_format = (
|
||||||
|
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||||
|
"<level>{level: <8}</level> | "
|
||||||
|
"<cyan>{name}</cyan>:<cyan>{line}</cyan> | "
|
||||||
|
"<level>{message}</level>"
|
||||||
|
)
|
||||||
|
logger.remove()
|
||||||
|
logger.add(sys.stderr, format=logger_format, level="INFO")
|
||||||
|
|
||||||
|
bcc = create(Broadcast)
|
||||||
|
loop = bcc.loop
|
||||||
|
loop.create_task(self.__main())
|
||||||
|
loop.run_forever()
|
||||||
|
|||||||
@@ -2,13 +2,22 @@
|
|||||||
Bot 配置相关类
|
Bot 配置相关类
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
from asyncio import AbstractEventLoop
|
from asyncio import AbstractEventLoop
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Any, Union, Dict
|
from typing import List, Optional, Any, Union, Dict
|
||||||
|
|
||||||
|
from graia.ariadne import Ariadne
|
||||||
|
from graia.ariadne.connection.config import config as AriadneConfig, HttpClientConfig, WebsocketClientConfig
|
||||||
|
from graia.ariadne.event.lifecycle import ApplicationLaunched
|
||||||
|
from graia.ariadne.message.chain import MessageChain
|
||||||
|
from graia.ariadne.message.element import Plain, At, AtAll, Image
|
||||||
|
from graia.ariadne.model import LogConfig, MemberPerm
|
||||||
|
from loguru import logger
|
||||||
from pydantic import BaseModel, PrivateAttr
|
from pydantic import BaseModel, PrivateAttr
|
||||||
|
|
||||||
from .live import LiveDanmaku
|
from .live import LiveDanmaku
|
||||||
|
from ..utils import config
|
||||||
from ..utils.AsyncEvent import AsyncEvent
|
from ..utils.AsyncEvent import AsyncEvent
|
||||||
|
|
||||||
|
|
||||||
@@ -236,6 +245,94 @@ class Up(BaseModel):
|
|||||||
return hash(self.uid)
|
return hash(self.uid)
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
"""
|
||||||
|
消息封装类
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
"""目标 QQ 号或目标群号"""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
"""原始消息内容,可包含 {next}、{atall} 等占位符"""
|
||||||
|
|
||||||
|
type: Optional[PushType] = PushType.Group
|
||||||
|
"""发送目标类型,PushType.Friend 为私聊消息,PushType.Group 为群聊消息。默认:PushType.Group"""
|
||||||
|
|
||||||
|
__time: Optional[int] = PrivateAttr()
|
||||||
|
"""消息创建时间戳"""
|
||||||
|
|
||||||
|
__chains: Optional[List[MessageChain]] = PrivateAttr()
|
||||||
|
"""原始消息内容自动处理后的消息链列表"""
|
||||||
|
|
||||||
|
def __init__(self, **data: Any):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.__time = int(time.time())
|
||||||
|
self.__chains = Message.gen_message_chains(self.content)
|
||||||
|
|
||||||
|
def get_time(self) -> int:
|
||||||
|
return self.__time
|
||||||
|
|
||||||
|
def get_message_chains(self) -> List[MessageChain]:
|
||||||
|
return self.__chains
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def gen_message_chains(cls, raw_msg: str) -> List[MessageChain]:
|
||||||
|
"""
|
||||||
|
转换 {next},{atall},{at},{pic_url},{pic_path} 元素,将原始消息内容转换为可发送的消息链
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_msg: 原始消息文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
可直接发送的消息链列表
|
||||||
|
"""
|
||||||
|
chains = []
|
||||||
|
|
||||||
|
raw_msgs = raw_msg.split("{next}")
|
||||||
|
for msg in raw_msgs:
|
||||||
|
chain = MessageChain([])
|
||||||
|
next_code = msg.find("{")
|
||||||
|
while msg != "":
|
||||||
|
if next_code == -1:
|
||||||
|
chain.append(Plain(msg))
|
||||||
|
msg = ""
|
||||||
|
elif next_code != 0:
|
||||||
|
chain.append(Plain(msg[:next_code]))
|
||||||
|
msg = msg[next_code:]
|
||||||
|
next_code = msg.find("{")
|
||||||
|
else:
|
||||||
|
code_end = msg.find("}")
|
||||||
|
if code_end == -1:
|
||||||
|
chain.append(Plain(msg))
|
||||||
|
msg = ""
|
||||||
|
else:
|
||||||
|
if msg[1:3] == "at":
|
||||||
|
at_target = msg[3:code_end]
|
||||||
|
if at_target == "all":
|
||||||
|
chain.append(AtAll())
|
||||||
|
elif at_target.isdigit():
|
||||||
|
chain.append((At(int(at_target))))
|
||||||
|
chain.append(Plain(" "))
|
||||||
|
else:
|
||||||
|
chain.append(Plain("[无效的@参数]"))
|
||||||
|
elif msg[1:7] == "urlpic":
|
||||||
|
pic_url = msg[8:code_end]
|
||||||
|
if pic_url != "":
|
||||||
|
chain.append(Image(url=pic_url))
|
||||||
|
elif msg[1:8] == "pathpic":
|
||||||
|
pic_path = msg[9:code_end]
|
||||||
|
if pic_path != "":
|
||||||
|
chain.append(Image(path=pic_path))
|
||||||
|
else:
|
||||||
|
chain.append(Plain(msg[:code_end + 1]))
|
||||||
|
msg = msg[code_end + 1:]
|
||||||
|
next_code = msg.find("{")
|
||||||
|
chains.append(chain)
|
||||||
|
|
||||||
|
return chains
|
||||||
|
|
||||||
|
|
||||||
class Bot(BaseModel, AsyncEvent):
|
class Bot(BaseModel, AsyncEvent):
|
||||||
"""
|
"""
|
||||||
Bot 类,每个实例为一个 QQ 号,可用于配置多 Bot 推送
|
Bot 类,每个实例为一个 QQ 号,可用于配置多 Bot 推送
|
||||||
@@ -247,6 +344,100 @@ class Bot(BaseModel, AsyncEvent):
|
|||||||
ups: List[Up]
|
ups: List[Up]
|
||||||
"""Bot 账号下运行的 UP 主列表"""
|
"""Bot 账号下运行的 UP 主列表"""
|
||||||
|
|
||||||
|
__loop: Optional[AbstractEventLoop] = PrivateAttr()
|
||||||
|
"""asyncio 事件循环"""
|
||||||
|
|
||||||
|
__bot: Optional[Ariadne] = PrivateAttr()
|
||||||
|
"""Ariadne 实例"""
|
||||||
|
|
||||||
|
__queue: Optional[List[Message]] = PrivateAttr()
|
||||||
|
"""待发送消息队列"""
|
||||||
|
|
||||||
|
def __init__(self, **data: Any):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.__loop = asyncio.get_event_loop()
|
||||||
|
self.__bot = Ariadne(
|
||||||
|
connection=AriadneConfig(
|
||||||
|
self.qq,
|
||||||
|
"StarBot",
|
||||||
|
HttpClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
|
||||||
|
WebsocketClientConfig(host=f"http://localhost:{config.get('MIRAI_PORT')}"),
|
||||||
|
),
|
||||||
|
log_config=LogConfig(log_level="DEBUG")
|
||||||
|
)
|
||||||
|
self.__queue = []
|
||||||
|
|
||||||
|
@self.on("SEND_MESSAGE")
|
||||||
|
async def send_message(msg: Message):
|
||||||
|
self.__queue.append(msg)
|
||||||
|
|
||||||
|
@self.__bot.broadcast.receiver(ApplicationLaunched)
|
||||||
|
async def start_sender():
|
||||||
|
logger.success(f"Bot [{self.qq}] 已启动")
|
||||||
|
self.__loop.create_task(self.__sender())
|
||||||
|
|
||||||
|
async def __sender(self):
|
||||||
|
"""
|
||||||
|
消息发送模块
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
if self.__queue:
|
||||||
|
msg = self.__queue[0]
|
||||||
|
if msg.type == PushType.Friend:
|
||||||
|
for message in msg.get_message_chains():
|
||||||
|
logger.info(f"{self.qq} -> 好友[{msg.id}] : {message}")
|
||||||
|
await self.__bot.send_friend_message(msg.id, message)
|
||||||
|
else:
|
||||||
|
for message in await self.group_message_filter(msg):
|
||||||
|
logger.info(f"{self.qq} -> 群[{msg.id}] : {message}")
|
||||||
|
await self.__bot.send_group_message(msg.id, message)
|
||||||
|
self.__queue.pop(0)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
async def group_message_filter(self, message: Message) -> List[MessageChain]:
|
||||||
|
"""
|
||||||
|
过滤群消息中的非法元素
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 源消息链
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的消息链
|
||||||
|
"""
|
||||||
|
if message.type == PushType.Friend:
|
||||||
|
return message.get_message_chains()
|
||||||
|
|
||||||
|
new_chains = []
|
||||||
|
|
||||||
|
# 过滤 Bot 不在群内的消息
|
||||||
|
group = await self.__bot.get_group(message.id)
|
||||||
|
if group is None:
|
||||||
|
return new_chains
|
||||||
|
|
||||||
|
for chain in message.get_message_chains():
|
||||||
|
if AtAll in chain:
|
||||||
|
# 过滤 Bot 不是群管理员时的 @全体成员 消息
|
||||||
|
bot_info = await self.__bot.get_member(self.qq, message.id)
|
||||||
|
if bot_info.permission < MemberPerm.Administrator:
|
||||||
|
chain = chain.exclude(AtAll)
|
||||||
|
|
||||||
|
# 过滤多余的 @全体成员 消息
|
||||||
|
if chain.count(AtAll) > 1:
|
||||||
|
elements = [e for e in chain.exclude(AtAll)]
|
||||||
|
elements.insert(chain.index(AtAll), AtAll())
|
||||||
|
chain = MessageChain(elements)
|
||||||
|
|
||||||
|
if At in message:
|
||||||
|
# 过滤已不在群内的群成员的 @ 消息
|
||||||
|
member_list = [member.id for member in await self.__bot.get_member_list(message.id)]
|
||||||
|
elements = [e for e in chain if (not isinstance(e, At)) or (e.target in member_list)]
|
||||||
|
chain = MessageChain(elements)
|
||||||
|
|
||||||
|
new_chains.append(chain)
|
||||||
|
|
||||||
|
return new_chains
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, Bot):
|
if isinstance(other, Bot):
|
||||||
return self.qq == other.qq
|
return self.qq == other.qq
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Coroutine
|
from typing import Any, Coroutine, Optional, Dict
|
||||||
|
|
||||||
|
|
||||||
class AsyncEvent:
|
class AsyncEvent:
|
||||||
@@ -13,6 +13,8 @@ class AsyncEvent:
|
|||||||
特殊事件:__ALL__ 所有事件均触发
|
特殊事件:__ALL__ 所有事件均触发
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__handlers: Optional[Dict] = {}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.__handlers = {}
|
self.__handlers = {}
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,9 @@ SIMPLE_CONFIG = {
|
|||||||
# MySQL 数据库名
|
# MySQL 数据库名
|
||||||
"MYSQL_DB": "starbot",
|
"MYSQL_DB": "starbot",
|
||||||
|
|
||||||
|
# Mirai HTTP 及 Websocket 端口
|
||||||
|
"MIRAI_PORT": 7827,
|
||||||
|
|
||||||
# 登录 B 站账号所需 Cookie 数据 ( 不登录账号将有部分功能不可用 ) 各字段获取方式查看:https://bili.moyu.moe/#/get-credential.md
|
# 登录 B 站账号所需 Cookie 数据 ( 不登录账号将有部分功能不可用 ) 各字段获取方式查看:https://bili.moyu.moe/#/get-credential.md
|
||||||
"SESSDATA": None,
|
"SESSDATA": None,
|
||||||
"BILI_JCT": None,
|
"BILI_JCT": None,
|
||||||
@@ -89,6 +92,9 @@ FULL_CONFIG = {
|
|||||||
# MySQL 数据库名
|
# MySQL 数据库名
|
||||||
"MYSQL_DB": "starbot",
|
"MYSQL_DB": "starbot",
|
||||||
|
|
||||||
|
# Mirai HTTP 及 Websocket 端口
|
||||||
|
"MIRAI_PORT": 7827,
|
||||||
|
|
||||||
# 登录 B 站账号所需 Cookie 数据 ( 不登录账号将有部分功能不可用 ) 各字段获取方式查看:https://bili.moyu.moe/#/get-credential.md
|
# 登录 B 站账号所需 Cookie 数据 ( 不登录账号将有部分功能不可用 ) 各字段获取方式查看:https://bili.moyu.moe/#/get-credential.md
|
||||||
"SESSDATA": None,
|
"SESSDATA": None,
|
||||||
"BILI_JCT": None,
|
"BILI_JCT": None,
|
||||||
@@ -141,6 +147,9 @@ def use_simple_config():
|
|||||||
自动检测最新版本
|
自动检测最新版本
|
||||||
使用 Redis 默认连接配置 (host: "localhost", port: 6379, db: 0, username: "", password: "")
|
使用 Redis 默认连接配置 (host: "localhost", port: 6379, db: 0, username: "", password: "")
|
||||||
使用 MySQL 默认连接配置 (host: "localhost", port: 3306, db: "starbot", username: "root", password: "123456")
|
使用 MySQL 默认连接配置 (host: "localhost", port: 3306, db: "starbot", username: "root", password: "123456")
|
||||||
|
Mirai 连接端口 7827
|
||||||
|
未设置登录 B 站账号所需 Cookie 数据
|
||||||
|
未设置 Bot 主人 QQ
|
||||||
不使用 HTTP 代理
|
不使用 HTTP 代理
|
||||||
不开启 HTTP API 推送
|
不开启 HTTP API 推送
|
||||||
无命令触发前缀
|
无命令触发前缀
|
||||||
@@ -158,6 +167,9 @@ def use_full_config():
|
|||||||
自动检测最新版本
|
自动检测最新版本
|
||||||
使用 Redis 默认连接配置 (host: "localhost", port: 6379, db: 0, username: "", password: "")
|
使用 Redis 默认连接配置 (host: "localhost", port: 6379, db: 0, username: "", password: "")
|
||||||
使用 MySQL 默认连接配置 (host: "localhost", port: 3306, db: "starbot", username: "root", password: "123456")
|
使用 MySQL 默认连接配置 (host: "localhost", port: 3306, db: "starbot", username: "root", password: "123456")
|
||||||
|
Mirai 连接端口 7827
|
||||||
|
未设置登录 B 站账号所需 Cookie 数据
|
||||||
|
未设置 Bot 主人 QQ
|
||||||
不使用 HTTP 代理
|
不使用 HTTP 代理
|
||||||
开启 HTTP API 推送
|
开启 HTTP API 推送
|
||||||
无命令触发前缀
|
无命令触发前缀
|
||||||
|
|||||||
Reference in New Issue
Block a user