refactor: Remove key attribute from PushTarget and rename groups table name
This commit is contained in:
+4
-4
@@ -59,6 +59,10 @@ class StarBot:
|
|||||||
logger.error(ex.msg)
|
logger.error(ex.msg)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not self.__datasource.bots:
|
||||||
|
logger.error("数据源配置为空, 请先在数据源中配置完毕后再重新运行")
|
||||||
|
return
|
||||||
|
|
||||||
# 连接 Redis
|
# 连接 Redis
|
||||||
try:
|
try:
|
||||||
await redis.init()
|
await redis.init()
|
||||||
@@ -120,10 +124,6 @@ class StarBot:
|
|||||||
logger.success("用户自定义命令模块载入完毕")
|
logger.success("用户自定义命令模块载入完毕")
|
||||||
|
|
||||||
# 启动消息推送模块
|
# 启动消息推送模块
|
||||||
if not self.__datasource.bots:
|
|
||||||
logger.error("不存在需要启动的 Bot 账号, 请先在数据源中配置完毕后再重新运行")
|
|
||||||
return
|
|
||||||
|
|
||||||
Ariadne.options["default_account"] = self.__datasource.bots[0].qq
|
Ariadne.options["default_account"] = self.__datasource.bots[0].qq
|
||||||
|
|
||||||
logger.info("开始运行 Ariadne 消息推送模块")
|
logger.info("开始运行 Ariadne 消息推送模块")
|
||||||
|
|||||||
+34
-70
@@ -24,9 +24,6 @@ class DataSource(metaclass=abc.ABCMeta):
|
|||||||
self.__up_list: List[Up] = []
|
self.__up_list: List[Up] = []
|
||||||
self.__up_map: Dict[int, Up] = {}
|
self.__up_map: Dict[int, Up] = {}
|
||||||
self.__uid_list: List[int] = []
|
self.__uid_list: List[int] = []
|
||||||
self.__target_list: List[PushTarget] = []
|
|
||||||
self.__target_key_map: Dict[str, PushTarget] = {}
|
|
||||||
self.__target_bot_map: Dict[str, Bot] = {}
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def load(self):
|
async def load(self):
|
||||||
@@ -47,13 +44,6 @@ class DataSource(metaclass=abc.ABCMeta):
|
|||||||
self.__uid_list = list(self.__up_map.keys())
|
self.__uid_list = list(self.__up_map.keys())
|
||||||
if len(set(self.__uid_list)) < len(self.__uid_list):
|
if len(set(self.__uid_list)) < len(self.__uid_list):
|
||||||
raise DataSourceException("配置中不可含有重复的 UID")
|
raise DataSourceException("配置中不可含有重复的 UID")
|
||||||
self.__target_list = [x for target in map(lambda up: up.targets, self.__up_list) for x in target]
|
|
||||||
self.__target_key_map = dict(zip(map(lambda target: target.key, self.__target_list), self.__target_list))
|
|
||||||
|
|
||||||
for bot in self.bots:
|
|
||||||
for up in bot.ups:
|
|
||||||
for target in up.targets:
|
|
||||||
self.__target_bot_map[target.key] = bot
|
|
||||||
|
|
||||||
def get_up_list(self) -> List[Up]:
|
def get_up_list(self) -> List[Up]:
|
||||||
"""
|
"""
|
||||||
@@ -91,12 +81,12 @@ class DataSource(metaclass=abc.ABCMeta):
|
|||||||
raise DataSourceException(f"不存在的 UID: {uid}")
|
raise DataSourceException(f"不存在的 UID: {uid}")
|
||||||
return up
|
return up
|
||||||
|
|
||||||
def get_bot(self, qq: int) -> Bot:
|
def get_bot(self, qq: Optional[int] = None) -> Bot:
|
||||||
"""
|
"""
|
||||||
根据 QQ 获取 Bot 实例
|
根据 QQ 获取 Bot 实例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
qq: 需要获取 Bot 的 QQ
|
qq: 需要获取 Bot 的 QQ,单 Bot 推送时可不传入
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Bot 实例
|
Bot 实例
|
||||||
@@ -104,6 +94,11 @@ class DataSource(metaclass=abc.ABCMeta):
|
|||||||
Raises:
|
Raises:
|
||||||
DataSourceException: QQ 不存在
|
DataSourceException: QQ 不存在
|
||||||
"""
|
"""
|
||||||
|
if qq is None:
|
||||||
|
if len(self.bots) != 1:
|
||||||
|
raise DataSourceException(f"多 Bot 推送时需明确指定要获取的 Bot QQ")
|
||||||
|
return self.bots[0]
|
||||||
|
|
||||||
bot = next((b for b in self.bots if b.qq == qq), None)
|
bot = next((b for b in self.bots if b.qq == qq), None)
|
||||||
if bot is None:
|
if bot is None:
|
||||||
raise DataSourceException(f"不存在的 QQ: {qq}")
|
raise DataSourceException(f"不存在的 QQ: {qq}")
|
||||||
@@ -130,42 +125,6 @@ class DataSource(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
return ups
|
return ups
|
||||||
|
|
||||||
def get_target_by_key(self, key: str) -> PushTarget:
|
|
||||||
"""
|
|
||||||
根据推送 key 获取 PushTarget 实例,用于 HTTP API 推送
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: 需要获取 PushTarget 的推送 key
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PushTarget 实例
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
DataSourceException: key 不存在
|
|
||||||
"""
|
|
||||||
target = self.__target_key_map.get(key)
|
|
||||||
if target is None:
|
|
||||||
raise DataSourceException(f"不存在的推送 key: {key}")
|
|
||||||
return target
|
|
||||||
|
|
||||||
def get_bot_by_key(self, key: str) -> Bot:
|
|
||||||
"""
|
|
||||||
根据推送 key 获取其所在的 Bot 实例,用于 HTTP API 推送
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: 需要获取所在 Bot 的推送 key
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Bot 实例
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
DataSourceException: key 不存在
|
|
||||||
"""
|
|
||||||
bot = self.__target_bot_map.get(key)
|
|
||||||
if bot is None:
|
|
||||||
raise DataSourceException(f"不存在的推送 key: {key}")
|
|
||||||
return bot
|
|
||||||
|
|
||||||
async def wait_for_connects(self):
|
async def wait_for_connects(self):
|
||||||
"""
|
"""
|
||||||
等待所有 Up 实例连接直播间完毕
|
等待所有 Up 实例连接直播间完毕
|
||||||
@@ -301,37 +260,37 @@ class MySQLDataSource(DataSource):
|
|||||||
推送目标列表
|
推送目标列表
|
||||||
"""
|
"""
|
||||||
live_on = await self.__query(
|
live_on = await self.__query(
|
||||||
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` "
|
"SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, `enabled`, `message` "
|
||||||
"FROM `groups` AS `g` LEFT JOIN `live_on` AS `l` "
|
"FROM `targets` AS `t` LEFT JOIN `live_on` AS `l` "
|
||||||
"ON g.`uid` = l.`uid` AND g.`index` = l.`index` "
|
"ON t.`uid` = l.`uid` AND t.`id` = l.`id` "
|
||||||
f"WHERE g.`uid` = {uid} "
|
f"WHERE t.`uid` = {uid} "
|
||||||
"ORDER BY g.`index`"
|
"ORDER BY t.`id`"
|
||||||
)
|
)
|
||||||
live_off = await self.__query(
|
live_off = await self.__query(
|
||||||
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` "
|
"SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, `enabled`, `message` "
|
||||||
"FROM `groups` AS `g` LEFT JOIN `live_off` AS `l` "
|
"FROM `targets` AS `t` LEFT JOIN `live_off` AS `l` "
|
||||||
"ON g.`uid` = l.`uid` AND g.`index` = l.`index` "
|
"ON t.`uid` = l.`uid` AND t.`id` = l.`id` "
|
||||||
f"WHERE g.`uid` = {uid} "
|
f"WHERE t.`uid` = {uid} "
|
||||||
"ORDER BY g.`index`"
|
"ORDER BY t.`id`"
|
||||||
)
|
)
|
||||||
live_report = await self.__query(
|
live_report = await self.__query(
|
||||||
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, "
|
"SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, "
|
||||||
"`enabled`, `logo`, `logo_base64`, `time`, `fans_change`, `fans_medal_change`, `guard_change`, "
|
"`enabled`, `logo`, `logo_base64`, `time`, `fans_change`, `fans_medal_change`, `guard_change`, "
|
||||||
"`danmu`, `box`, `gift`, `sc`, `guard`, "
|
"`danmu`, `box`, `gift`, `sc`, `guard`, "
|
||||||
"`danmu_ranking`, `box_ranking`, `box_profit_ranking`, `gift_ranking`, `sc_ranking`, "
|
"`danmu_ranking`, `box_ranking`, `box_profit_ranking`, `gift_ranking`, `sc_ranking`, "
|
||||||
"`guard_list`, `box_profit_diagram`, `danmu_diagram`, `box_diagram`, `gift_diagram`, "
|
"`guard_list`, `box_profit_diagram`, `danmu_diagram`, `box_diagram`, `gift_diagram`, "
|
||||||
"`sc_diagram`, `guard_diagram`, `danmu_cloud` "
|
"`sc_diagram`, `guard_diagram`, `danmu_cloud` "
|
||||||
"FROM `groups` AS `g` LEFT JOIN `live_report` AS `l` "
|
"FROM `targets` AS `t` LEFT JOIN `live_report` AS `l` "
|
||||||
"ON g.`uid` = l.`uid` AND g.`index` = l.`index` "
|
"ON t.`uid` = l.`uid` AND t.`id` = l.`id` "
|
||||||
f"WHERE g.`uid` = {uid} "
|
f"WHERE t.`uid` = {uid} "
|
||||||
"ORDER BY g.`index`"
|
"ORDER BY t.`id`"
|
||||||
)
|
)
|
||||||
dynamic_update = await self.__query(
|
dynamic_update = await self.__query(
|
||||||
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` "
|
"SELECT t.`uid`, t.`uname`, t.`room_id`, `type`, `num`, `enabled`, `message` "
|
||||||
"FROM `groups` AS `g` LEFT JOIN `dynamic_update` AS `d` "
|
"FROM `targets` AS `t` LEFT JOIN `dynamic_update` AS `d` "
|
||||||
"ON g.`uid` = d.`uid` AND g.`index` = d.`index` "
|
"ON t.`uid` = d.`uid` AND t.`id` = d.`id` "
|
||||||
f"WHERE g.`uid` = {uid} "
|
f"WHERE t.`uid` = {uid} "
|
||||||
"ORDER BY g.`index`"
|
"ORDER BY t.`id`"
|
||||||
)
|
)
|
||||||
|
|
||||||
targets = []
|
targets = []
|
||||||
@@ -430,15 +389,20 @@ class MySQLDataSource(DataSource):
|
|||||||
Args:
|
Args:
|
||||||
uid: 需要追加读取配置的 UID
|
uid: 需要追加读取配置的 UID
|
||||||
"""
|
"""
|
||||||
|
if uid in self.get_uid_list():
|
||||||
|
raise DataSourceException(f"载入 UID: {uid} 的推送配置失败, 不可重复载入")
|
||||||
|
|
||||||
user = await self.__query(f"SELECT * FROM `bot` WHERE uid = {uid}")
|
user = await self.__query(f"SELECT * FROM `bot` WHERE uid = {uid}")
|
||||||
if len(user) == 0:
|
if len(user) == 0:
|
||||||
logger.error(f"载入 UID: {uid} 的推送配置失败, UID 不存在")
|
logger.error(f"载入 UID: {uid} 的推送配置失败, UID 不存在")
|
||||||
raise DataSourceException(f"载入 UID: {uid} 的推送配置失败, UID 不存在")
|
raise DataSourceException(f"载入 UID: {uid} 的推送配置失败, UID 不存在")
|
||||||
|
|
||||||
bot = user[0].get("bot")
|
qq = user[0].get("bot")
|
||||||
targets = await self.__load_targets(uid)
|
targets = await self.__load_targets(uid)
|
||||||
up = Up(uid=uid, targets=targets)
|
up = Up(uid=uid, targets=targets)
|
||||||
self.get_bot(bot).ups.append(up)
|
bot = self.get_bot(qq)
|
||||||
|
bot.ups.append(up)
|
||||||
|
up.inject_bot(bot)
|
||||||
super().format_data()
|
super().format_data()
|
||||||
logger.success(f"已成功载入 UID: {uid} 的推送配置")
|
logger.success(f"已成功载入 UID: {uid} 的推送配置")
|
||||||
|
|
||||||
|
|||||||
@@ -264,13 +264,8 @@ class PushTarget(BaseModel):
|
|||||||
dynamic_update: Optional[DynamicUpdate] = DynamicUpdate()
|
dynamic_update: Optional[DynamicUpdate] = DynamicUpdate()
|
||||||
"""动态推送配置。默认:DynamicUpdate()"""
|
"""动态推送配置。默认:DynamicUpdate()"""
|
||||||
|
|
||||||
key: Optional[str] = None
|
|
||||||
"""推送 Key,可选功能,可使用此 Key 通过 HTTP API 向对应的好友或群推送消息。默认:str(id)-str(type)"""
|
|
||||||
|
|
||||||
def __init__(self, **data: Any):
|
def __init__(self, **data: Any):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
if not self.key:
|
|
||||||
self.key = "-".join([str(self.id), str(self.type.value)])
|
|
||||||
self.__raise_for_not_invalid_placeholders()
|
self.__raise_for_not_invalid_placeholders()
|
||||||
|
|
||||||
def __raise_for_not_invalid_placeholders(self):
|
def __raise_for_not_invalid_placeholders(self):
|
||||||
@@ -287,7 +282,7 @@ class PushTarget(BaseModel):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.key)
|
return hash(self.id) ^ hash(self.type.value)
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
|
|||||||
+50
-14
@@ -3,10 +3,11 @@ from typing import Optional
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.web_routedef import RouteTableDef
|
from aiohttp.web_routedef import RouteTableDef
|
||||||
|
from graia.ariadne.exception import UnknownTarget
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .datasource import DataSource
|
from .datasource import DataSource
|
||||||
from .model import Message
|
from .model import Message, PushType
|
||||||
from ..exception import DataSourceException
|
from ..exception import DataSourceException
|
||||||
from ..utils import config
|
from ..utils import config
|
||||||
|
|
||||||
@@ -14,21 +15,55 @@ routes = web.RouteTableDef()
|
|||||||
datasource: Optional[DataSource] = None
|
datasource: Optional[DataSource] = None
|
||||||
|
|
||||||
|
|
||||||
@routes.get("/send/{key}/{message}")
|
@routes.get("/send/{type}/{key}/{message}")
|
||||||
async def send(request: aiohttp.web.Request) -> aiohttp.web.Response:
|
async def send(request: aiohttp.web.Request, qq: int = None) -> aiohttp.web.Response:
|
||||||
key = request.match_info['key']
|
if len(datasource.bots) == 1:
|
||||||
message = request.match_info['message']
|
bot = datasource.get_bot()
|
||||||
|
else:
|
||||||
|
if qq is None:
|
||||||
|
qq = config.get("HTTP_API_DEAFULT_BOT")
|
||||||
|
if qq is None:
|
||||||
|
logger.warning("HTTP API 推送失败, 多 Bot 推送时使用 HTTP API 需填写 HTTP_API_DEAFULT_BOT 配置项")
|
||||||
|
return web.Response(text="fail")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
target = datasource.get_target_by_key(key)
|
bot = datasource.get_bot(qq)
|
||||||
bot = datasource.get_bot_by_key(key)
|
|
||||||
msg = Message(id=target.id, content=message, type=target.type)
|
|
||||||
await bot.send_message(msg)
|
|
||||||
return web.Response(text="success")
|
|
||||||
except DataSourceException:
|
except DataSourceException:
|
||||||
logger.warning(f"HTTP API 推送失败, 不存在的推送 key: {key}")
|
logger.warning("HTTP API 推送失败, 填写的 HTTP_API_DEAFULT_BOT 配置项不正确")
|
||||||
return web.Response(text="fail")
|
return web.Response(text="fail")
|
||||||
|
|
||||||
|
if not str(request.match_info['key']).isdigit():
|
||||||
|
logger.warning("HTTP API 推送失败, 传入的 QQ 或群号格式不正确")
|
||||||
|
return web.Response(text="fail")
|
||||||
|
|
||||||
|
type_map = {
|
||||||
|
"friend": PushType.Friend,
|
||||||
|
"group": PushType.Group
|
||||||
|
}
|
||||||
|
_type = type_map.get(str(request.match_info['type']), None)
|
||||||
|
if _type is None:
|
||||||
|
logger.warning("HTTP API 推送失败, 传入的推送类型格式不正确")
|
||||||
|
return web.Response(text="fail")
|
||||||
|
|
||||||
|
key = int(request.match_info['key'])
|
||||||
|
message = Message(id=key, content=str(request.match_info['message']), type=_type)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await bot.send_message(message)
|
||||||
|
except UnknownTarget:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return web.Response(text="success")
|
||||||
|
|
||||||
|
|
||||||
|
@routes.get("/send/{bot}/{type}/{key}/{message}")
|
||||||
|
async def send_by_bot(request: aiohttp.web.Request) -> aiohttp.web.Response:
|
||||||
|
if not str(request.match_info['bot']).isdigit():
|
||||||
|
logger.warning("HTTP API 推送失败, 传入的 Bot QQ 格式不正确")
|
||||||
|
return web.Response(text="fail")
|
||||||
|
|
||||||
|
return await send(request, int(request.match_info['bot']))
|
||||||
|
|
||||||
|
|
||||||
def get_routes() -> RouteTableDef:
|
def get_routes() -> RouteTableDef:
|
||||||
"""
|
"""
|
||||||
@@ -43,6 +78,7 @@ def get_routes() -> RouteTableDef:
|
|||||||
async def http_init(source: DataSource):
|
async def http_init(source: DataSource):
|
||||||
global datasource
|
global datasource
|
||||||
datasource = source
|
datasource = source
|
||||||
|
port = config.get("HTTP_API_PORT")
|
||||||
|
|
||||||
logger.info("开始启动 HTTP API 推送服务")
|
logger.info("开始启动 HTTP API 推送服务")
|
||||||
|
|
||||||
@@ -50,10 +86,10 @@ async def http_init(source: DataSource):
|
|||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
runner = web.AppRunner(app)
|
runner = web.AppRunner(app)
|
||||||
await runner.setup()
|
await runner.setup()
|
||||||
site = web.TCPSite(runner, 'localhost', config.get("HTTP_API_PORT"))
|
site = web.TCPSite(runner, 'localhost', port)
|
||||||
try:
|
try:
|
||||||
await site.start()
|
await site.start()
|
||||||
except OSError:
|
except OSError:
|
||||||
logger.error(f"设定的 HTTP API 端口 {config.get('HTTP_API_PORT')} 已被占用, HTTP API 推送服务启动失败")
|
logger.error(f"设定的 HTTP API 端口 {port} 已被占用, HTTP API 推送服务启动失败")
|
||||||
return
|
return
|
||||||
logger.success("成功启动 HTTP API 推送服务")
|
logger.success(f"成功启动 HTTP API 推送服务: http://localhost:{port}")
|
||||||
|
|||||||
@@ -81,6 +81,8 @@ SIMPLE_CONFIG = {
|
|||||||
"USE_HTTP_API": False,
|
"USE_HTTP_API": False,
|
||||||
# HTTP API 端口
|
# HTTP API 端口
|
||||||
"HTTP_API_PORT": 8088,
|
"HTTP_API_PORT": 8088,
|
||||||
|
# 默认 HTTP API 推送 Bot QQ,多 Bot 推送时必填
|
||||||
|
"HTTP_API_DEAFULT_BOT": None,
|
||||||
|
|
||||||
# 命令触发前缀
|
# 命令触发前缀
|
||||||
"COMMAND_PREFIX": "",
|
"COMMAND_PREFIX": "",
|
||||||
@@ -175,9 +177,11 @@ FULL_CONFIG = {
|
|||||||
"PROXY": "",
|
"PROXY": "",
|
||||||
|
|
||||||
# 是否使用 HTTP API 推送
|
# 是否使用 HTTP API 推送
|
||||||
"USE_HTTP_API": True,
|
"USE_HTTP_API": False,
|
||||||
# HTTP API 端口
|
# HTTP API 端口
|
||||||
"HTTP_API_PORT": 8088,
|
"HTTP_API_PORT": 8088,
|
||||||
|
# 默认 HTTP API 推送 Bot QQ,多 Bot 推送时必填
|
||||||
|
"HTTP_API_DEAFULT_BOT": None,
|
||||||
|
|
||||||
# 命令触发前缀
|
# 命令触发前缀
|
||||||
"COMMAND_PREFIX": "",
|
"COMMAND_PREFIX": "",
|
||||||
|
|||||||
@@ -121,6 +121,10 @@ async def hincrbyfloat(key: str, hkey: Union[str, int], value: float = 1.0) -> f
|
|||||||
return await __redis.hincrbyfloat(key, hkey, value)
|
return await __redis.hincrbyfloat(key, hkey, value)
|
||||||
|
|
||||||
|
|
||||||
|
async def hdel(key: str, hkey: Union[str, int]):
|
||||||
|
await __redis.hdel(key, hkey)
|
||||||
|
|
||||||
|
|
||||||
# Set
|
# Set
|
||||||
|
|
||||||
async def scard(key: str) -> int:
|
async def scard(key: str) -> int:
|
||||||
|
|||||||
Reference in New Issue
Block a user