Merge remote-tracking branch 'origin/dev' into dev

This commit is contained in:
LWR
2023-06-19 22:22:41 +08:00
3 changed files with 63 additions and 2 deletions

View File

@@ -1,3 +1,4 @@
import os
import abc
import asyncio
import json
@@ -194,17 +195,19 @@ class JsonDataSource(DataSource):
"""
从 JSON 字符串初始化的 Bot 推送配置数据源
"""
def __init__(self, json_file: Optional[str] = None, json_str: Optional[str] = None):
def __init__(self, json_file: Optional[str] = None, json_str: Optional[str] = None, credential_file: Optional[str] = None):
"""
Args:
json_file: JSON 文件路径,两个参数任选其一传入,全部传入优先使用 json_str
json_str: JSON 配置字符串,两个参数任选其一传入,全部传入优先使用 json_str
credential_file: B站凭据 JSON 文件路径不填默认从运行目录下的credential.json中读取凭据
"""
super().__init__()
self.__config = None
self.__json_file = json_file
self.__json_str = json_str
self.__credential_file = credential_file
async def load(self):
"""
@@ -249,6 +252,15 @@ class JsonDataSource(DataSource):
super().format_data()
logger.success(f"成功从 JSON 中导入了 {len(self.get_up_list())} 个 UP 主")
# 判断用户是否已通过config.set_credential设置凭据若已设置则跳过设置
if config.get("SESSDATA") is None or config.get("BILI_JCT") is None or config.get("BUVID3") is None:
# 用户不填credential_file字段时默认运行目录下credential.json
if self.__credential_file is None:
# 判断运行目录下是否存在credential.json若不存在则不调set_credential_from_json
if os.path.exists("credential.json"):
config.set_credential_from_json("credential.json")
else:
config.set_credential_from_json(self.__credential_file)
class MySQLDataSource(DataSource):
"""

View File

@@ -0,0 +1,11 @@
from .ApiException import ApiException
class CredentialFromJSONException(ApiException):
"""
从JSON文件读取Credential时发生的异常
"""
def __init__(self, msg: str):
super().__init__()
self.msg = msg

View File

@@ -1,4 +1,7 @@
from typing import Any
from loguru import logger
import json
from typing import Any, Optional
from ..exception.CredentialFromJSONException import CredentialFromJSONException
SIMPLE_CONFIG = {
# 是否检测最新 StarBot 版本
@@ -312,6 +315,41 @@ def set_credential(sessdata: str, bili_jct: str, buvid3: str):
set("BILI_JCT", bili_jct)
set("BUVID3", buvid3)
def set_credential_from_json(json_file: Optional[str] = None, json_str: Optional[str] = None):
"""
从JSON读取B站credential
Args:
json_file: JSON 文件路径,两个参数任选其一传入,全部传入优先使用 json_str
json_str: JSON 配置字符串,两个参数任选其一传入,全部传入优先使用 json_str
Raises:
CredentialFromJSONException: JSON 格式错误或缺少必要参数
"""
if json_str is None:
try:
with open(json_file, "r", encoding="utf-8") as file:
json_str = file.read()
except FileNotFoundError:
logger.error("B站凭据 JSON 文件不存在, 请检查文件路径是否正确")
raise CredentialFromJSONException("B站凭据 JSON 文件不存在, 请检查文件路径是否正确")
except UnicodeDecodeError:
logger.error("B站凭据 JSON 文件编码不正确, 请将其转换为 UTF-8 格式编码后重试")
raise CredentialFromJSONException("B站凭据 JSON 文件编码不正确, 请将其转换为 UTF-8 格式编码后重试")
except Exception as ex:
logger.error(f"读取B站凭据 JSON 文件异常 {ex}")
raise CredentialFromJSONException(f"读取B站凭据 JSON 文件异常 {ex}")
try:
config = json.loads(json_str)
except Exception:
logger.error("提供的B站凭据 JSON 字符串格式不正确")
raise CredentialFromJSONException("提供的B站凭据 JSON 字符串格式不正确")
set("SESSDATA",config["sessdata"])
set("BILI_JCT",config["bili_jct"])
set("BUVID3",config["buvid3"])
logger.success("成功从JSON中导入了B站凭据")
def get(key: str) -> Any:
"""