refactor: split some logic into util func

This commit is contained in:
2026-01-14 21:42:32 +08:00
parent 5a60e381d7
commit e842a9bc1d
14 changed files with 500 additions and 175 deletions
+123
View File
@@ -0,0 +1,123 @@
"""
数据库操作辅助函数
提供统一的资源查询、权限验证等通用功能
"""
from typing import TypeVar, Type, Optional, Any
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
T = TypeVar('T')
def get_or_404(
model: Type[T],
model_id: int,
db: Session,
error_message: Optional[str] = None
) -> T:
"""
查询资源,不存在则抛出 404
Args:
model: SQLAlchemy 模型类
model_id: 资源 ID
db: 数据库会话
error_message: 自定义错误消息
Returns:
查询到的资源对象
Raises:
HTTPException: 404 资源不存在
"""
obj = db.query(model).filter(model.id == model_id).first()
if not obj:
default_message = f"{model.__name__}不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=error_message or default_message
)
return obj
def get_owned_or_403(
model: Type[T],
model_id: int,
user_id: int,
db: Session,
error_message: Optional[str] = None
) -> T:
"""
查询资源并验证归属,否则抛出 403
Args:
model: SQLAlchemy 模型类(必须有 user_id 字段)
model_id: 资源 ID
user_id: 当前用户 ID
db: 数据库会话
error_message: 自定义错误消息
Returns:
查询到的资源对象
Raises:
HTTPException: 403 无权访问此资源
"""
obj = db.query(model).filter(
model.id == model_id,
model.user_id == user_id
).first()
if not obj:
# 先检查资源是否存在
exists = db.query(model).filter(model.id == model_id).first()
if not exists:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"{model.__name__}不存在"
)
# 资源存在但不属于当前用户
default_message = f"无权访问此{model.__name__}"
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=error_message or default_message
)
return obj
def get_by_field_or_404(
model: Type[T],
field_name: str,
field_value: Any,
db: Session,
error_message: Optional[str] = None
) -> T:
"""
根据字段查询资源,不存在则抛出 404
Args:
model: SQLAlchemy 模型类
field_name: 字段名
field_value: 字段值
db: 数据库会话
error_message: 自定义错误消息
Returns:
查询到的资源对象
Raises:
HTTPException: 404 资源不存在
"""
obj = db.query(model).filter(
getattr(model, field_name) == field_value
).first()
if not obj:
default_message = f"{model.__name__}不存在"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=error_message or default_message
)
return obj
+103
View File
@@ -0,0 +1,103 @@
"""
JSON 处理辅助函数
提供安全的 JSON 解析和数据提取功能
"""
import json
import logging
from typing import Optional, Any, Dict
logger = logging.getLogger(__name__)
def safe_parse_json(
json_str: Optional[str],
default: Any = None,
log_error: bool = True
) -> Any:
"""
安全解析 JSON 字符串,失败时返回默认值
Args:
json_str: JSON 字符串
default: 解析失败时的默认值
log_error: 是否记录解析错误日志
Returns:
解析后的对象,失败时返回 default
"""
if not json_str:
return default
try:
return json.loads(str(json_str))
except (json.JSONDecodeError, AttributeError, TypeError) as e:
if log_error:
logger.debug(f"JSON 解析失败: {str(e)}, 原始数据: {json_str[:100]}...")
return default
def safe_parse_payload(
payload_config: Optional[str],
default: Optional[Dict] = None
) -> Dict:
"""
安全解析 payload_config,失败时返回默认字典
Args:
payload_config: payload 配置字符串
default: 解析失败时的默认值
Returns:
解析后的字典
"""
result = safe_parse_json(payload_config, default or {})
# 确保返回值是字典类型
if not isinstance(result, dict):
logger.warning(f"payload_config 不是字典类型: {type(result)}")
return default or {}
return result
def extract_thread_id(payload_config: Optional[str]) -> Optional[str]:
"""
从 payload_config 中提取 ThreadId
Args:
payload_config: payload 配置字符串
Returns:
ThreadId 或 None
"""
payload = safe_parse_payload(payload_config)
return payload.get('ThreadId')
def extract_signature(payload_config: Optional[str]) -> Optional[str]:
"""
从 payload_config 中提取 Signature
Args:
payload_config: payload 配置字符串
Returns:
Signature 或 None
"""
payload = safe_parse_payload(payload_config)
return payload.get('Signature')
def build_task_info(task) -> Dict[str, str]:
"""
从 task 对象构建 task_info 字典(用于邮件通知等场景)
Args:
task: CheckInTask 对象
Returns:
包含 thread_id 和 name 的字典
"""
return {
'thread_id': extract_thread_id(getattr(task, 'payload_config', None)) or '未知',
'name': getattr(task, 'name', None) or f'Task-{getattr(task, "id", "Unknown")}'
}
+148
View File
@@ -0,0 +1,148 @@
"""
时间处理辅助函数
提供统一的时间戳处理和格式化功能
"""
from datetime import datetime, timedelta
from typing import Optional
def now_timestamp() -> int:
"""
获取当前时间戳(秒)
Returns:
当前时间戳
"""
return int(datetime.now().timestamp())
def is_timestamp_expired(timestamp: int) -> bool:
"""
检查时间戳是否已过期
Args:
timestamp: 时间戳(秒)
Returns:
是否已过期
"""
return now_timestamp() > timestamp
def seconds_until_expiry(timestamp: int) -> int:
"""
计算距离过期的秒数(负数表示已过期)
Args:
timestamp: 时间戳(秒)
Returns:
距离过期的秒数
"""
return timestamp - now_timestamp()
def days_until_expiry(timestamp: int) -> int:
"""
计算距离过期的天数(负数表示已过期)
Args:
timestamp: 时间戳(秒)
Returns:
距离过期的天数
"""
seconds = seconds_until_expiry(timestamp)
return seconds // 86400
def hours_until_expiry(timestamp: int) -> int:
"""
计算距离过期的小时数(负数表示已过期)
Args:
timestamp: 时间戳(秒)
Returns:
距离过期的小时数
"""
seconds = seconds_until_expiry(timestamp)
return seconds // 3600
def minutes_until_expiry(timestamp: int) -> int:
"""
计算距离过期的分钟数(负数表示已过期)
Args:
timestamp: 时间戳(秒)
Returns:
距离过期的分钟数
"""
seconds = seconds_until_expiry(timestamp)
return seconds // 60
def format_timestamp(timestamp: int, format_str: str = '%Y-%m-%d %H:%M:%S') -> str:
"""
格式化时间戳为人类可读格式
Args:
timestamp: 时间戳(秒)
format_str: 时间格式字符串
Returns:
格式化后的时间字符串
"""
dt = datetime.fromtimestamp(timestamp)
return dt.strftime(format_str)
def format_expiry_time(timestamp: int) -> str:
"""
格式化过期时间为人类可读格式(带中文说明)
Args:
timestamp: 时间戳(秒)
Returns:
格式化后的时间字符串,如 "2024-01-01 12:00:00 (已过期 2 天)"
"""
formatted_time = format_timestamp(timestamp)
days = days_until_expiry(timestamp)
if days > 0:
return f"{formatted_time} (还剩 {days} 天)"
elif days == 0:
hours = hours_until_expiry(timestamp)
if hours > 0:
return f"{formatted_time} (还剩 {hours} 小时)"
else:
minutes = minutes_until_expiry(timestamp)
if minutes > 0:
return f"{formatted_time} (还剩 {minutes} 分钟)"
else:
return f"{formatted_time} (即将过期)"
else:
return f"{formatted_time} (已过期 {abs(days)} 天)"
def parse_jwt_exp(jwt_exp: Optional[str]) -> Optional[int]:
"""
解析 jwt_exp 字段为时间戳
Args:
jwt_exp: jwt_exp 字符串(可能是 "0" 或数字字符串)
Returns:
时间戳,无效时返回 None
"""
if not jwt_exp or jwt_exp == "0":
return None
try:
return int(jwt_exp)
except (ValueError, TypeError):
return None