refactor: Some model modifications and code optimization

This commit is contained in:
LWR
2022-11-23 01:07:26 +08:00
parent 35bc6a4a21
commit dc6125f142
5 changed files with 66 additions and 54 deletions
+5 -16
View File
@@ -81,20 +81,8 @@ class StarBot:
await redis.hset("LiveStatus", up.room_id, status) await redis.hset("LiveStatus", up.room_id, status)
await redis.hset("StartTime", up.room_id, base["live_time"]) await redis.hset("StartTime", up.room_id, base["live_time"])
await redis.hset_ifnotexists("EndTime", up.room_id, 0)
await redis.hset_ifnotexists("RoomDanmuCount", up.room_id, 0)
await redis.hset_ifnotexists("RoomDanmuTotal", up.room_id, 0)
await redis.hset_ifnotexists("RoomBoxCount", up.room_id, 0)
await redis.hset_ifnotexists("RoomBoxTotal", up.room_id, 0)
await redis.hset_ifnotexists("RoomBoxProfit", up.room_id, 0)
await redis.hset_ifnotexists("RoomBoxProfitTotal", up.room_id, 0)
await redis.hset_ifnotexists("RoomGiftProfit", up.room_id, 0)
await redis.hset_ifnotexists("RoomGiftTotal", up.room_id, 0)
await redis.hset_ifnotexists("RoomScProfit", up.room_id, 0)
await redis.hset_ifnotexists("RoomScTotal", up.room_id, 0)
await redis.hset_ifnotexists("RoomGuardCount", up.room_id, "0-0-0")
await redis.hset_ifnotexists("RoomGuardTotal", up.room_id, "0-0-0")
# 连接直播间
for up in self.__datasource.get_up_list(): for up in self.__datasource.get_up_list():
try: try:
await up.connect() await up.connect()
@@ -105,7 +93,7 @@ class StarBot:
if config.get("USE_HTTP_API"): if config.get("USE_HTTP_API"):
asyncio.get_event_loop().create_task(http_init(self.__datasource)) asyncio.get_event_loop().create_task(http_init(self.__datasource))
# 启动 Bot # 启动消息推送模块
if not self.__datasource.bots: if not self.__datasource.bots:
logger.error("不存在需要启动的 Bot 账号, 请先在数据源中配置完毕后再重新运行") logger.error("不存在需要启动的 Bot 账号, 请先在数据源中配置完毕后再重新运行")
return return
@@ -113,8 +101,6 @@ class StarBot:
Ariadne.options["default_account"] = self.__datasource.bots[0].qq Ariadne.options["default_account"] = self.__datasource.bots[0].qq
logger.info("开始运行 Ariadne 消息推送模块") logger.info("开始运行 Ariadne 消息推送模块")
logger.disable("graia.ariadne.service")
logger.disable("launart")
for bot in self.__datasource.bots: for bot in self.__datasource.bots:
bot.start_sender() bot.start_sender()
@@ -141,6 +127,9 @@ class StarBot:
) )
logger.remove() logger.remove()
logger.add(sys.stderr, format=logger_format, level="INFO") logger.add(sys.stderr, format=logger_format, level="INFO")
logger.disable("graia.ariadne.model")
logger.disable("graia.ariadne.service")
logger.disable("launart")
bcc = create(Broadcast) bcc = create(Broadcast)
loop = bcc.loop loop = bcc.loop
+4 -5
View File
@@ -42,7 +42,7 @@ class DataSource(metaclass=abc.ABCMeta):
Raises: Raises:
DataSourceException: 配置中包含重复 uid DataSourceException: 配置中包含重复 uid
""" """
self.__up_list = [x for up in map(lambda bot: bot.ups, self.bots) for x in up] self.__up_list = [x for up in map(lambda b: b.ups, self.bots) for x in up]
self.__up_map = dict(zip(map(lambda up: up.uid, self.__up_list), self.__up_list)) self.__up_map = dict(zip(map(lambda up: up.uid, self.__up_list), self.__up_list))
self.__uid_list = list(self.__up_map.keys()) self.__uid_list = list(self.__up_map.keys())
if len(set(self.__uid_list)) < len(self.__uid_list): if len(set(self.__uid_list)) < len(self.__uid_list):
@@ -195,7 +195,6 @@ class MySQLDataSource(DataSource):
self.__port = port or int(config.get("MYSQL_PORT")) self.__port = port or int(config.get("MYSQL_PORT"))
self.__db = db or config.get("MYSQL_DB") self.__db = db or config.get("MYSQL_DB")
self.__pool: Optional[aiomysql.pool.Pool] = None self.__pool: Optional[aiomysql.pool.Pool] = None
self.__loop = asyncio.get_event_loop()
async def __connect(self): async def __connect(self):
""" """
@@ -205,6 +204,7 @@ class MySQLDataSource(DataSource):
DataSourceException: 连接数据库失败 DataSourceException: 连接数据库失败
""" """
try: try:
self.__loop = asyncio.get_event_loop()
self.__pool = await aiomysql.create_pool(host=self.__host, self.__pool = await aiomysql.create_pool(host=self.__host,
port=self.__port, port=self.__port,
user=self.__username, user=self.__username,
@@ -265,7 +265,7 @@ class MySQLDataSource(DataSource):
uid = now_user.get("uid") uid = now_user.get("uid")
live_on = await self.__query( live_on = await self.__query(
"SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `at_all`, `message` " "SELECT g.`uid`, g.`uname`, g.`room_id`, `key`, `type`, `num`, `enabled`, `message` "
"FROM `groups` AS `g` LEFT JOIN `live_on` AS `l` " "FROM `groups` AS `g` LEFT JOIN `live_on` AS `l` "
"ON g.`uid` = l.`uid` AND g.`index` = l.`index` " "ON g.`uid` = l.`uid` AND g.`index` = l.`index` "
f"WHERE g.`uid` = {uid} " f"WHERE g.`uid` = {uid} "
@@ -295,9 +295,8 @@ class MySQLDataSource(DataSource):
targets = [] targets = []
for i, target in enumerate(live_on): for i, target in enumerate(live_on):
if all((live_on[i]["enabled"], live_on[i]["at_all"], live_on[i]["message"])): if all((live_on[i]["enabled"], live_on[i]["message"])):
on = LiveOn(enabled=live_on[i]["enabled"], on = LiveOn(enabled=live_on[i]["enabled"],
at_all=live_on[i]["at_all"],
message=live_on[i]["message"]) message=live_on[i]["message"])
else: else:
on = LiveOn() on = LiveOn()
-8
View File
@@ -1,8 +0,0 @@
s = "Unccl oveguqnl gb zr!"
d = {}
for c in (65, 97):
for i in range(26):
d[chr(i+c)] = chr((i+13) % 26 + c)
print("".join([d.get(c, c) for c in s]))
+22 -10
View File
@@ -6,6 +6,8 @@ from graia.ariadne.message.chain import MessageChain
from graia.ariadne.message.element import Plain, At, AtAll, Image from graia.ariadne.message.element import Plain, At, AtAll, Image
from pydantic import BaseModel, PrivateAttr from pydantic import BaseModel, PrivateAttr
from ..exception import DataSourceException
class LiveOn(BaseModel): class LiveOn(BaseModel):
""" """
@@ -20,13 +22,11 @@ class LiveOn(BaseModel):
enabled: Optional[bool] = False enabled: Optional[bool] = False
"""是否启用开播推送。默认:False""" """是否启用开播推送。默认:False"""
at_all: Optional[bool] = False
"""是否 @全体成员。默认:False"""
message: Optional[str] = "" message: Optional[str] = ""
""" """
开播推送内容模板。 开播推送内容模板。
用占位符:{uname}主播昵称,{title}直播间标题,{url}直播间链接,{cover}直播间封面图{next}消息分条 用占位符:{uname} 主播昵称,{title} 直播间标题,{url} 直播间链接,{cover} 直播间封面图。
通用占位符:{next} 消息分条,{atall} @全体成员,{at114514} @指定QQ号,{urlpic=链接} 网络图片,{pathpic=路径} 本地图片。
默认:"" 默认:""
""" """
@@ -34,12 +34,12 @@ class LiveOn(BaseModel):
def default(cls): def default(cls):
""" """
获取功能全部开启的默认 LiveOn 实例 获取功能全部开启的默认 LiveOn 实例
默认配置:启用开播推送,启用 @全体成员,推送内容模板为 "{uname} 正在直播 {title}\n{url}{next}{cover}" 默认配置:启用开播推送,推送内容模板为 "{uname} 正在直播 {title}\n{url}{next}{cover}"
""" """
return LiveOn(enabled=True, at_all=True, message=LiveOn.DEFAULT_MESSAGE) return LiveOn(enabled=True, message=LiveOn.DEFAULT_MESSAGE)
def __str__(self): def __str__(self):
return f"启用: {self.enabled}\n@全体: {self.at_all}\n推送内容:\n{self.message}" return f"启用: {self.enabled}\n推送内容:\n{self.message}"
class LiveOff(BaseModel): class LiveOff(BaseModel):
@@ -49,7 +49,7 @@ class LiveOff(BaseModel):
或使用 LiveOff.default() 获取功能全部开启的默认配置 或使用 LiveOff.default() 获取功能全部开启的默认配置
""" """
DEFAULT_MESSAGE: Optional[str] = "{uname} 直播结束了\n{time}{next}{danmu_count}{danmu_mvp}{box_profit}" DEFAULT_MESSAGE: Optional[str] = "{uname} 直播结束了"
"""默认消息模板""" """默认消息模板"""
enabled: Optional[bool] = False enabled: Optional[bool] = False
@@ -58,7 +58,8 @@ class LiveOff(BaseModel):
message: Optional[str] = "" message: Optional[str] = ""
""" """
下播推送内容模板。 下播推送内容模板。
用占位符:{uname}主播昵称{time}本次直播时长,{danmu_count}弹幕总数,{danmu_mvp}弹幕MVP{box_profit}宝盒盈亏,{next}消息分条 用占位符:{uname}主播昵称。
通用占位符:{next} 消息分条,{atall} @全体成员,{at114514} @指定QQ号,{urlpic=链接} 网络图片,{pathpic=路径} 本地图片。
默认:"" 默认:""
""" """
@@ -115,7 +116,8 @@ class DynamicUpdate(BaseModel):
message: Optional[str] = "" message: Optional[str] = ""
""" """
动态推送内容模板。 动态推送内容模板。
用占位符:{uname}主播昵称,{action}动态操作类型(发表了新动态,转发了新动态,投稿了新视频...),{url}动态链接(若为发表视频、专栏等则为视频、专栏等对应的链接)。 用占位符:{uname}主播昵称,{action}动态操作类型(发表了新动态,转发了新动态,投稿了新视频...),{url}动态链接(若为发表视频、专栏等则为视频、专栏等对应的链接)。
通用占位符:{next} 消息分条,{atall} @全体成员,{at114514} @指定QQ号,{urlpic=链接} 网络图片,{pathpic=路径} 本地图片。
默认:"" 默认:""
""" """
@@ -174,6 +176,15 @@ class PushTarget(BaseModel):
super().__init__(**data) super().__init__(**data)
if not self.key: if not self.key:
self.key = "-".join([str(self.id), str(self.type.value)]) self.key = "-".join([str(self.id), str(self.type.value)])
self.__raise_for_not_invalid_placeholders()
def __raise_for_not_invalid_placeholders(self):
"""
使用不合法的占位符时抛出异常
"""
if self.type == PushType.Friend:
if "{at" in self.live_on.message or "{at" in self.live_off.message or "{at" in self.dynamic_update.message:
raise DataSourceException(f"好友类型的推送目标 (QQ: {self.id}) 推送内容中不能含有 @ 消息, 请检查配置后重试")
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, PushTarget): if isinstance(other, PushTarget):
@@ -267,6 +278,7 @@ class Message(BaseModel):
chain.append(Plain(msg[:code_end + 1])) chain.append(Plain(msg[:code_end + 1]))
msg = msg[code_end + 1:] msg = msg[code_end + 1:]
next_code = msg.find("{") next_code = msg.find("{")
if len(chain) != 0:
chains.append(chain) chains.append(chain)
return chains return chains
+33 -13
View File
@@ -1,4 +1,4 @@
from typing import Any, Union from typing import Any, Union, Optional, List
import aioredis import aioredis
from loguru import logger from loguru import logger
@@ -27,12 +27,14 @@ async def init():
# String # String
async def get(key: str) -> str: async def delete(key: str):
return str(await __redis.get(key)) await __redis.delete(key)
async def geti(key: str) -> int: # List
return int(await __redis.get(key))
async def rpush(key: str, value: Any):
await __redis.rpush(key, value)
# Hash # Hash
@@ -41,18 +43,36 @@ async def hexists(key: str, hkey: Union[str, int]) -> bool:
return await __redis.hexists(key, hkey) return await __redis.hexists(key, hkey)
async def hget(key: str, hkey: Union[str, int]) -> str:
return str(await __redis.hget(key, hkey))
async def hgeti(key: str, hkey: Union[str, int]) -> int: async def hgeti(key: str, hkey: Union[str, int]) -> int:
return int(await __redis.hget(key, hkey)) result = await __redis.hget(key, hkey)
if result is None:
return 0
return int(result)
async def hgetf1(key: str, hkey: Union[str, int]) -> float:
result = await __redis.hget(key, hkey)
if result is None:
return 0.0
return float("{:.1f}".format(float(result)))
async def hset(key: str, hkey: Union[str, int], value: Any): async def hset(key: str, hkey: Union[str, int], value: Any):
await __redis.hset(key, hkey, value) await __redis.hset(key, hkey, value)
async def hset_ifnotexists(key: str, hkey: Union[str, int], value: Any): async def hincrby(key: str, hkey: Union[str, int], value: Optional[int] = 1) -> int:
if not await hexists(key, hkey): return await __redis.hincrby(key, hkey, value)
await hset(key, hkey, value)
async def hincrbyfloat(key: str, hkey: Union[str, int], value: Optional[float] = 1.0) -> float:
return await __redis.hincrbyfloat(key, hkey, value)
# Zset
async def zunionstore(dest: str, source: Union[str, List[str]]):
if isinstance(source, str):
await __redis.zunionstore(dest, [dest, source])
if isinstance(source, list):
await __redis.zunionstore(dest, source)