style(backend): apply ruff format

This commit is contained in:
2026-05-03 18:14:23 +08:00
parent 738217d9c9
commit ab68f019c5
41 changed files with 960 additions and 970 deletions
+58 -67
View File
@@ -18,6 +18,7 @@ router = APIRouter()
class BatchToggleTasksRequest(BaseModel): class BatchToggleTasksRequest(BaseModel):
"""批量启用/禁用任务请求""" """批量启用/禁用任务请求"""
task_ids: List[int] task_ids: List[int]
is_active: bool is_active: bool
@@ -26,7 +27,7 @@ class BatchToggleTasksRequest(BaseModel):
async def batch_toggle_tasks( async def batch_toggle_tasks(
request: BatchToggleTasksRequest, request: BatchToggleTasksRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
批量启用或禁用任务的自动打卡功能(需要管理员权限) 批量启用或禁用任务的自动打卡功能(需要管理员权限)
@@ -47,12 +48,11 @@ async def batch_toggle_tasks(
return { return {
"success": True, "success": True,
"message": f"{'启用' if request.is_active else '禁用'} {count} 个任务", "message": f"{'启用' if request.is_active else '禁用'} {count} 个任务",
"count": count "count": count,
} }
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量操作失败: {str(e)}"
detail=f"批量操作失败: {str(e)}"
) )
@@ -60,7 +60,7 @@ async def batch_toggle_tasks(
async def batch_check_in( async def batch_check_in(
request: BatchCheckInRequest, request: BatchCheckInRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
批量触发任务打卡(需要管理员权限) 批量触发任务打卡(需要管理员权限)
@@ -74,15 +74,14 @@ async def batch_check_in(
return result return result
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量打卡失败: {str(e)}"
detail=f"批量打卡失败: {str(e)}"
) )
@router.get("/logs", summary="获取系统日志") @router.get("/logs", summary="获取系统日志")
async def get_system_logs( async def get_system_logs(
lines: int = Query(200, ge=1, le=2000, description="读取的日志行数"), lines: int = Query(200, ge=1, le=2000, description="读取的日志行数"),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
获取系统日志(需要管理员权限) 获取系统日志(需要管理员权限)
@@ -95,40 +94,34 @@ async def get_system_logs(
log_file = settings.LOG_FILE log_file = settings.LOG_FILE
if not log_file.exists(): if not log_file.exists():
return { return {"success": True, "message": "日志文件不存在", "logs": "日志文件不存在"}
"success": True,
"message": "日志文件不存在",
"logs": "日志文件不存在"
}
# 使用 deque 高效读取最后 N 行,避免将整个文件加载到内存 # 使用 deque 高效读取最后 N 行,避免将整个文件加载到内存
from collections import deque from collections import deque
with open(log_file, 'r', encoding='utf-8', errors='ignore') as f: with open(log_file, "r", encoding="utf-8", errors="ignore") as f:
# 使用 deque 保持最后 N 行,内存占用固定 # 使用 deque 保持最后 N 行,内存占用固定
last_lines = deque(f, maxlen=lines) last_lines = deque(f, maxlen=lines)
# 返回字符串格式(不是数组) # 返回字符串格式(不是数组)
log_content = ''.join(last_lines) log_content = "".join(last_lines)
return { return {
"success": True, "success": True,
"message": f"读取了最后 {len(last_lines)} 行日志", "message": f"读取了最后 {len(last_lines)} 行日志",
"logs": log_content "logs": log_content,
} }
except Exception as e: except Exception as e:
logger.error(f"读取日志失败: {str(e)}") logger.error(f"读取日志失败: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"读取日志失败: {str(e)}"
detail=f"读取日志失败: {str(e)}"
) )
@router.get("/stats", summary="获取系统统计") @router.get("/stats", summary="获取系统统计")
async def get_system_stats( async def get_system_stats(
db: Session = Depends(get_db), db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_user)
current_user: User = Depends(get_current_admin_user)
): ):
""" """
获取系统统计信息(需要管理员权限) 获取系统统计信息(需要管理员权限)
@@ -159,33 +152,39 @@ async def get_system_stats(
# 今日打卡记录数 # 今日打卡记录数
today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) today_start = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
today_records = db.query(CheckInRecord).filter( today_records = (
CheckInRecord.check_in_time >= today_start db.query(CheckInRecord).filter(CheckInRecord.check_in_time >= today_start).count()
).count() )
# 今日成功打卡数 # 今日成功打卡数
today_success = db.query(CheckInRecord).filter( today_success = (
CheckInRecord.check_in_time >= today_start, db.query(CheckInRecord)
CheckInRecord.status == "success" .filter(CheckInRecord.check_in_time >= today_start, CheckInRecord.status == "success")
).count() .count()
)
# 今日失败打卡数 # 今日失败打卡数
today_failure = db.query(CheckInRecord).filter( today_failure = (
CheckInRecord.check_in_time >= today_start, db.query(CheckInRecord)
CheckInRecord.status == "failure" .filter(CheckInRecord.check_in_time >= today_start, CheckInRecord.status == "failure")
).count() .count()
)
# 今日时间范围外打卡数 # 今日时间范围外打卡数
today_out_of_time = db.query(CheckInRecord).filter( today_out_of_time = (
CheckInRecord.check_in_time >= today_start, db.query(CheckInRecord)
CheckInRecord.status == "out_of_time" .filter(
).count() CheckInRecord.check_in_time >= today_start, CheckInRecord.status == "out_of_time"
)
.count()
)
# 今日异常打卡数 # 今日异常打卡数
today_unknown = db.query(CheckInRecord).filter( today_unknown = (
CheckInRecord.check_in_time >= today_start, db.query(CheckInRecord)
CheckInRecord.status == "unknown" .filter(CheckInRecord.check_in_time >= today_start, CheckInRecord.status == "unknown")
).count() .count()
)
# Token 即将过期的用户数(7天内) # Token 即将过期的用户数(7天内)
# 使用 SQL 直接查询,避免 N+1 问题 # 使用 SQL 直接查询,避免 N+1 问题
@@ -198,28 +197,32 @@ async def get_system_stats(
# 条件:authorization 不为空、jwt_exp 不为 "0"、且在未来 7 天内过期 # 条件:authorization 不为空、jwt_exp 不为 "0"、且在未来 7 天内过期
from sqlalchemy import cast, Integer, and_ from sqlalchemy import cast, Integer, and_
expiring_users = db.query(User).filter( expiring_users = (
db.query(User)
.filter(
and_( and_(
User.authorization.isnot(None), User.authorization.isnot(None),
User.authorization != "", User.authorization != "",
User.jwt_exp.isnot(None), User.jwt_exp.isnot(None),
User.jwt_exp != "0", User.jwt_exp != "0",
cast(User.jwt_exp, Integer) > current_timestamp, # 未过期 cast(User.jwt_exp, Integer) > current_timestamp, # 未过期
cast(User.jwt_exp, Integer) < expiring_soon_timestamp # 7天内过期 cast(User.jwt_exp, Integer) < expiring_soon_timestamp, # 7天内过期
)
)
.count()
) )
).count()
return { return {
"users": { "users": {
"total": total_users, "total": total_users,
"admin": admin_users, "admin": admin_users,
"regular": total_users - admin_users, "regular": total_users - admin_users,
"active": approved_users # 使用已审批用户数 "active": approved_users, # 使用已审批用户数
}, },
"tasks": { "tasks": {
"total": total_tasks, "total": total_tasks,
"active": active_tasks, "active": active_tasks,
"inactive": total_tasks - active_tasks "inactive": total_tasks - active_tasks,
}, },
"check_in_records": { "check_in_records": {
"total": total_records, "total": total_records,
@@ -227,24 +230,20 @@ async def get_system_stats(
"today_success": today_success, "today_success": today_success,
"today_failure": today_failure, "today_failure": today_failure,
"today_out_of_time": today_out_of_time, "today_out_of_time": today_out_of_time,
"today_unknown": today_unknown "today_unknown": today_unknown,
}, },
"tokens": { "tokens": {"expiring_soon": expiring_users},
"expiring_soon": expiring_users
}
} }
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}"
detail=f"获取统计失败: {str(e)}"
) )
@router.get("/users/pending", response_model=List[UserResponse], summary="获取待审批用户") @router.get("/users/pending", response_model=List[UserResponse], summary="获取待审批用户")
async def get_pending_users( async def get_pending_users(
db: Session = Depends(get_db), db: Session = Depends(get_db), current_user: User = Depends(get_current_admin_user)
current_user: User = Depends(get_current_admin_user)
): ):
""" """
获取所有待审批的用户(需要管理员权限) 获取所有待审批的用户(需要管理员权限)
@@ -255,7 +254,7 @@ async def get_pending_users(
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取待审批用户失败: {str(e)}" detail=f"获取待审批用户失败: {str(e)}",
) )
@@ -263,7 +262,7 @@ async def get_pending_users(
async def approve_user( async def approve_user(
user_id: int, user_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
审批通过指定用户(需要管理员权限) 审批通过指定用户(需要管理员权限)
@@ -272,18 +271,14 @@ async def approve_user(
result = AdminService.approve_user(user_id, db) result = AdminService.approve_user(user_id, db)
if not result["success"]: if not result["success"]:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"])
status_code=status.HTTP_400_BAD_REQUEST,
detail=result["message"]
)
return result return result
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"审批用户失败: {str(e)}"
detail=f"审批用户失败: {str(e)}"
) )
@@ -291,7 +286,7 @@ async def approve_user(
async def reject_user( async def reject_user(
user_id: int, user_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
拒绝并删除指定用户(需要管理员权限) 拒绝并删除指定用户(需要管理员权限)
@@ -300,16 +295,12 @@ async def reject_user(
result = AdminService.reject_user(user_id, db) result = AdminService.reject_user(user_id, db)
if not result["success"]: if not result["success"]:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"])
status_code=status.HTTP_400_BAD_REQUEST,
detail=result["message"]
)
return result return result
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"拒绝用户失败: {str(e)}"
detail=f"拒绝用户失败: {str(e)}"
) )
+12 -28
View File
@@ -21,10 +21,7 @@ router = APIRouter()
@router.post("/request_qrcode", response_model=dict, summary="请求 QQ 扫码二维码") @router.post("/request_qrcode", response_model=dict, summary="请求 QQ 扫码二维码")
@limiter.limit("10/minute") # 每分钟最多10次请求 @limiter.limit("10/minute") # 每分钟最多10次请求
async def request_qrcode( async def request_qrcode(
request_obj: QRCodeRequest, request_obj: QRCodeRequest, request: Request, response: Response, db: Session = Depends(get_db)
request: Request,
response: Response,
db: Session = Depends(get_db)
): ):
""" """
请求 QQ 扫码二维码 请求 QQ 扫码二维码
@@ -44,7 +41,7 @@ async def request_qrcode(
raise BusinessLogicError( raise BusinessLogicError(
message="注册过于频繁,请 10 分钟后再试", message="注册过于频繁,请 10 分钟后再试",
error_code="RATE_LIMIT_EXCEEDED", error_code="RATE_LIMIT_EXCEEDED",
status_code=429 status_code=429,
) )
else: else:
# 生成新的 Cookie # 生成新的 Cookie
@@ -67,22 +64,18 @@ async def request_qrcode(
value=reg_cookie, value=reg_cookie,
max_age=600, # 10 分钟 max_age=600, # 10 分钟
httponly=True, httponly=True,
samesite="lax" samesite="lax",
) )
return result return result
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建扫码会话失败: {str(e)}"
detail=f"创建扫码会话失败: {str(e)}"
) )
@router.get("/qrcode_status/{session_id}", response_model=dict, summary="检查二维码扫描状态") @router.get("/qrcode_status/{session_id}", response_model=dict, summary="检查二维码扫描状态")
async def get_qrcode_status( async def get_qrcode_status(session_id: str, db: Session = Depends(get_db)):
session_id: str,
db: Session = Depends(get_db)
):
""" """
检查二维码扫描状态 检查二维码扫描状态
@@ -104,15 +97,12 @@ async def get_qrcode_status(
return result return result
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"查询扫码状态失败: {str(e)}"
detail=f"查询扫码状态失败: {str(e)}"
) )
@router.delete("/qrcode_session/{session_id}", response_model=dict, summary="取消二维码登录会话") @router.delete("/qrcode_session/{session_id}", response_model=dict, summary="取消二维码登录会话")
async def cancel_qrcode_session( async def cancel_qrcode_session(session_id: str):
session_id: str
):
""" """
取消二维码登录会话 取消二维码登录会话
@@ -125,16 +115,12 @@ async def cancel_qrcode_session(
return result return result
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"取消会话失败: {str(e)}"
detail=f"取消会话失败: {str(e)}"
) )
@router.post("/verify_token", response_model=dict, summary="验证 JWT Token 有效性") @router.post("/verify_token", response_model=dict, summary="验证 JWT Token 有效性")
async def verify_token( async def verify_token(request: TokenVerifyRequest, db: Session = Depends(get_db)):
request: TokenVerifyRequest,
db: Session = Depends(get_db)
):
""" """
验证 JWT Token 有效性(网站登录认证) 验证 JWT Token 有效性(网站登录认证)
@@ -152,8 +138,7 @@ async def verify_token(
return result return result
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"验证 Token 失败: {str(e)}"
detail=f"验证 Token 失败: {str(e)}"
) )
@@ -162,7 +147,7 @@ async def verify_token(
async def alias_login( async def alias_login(
login_data: AliasLoginRequest, login_data: AliasLoginRequest,
request: Request, # slowapi需要的request参数 request: Request, # slowapi需要的request参数
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
别名+密码登录(仅限已设置密码的用户) 别名+密码登录(仅限已设置密码的用户)
@@ -187,6 +172,5 @@ async def alias_login(
return result return result
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"别名登录失败: {str(e)}"
detail=f"别名登录失败: {str(e)}"
) )
+49 -64
View File
@@ -18,9 +18,7 @@ router = APIRouter()
@router.post("/manual/{task_id}", summary="手动触发打卡(异步)") @router.post("/manual/{task_id}", summary="手动触发打卡(异步)")
async def manual_check_in( async def manual_check_in(
task_id: int, task_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
): ):
""" """
手动触发指定任务的打卡(异步方式,立即返回) 手动触发指定任务的打卡(异步方式,立即返回)
@@ -31,33 +29,24 @@ async def manual_check_in(
""" """
# 验证任务归属 # 验证任务归属
if not TaskService.verify_task_ownership(task_id, current_user.id, db): if not TaskService.verify_task_ownership(task_id, current_user.id, db):
raise HTTPException( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此任务")
status_code=status.HTTP_403_FORBIDDEN,
detail="无权访问此任务"
)
task = TaskService.get_task(task_id, db) task = TaskService.get_task(task_id, db)
if not task: if not task:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="任务不存在")
status_code=status.HTTP_404_NOT_FOUND,
detail="任务不存在"
)
try: try:
result = CheckInService.start_async_check_in(task, "manual", db) result = CheckInService.start_async_check_in(task, "manual", db)
return result return result
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"启动打卡任务失败: {str(e)}"
detail=f"启动打卡任务失败: {str(e)}"
) )
@router.get("/record/{record_id}/status", summary="查询打卡记录状态") @router.get("/record/{record_id}/status", summary="查询打卡记录状态")
async def get_check_in_record_status( async def get_check_in_record_status(
record_id: int, record_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
): ):
""" """
查询指定打卡记录的状态 查询指定打卡记录的状态
@@ -73,10 +62,7 @@ async def get_check_in_record_status(
# 验证记录归属(通过任务归属) # 验证记录归属(通过任务归属)
if not TaskService.verify_task_ownership(record.task_id, current_user.id, db): if not TaskService.verify_task_ownership(record.task_id, current_user.id, db):
raise HTTPException( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此记录")
status_code=status.HTTP_403_FORBIDDEN,
detail="无权访问此记录"
)
return { return {
"record_id": record.id, "record_id": record.id,
@@ -85,19 +71,25 @@ async def get_check_in_record_status(
"response_text": record.response_text, "response_text": record.response_text,
"error_message": record.error_message, "error_message": record.error_message,
"trigger_type": record.trigger_type, "trigger_type": record.trigger_type,
"check_in_time": record.check_in_time "check_in_time": record.check_in_time,
} }
@router.get("/task/{task_id}/records", response_model=PaginatedResponse[CheckInRecordResponse], summary="查看任务的打卡记录") @router.get(
"/task/{task_id}/records",
response_model=PaginatedResponse[CheckInRecordResponse],
summary="查看任务的打卡记录",
)
async def get_task_check_in_records( async def get_task_check_in_records(
task_id: int, task_id: int,
skip: int = Query(0, ge=0, description="跳过记录数"), skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=500, description="限制记录数"), limit: int = Query(100, ge=1, le=500, description="限制记录数"),
status_filter: Optional[str] = Query(None, alias="status", description="过滤状态 (success/failure)"), status_filter: Optional[str] = Query(
None, alias="status", description="过滤状态 (success/failure)"
),
trigger_type: Optional[str] = Query(None, description="过滤触发类型 (scheduler/manual)"), trigger_type: Optional[str] = Query(None, description="过滤触发类型 (scheduler/manual)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
""" """
查看指定任务的打卡记录 查看指定任务的打卡记录
@@ -112,36 +104,33 @@ async def get_task_check_in_records(
""" """
# 验证任务归属 # 验证任务归属
if not TaskService.verify_task_ownership(task_id, current_user.id, db): if not TaskService.verify_task_ownership(task_id, current_user.id, db):
raise HTTPException( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此任务")
status_code=status.HTTP_403_FORBIDDEN,
detail="无权访问此任务"
)
try: try:
records, total = CheckInService.get_task_records( records, total = CheckInService.get_task_records(
task_id, db, skip, limit, status_filter, trigger_type task_id, db, skip, limit, status_filter, trigger_type
) )
return PaginatedResponse( return PaginatedResponse(records=records, total=total, skip=skip, limit=limit)
records=records,
total=total,
skip=skip,
limit=limit
)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
detail=f"获取打卡记录失败: {str(e)}"
) )
@router.get("/my-records", response_model=PaginatedResponse[CheckInRecordResponse], summary="查看当前用户的所有打卡记录") @router.get(
"/my-records",
response_model=PaginatedResponse[CheckInRecordResponse],
summary="查看当前用户的所有打卡记录",
)
async def get_my_check_in_records( async def get_my_check_in_records(
skip: int = Query(0, ge=0, description="跳过记录数"), skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=500, description="限制记录数"), limit: int = Query(100, ge=1, le=500, description="限制记录数"),
status_filter: Optional[str] = Query(None, alias="status", description="过滤状态 (success/failure)"), status_filter: Optional[str] = Query(
None, alias="status", description="过滤状态 (success/failure)"
),
trigger_type: Optional[str] = Query(None, description="过滤触发类型 (scheduler/manual)"), trigger_type: Optional[str] = Query(None, description="过滤触发类型 (scheduler/manual)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
""" """
查看当前用户所有任务的打卡记录 查看当前用户所有任务的打卡记录
@@ -155,28 +144,27 @@ async def get_my_check_in_records(
records, total = CheckInService.get_user_records( records, total = CheckInService.get_user_records(
current_user.id, db, skip, limit, status_filter, trigger_type current_user.id, db, skip, limit, status_filter, trigger_type
) )
return PaginatedResponse( return PaginatedResponse(records=records, total=total, skip=skip, limit=limit)
records=records,
total=total,
skip=skip,
limit=limit
)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
detail=f"获取打卡记录失败: {str(e)}"
) )
@router.get(
@router.get("/records", response_model=PaginatedResponse[CheckInRecordResponse], summary="查看所有打卡记录(管理员)") "/records",
response_model=PaginatedResponse[CheckInRecordResponse],
summary="查看所有打卡记录(管理员)",
)
async def get_all_check_in_records( async def get_all_check_in_records(
skip: int = Query(0, ge=0, description="跳过记录数"), skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=500, description="限制记录数"), limit: int = Query(100, ge=1, le=500, description="限制记录数"),
task_id: Optional[int] = Query(None, description="过滤任务 ID"), task_id: Optional[int] = Query(None, description="过滤任务 ID"),
status_filter: Optional[str] = Query(None, alias="status", description="过滤状态 (success/failure)"), status_filter: Optional[str] = Query(
None, alias="status", description="过滤状态 (success/failure)"
),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
查看所有打卡记录(需要管理员权限) 查看所有打卡记录(需要管理员权限)
@@ -189,26 +177,24 @@ async def get_all_check_in_records(
try: try:
records, total = CheckInService.get_all_records(db, skip, limit, task_id, status_filter) records, total = CheckInService.get_all_records(db, skip, limit, task_id, status_filter)
# 为每条记录添加用户和任务信息 # 为每条记录添加用户和任务信息
enriched_records = [CheckInService.enrich_record_with_user_task_info(record, db) for record in records] enriched_records = [
return PaginatedResponse( CheckInService.enrich_record_with_user_task_info(record, db) for record in records
records=enriched_records, ]
total=total, return PaginatedResponse(records=enriched_records, total=total, skip=skip, limit=limit)
skip=skip,
limit=limit
)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
detail=f"获取打卡记录失败: {str(e)}"
) )
@router.get("/records/count", summary="获取打卡记录统计(管理员)") @router.get("/records/count", summary="获取打卡记录统计(管理员)")
async def get_check_in_records_count( async def get_check_in_records_count(
task_id: Optional[int] = Query(None, description="过滤任务 ID"), task_id: Optional[int] = Query(None, description="过滤任务 ID"),
status_filter: Optional[str] = Query(None, alias="status", description="过滤状态 (success/failure)"), status_filter: Optional[str] = Query(
None, alias="status", description="过滤状态 (success/failure)"
),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
获取打卡记录统计(需要管理员权限) 获取打卡记录统计(需要管理员权限)
@@ -229,6 +215,5 @@ async def get_check_in_records_count(
return {"total": total} return {"total": total}
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}"
detail=f"获取统计失败: {str(e)}"
) )
+19 -34
View File
@@ -14,15 +14,18 @@ router = APIRouter()
class CronValidateRequest(BaseModel): class CronValidateRequest(BaseModel):
"""Cron 表达式验证请求""" """Cron 表达式验证请求"""
cron_expression: str = Field(..., min_length=9, description="Crontab 表达式") cron_expression: str = Field(..., min_length=9, description="Crontab 表达式")
# create_task_from_template: 已在 templates.py 中定义 # create_task_from_template: 已在 templates.py 中定义
@router.get("/", response_model=List[TaskResponse], summary="获取当前用户的任务列表") @router.get("/", response_model=List[TaskResponse], summary="获取当前用户的任务列表")
async def get_tasks( async def get_tasks(
include_inactive: bool = True, include_inactive: bool = True,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
获取当前用户的所有打卡任务 获取当前用户的所有打卡任务
@@ -36,16 +39,13 @@ async def get_tasks(
return enriched_tasks return enriched_tasks
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}"
detail=f"获取任务列表失败: {str(e)}"
) )
@router.get("/{task_id}", response_model=TaskResponse, summary="获取任务详情") @router.get("/{task_id}", response_model=TaskResponse, summary="获取任务详情")
async def get_task( async def get_task(
task_id: int, task_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
): ):
""" """
获取指定任务的详情 获取指定任务的详情
@@ -66,7 +66,7 @@ async def update_task(
task_id: int, task_id: int,
task_data: TaskUpdate, task_data: TaskUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
""" """
更新指定任务的信息 更新指定任务的信息
@@ -82,19 +82,14 @@ async def update_task(
task = TaskService.update_task(task_id, task_data, db) task = TaskService.update_task(task_id, task_data, db)
if not task: if not task:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="任务不存在")
status_code=status.HTTP_404_NOT_FOUND,
detail="任务不存在"
)
return task return task
@router.delete("/{task_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除任务") @router.delete("/{task_id}", status_code=status.HTTP_204_NO_CONTENT, summary="删除任务")
async def delete_task( async def delete_task(
task_id: int, task_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
): ):
""" """
删除指定任务 删除指定任务
@@ -110,17 +105,12 @@ async def delete_task(
success = TaskService.delete_task(task_id, db) success = TaskService.delete_task(task_id, db)
if not success: if not success:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="任务不存在")
status_code=status.HTTP_404_NOT_FOUND,
detail="任务不存在"
)
@router.post("/{task_id}/toggle", response_model=TaskResponse, summary="切换任务启用状态") @router.post("/{task_id}/toggle", response_model=TaskResponse, summary="切换任务启用状态")
async def toggle_task( async def toggle_task(
task_id: int, task_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
): ):
""" """
切换任务的启用/禁用状态 切换任务的启用/禁用状态
@@ -136,10 +126,7 @@ async def toggle_task(
task = TaskService.toggle_task(task_id, db) task = TaskService.toggle_task(task_id, db)
if not task: if not task:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="任务不存在")
status_code=status.HTTP_404_NOT_FOUND,
detail="任务不存在"
)
return task return task
@@ -167,8 +154,7 @@ async def validate_cron_expression(request: CronValidateRequest):
if not cron_expr: if not cron_expr:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST, detail="cron_expression 是必需的"
detail="cron_expression 是必需的"
) )
try: try:
@@ -179,18 +165,17 @@ async def validate_cron_expression(request: CronValidateRequest):
# 生成接下来的 5 个执行时间 # 生成接下来的 5 个执行时间
cron = croniter(cron_expr, datetime.now()) cron = croniter(cron_expr, datetime.now())
next_times = [cron.get_next(datetime).strftime('%Y-%m-%d %H:%M:%S') for _ in range(5)] next_times = [cron.get_next(datetime).strftime("%Y-%m-%d %H:%M:%S") for _ in range(5)]
return { return {
"valid": True, "valid": True,
"message": "有效的 Crontab 表达式", "message": "有效的 Crontab 表达式",
"next_times": next_times, "next_times": next_times,
"description": generate_cron_description(cron_expr) "description": generate_cron_description(cron_expr),
} }
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的 Crontab 表达式: {str(e)}"
detail=f"无效的 Crontab 表达式: {str(e)}"
) )
@@ -203,11 +188,11 @@ def generate_cron_description(cron_expr: str) -> str:
minute, hour, day, month, dow = parts minute, hour, day, month, dow = parts
descriptions = [] descriptions = []
if hour == '*' and minute == '*': if hour == "*" and minute == "*":
descriptions.append("每分钟") descriptions.append("每分钟")
elif hour == '*': elif hour == "*":
descriptions.append(f"每小时的第 {minute} 分钟") descriptions.append(f"每小时的第 {minute} 分钟")
elif day == '*' and month == '*' and dow == '*': elif day == "*" and month == "*" and dow == "*":
descriptions.append(f"每天 {hour}:{minute:0>2}") descriptions.append(f"每天 {hour}:{minute:0>2}")
else: else:
descriptions.append(f"复杂的时间表: {cron_expr}") descriptions.append(f"复杂的时间表: {cron_expr}")
+23 -38
View File
@@ -9,7 +9,7 @@ from backend.schemas.template import (
TemplateUpdate, TemplateUpdate,
TemplateResponse, TemplateResponse,
TaskFromTemplateRequest, TaskFromTemplateRequest,
TemplatePreviewResponse TemplatePreviewResponse,
) )
from backend.schemas.task import TaskResponse from backend.schemas.task import TaskResponse
from backend.services.template_service import TemplateService from backend.services.template_service import TemplateService
@@ -23,7 +23,7 @@ async def get_all_templates(
limit: int = Query(100, ge=1, le=500, description="限制记录数"), limit: int = Query(100, ge=1, le=500, description="限制记录数"),
is_active: Optional[bool] = Query(None, description="过滤启用状态"), is_active: Optional[bool] = Query(None, description="过滤启用状态"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
""" """
获取所有模板列表(普通用户可访问) 获取所有模板列表(普通用户可访问)
@@ -37,8 +37,7 @@ async def get_all_templates(
return templates return templates
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}"
detail=f"获取模板列表失败: {str(e)}"
) )
@@ -47,7 +46,7 @@ async def get_active_templates(
skip: int = Query(0, ge=0, description="跳过记录数"), skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=500, description="限制记录数"), limit: int = Query(100, ge=1, le=500, description="限制记录数"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
""" """
获取所有启用的模板(用户创建任务时使用) 获取所有启用的模板(用户创建任务时使用)
@@ -60,16 +59,13 @@ async def get_active_templates(
return templates return templates
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}"
detail=f"获取模板列表失败: {str(e)}"
) )
@router.get("/{template_id}", response_model=TemplateResponse, summary="获取单个模板详情") @router.get("/{template_id}", response_model=TemplateResponse, summary="获取单个模板详情")
async def get_template( async def get_template(
template_id: int, template_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
): ):
""" """
获取单个模板的详细信息(普通用户只能访问启用的模板) 获取单个模板的详细信息(普通用户只能访问启用的模板)
@@ -78,26 +74,22 @@ async def get_template(
""" """
template = TemplateService.get_template(template_id, db) template = TemplateService.get_template(template_id, db)
if not template: if not template:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模板不存在")
status_code=status.HTTP_404_NOT_FOUND,
detail="模板不存在"
)
# 普通用户只能访问启用的模板 # 普通用户只能访问启用的模板
if not current_user.is_admin and template.is_active is not True: if not current_user.is_admin and template.is_active is not True:
raise HTTPException( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此模板")
status_code=status.HTTP_403_FORBIDDEN,
detail="无权访问此模板"
)
return template return template
@router.get("/{template_id}/preview", response_model=TemplatePreviewResponse, summary="预览模板生成的 payload") @router.get(
"/{template_id}/preview",
response_model=TemplatePreviewResponse,
summary="预览模板生成的 payload",
)
async def preview_template( async def preview_template(
template_id: int, template_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
): ):
""" """
预览模板生成的 payload(使用默认值,普通用户只能访问启用的模板) 预览模板生成的 payload(使用默认值,普通用户只能访问启用的模板)
@@ -106,17 +98,11 @@ async def preview_template(
""" """
template = TemplateService.get_template(template_id, db) template = TemplateService.get_template(template_id, db)
if not template: if not template:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模板不存在")
status_code=status.HTTP_404_NOT_FOUND,
detail="模板不存在"
)
# 普通用户只能访问启用的模板 # 普通用户只能访问启用的模板
if not current_user.is_admin and template.is_active is not True: if not current_user.is_admin and template.is_active is not True:
raise HTTPException( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此模板")
status_code=status.HTTP_403_FORBIDDEN,
detail="无权访问此模板"
)
try: try:
preview_payload = TemplateService.generate_preview_payload(template, db) preview_payload = TemplateService.generate_preview_payload(template, db)
@@ -127,12 +113,11 @@ async def preview_template(
"template_id": template.id, "template_id": template.id,
"template_name": template.name, "template_name": template.name,
"preview_payload": preview_payload, "preview_payload": preview_payload,
"field_config": merged_config "field_config": merged_config,
} }
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"生成预览失败: {str(e)}"
detail=f"生成预览失败: {str(e)}"
) )
@@ -140,7 +125,7 @@ async def preview_template(
async def create_template( async def create_template(
template_data: TemplateCreate, template_data: TemplateCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
创建新的打卡任务模板(仅管理员) 创建新的打卡任务模板(仅管理员)
@@ -158,7 +143,7 @@ async def update_template(
template_id: int, template_id: int,
template_data: TemplateUpdate, template_data: TemplateUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
更新模板信息(仅管理员) 更新模板信息(仅管理员)
@@ -176,7 +161,7 @@ async def update_template(
async def delete_template( async def delete_template(
template_id: int, template_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
删除模板(仅管理员) 删除模板(仅管理员)
@@ -191,7 +176,7 @@ async def delete_template(
async def create_task_from_template( async def create_task_from_template(
request: TaskFromTemplateRequest, request: TaskFromTemplateRequest,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
""" """
从模板创建打卡任务 从模板创建打卡任务
@@ -209,6 +194,6 @@ async def create_task_from_template(
user_id=current_user.id, user_id=current_user.id,
task_name=request.task_name, task_name=request.task_name,
db=db, db=db,
cron_expression=request.cron_expression cron_expression=request.cron_expression,
) )
return task return task
+36 -40
View File
@@ -3,7 +3,13 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from backend.models import get_db, User from backend.models import get_db, User
from backend.schemas.user import UserCreate, UserUpdate, UserResponse, TokenStatus, UserUpdateProfile from backend.schemas.user import (
UserCreate,
UserUpdate,
UserResponse,
TokenStatus,
UserUpdateProfile,
)
from backend.schemas.task import TaskResponse from backend.schemas.task import TaskResponse
from backend.services.user_service import UserService from backend.services.user_service import UserService
from backend.services.task_service import TaskService from backend.services.task_service import TaskService
@@ -13,11 +19,16 @@ from backend.exceptions import ValidationError, AuthorizationError, ResourceNotF
router = APIRouter() router = APIRouter()
@router.post("", response_model=UserResponse, status_code=status.HTTP_201_CREATED, summary="创建用户(管理员)") @router.post(
"",
response_model=UserResponse,
status_code=status.HTTP_201_CREATED,
summary="创建用户(管理员)",
)
async def create_user( async def create_user(
user_data: UserCreate, user_data: UserCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
创建用户(需要管理员权限) 创建用户(需要管理员权限)
@@ -33,15 +44,12 @@ async def create_user(
raise ValidationError(str(e)) raise ValidationError(str(e))
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建用户失败: {str(e)}"
detail=f"创建用户失败: {str(e)}"
) )
@router.get("/me", response_model=UserResponse, summary="获取当前用户信息") @router.get("/me", response_model=UserResponse, summary="获取当前用户信息")
async def get_current_user_info( async def get_current_user_info(current_user: User = Depends(get_current_user)):
current_user: User = Depends(get_current_user)
):
""" """
获取当前登录用户的信息 获取当前登录用户的信息
""" """
@@ -61,9 +69,7 @@ async def get_current_user_info(
@router.get("/me/status", response_model=dict, summary="获取当前用户审批状态") @router.get("/me/status", response_model=dict, summary="获取当前用户审批状态")
async def get_user_status( async def get_user_status(current_user: User = Depends(get_current_user)):
current_user: User = Depends(get_current_user)
):
""" """
获取用户审批状态(不要求审批通过) 获取用户审批状态(不要求审批通过)
""" """
@@ -71,7 +77,7 @@ async def get_user_status(
"user_id": current_user.id, "user_id": current_user.id,
"alias": current_user.alias, "alias": current_user.alias,
"is_approved": current_user.is_approved, "is_approved": current_user.is_approved,
"created_at": current_user.created_at.isoformat() if current_user.created_at else None "created_at": current_user.created_at.isoformat() if current_user.created_at else None,
} }
@@ -79,7 +85,7 @@ async def get_user_status(
async def update_current_user_profile( async def update_current_user_profile(
profile_data: UserUpdateProfile, profile_data: UserUpdateProfile,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
""" """
更新当前用户的个人信息 更新当前用户的个人信息
@@ -99,15 +105,12 @@ async def update_current_user_profile(
raise ValidationError(str(e)) raise ValidationError(str(e))
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新个人信息失败: {str(e)}"
detail=f"更新个人信息失败: {str(e)}"
) )
@router.get("/me/token_status", response_model=TokenStatus, summary="获取当前用户打卡 Token 状态") @router.get("/me/token_status", response_model=TokenStatus, summary="获取当前用户打卡 Token 状态")
async def get_current_user_token_status( async def get_current_user_token_status(current_user: User = Depends(get_current_user)):
current_user: User = Depends(get_current_user)
):
""" """
获取当前用户的打卡 Token 状态(authorization token,非 JWT 获取当前用户的打卡 Token 状态(authorization token,非 JWT
@@ -123,7 +126,7 @@ async def get_current_user_token_status(
"jwt_exp": current_user.jwt_exp, "jwt_exp": current_user.jwt_exp,
"expires_at": result.get("expires_at"), "expires_at": result.get("expires_at"),
"days_until_expiry": result.get("days_remaining"), "days_until_expiry": result.get("days_remaining"),
"expiring_soon": result.get("expiring_soon", False) "expiring_soon": result.get("expiring_soon", False),
} }
@@ -131,7 +134,7 @@ async def get_current_user_token_status(
async def get_current_user_tasks( async def get_current_user_tasks(
include_inactive: bool = Query(True, description="是否包含未启用的任务"), include_inactive: bool = Query(True, description="是否包含未启用的任务"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
""" """
获取当前登录用户的所有打卡任务 获取当前登录用户的所有打卡任务
@@ -143,8 +146,7 @@ async def get_current_user_tasks(
return tasks return tasks
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}"
detail=f"获取任务列表失败: {str(e)}"
) )
@@ -155,7 +157,7 @@ async def get_all_users(
search: Optional[str] = Query(None, description="搜索关键词(alias"), search: Optional[str] = Query(None, description="搜索关键词(alias"),
role: Optional[str] = Query(None, description="过滤角色 (user/admin)"), role: Optional[str] = Query(None, description="过滤角色 (user/admin)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
获取所有用户列表(需要管理员权限) 获取所有用户列表(需要管理员权限)
@@ -170,16 +172,13 @@ async def get_all_users(
return users return users
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取用户列表失败: {str(e)}"
detail=f"获取用户列表失败: {str(e)}"
) )
@router.get("/{user_id}", response_model=UserResponse, summary="获取指定用户") @router.get("/{user_id}", response_model=UserResponse, summary="获取指定用户")
async def get_user( async def get_user(
user_id: int, user_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
): ):
""" """
获取指定用户信息 获取指定用户信息
@@ -203,7 +202,7 @@ async def update_user(
user_id: int, user_id: int,
user_data: UserUpdate, user_data: UserUpdate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
): ):
""" """
更新用户信息 更新用户信息
@@ -236,28 +235,26 @@ async def update_user(
new_approved_value = user.is_approved new_approved_value = user.is_approved
is_approved_now = True if new_approved_value else False is_approved_now = True if new_approved_value else False
is_admin = (current_user.role == "admin") is_admin = current_user.role == "admin"
needs_notification = (is_admin and (not was_approved_before) and is_approved_now) needs_notification = is_admin and (not was_approved_before) and is_approved_now
if needs_notification: if needs_notification:
try: try:
from backend.services.email_service import EmailService from backend.services.email_service import EmailService
EmailService.notify_user_approved(user) EmailService.notify_user_approved(user)
except Exception as e: except Exception as e:
# 邮件发送失败不影响审批操作 # 邮件发送失败不影响审批操作
import logging import logging
logging.getLogger(__name__).error(f"发送审批通过邮件失败: {e}") logging.getLogger(__name__).error(f"发送审批通过邮件失败: {e}")
return user return user
except ValueError as e: except ValueError as e:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新用户失败: {str(e)}"
detail=f"更新用户失败: {str(e)}"
) )
@@ -265,7 +262,7 @@ async def update_user(
async def delete_user( async def delete_user(
user_id: int, user_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_admin_user) current_user: User = Depends(get_current_admin_user),
): ):
""" """
删除用户(需要管理员权限) 删除用户(需要管理员权限)
@@ -277,6 +274,5 @@ async def delete_user(
raise ResourceNotFoundError(str(e)) raise ResourceNotFoundError(str(e))
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"删除用户失败: {str(e)}"
detail=f"删除用户失败: {str(e)}"
) )
+2 -2
View File
@@ -11,9 +11,9 @@ class Settings(BaseSettings):
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=str(BASE_DIR / ".env"), env_file=str(BASE_DIR / ".env"),
env_file_encoding='utf-8', env_file_encoding="utf-8",
case_sensitive=True, case_sensitive=True,
extra='ignore' extra="ignore",
) )
# 项目根目录 # 项目根目录
+11 -14
View File
@@ -11,8 +11,7 @@ logger = logging.getLogger(__name__)
async def get_current_user( async def get_current_user(
authorization: Optional[str] = Header(None), authorization: Optional[str] = Header(None), db: Session = Depends(get_db)
db: Session = Depends(get_db)
) -> User: ) -> User:
""" """
获取当前用户(使用 JWT 认证) 获取当前用户(使用 JWT 认证)
@@ -30,7 +29,11 @@ async def get_current_user(
) )
# 移除 "Bearer " 前缀(如果存在) # 移除 "Bearer " 前缀(如果存在)
token = authorization.replace("Bearer ", "") if authorization.startswith("Bearer ") else authorization token = (
authorization.replace("Bearer ", "")
if authorization.startswith("Bearer ")
else authorization
)
try: try:
# 验证 JWT token # 验证 JWT token
@@ -77,39 +80,33 @@ async def get_current_user(
) )
async def require_approved_user( async def require_approved_user(current_user: User = Depends(get_current_user)) -> User:
current_user: User = Depends(get_current_user)
) -> User:
""" """
要求用户已通过审批 要求用户已通过审批
""" """
if not current_user.is_approved: if not current_user.is_approved:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="您的账户正在等待管理员审批,请耐心等待(24小时内)" detail="您的账户正在等待管理员审批,请耐心等待(24小时内)",
) )
return current_user return current_user
async def get_current_admin_user( async def get_current_admin_user(current_user: User = Depends(require_approved_user)) -> User:
current_user: User = Depends(require_approved_user)
) -> User:
""" """
获取当前管理员用户 获取当前管理员用户
验证用户是否具有管理员权限 验证用户是否具有管理员权限
""" """
if current_user.role != "admin": if current_user.role != "admin":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN, detail="权限不足,需要管理员权限"
detail="权限不足,需要管理员权限"
) )
return current_user return current_user
async def get_optional_user( async def get_optional_user(
authorization: Optional[str] = Header(None), authorization: Optional[str] = Header(None), db: Session = Depends(get_db)
db: Session = Depends(get_db)
) -> Optional[User]: ) -> Optional[User]:
""" """
可选的用户认证 可选的用户认证
+1
View File
@@ -3,6 +3,7 @@
支持Cloudflare Tunnel和其他代理服务 支持Cloudflare Tunnel和其他代理服务
""" """
from slowapi import Limiter from slowapi import Limiter
from fastapi import Request from fastapi import Request
+10 -16
View File
@@ -44,6 +44,7 @@ async def lifespan(app: FastAPI):
# 启动调度器 # 启动调度器
logger.info("正在启动调度器...") logger.info("正在启动调度器...")
from backend.services.scheduler_service import start_scheduler from backend.services.scheduler_service import start_scheduler
start_scheduler() start_scheduler()
logger.info(f"CheckIn API 服务已启动,版本: {settings.VERSION}") logger.info(f"CheckIn API 服务已启动,版本: {settings.VERSION}")
@@ -53,6 +54,7 @@ async def lifespan(app: FastAPI):
# 关闭时执行 # 关闭时执行
logger.info("正在关闭 CheckIn API 服务...") logger.info("正在关闭 CheckIn API 服务...")
from backend.services.scheduler_service import stop_scheduler from backend.services.scheduler_service import stop_scheduler
stop_scheduler() stop_scheduler()
logger.info("CheckIn API 服务已关闭") logger.info("CheckIn API 服务已关闭")
@@ -85,11 +87,8 @@ async def api_exception_handler(request: Request, exc: BaseAPIException):
return JSONResponse( return JSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
content=ErrorResponse( content=ErrorResponse(
error=ErrorDetail( error=ErrorDetail(code=exc.error_code, message=exc.message)
code=exc.error_code, ).model_dump(),
message=exc.message
)
).model_dump()
) )
@@ -105,12 +104,8 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
return JSONResponse( return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=ErrorResponse( content=ErrorResponse(
error=ErrorDetail( error=ErrorDetail(code="VALIDATION_ERROR", message=message, field=field or None)
code="VALIDATION_ERROR", ).model_dump(),
message=message,
field=field or None
)
).model_dump()
) )
@@ -123,11 +118,8 @@ async def general_exception_handler(request: Request, exc: Exception):
return JSONResponse( return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=ErrorResponse( content=ErrorResponse(
error=ErrorDetail( error=ErrorDetail(code="INTERNAL_ERROR", message="服务器内部错误,请稍后重试")
code="INTERNAL_ERROR", ).model_dump(),
message="服务器内部错误,请稍后重试"
)
).model_dump()
) )
@@ -156,6 +148,7 @@ async def root():
# 注册路由 # 注册路由
from backend.api import auth, users, check_in, admin, tasks, templates from backend.api import auth, users, check_in, admin, tasks, templates
app.include_router(auth.router, prefix=f"{settings.API_PREFIX}/auth", tags=["认证"]) app.include_router(auth.router, prefix=f"{settings.API_PREFIX}/auth", tags=["认证"])
app.include_router(users.router, prefix=f"{settings.API_PREFIX}/users", tags=["用户"]) app.include_router(users.router, prefix=f"{settings.API_PREFIX}/users", tags=["用户"])
app.include_router(tasks.router, prefix=f"{settings.API_PREFIX}/tasks", tags=["打卡任务"]) app.include_router(tasks.router, prefix=f"{settings.API_PREFIX}/tasks", tags=["打卡任务"])
@@ -166,6 +159,7 @@ app.include_router(templates.router, prefix=f"{settings.API_PREFIX}/templates",
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run( uvicorn.run(
"backend.main:app", "backend.main:app",
host="0.0.0.0", host="0.0.0.0",
+24 -6
View File
@@ -10,21 +10,39 @@ class CheckInRecord(Base):
__tablename__ = "check_in_records" __tablename__ = "check_in_records"
id = Column(Integer, primary_key=True, index=True, autoincrement=True) id = Column(Integer, primary_key=True, index=True, autoincrement=True)
task_id = Column(Integer, ForeignKey("check_in_tasks.id", ondelete="CASCADE"), nullable=False, index=True, comment="任务 ID") task_id = Column(
status = Column(String(20), nullable=False, index=True, comment="状态: success/failure/out_of_time/unknown/pending") Integer,
ForeignKey("check_in_tasks.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="任务 ID",
)
status = Column(
String(20),
nullable=False,
index=True,
comment="状态: success/failure/out_of_time/unknown/pending",
)
response_text = Column(Text, default="", comment="响应文本") response_text = Column(Text, default="", comment="响应文本")
error_message = Column(Text, default="", comment="错误信息") error_message = Column(Text, default="", comment="错误信息")
location = Column(Text, default="{}", comment="位置信息 JSON") location = Column(Text, default="{}", comment="位置信息 JSON")
trigger_type = Column(String(50), default="scheduled", comment="触发类型: scheduled/manual/admin") trigger_type = Column(
check_in_time = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), index=True, comment="打卡时间(UTC") String(50), default="scheduled", comment="触发类型: scheduled/manual/admin"
)
check_in_time = Column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
index=True,
comment="打卡时间(UTC",
)
# 关联任务 # 关联任务
task = relationship("CheckInTask", back_populates="check_in_records") task = relationship("CheckInTask", back_populates="check_in_records")
# 添加复合索引:加速常见查询 # 添加复合索引:加速常见查询
__table_args__ = ( __table_args__ = (
Index('ix_record_task_time', 'task_id', 'check_in_time'), # 获取任务的打卡记录(按时间排序) Index("ix_record_task_time", "task_id", "check_in_time"), # 获取任务的打卡记录(按时间排序)
Index('ix_record_status_time', 'status', 'check_in_time'), # 按状态和时间查询 Index("ix_record_status_time", "status", "check_in_time"), # 按状态和时间查询
) )
def __repr__(self): def __repr__(self):
+24 -6
View File
@@ -10,11 +10,27 @@ class CheckInTask(Base):
__tablename__ = "check_in_tasks" __tablename__ = "check_in_tasks"
id = Column(Integer, primary_key=True, index=True, autoincrement=True) id = Column(Integer, primary_key=True, index=True, autoincrement=True)
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True, comment="用户 ID") user_id = Column(
payload_config = Column(Text, default="{}", nullable=False, comment="完整的 payload 配置 JSON(从模板生成,包含 ThreadId 和所有字段)") Integer,
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="用户 ID",
)
payload_config = Column(
Text,
default="{}",
nullable=False,
comment="完整的 payload 配置 JSON(从模板生成,包含 ThreadId 和所有字段)",
)
name = Column(String(100), default="", comment="任务名称(用户自定义)") name = Column(String(100), default="", comment="任务名称(用户自定义)")
is_active = Column(Boolean, default=True, comment="是否启用自动打卡(不影响手动打卡)") is_active = Column(Boolean, default=True, comment="是否启用自动打卡(不影响手动打卡)")
cron_expression = Column(String(100), default="0 20 * * *", nullable=True, comment="Crontab 表达式(NULL 表示禁用自动打卡,否则按表达式执行)") cron_expression = Column(
String(100),
default="0 20 * * *",
nullable=True,
comment="Crontab 表达式(NULL 表示禁用自动打卡,否则按表达式执行)",
)
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间") created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime(timezone=True), onupdate=func.now(), comment="更新时间") updated_at = Column(DateTime(timezone=True), onupdate=func.now(), comment="更新时间")
@@ -22,12 +38,14 @@ class CheckInTask(Base):
user = relationship("User", back_populates="tasks") user = relationship("User", back_populates="tasks")
# 关联打卡记录 # 关联打卡记录
check_in_records = relationship("CheckInRecord", back_populates="task", cascade="all, delete-orphan") check_in_records = relationship(
"CheckInRecord", back_populates="task", cascade="all, delete-orphan"
)
# 添加索引:加速查询 # 添加索引:加速查询
__table_args__ = ( __table_args__ = (
Index('ix_task_user_active', 'user_id', 'is_active'), Index("ix_task_user_active", "user_id", "is_active"),
Index('ix_task_cron', 'cron_expression'), # 加速查询启用了定时打卡的任务 Index("ix_task_cron", "cron_expression"), # 加速查询启用了定时打卡的任务
) )
def __repr__(self): def __repr__(self):
+1 -1
View File
@@ -24,7 +24,7 @@ def receive_load(target, context):
"""在从数据库加载对象后,将所有 datetime 字段转换为 timezone-aware (UTC)""" """在从数据库加载对象后,将所有 datetime 字段转换为 timezone-aware (UTC)"""
for attr_name in dir(target): for attr_name in dir(target):
# 跳过私有属性和方法 # 跳过私有属性和方法
if attr_name.startswith('_'): if attr_name.startswith("_"):
continue continue
try: try:
+7 -1
View File
@@ -6,6 +6,7 @@ from backend.models.database import Base
class TaskTemplate(Base): class TaskTemplate(Base):
"""打卡任务模板""" """打卡任务模板"""
__tablename__ = "task_templates" __tablename__ = "task_templates"
id = Column(Integer, primary_key=True, index=True, autoincrement=True) id = Column(Integer, primary_key=True, index=True, autoincrement=True)
@@ -13,7 +14,12 @@ class TaskTemplate(Base):
description = Column(Text, nullable=True, comment="模板描述") description = Column(Text, nullable=True, comment="模板描述")
# 父模板 ID(用于继承) # 父模板 ID(用于继承)
parent_id = Column(Integer, ForeignKey("task_templates.id", ondelete="SET NULL"), nullable=True, comment="父模板 ID") parent_id = Column(
Integer,
ForeignKey("task_templates.id", ondelete="SET NULL"),
nullable=True,
comment="父模板 ID",
)
# 字段配置(JSON 格式) # 字段配置(JSON 格式)
field_config = Column(Text, nullable=False, comment="字段配置(JSON") field_config = Column(Text, nullable=False, comment="字段配置(JSON")
+26 -6
View File
@@ -10,21 +10,41 @@ class User(Base):
__tablename__ = "users" __tablename__ = "users"
id = Column(Integer, primary_key=True, index=True, autoincrement=True) id = Column(Integer, primary_key=True, index=True, autoincrement=True)
jwt_sub = Column(String(200), unique=True, nullable=True, index=True, comment="QQ 扫码登录的唯一用户标识(注册时为空)") jwt_sub = Column(
alias = Column(String(50), unique=True, nullable=False, index=True, comment="用户别名(用于登录)") String(200),
unique=True,
nullable=True,
index=True,
comment="QQ 扫码登录的唯一用户标识(注册时为空)",
)
alias = Column(
String(50), unique=True, nullable=False, index=True, comment="用户别名(用于登录)"
)
email = Column(String(100), nullable=True, comment="用户邮箱(用于接收通知)") email = Column(String(100), nullable=True, comment="用户邮箱(用于接收通知)")
password_hash = Column(String(200), nullable=True, comment="密码哈希(bcrypt加密)") password_hash = Column(String(200), nullable=True, comment="密码哈希(bcrypt加密)")
authorization = Column(Text, nullable=True, comment="当前有效的 QQ Token") authorization = Column(Text, nullable=True, comment="当前有效的 QQ Token")
jwt_exp = Column(String(20), default="0", comment="Token 过期时间戳") jwt_exp = Column(String(20), default="0", comment="Token 过期时间戳")
token_expiring_notified = Column(Boolean, default=False, nullable=False, comment="Token 即将过期提醒是否已发送(过期前30分钟)") token_expiring_notified = Column(
token_expired_notified = Column(Boolean, default=False, nullable=False, comment="Token 已过期提醒是否已发送(过期后30分钟内)") Boolean,
default=False,
nullable=False,
comment="Token 即将过期提醒是否已发送(过期前30分钟)",
)
token_expired_notified = Column(
Boolean,
default=False,
nullable=False,
comment="Token 已过期提醒是否已发送(过期后30分钟内)",
)
role = Column(String(20), default="user", index=True, comment="角色: user/admin") role = Column(String(20), default="user", index=True, comment="角色: user/admin")
is_approved = Column(Boolean, default=False, index=True, comment="是否已通过管理员审批") is_approved = Column(Boolean, default=False, index=True, comment="是否已通过管理员审批")
# 账户锁定相关字段 # 账户锁定相关字段
failed_login_attempts = Column(Integer, default=0, nullable=False, comment="连续登录失败次数") failed_login_attempts = Column(Integer, default=0, nullable=False, comment="连续登录失败次数")
locked_until = Column(DateTime(timezone=True), nullable=True, comment="账户锁定到期时间") locked_until = Column(DateTime(timezone=True), nullable=True, comment="账户锁定到期时间")
last_failed_login = Column(DateTime(timezone=True), nullable=True, comment="最后一次登录失败时间") last_failed_login = Column(
DateTime(timezone=True), nullable=True, comment="最后一次登录失败时间"
)
created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间") created_at = Column(DateTime(timezone=True), server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime(timezone=True), onupdate=func.now(), comment="更新时间") updated_at = Column(DateTime(timezone=True), onupdate=func.now(), comment="更新时间")
@@ -34,7 +54,7 @@ class User(Base):
# 添加复合索引:加速审批管理查询 # 添加复合索引:加速审批管理查询
__table_args__ = ( __table_args__ = (
Index('ix_user_role_approved', 'role', 'is_approved'), # 管理员查询待审批用户 Index("ix_user_role_approved", "role", "is_approved"), # 管理员查询待审批用户
) )
def __repr__(self): def __repr__(self):
+7
View File
@@ -4,17 +4,20 @@ from pydantic import BaseModel, Field
class QRCodeRequest(BaseModel): class QRCodeRequest(BaseModel):
"""请求二维码 Schema""" """请求二维码 Schema"""
alias: str = Field(..., description="用户别名") alias: str = Field(..., description="用户别名")
class QRCodeResponse(BaseModel): class QRCodeResponse(BaseModel):
"""二维码响应 Schema""" """二维码响应 Schema"""
session_id: str = Field(..., description="会话 ID") session_id: str = Field(..., description="会话 ID")
qrcode_image: str = Field(..., description="二维码 Base64 图片") qrcode_image: str = Field(..., description="二维码 Base64 图片")
class QRCodeStatusResponse(BaseModel): class QRCodeStatusResponse(BaseModel):
"""二维码状态响应 Schema""" """二维码状态响应 Schema"""
status: str = Field(..., description="状态: pending/waiting_scan/success/error") status: str = Field(..., description="状态: pending/waiting_scan/success/error")
message: Optional[str] = Field(None, description="状态消息") message: Optional[str] = Field(None, description="状态消息")
user_id: Optional[int] = Field(None, description="用户 ID (扫码成功时返回)") user_id: Optional[int] = Field(None, description="用户 ID (扫码成功时返回)")
@@ -24,11 +27,13 @@ class QRCodeStatusResponse(BaseModel):
class TokenVerifyRequest(BaseModel): class TokenVerifyRequest(BaseModel):
"""Token 验证请求 Schema""" """Token 验证请求 Schema"""
authorization: str = Field(..., description="Token") authorization: str = Field(..., description="Token")
class TokenVerifyResponse(BaseModel): class TokenVerifyResponse(BaseModel):
"""Token 验证响应 Schema""" """Token 验证响应 Schema"""
is_valid: bool = Field(..., description="Token 是否有效") is_valid: bool = Field(..., description="Token 是否有效")
message: str = Field(..., description="验证消息") message: str = Field(..., description="验证消息")
user_id: Optional[int] = Field(None, description="用户 ID") user_id: Optional[int] = Field(None, description="用户 ID")
@@ -36,12 +41,14 @@ class TokenVerifyResponse(BaseModel):
class AliasLoginRequest(BaseModel): class AliasLoginRequest(BaseModel):
"""别名+密码登录请求 Schema""" """别名+密码登录请求 Schema"""
alias: str = Field(..., min_length=2, max_length=50, description="用户别名") alias: str = Field(..., min_length=2, max_length=50, description="用户别名")
password: str = Field(..., min_length=6, description="密码") password: str = Field(..., min_length=6, description="密码")
class AliasLoginResponse(BaseModel): class AliasLoginResponse(BaseModel):
"""别名+密码登录响应 Schema""" """别名+密码登录响应 Schema"""
success: bool = Field(..., description="登录是否成功") success: bool = Field(..., description="登录是否成功")
message: str = Field(..., description="登录消息") message: str = Field(..., description="登录消息")
user_id: Optional[int] = Field(None, description="用户 ID") user_id: Optional[int] = Field(None, description="用户 ID")
+7 -1
View File
@@ -2,21 +2,24 @@ from datetime import datetime
from typing import Optional, List, Generic, TypeVar from typing import Optional, List, Generic, TypeVar
from pydantic import BaseModel, Field, ConfigDict from pydantic import BaseModel, Field, ConfigDict
T = TypeVar('T') T = TypeVar("T")
class ManualCheckInRequest(BaseModel): class ManualCheckInRequest(BaseModel):
"""手动打卡请求 Schema(已废弃,现在使用路径参数 task_id)""" """手动打卡请求 Schema(已废弃,现在使用路径参数 task_id)"""
task_id: Optional[int] = Field(None, description="任务 ID") task_id: Optional[int] = Field(None, description="任务 ID")
class BatchCheckInRequest(BaseModel): class BatchCheckInRequest(BaseModel):
"""批量打卡请求 Schema""" """批量打卡请求 Schema"""
task_ids: list[int] = Field(..., description="任务 ID 列表") task_ids: list[int] = Field(..., description="任务 ID 列表")
class CheckInRecordResponse(BaseModel): class CheckInRecordResponse(BaseModel):
"""打卡记录响应 Schema""" """打卡记录响应 Schema"""
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
id: int id: int
@@ -37,6 +40,7 @@ class CheckInRecordResponse(BaseModel):
class CheckInRecordWithTaskInfo(CheckInRecordResponse): class CheckInRecordWithTaskInfo(CheckInRecordResponse):
"""带任务信息的打卡记录响应 Schema""" """带任务信息的打卡记录响应 Schema"""
task_name: str task_name: str
task_signature: str task_signature: str
user_alias: str user_alias: str
@@ -44,6 +48,7 @@ class CheckInRecordWithTaskInfo(CheckInRecordResponse):
class CheckInResultResponse(BaseModel): class CheckInResultResponse(BaseModel):
"""打卡结果响应 Schema""" """打卡结果响应 Schema"""
success: bool success: bool
message: str message: str
record_id: Optional[int] = None record_id: Optional[int] = None
@@ -52,6 +57,7 @@ class CheckInResultResponse(BaseModel):
class PaginatedResponse(BaseModel, Generic[T]): class PaginatedResponse(BaseModel, Generic[T]):
"""分页响应 Schema""" """分页响应 Schema"""
records: List[T] = Field(..., description="记录列表") records: List[T] = Field(..., description="记录列表")
total: int = Field(..., description="总记录数") total: int = Field(..., description="总记录数")
skip: int = Field(..., description="跳过的记录数") skip: int = Field(..., description="跳过的记录数")
+5 -1
View File
@@ -1,15 +1,17 @@
""" """
统一的 API 响应 Schema 统一的 API 响应 Schema
""" """
from typing import Generic, TypeVar, Optional from typing import Generic, TypeVar, Optional
from pydantic import BaseModel from pydantic import BaseModel
T = TypeVar('T') T = TypeVar("T")
class ApiResponse(BaseModel, Generic[T]): class ApiResponse(BaseModel, Generic[T]):
"""统一成功响应""" """统一成功响应"""
success: bool = True success: bool = True
data: Optional[T] = None data: Optional[T] = None
message: Optional[str] = None message: Optional[str] = None
@@ -17,6 +19,7 @@ class ApiResponse(BaseModel, Generic[T]):
class ErrorDetail(BaseModel): class ErrorDetail(BaseModel):
"""错误详情""" """错误详情"""
code: str code: str
message: str message: str
field: Optional[str] = None # 字段验证错误时使用 field: Optional[str] = None # 字段验证错误时使用
@@ -24,5 +27,6 @@ class ErrorDetail(BaseModel):
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
"""统一错误响应""" """统一错误响应"""
success: bool = False success: bool = False
error: ErrorDetail error: ErrorDetail
+17 -15
View File
@@ -5,11 +5,14 @@ from pydantic import BaseModel, Field, field_validator
class TaskBase(BaseModel): class TaskBase(BaseModel):
"""打卡任务基础 Schema""" """打卡任务基础 Schema"""
payload_config: str = Field(..., description="完整的 payload 配置 JSON(包含 ThreadId 和所有字段)")
payload_config: str = Field(
..., description="完整的 payload 配置 JSON(包含 ThreadId 和所有字段)"
)
name: Optional[str] = Field("", max_length=100, description="任务名称(用户自定义)") name: Optional[str] = Field("", max_length=100, description="任务名称(用户自定义)")
is_active: Optional[bool] = Field(True, description="是否启用自动打卡") is_active: Optional[bool] = Field(True, description="是否启用自动打卡")
@field_validator('payload_config') @field_validator("payload_config")
@classmethod @classmethod
def validate_payload_config(cls, v: str) -> str: def validate_payload_config(cls, v: str) -> str:
""" """
@@ -38,13 +41,14 @@ class TaskBase(BaseModel):
class TaskCreate(TaskBase): class TaskCreate(TaskBase):
"""创建打卡任务 Schema""" """创建打卡任务 Schema"""
cron_expression: Optional[str] = Field( cron_expression: Optional[str] = Field(
None, None,
max_length=100, max_length=100,
description="Crontab 表达式(例如 '0 20 * * *' 表示每天 20:00)。NULL 表示禁用定时打卡" description="Crontab 表达式(例如 '0 20 * * *' 表示每天 20:00)。NULL 表示禁用定时打卡",
) )
@field_validator('cron_expression') @field_validator("cron_expression")
@classmethod @classmethod
def validate_cron_expression(cls, v: Optional[str]) -> Optional[str]: def validate_cron_expression(cls, v: Optional[str]) -> Optional[str]:
"""验证 Crontab 表达式格式""" """验证 Crontab 表达式格式"""
@@ -56,6 +60,7 @@ class TaskCreate(TaskBase):
try: try:
from croniter import croniter from croniter import croniter
if not croniter.is_valid(v): if not croniter.is_valid(v):
raise ValueError(f"无效的 Crontab 表达式: '{v}'") raise ValueError(f"无效的 Crontab 表达式: '{v}'")
except Exception as e: except Exception as e:
@@ -66,16 +71,15 @@ class TaskCreate(TaskBase):
class TaskUpdate(BaseModel): class TaskUpdate(BaseModel):
"""更新打卡任务 Schema""" """更新打卡任务 Schema"""
payload_config: Optional[str] = None payload_config: Optional[str] = None
name: Optional[str] = None name: Optional[str] = None
is_active: Optional[bool] = None is_active: Optional[bool] = None
cron_expression: Optional[str] = Field( cron_expression: Optional[str] = Field(
None, None, max_length=100, description="Crontab 表达式。NULL 表示禁用定时打卡"
max_length=100,
description="Crontab 表达式。NULL 表示禁用定时打卡"
) )
@field_validator('payload_config') @field_validator("payload_config")
@classmethod @classmethod
def validate_payload_config(cls, v: Optional[str]) -> Optional[str]: def validate_payload_config(cls, v: Optional[str]) -> Optional[str]:
""" """
@@ -104,7 +108,7 @@ class TaskUpdate(BaseModel):
return v return v
@field_validator('cron_expression') @field_validator("cron_expression")
@classmethod @classmethod
def validate_cron_expression(cls, v: Optional[str]) -> Optional[str]: def validate_cron_expression(cls, v: Optional[str]) -> Optional[str]:
"""验证 Crontab 表达式(与 TaskCreate 相同)""" """验证 Crontab 表达式(与 TaskCreate 相同)"""
@@ -116,6 +120,7 @@ class TaskUpdate(BaseModel):
try: try:
from croniter import croniter from croniter import croniter
if not croniter.is_valid(v): if not croniter.is_valid(v):
raise ValueError(f"无效的 Crontab 表达式: '{v}'") raise ValueError(f"无效的 Crontab 表达式: '{v}'")
except Exception as e: except Exception as e:
@@ -126,18 +131,15 @@ class TaskUpdate(BaseModel):
class TaskResponse(TaskBase): class TaskResponse(TaskBase):
"""打卡任务响应 Schema""" """打卡任务响应 Schema"""
id: int id: int
user_id: int user_id: int
created_at: datetime created_at: datetime
updated_at: Optional[datetime] = None updated_at: Optional[datetime] = None
cron_expression: Optional[str] = Field( cron_expression: Optional[str] = Field(
None, None, description="当前 Crontab 表达式(NULL = 禁用定时打卡)"
description="当前 Crontab 表达式(NULL = 禁用定时打卡)"
)
is_scheduled_enabled: Optional[bool] = Field(
None,
description="是否启用了定时打卡"
) )
is_scheduled_enabled: Optional[bool] = Field(None, description="是否启用了定时打卡")
# 新增字段:最后一次打卡信息 # 新增字段:最后一次打卡信息
last_check_in_time: Optional[datetime] = Field(None, description="最后一次打卡时间") last_check_in_time: Optional[datetime] = Field(None, description="最后一次打卡时间")
+29 -15
View File
@@ -6,12 +6,14 @@ import json
class FieldOption(BaseModel): class FieldOption(BaseModel):
"""字段选项(用于 select 类型)""" """字段选项(用于 select 类型)"""
label: str = Field(..., description="选项显示文本") label: str = Field(..., description="选项显示文本")
value: str = Field(..., description="选项值") value: str = Field(..., description="选项值")
class FieldConfigItem(BaseModel): class FieldConfigItem(BaseModel):
"""单个字段配置项""" """单个字段配置项"""
display_name: str = Field(..., description="字段显示名称") display_name: str = Field(..., description="字段显示名称")
field_type: str = Field(..., description="字段输入类型:text, textarea, number, select") field_type: str = Field(..., description="字段输入类型:text, textarea, number, select")
default_value: str = Field(default="", description="默认值") default_value: str = Field(default="", description="默认值")
@@ -21,33 +23,35 @@ class FieldConfigItem(BaseModel):
value_type: str = Field(default="string", description="值类型:string, int, double") value_type: str = Field(default="string", description="值类型:string, int, double")
options: Optional[List[FieldOption]] = Field(None, description="选项列表(仅 select 类型)") options: Optional[List[FieldOption]] = Field(None, description="选项列表(仅 select 类型)")
@field_validator('field_type') @field_validator("field_type")
@classmethod @classmethod
def validate_field_type(cls, v): def validate_field_type(cls, v):
allowed_types = ['text', 'textarea', 'number', 'select'] allowed_types = ["text", "textarea", "number", "select"]
if v not in allowed_types: if v not in allowed_types:
raise ValueError(f'field_type must be one of {allowed_types}') raise ValueError(f"field_type must be one of {allowed_types}")
return v return v
@field_validator('value_type') @field_validator("value_type")
@classmethod @classmethod
def validate_value_type(cls, v): def validate_value_type(cls, v):
allowed_types = ['string', 'int', 'double'] allowed_types = ["string", "int", "double"]
if v not in allowed_types: if v not in allowed_types:
raise ValueError(f'value_type must be one of {allowed_types}') raise ValueError(f"value_type must be one of {allowed_types}")
return v return v
class FieldConfigValues(BaseModel): class FieldConfigValues(BaseModel):
"""Values 字段的嵌套配置(如 location, temperature 等)""" """Values 字段的嵌套配置(如 location, temperature 等)"""
pass pass
class Config: class Config:
extra = 'allow' # 允许任意字段 extra = "allow" # 允许任意字段
class FieldConfig(BaseModel): class FieldConfig(BaseModel):
"""完整的字段配置""" """完整的字段配置"""
signature: Optional[FieldConfigItem] = None signature: Optional[FieldConfigItem] = None
texts: Optional[FieldConfigItem] = None texts: Optional[FieldConfigItem] = None
values: Optional[Dict[str, FieldConfigItem]] = Field(None, description="Values 字段的嵌套配置") values: Optional[Dict[str, FieldConfigItem]] = Field(None, description="Values 字段的嵌套配置")
@@ -55,13 +59,14 @@ class FieldConfig(BaseModel):
class TemplateBase(BaseModel): class TemplateBase(BaseModel):
"""模板基础 Schema""" """模板基础 Schema"""
name: str = Field(..., min_length=1, max_length=100, description="模板名称") name: str = Field(..., min_length=1, max_length=100, description="模板名称")
description: Optional[str] = Field(None, description="模板描述") description: Optional[str] = Field(None, description="模板描述")
parent_id: Optional[int] = Field(None, description="父模板 ID(用于继承)") parent_id: Optional[int] = Field(None, description="父模板 ID(用于继承)")
field_config: Union[str, FieldConfig] = Field(..., description="字段配置(JSON 字符串或对象)") field_config: Union[str, FieldConfig] = Field(..., description="字段配置(JSON 字符串或对象)")
is_active: bool = Field(default=True, description="是否启用") is_active: bool = Field(default=True, description="是否启用")
@field_validator('field_config') @field_validator("field_config")
@classmethod @classmethod
def validate_field_config(cls, v): def validate_field_config(cls, v):
"""验证并转换 field_config""" """验证并转换 field_config"""
@@ -71,7 +76,7 @@ class TemplateBase(BaseModel):
config_dict = json.loads(v) config_dict = json.loads(v)
return json.dumps(config_dict) # 返回格式化的 JSON 字符串 return json.dumps(config_dict) # 返回格式化的 JSON 字符串
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError('field_config must be valid JSON string') raise ValueError("field_config must be valid JSON string")
elif isinstance(v, dict): elif isinstance(v, dict):
# 如果是字典,转换为 JSON 字符串 # 如果是字典,转换为 JSON 字符串
return json.dumps(v) return json.dumps(v)
@@ -79,23 +84,27 @@ class TemplateBase(BaseModel):
# 如果是 FieldConfig 对象,转换为 JSON 字符串 # 如果是 FieldConfig 对象,转换为 JSON 字符串
return v.model_dump_json(exclude_none=True) return v.model_dump_json(exclude_none=True)
else: else:
raise ValueError('field_config must be JSON string, dict, or FieldConfig object') raise ValueError("field_config must be JSON string, dict, or FieldConfig object")
class TemplateCreate(TemplateBase): class TemplateCreate(TemplateBase):
"""创建模板 Schema""" """创建模板 Schema"""
pass pass
class TemplateUpdate(BaseModel): class TemplateUpdate(BaseModel):
"""更新模板 Schema""" """更新模板 Schema"""
name: Optional[str] = Field(None, min_length=1, max_length=100, description="模板名称") name: Optional[str] = Field(None, min_length=1, max_length=100, description="模板名称")
description: Optional[str] = Field(None, description="模板描述") description: Optional[str] = Field(None, description="模板描述")
parent_id: Optional[int] = Field(None, description="父模板 ID(用于继承)") parent_id: Optional[int] = Field(None, description="父模板 ID(用于继承)")
field_config: Optional[Union[str, FieldConfig]] = Field(None, description="字段配置(JSON 字符串或对象)") field_config: Optional[Union[str, FieldConfig]] = Field(
None, description="字段配置(JSON 字符串或对象)"
)
is_active: Optional[bool] = Field(None, description="是否启用") is_active: Optional[bool] = Field(None, description="是否启用")
@field_validator('field_config') @field_validator("field_config")
@classmethod @classmethod
def validate_field_config(cls, v): def validate_field_config(cls, v):
"""验证并转换 field_config""" """验证并转换 field_config"""
@@ -107,17 +116,18 @@ class TemplateUpdate(BaseModel):
config_dict = json.loads(v) config_dict = json.loads(v)
return json.dumps(config_dict) return json.dumps(config_dict)
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError('field_config must be valid JSON string') raise ValueError("field_config must be valid JSON string")
elif isinstance(v, dict): elif isinstance(v, dict):
return json.dumps(v) return json.dumps(v)
elif isinstance(v, FieldConfig): elif isinstance(v, FieldConfig):
return v.model_dump_json(exclude_none=True) return v.model_dump_json(exclude_none=True)
else: else:
raise ValueError('field_config must be JSON string, dict, or FieldConfig object') raise ValueError("field_config must be JSON string, dict, or FieldConfig object")
class TemplateResponse(BaseModel): class TemplateResponse(BaseModel):
"""模板响应 Schema""" """模板响应 Schema"""
id: int id: int
name: str name: str
description: Optional[str] description: Optional[str]
@@ -133,15 +143,19 @@ class TemplateResponse(BaseModel):
class TaskFromTemplateRequest(BaseModel): class TaskFromTemplateRequest(BaseModel):
"""从模板创建任务的请求 Schema""" """从模板创建任务的请求 Schema"""
template_id: int = Field(..., description="模板 ID") template_id: int = Field(..., description="模板 ID")
thread_id: str = Field(..., min_length=1, description="接龙项目 ID") thread_id: str = Field(..., min_length=1, description="接龙项目 ID")
field_values: Dict[str, Any] = Field(default_factory=dict, description="用户填写的字段值") field_values: Dict[str, Any] = Field(default_factory=dict, description="用户填写的字段值")
task_name: Optional[str] = Field(None, max_length=100, description="任务名称(可选)") task_name: Optional[str] = Field(None, max_length=100, description="任务名称(可选)")
cron_expression: Optional[str] = Field("0 20 * * *", description="Cron 表达式(可选,默认每天 20:00)") cron_expression: Optional[str] = Field(
"0 20 * * *", description="Cron 表达式(可选,默认每天 20:00)"
)
class TemplatePreviewResponse(BaseModel): class TemplatePreviewResponse(BaseModel):
"""模板预览响应 Schema""" """模板预览响应 Schema"""
template_id: int template_id: int
template_name: str template_name: str
preview_payload: Dict[str, Any] = Field(..., description="预览生成的 payload(使用默认值)") preview_payload: Dict[str, Any] = Field(..., description="预览生成的 payload(使用默认值)")
+13 -2
View File
@@ -5,11 +5,13 @@ from pydantic import BaseModel, Field, EmailStr
class UserBase(BaseModel): class UserBase(BaseModel):
"""用户基础 Schema""" """用户基础 Schema"""
alias: str = Field(..., min_length=2, max_length=50, description="用户别名(用于登录)") alias: str = Field(..., min_length=2, max_length=50, description="用户别名(用于登录)")
class UserCreate(UserBase): class UserCreate(UserBase):
"""创建用户 Schema(管理员手动创建,只需要别名)""" """创建用户 Schema(管理员手动创建,只需要别名)"""
role: Optional[str] = Field("user", description="角色: user/admin") role: Optional[str] = Field("user", description="角色: user/admin")
email: Optional[EmailStr] = Field(None, description="邮箱地址") email: Optional[EmailStr] = Field(None, description="邮箱地址")
password: Optional[str] = Field(None, min_length=6, description="初始密码(可选)") password: Optional[str] = Field(None, min_length=6, description="初始密码(可选)")
@@ -18,24 +20,31 @@ class UserCreate(UserBase):
class UserUpdate(BaseModel): class UserUpdate(BaseModel):
"""更新用户 Schema(管理员编辑用户)""" """更新用户 Schema(管理员编辑用户)"""
alias: Optional[str] = Field(None, min_length=2, max_length=50, description="用户别名") alias: Optional[str] = Field(None, min_length=2, max_length=50, description="用户别名")
role: Optional[str] = None role: Optional[str] = None
is_approved: Optional[bool] = None is_approved: Optional[bool] = None
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
password: Optional[str] = Field(None, min_length=6, description="新密码(可选,留空表示不修改)") password: Optional[str] = Field(
None, min_length=6, description="新密码(可选,留空表示不修改)"
)
reset_password: Optional[bool] = Field(False, description="是否清空密码") reset_password: Optional[bool] = Field(False, description="是否清空密码")
class UserUpdateProfile(BaseModel): class UserUpdateProfile(BaseModel):
"""用户更新个人信息 Schema""" """用户更新个人信息 Schema"""
alias: Optional[str] = Field(None, min_length=2, max_length=50, description="新别名") alias: Optional[str] = Field(None, min_length=2, max_length=50, description="新别名")
email: Optional[EmailStr] = Field(None, description="邮箱地址") email: Optional[EmailStr] = Field(None, description="邮箱地址")
current_password: Optional[str] = Field(None, min_length=6, description="当前密码(修改密码时必填)") current_password: Optional[str] = Field(
None, min_length=6, description="当前密码(修改密码时必填)"
)
new_password: Optional[str] = Field(None, min_length=6, description="新密码") new_password: Optional[str] = Field(None, min_length=6, description="新密码")
class UserResponse(BaseModel): class UserResponse(BaseModel):
"""用户响应 Schema""" """用户响应 Schema"""
id: int id: int
alias: str alias: str
role: str role: str
@@ -52,11 +61,13 @@ class UserResponse(BaseModel):
class UserWithToken(UserResponse): class UserWithToken(UserResponse):
"""带 Token 的用户响应 Schema""" """带 Token 的用户响应 Schema"""
authorization: Optional[str] = None authorization: Optional[str] = None
class TokenStatus(BaseModel): class TokenStatus(BaseModel):
"""Token 状态 Schema""" """Token 状态 Schema"""
is_valid: bool is_valid: bool
jwt_exp: str jwt_exp: str
expires_at: Optional[int] = None # Unix 时间戳(秒) expires_at: Optional[int] = None # Unix 时间戳(秒)
+2 -1
View File
@@ -5,6 +5,7 @@
使用方法: 使用方法:
uv run python apps/backend/scripts/create_admin.py uv run python apps/backend/scripts/create_admin.py
""" """
import sys import sys
from pathlib import Path from pathlib import Path
@@ -51,7 +52,7 @@ def create_admin_user(alias: str):
# 升级为管理员 # 升级为管理员
response = input("\n是否将该用户升级为管理员?(y/n): ") response = input("\n是否将该用户升级为管理员?(y/n): ")
if response.lower() == 'y': if response.lower() == "y":
existing_user.role = "admin" existing_user.role = "admin"
existing_user.is_approved = True # 确保已审批 existing_user.is_approved = True # 确保已审批
db.commit() db.commit()
@@ -34,33 +34,31 @@ def migrate():
columns = [row[1] for row in result] columns = [row[1] for row in result]
# 添加 failed_login_attempts 字段 # 添加 failed_login_attempts 字段
if 'failed_login_attempts' not in columns: if "failed_login_attempts" not in columns:
logger.info("添加 failed_login_attempts 字段...") logger.info("添加 failed_login_attempts 字段...")
conn.execute(text( conn.execute(
text(
"ALTER TABLE users ADD COLUMN failed_login_attempts INTEGER DEFAULT 0 NOT NULL" "ALTER TABLE users ADD COLUMN failed_login_attempts INTEGER DEFAULT 0 NOT NULL"
)) )
)
conn.commit() conn.commit()
logger.info("✓ failed_login_attempts 字段添加成功") logger.info("✓ failed_login_attempts 字段添加成功")
else: else:
logger.info("✓ failed_login_attempts 字段已存在,跳过") logger.info("✓ failed_login_attempts 字段已存在,跳过")
# 添加 locked_until 字段 # 添加 locked_until 字段
if 'locked_until' not in columns: if "locked_until" not in columns:
logger.info("添加 locked_until 字段...") logger.info("添加 locked_until 字段...")
conn.execute(text( conn.execute(text("ALTER TABLE users ADD COLUMN locked_until DATETIME"))
"ALTER TABLE users ADD COLUMN locked_until DATETIME"
))
conn.commit() conn.commit()
logger.info("✓ locked_until 字段添加成功") logger.info("✓ locked_until 字段添加成功")
else: else:
logger.info("✓ locked_until 字段已存在,跳过") logger.info("✓ locked_until 字段已存在,跳过")
# 添加 last_failed_login 字段 # 添加 last_failed_login 字段
if 'last_failed_login' not in columns: if "last_failed_login" not in columns:
logger.info("添加 last_failed_login 字段...") logger.info("添加 last_failed_login 字段...")
conn.execute(text( conn.execute(text("ALTER TABLE users ADD COLUMN last_failed_login DATETIME"))
"ALTER TABLE users ADD COLUMN last_failed_login DATETIME"
))
conn.commit() conn.commit()
logger.info("✓ last_failed_login 字段添加成功") logger.info("✓ last_failed_login 字段添加成功")
else: else:
+13 -11
View File
@@ -1,6 +1,7 @@
""" """
测试新的异常处理系统 测试新的异常处理系统
""" """
import sys import sys
from pathlib import Path from pathlib import Path
@@ -16,6 +17,7 @@ from backend.exceptions import (
) )
from backend.schemas.response import ErrorResponse, ErrorDetail from backend.schemas.response import ErrorResponse, ErrorDetail
def test_exceptions(): def test_exceptions():
"""测试自定义异常""" """测试自定义异常"""
print("=" * 60) print("=" * 60)
@@ -32,7 +34,9 @@ def test_exceptions():
try: try:
raise AuthenticationError("Token已过期") raise AuthenticationError("Token已过期")
except AuthenticationError as e: except AuthenticationError as e:
print(f"✅ AuthenticationError: {e.message} (状态码: {e.status_code}, 代码: {e.error_code})") print(
f"✅ AuthenticationError: {e.message} (状态码: {e.status_code}, 代码: {e.error_code})"
)
# 测试 AuthorizationError # 测试 AuthorizationError
try: try:
@@ -44,7 +48,9 @@ def test_exceptions():
try: try:
raise ResourceNotFoundError("用户不存在") raise ResourceNotFoundError("用户不存在")
except ResourceNotFoundError as e: except ResourceNotFoundError as e:
print(f"✅ ResourceNotFoundError: {e.message} (状态码: {e.status_code}, 代码: {e.error_code})") print(
f"✅ ResourceNotFoundError: {e.message} (状态码: {e.status_code}, 代码: {e.error_code})"
)
# 测试 BusinessLogicError # 测试 BusinessLogicError
try: try:
@@ -61,11 +67,7 @@ def test_response_schemas():
# 测试 ErrorResponse # 测试 ErrorResponse
error_response = ErrorResponse( error_response = ErrorResponse(
error=ErrorDetail( error=ErrorDetail(code="VALIDATION_ERROR", message="邮箱格式不正确", field="email")
code="VALIDATION_ERROR",
message="邮箱格式不正确",
field="email"
)
) )
response_dict = error_response.model_dump() response_dict = error_response.model_dump()
@@ -90,18 +92,18 @@ def check_old_exception_patterns():
patterns = { patterns = {
"HTTPException with detail": r'raise HTTPException.*detail=f?".*{', "HTTPException with detail": r'raise HTTPException.*detail=f?".*{',
"except Exception": r'except Exception as', "except Exception": r"except Exception as",
} }
results = {} results = {}
for pattern_name, pattern in patterns.items(): for pattern_name, pattern in patterns.items():
results[pattern_name] = [] results[pattern_name] = []
for root, dirs, files in os.walk(APPS_DIR / 'backend' / 'api'): for root, dirs, files in os.walk(APPS_DIR / "backend" / "api"):
for file in files: for file in files:
if file.endswith('.py'): if file.endswith(".py"):
filepath = os.path.join(root, file) filepath = os.path.join(root, file)
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, "r", encoding="utf-8") as f:
content = f.read() content = f.read()
matches = re.findall(pattern, content, re.MULTILINE) matches = re.findall(pattern, content, re.MULTILINE)
if matches: if matches:
+13 -18
View File
@@ -14,10 +14,12 @@ class AdminService:
@staticmethod @staticmethod
def get_pending_users(db: Session) -> List[User]: def get_pending_users(db: Session) -> List[User]:
"""获取待审批用户列表""" """获取待审批用户列表"""
users = db.query(User).filter( users = (
User.is_approved == False, db.query(User)
User.role == "user" .filter(User.is_approved == False, User.role == "user")
).order_by(User.created_at.desc()).all() .order_by(User.created_at.desc())
.all()
)
return users return users
@@ -38,11 +40,7 @@ class AdminService:
logger.info(f"管理员审批通过用户: {user.alias} (ID: {user.id})") logger.info(f"管理员审批通过用户: {user.alias} (ID: {user.id})")
return { return {"success": True, "message": "审批成功", "user_id": user.id}
"success": True,
"message": "审批成功",
"user_id": user.id
}
@staticmethod @staticmethod
def reject_user(user_id: int, db: Session) -> Dict[str, Any]: def reject_user(user_id: int, db: Session) -> Dict[str, Any]:
@@ -58,21 +56,18 @@ class AdminService:
logger.info(f"管理员拒绝用户: {alias} (ID: {user_id})") logger.info(f"管理员拒绝用户: {alias} (ID: {user_id})")
return { return {"success": True, "message": "已拒绝并删除用户"}
"success": True,
"message": "已拒绝并删除用户"
}
@staticmethod @staticmethod
def delete_expired_pending_users(db: Session) -> int: def delete_expired_pending_users(db: Session) -> int:
"""删除24小时未审批的用户""" """删除24小时未审批的用户"""
cutoff_time = datetime.now() - timedelta(hours=24) cutoff_time = datetime.now() - timedelta(hours=24)
expired_users = db.query(User).filter( expired_users = (
User.is_approved == False, db.query(User)
User.role == "user", .filter(User.is_approved == False, User.role == "user", User.created_at < cutoff_time)
User.created_at < cutoff_time .all()
).all() )
count = len(expired_users) count = len(expired_users)
+75 -132
View File
@@ -45,10 +45,7 @@ class AuthService:
# 检查是否为空 jwt_sub(测试账号) # 检查是否为空 jwt_sub(测试账号)
if not existing_user.jwt_sub: if not existing_user.jwt_sub:
logger.warning(f"用户 {alias} 是测试账号(未绑定 QQ),禁止扫码登录") logger.warning(f"用户 {alias} 是测试账号(未绑定 QQ),禁止扫码登录")
return { return {"status": "error", "message": "此账户为测试账号,暂未绑定 QQ,无法扫码登录"}
"status": "error",
"message": "此账户为测试账号,暂未绑定 QQ,无法扫码登录"
}
# 老用户:刷新 Token # 老用户:刷新 Token
logger.info(f"老用户 {alias} 请求刷新 Token,会话: {session_id}") logger.info(f"老用户 {alias} 请求刷新 Token,会话: {session_id}")
@@ -57,7 +54,7 @@ class AuthService:
thread = threading.Thread( thread = threading.Thread(
target=get_token_headless, target=get_token_headless,
args=(session_id, existing_user.jwt_sub, alias, client_ip), args=(session_id, existing_user.jwt_sub, alias, client_ip),
daemon=True daemon=True,
) )
thread.start() thread.start()
@@ -67,16 +64,14 @@ class AuthService:
logger.warning(f"用户名 {alias} 已被预占") logger.warning(f"用户名 {alias} 已被预占")
return { return {
"status": "error", "status": "error",
"message": "该用户名正在被其他人注册,请稍后再试或更换用户名" "message": "该用户名正在被其他人注册,请稍后再试或更换用户名",
} }
logger.info(f"新用户 {alias} 请求注册,会话: {session_id},已预占用户名") logger.info(f"新用户 {alias} 请求注册,会话: {session_id},已预占用户名")
# 在后台线程启动 Selenium,不传入 jwt_sub(新用户) # 在后台线程启动 Selenium,不传入 jwt_sub(新用户)
thread = threading.Thread( thread = threading.Thread(
target=get_token_headless, target=get_token_headless, args=(session_id, None, alias, client_ip), daemon=True
args=(session_id, None, alias, client_ip),
daemon=True
) )
thread.start() thread.start()
@@ -96,29 +91,20 @@ class AuthService:
qr_image_data = session_data.get("qr_image_data") qr_image_data = session_data.get("qr_image_data")
if qr_image_data: if qr_image_data:
logger.info(f"会话 {session_id} 的二维码已生成") logger.info(f"会话 {session_id} 的二维码已生成")
return { return {"session_id": session_id, "qrcode_base64": qr_image_data}
"session_id": session_id,
"qrcode_base64": qr_image_data
}
# 如果已经失败,直接返回错误 # 如果已经失败,直接返回错误
elif status == "failed": elif status == "failed":
error_msg = session_data.get("message", "生成二维码失败") error_msg = session_data.get("message", "生成二维码失败")
logger.error(f"会话 {session_id} 生成二维码失败: {error_msg}") logger.error(f"会话 {session_id} 生成二维码失败: {error_msg}")
return { return {"status": "error", "message": error_msg}
"status": "error",
"message": error_msg
}
# 每 0.5 秒检查一次 # 每 0.5 秒检查一次
time.sleep(0.5) time.sleep(0.5)
# 超时 # 超时
logger.error(f"会话 {session_id} 等待二维码生成超时({max_wait_time}秒)") logger.error(f"会话 {session_id} 等待二维码生成超时({max_wait_time}秒)")
return { return {"status": "error", "message": f"生成二维码超时,请重试"}
"status": "error",
"message": f"生成二维码超时,请重试"
}
@staticmethod @staticmethod
def get_qrcode_status(session_id: str, db: Session) -> Dict[str, Any]: def get_qrcode_status(session_id: str, db: Session) -> Dict[str, Any]:
@@ -135,10 +121,7 @@ class AuthService:
session_data = get_session_data(session_id) session_data = get_session_data(session_id)
if not session_data: if not session_data:
return { return {"status": "pending", "message": "会话不存在或正在初始化"}
"status": "pending",
"message": "会话不存在或正在初始化"
}
status = session_data.get("status") status = session_data.get("status")
jwt_sub = session_data.get("jwt_sub") # 使用 jwt_sub 而非 signature jwt_sub = session_data.get("jwt_sub") # 使用 jwt_sub 而非 signature
@@ -147,7 +130,7 @@ class AuthService:
return { return {
"status": "waiting_scan", "status": "waiting_scan",
"message": "请使用手机 QQ 扫描二维码", "message": "请使用手机 QQ 扫描二维码",
"qrcode_image": session_data.get("qr_image_data") "qrcode_image": session_data.get("qr_image_data"),
} }
elif status == "success": elif status == "success":
@@ -160,15 +143,12 @@ class AuthService:
if not token: if not token:
logger.error("Token 为空") logger.error("Token 为空")
return { return {"status": "error", "message": "Token 为空"}
"status": "error",
"message": "Token 为空"
}
try: try:
# 清洗 TokenURL 解码 + 去除 Bearer 前缀(参考 v1 实现) # 清洗 TokenURL 解码 + 去除 Bearer 前缀(参考 v1 实现)
pure_token = unquote(token) # URL 解码 pure_token = unquote(token) # URL 解码
if pure_token.lower().startswith('bearer '): if pure_token.lower().startswith("bearer "):
pure_token = pure_token[7:] # 去除 "Bearer " 前缀 pure_token = pure_token[7:] # 去除 "Bearer " 前缀
decoded = jwt.decode(pure_token, options={"verify_signature": False}) decoded = jwt.decode(pure_token, options={"verify_signature": False})
@@ -177,10 +157,7 @@ class AuthService:
logger.info(f"成功解析 JWT for sub={jwt_sub}, exp={jwt_exp}") logger.info(f"成功解析 JWT for sub={jwt_sub}, exp={jwt_exp}")
except Exception as e: except Exception as e:
logger.error(f"解析 JWT Token 失败: {e}") logger.error(f"解析 JWT Token 失败: {e}")
return { return {"status": "error", "message": f"Token 解析失败: {str(e)}"}
"status": "error",
"message": f"Token 解析失败: {str(e)}"
}
# 查找用户(通过 jwt_sub # 查找用户(通过 jwt_sub
user = db.query(User).filter(User.jwt_sub == jwt_sub).first() user = db.query(User).filter(User.jwt_sub == jwt_sub).first()
@@ -191,12 +168,18 @@ class AuthService:
if alias and alias == user.alias: if alias and alias == user.alias:
# 用户使用别名登录,验证 jwt_sub 是否一致 # 用户使用别名登录,验证 jwt_sub 是否一致
# 如果用户之前的 jwt_sub 不为空且与当前不一致,说明QQ号被换绑了 # 如果用户之前的 jwt_sub 不为空且与当前不一致,说明QQ号被换绑了
existing_jwt_sub = getattr(user, 'jwt_sub', '') existing_jwt_sub = getattr(user, "jwt_sub", "")
if isinstance(existing_jwt_sub, str) and existing_jwt_sub.strip() and existing_jwt_sub != jwt_sub: if (
logger.warning(f"⚠️ 用户 {user.alias} 的 jwt_sub 不匹配!数据库: {existing_jwt_sub}, 当前: {jwt_sub}") isinstance(existing_jwt_sub, str)
and existing_jwt_sub.strip()
and existing_jwt_sub != jwt_sub
):
logger.warning(
f"⚠️ 用户 {user.alias} 的 jwt_sub 不匹配!数据库: {existing_jwt_sub}, 当前: {jwt_sub}"
)
return { return {
"status": "error", "status": "error",
"message": "QQ账号不匹配,请使用正确的QQ号扫码登录" "message": "QQ账号不匹配,请使用正确的QQ号扫码登录",
} }
user.authorization = pure_token # 存储清理后的 token user.authorization = pure_token # 存储清理后的 token
@@ -221,9 +204,9 @@ class AuthService:
"alias": user.alias, "alias": user.alias,
"role": user.role, "role": user.role,
"is_approved": user.is_approved, "is_approved": user.is_approved,
"jwt_sub": user.jwt_sub "jwt_sub": user.jwt_sub,
}, },
"is_new_user": False "is_new_user": False,
} }
else: else:
@@ -233,20 +216,14 @@ class AuthService:
# 验证用户名是否被预占 # 验证用户名是否被预占
if not alias or not registration_manager.is_alias_reserved(alias): if not alias or not registration_manager.is_alias_reserved(alias):
logger.error(f"新用户注册失败:用户名 {alias} 未预占或已过期") logger.error(f"新用户注册失败:用户名 {alias} 未预占或已过期")
return { return {"status": "error", "message": "注册失败:会话已过期,请重新扫码"}
"status": "error",
"message": "注册失败:会话已过期,请重新扫码"
}
# 检查用户名是否已被其他人注册(防止竞态) # 检查用户名是否已被其他人注册(防止竞态)
existing_user_by_alias = db.query(User).filter(User.alias == alias).first() existing_user_by_alias = db.query(User).filter(User.alias == alias).first()
if existing_user_by_alias: if existing_user_by_alias:
registration_manager.release_alias(alias) registration_manager.release_alias(alias)
logger.error(f"新用户注册失败:用户名 {alias} 已被占用") logger.error(f"新用户注册失败:用户名 {alias} 已被占用")
return { return {"status": "error", "message": "注册失败:用户名已被占用,请更换用户名"}
"status": "error",
"message": "注册失败:用户名已被占用,请更换用户名"
}
# 创建新用户(待审批状态) # 创建新用户(待审批状态)
new_user = User( new_user = User(
@@ -270,6 +247,7 @@ class AuthService:
# 发送邮件通知管理员 # 发送邮件通知管理员
try: try:
from backend.services.email_service import EmailService from backend.services.email_service import EmailService
EmailService.notify_new_user_registration(new_user, db) EmailService.notify_new_user_registration(new_user, db)
except Exception as e: except Exception as e:
logger.error(f"发送注册通知邮件失败: {e}") logger.error(f"发送注册通知邮件失败: {e}")
@@ -286,22 +264,16 @@ class AuthService:
"alias": new_user.alias, "alias": new_user.alias,
"role": new_user.role, "role": new_user.role,
"is_approved": new_user.is_approved, "is_approved": new_user.is_approved,
"jwt_sub": new_user.jwt_sub "jwt_sub": new_user.jwt_sub,
}, },
"is_new_user": True "is_new_user": True,
} }
elif status == "error": elif status == "error":
return { return {"status": "error", "message": session_data.get("message", "未知错误")}
"status": "error",
"message": session_data.get("message", "未知错误")
}
else: else:
return { return {"status": "pending", "message": "正在初始化..."}
"status": "pending",
"message": "正在初始化..."
}
@staticmethod @staticmethod
def verify_token(authorization: str, db: Session) -> Dict[str, Any]: def verify_token(authorization: str, db: Session) -> Dict[str, Any]:
@@ -318,7 +290,11 @@ class AuthService:
from backend.utils.jwt import JWTManager from backend.utils.jwt import JWTManager
# 移除 "Bearer " 前缀 # 移除 "Bearer " 前缀
token = authorization.replace("Bearer ", "") if authorization.startswith("Bearer ") else authorization token = (
authorization.replace("Bearer ", "")
if authorization.startswith("Bearer ")
else authorization
)
try: try:
# 验证 JWT token # 验证 JWT token
@@ -326,19 +302,13 @@ class AuthService:
user_id = payload.get("user_id") user_id = payload.get("user_id")
if not user_id: if not user_id:
return { return {"is_valid": False, "message": "Token 格式错误"}
"is_valid": False,
"message": "Token 格式错误"
}
# 从数据库获取用户 # 从数据库获取用户
user = db.query(User).filter(User.id == user_id).first() user = db.query(User).filter(User.id == user_id).first()
if not user: if not user:
return { return {"is_valid": False, "message": "用户不存在"}
"is_valid": False,
"message": "用户不存在"
}
return { return {
"is_valid": True, "is_valid": True,
@@ -346,25 +316,16 @@ class AuthService:
"user_id": user.id, "user_id": user.id,
"alias": user.alias, "alias": user.alias,
"role": user.role, "role": user.role,
"is_approved": user.is_approved "is_approved": user.is_approved,
} }
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
return { return {"is_valid": False, "message": "JWT Token 已过期"}
"is_valid": False,
"message": "JWT Token 已过期"
}
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
return { return {"is_valid": False, "message": "JWT Token 无效"}
"is_valid": False,
"message": "JWT Token 无效"
}
except Exception as e: except Exception as e:
logger.error(f"验证 JWT Token 失败: {str(e)}") logger.error(f"验证 JWT Token 失败: {str(e)}")
return { return {"is_valid": False, "message": "Token 验证失败"}
"is_valid": False,
"message": "Token 验证失败"
}
@staticmethod @staticmethod
def verify_checkin_authorization(user: User) -> Dict[str, Any]: def verify_checkin_authorization(user: User) -> Dict[str, Any]:
@@ -386,25 +347,17 @@ class AuthService:
is_timestamp_expired, is_timestamp_expired,
days_until_expiry, days_until_expiry,
minutes_until_expiry, minutes_until_expiry,
seconds_until_expiry seconds_until_expiry,
) )
# 检查是否有 authorization token # 检查是否有 authorization token
if not user.authorization or user.authorization == "": if not user.authorization or user.authorization == "":
return { return {"is_valid": False, "message": "未设置打卡凭证", "reason": "no_token"}
"is_valid": False,
"message": "未设置打卡凭证",
"reason": "no_token"
}
# 解析 jwt_exp # 解析 jwt_exp
exp_timestamp = parse_jwt_exp(user.jwt_exp) exp_timestamp = parse_jwt_exp(user.jwt_exp)
if not exp_timestamp: if not exp_timestamp:
return { return {"is_valid": False, "message": "打卡凭证无效", "reason": "invalid_expiry"}
"is_valid": False,
"message": "打卡凭证无效",
"reason": "invalid_expiry"
}
# 检查是否过期 # 检查是否过期
if is_timestamp_expired(exp_timestamp): if is_timestamp_expired(exp_timestamp):
@@ -413,7 +366,7 @@ class AuthService:
"is_valid": False, "is_valid": False,
"message": f"打卡凭证已过期 {days_expired}", "message": f"打卡凭证已过期 {days_expired}",
"reason": "expired", "reason": "expired",
"days_expired": days_expired "days_expired": days_expired,
} }
# Token 有效,计算剩余时间 # Token 有效,计算剩余时间
@@ -430,7 +383,7 @@ class AuthService:
"days_remaining": days_remaining, "days_remaining": days_remaining,
"minutes_remaining": minutes_remaining, "minutes_remaining": minutes_remaining,
"expiring_soon": expiring_soon, "expiring_soon": expiring_soon,
"expires_at": exp_timestamp "expires_at": exp_timestamp,
} }
@staticmethod @staticmethod
@@ -451,10 +404,7 @@ class AuthService:
if not user: if not user:
logger.warning(f"别名登录失败:用户 {alias} 不存在") logger.warning(f"别名登录失败:用户 {alias} 不存在")
return { return {"success": False, "message": "用户名或密码错误"}
"success": False,
"message": "用户名或密码错误"
}
# 检查账户是否被锁定 # 检查账户是否被锁定
if user.locked_until: if user.locked_until:
@@ -462,10 +412,12 @@ class AuthService:
if datetime.now() < user.locked_until: if datetime.now() < user.locked_until:
remaining_seconds = (user.locked_until - datetime.now()).total_seconds() remaining_seconds = (user.locked_until - datetime.now()).total_seconds()
remaining_minutes = int(remaining_seconds / 60) + 1 remaining_minutes = int(remaining_seconds / 60) + 1
logger.warning(f"别名登录失败:用户 {alias} 账户已锁定,剩余 {remaining_minutes} 分钟") logger.warning(
f"别名登录失败:用户 {alias} 账户已锁定,剩余 {remaining_minutes} 分钟"
)
return { return {
"success": False, "success": False,
"message": f"账户已锁定,请 {remaining_minutes} 分钟后再试" "message": f"账户已锁定,请 {remaining_minutes} 分钟后再试",
} }
else: else:
# 锁定时间已过,重置锁定状态 # 锁定时间已过,重置锁定状态
@@ -477,15 +429,12 @@ class AuthService:
# 检查用户是否设置了密码 # 检查用户是否设置了密码
if not user.password_hash: if not user.password_hash:
logger.warning(f"别名登录失败:用户 {alias} 未设置密码") logger.warning(f"别名登录失败:用户 {alias} 未设置密码")
return { return {"success": False, "message": "该用户未设置密码,请使用扫码登录"}
"success": False,
"message": "该用户未设置密码,请使用扫码登录"
}
# 验证密码 # 验证密码
try: try:
password_bytes = password.encode('utf-8') password_bytes = password.encode("utf-8")
hash_bytes = user.password_hash.encode('utf-8') hash_bytes = user.password_hash.encode("utf-8")
if not bcrypt.checkpw(password_bytes, hash_bytes): if not bcrypt.checkpw(password_bytes, hash_bytes):
# 密码错误,增加失败次数 # 密码错误,增加失败次数
@@ -497,24 +446,20 @@ class AuthService:
user.locked_until = datetime.now() + timedelta(minutes=15) user.locked_until = datetime.now() + timedelta(minutes=15)
db.commit() db.commit()
logger.warning(f"别名登录失败:用户 {alias} 密码错误次数过多,账户已锁定15分钟") logger.warning(f"别名登录失败:用户 {alias} 密码错误次数过多,账户已锁定15分钟")
return { return {"success": False, "message": "密码错误次数过多,账户已锁定15分钟"}
"success": False,
"message": "密码错误次数过多,账户已锁定15分钟"
}
db.commit() db.commit()
remaining_attempts = 5 - user.failed_login_attempts remaining_attempts = 5 - user.failed_login_attempts
logger.warning(f"别名登录失败:用户 {alias} 密码错误,剩余尝试次数: {remaining_attempts}") logger.warning(
f"别名登录失败:用户 {alias} 密码错误,剩余尝试次数: {remaining_attempts}"
)
return { return {
"success": False, "success": False,
"message": f"用户名或密码错误,剩余尝试次数: {remaining_attempts}" "message": f"用户名或密码错误,剩余尝试次数: {remaining_attempts}",
} }
except Exception as e: except Exception as e:
logger.error(f"密码验证异常:{e}") logger.error(f"密码验证异常:{e}")
return { return {"success": False, "message": "登录失败,请稍后重试"}
"success": False,
"message": "登录失败,请稍后重试"
}
# 密码正确,重置失败次数 # 密码正确,重置失败次数
user.failed_login_attempts = 0 user.failed_login_attempts = 0
@@ -551,17 +496,21 @@ class AuthService:
"id": user.id, "id": user.id,
"alias": user.alias, "alias": user.alias,
"role": user.role, "role": user.role,
"is_approved": user.is_approved "is_approved": user.is_approved,
} },
} }
# 如果打卡 Token 有问题,添加警告信息(不影响网站使用) # 如果打卡 Token 有问题,添加警告信息(不影响网站使用)
if token_warning: if token_warning:
result["token_warning"] = token_warning result["token_warning"] = token_warning
if token_warning == "token_invalid": if token_warning == "token_invalid":
result["warning_message"] = "登录成功,但检测到打卡凭证无效,无法自动打卡,建议扫码更新" result["warning_message"] = (
"登录成功,但检测到打卡凭证无效,无法自动打卡,建议扫码更新"
)
elif token_warning == "token_expired": elif token_warning == "token_expired":
result["warning_message"] = "登录成功,但检测到打卡凭证已过期,无法自动打卡,建议扫码更新" result["warning_message"] = (
"登录成功,但检测到打卡凭证已过期,无法自动打卡,建议扫码更新"
)
return result return result
@@ -576,10 +525,10 @@ class AuthService:
Returns: Returns:
加密后的密码哈希 加密后的密码哈希
""" """
password_bytes = password.encode('utf-8') password_bytes = password.encode("utf-8")
salt = bcrypt.gensalt() salt = bcrypt.gensalt()
hash_bytes = bcrypt.hashpw(password_bytes, salt) hash_bytes = bcrypt.hashpw(password_bytes, salt)
return hash_bytes.decode('utf-8') return hash_bytes.decode("utf-8")
@staticmethod @staticmethod
def verify_password(password: str, password_hash: str) -> bool: def verify_password(password: str, password_hash: str) -> bool:
@@ -594,8 +543,8 @@ class AuthService:
密码是否正确 密码是否正确
""" """
try: try:
password_bytes = password.encode('utf-8') password_bytes = password.encode("utf-8")
hash_bytes = password_hash.encode('utf-8') hash_bytes = password_hash.encode("utf-8")
return bcrypt.checkpw(password_bytes, hash_bytes) return bcrypt.checkpw(password_bytes, hash_bytes)
except Exception as e: except Exception as e:
logger.error(f"密码验证异常:{e}") logger.error(f"密码验证异常:{e}")
@@ -617,12 +566,6 @@ class AuthService:
success = cancel_session(session_id) success = cancel_session(session_id)
if success: if success:
return { return {"success": True, "message": "会话已取消"}
"success": True,
"message": "会话已取消"
}
else: else:
return { return {"success": False, "message": "取消失败或会话不存在"}
"success": False,
"message": "取消失败或会话不存在"
}
+87 -86
View File
@@ -39,7 +39,9 @@ class CheckInService:
task_info = build_task_info(task) task_info = build_task_info(task)
# 发送打卡失败通知(内容包含 Token 失效说明和刷新指引) # 发送打卡失败通知(内容包含 Token 失效说明和刷新指引)
EmailService.notify_check_in_result(user, task_info, False, "Token 已失效,需要重新授权") EmailService.notify_check_in_result(
user, task_info, False, "Token 已失效,需要重新授权"
)
logger.info(f"已发送 Token 过期邮件到 {user.email}") logger.info(f"已发送 Token 过期邮件到 {user.email}")
# 标记已发送 Token 过期通知 # 标记已发送 Token 过期通知
@@ -63,7 +65,9 @@ class CheckInService:
Returns: Returns:
打卡记录 ID 打卡记录 ID
""" """
logger.info(f"🎯 创建待处理打卡记录 - 任务: {task.name or f'Task-{task.id}'} (ID: {task.id})") logger.info(
f"🎯 创建待处理打卡记录 - 任务: {task.name or f'Task-{task.id}'} (ID: {task.id})"
)
# 创建一个 pending 状态的记录 # 创建一个 pending 状态的记录
record = CheckInRecord( record = CheckInRecord(
@@ -72,7 +76,7 @@ class CheckInService:
response_text="", response_text="",
error_message="", error_message="",
location="{}", location="{}",
trigger_type=trigger_type trigger_type=trigger_type,
) )
db.add(record) db.add(record)
db.commit() db.commit()
@@ -106,10 +110,9 @@ class CheckInService:
# 更新记录状态为失败 # 更新记录状态为失败
record = db.query(CheckInRecord).filter(CheckInRecord.id == record_id).first() record = db.query(CheckInRecord).filter(CheckInRecord.id == record_id).first()
if record: if record:
db.query(CheckInRecord).filter(CheckInRecord.id == record_id).update({ db.query(CheckInRecord).filter(CheckInRecord.id == record_id).update(
"status": "failure", {"status": "failure", "error_message": "任务不存在"}
"error_message": "任务不存在" )
})
db.commit() db.commit()
return return
@@ -121,26 +124,31 @@ class CheckInService:
CheckInService.handle_token_expired(task.user, task, db) CheckInService.handle_token_expired(task.user, task, db)
# 更新记录 # 更新记录
db.query(CheckInRecord).filter(CheckInRecord.id == record_id).update({ db.query(CheckInRecord).filter(CheckInRecord.id == record_id).update(
{
"status": result["status"], "status": result["status"],
"response_text": result["response_text"], "response_text": result["response_text"],
"error_message": result["error_message"] "error_message": result["error_message"],
}) }
)
db.commit() db.commit()
if result["success"]: if result["success"]:
logger.info(f"✅ 后台打卡成功 - Record ID: {record_id}") logger.info(f"✅ 后台打卡成功 - Record ID: {record_id}")
else: else:
logger.error(f"❌ 后台打卡失败 - Record ID: {record_id}, 错误: {result['error_message']}") logger.error(
f"❌ 后台打卡失败 - Record ID: {record_id}, 错误: {result['error_message']}"
)
except Exception as e: except Exception as e:
logger.error(f"💥 后台打卡异常 - Task ID: {task_id}, Record ID: {record_id}, 错误: {str(e)}") logger.error(
f"💥 后台打卡异常 - Task ID: {task_id}, Record ID: {record_id}, 错误: {str(e)}"
)
# 更新记录状态 # 更新记录状态
try: try:
db.query(CheckInRecord).filter(CheckInRecord.id == record_id).update({ db.query(CheckInRecord).filter(CheckInRecord.id == record_id).update(
"status": "failure", {"status": "failure", "error_message": f"后台执行异常: {str(e)}"}
"error_message": f"后台执行异常: {str(e)}" )
})
db.commit() db.commit()
except Exception as inner_e: except Exception as inner_e:
logger.error(f"💥 更新记录失败: {str(inner_e)}") logger.error(f"💥 更新记录失败: {str(inner_e)}")
@@ -175,17 +183,13 @@ class CheckInService:
response_text="", response_text="",
error_message=error_msg, error_message=error_msg,
location="{}", location="{}",
trigger_type=trigger_type trigger_type=trigger_type,
) )
db.add(record) db.add(record)
db.commit() db.commit()
db.refresh(record) db.refresh(record)
return { return {"record_id": record.id, "status": "failure", "message": error_msg}
"record_id": record.id,
"status": "failure",
"message": error_msg
}
# 不再提前验证 Token,交给统一的打卡逻辑处理 # 不再提前验证 Token,交给统一的打卡逻辑处理
# 这样可以确保所有错误(包括 Token 过期)都通过统一的流程处理 # 这样可以确保所有错误(包括 Token 过期)都通过统一的流程处理
@@ -195,10 +199,11 @@ class CheckInService:
# 在后台线程中执行打卡 # 在后台线程中执行打卡
import threading import threading
thread = threading.Thread( thread = threading.Thread(
target=CheckInService.execute_check_in_async, target=CheckInService.execute_check_in_async,
args=(task.id, record_id, user.authorization), args=(task.id, record_id, user.authorization),
daemon=True daemon=True,
) )
thread.start() thread.start()
@@ -207,7 +212,7 @@ class CheckInService:
return { return {
"record_id": record_id, "record_id": record_id,
"status": "pending", "status": "pending",
"message": "打卡任务已启动,正在后台处理" "message": "打卡任务已启动,正在后台处理",
} }
@staticmethod @staticmethod
@@ -223,13 +228,17 @@ class CheckInService:
Returns: Returns:
打卡结果字典 打卡结果字典
""" """
logger.info(f"🎯 开始打卡 - 任务: {task.name or f'Task-{task.id}'} (ID: {task.id}), 触发: {trigger_type}") logger.info(
f"🎯 开始打卡 - 任务: {task.name or f'Task-{task.id}'} (ID: {task.id}), 触发: {trigger_type}"
)
# 获取用户的打卡 Token # 获取用户的打卡 Token
user = task.user user = task.user
if not user or not user.authorization: if not user or not user.authorization:
error_msg = f"用户没有有效的打卡 Token" error_msg = f"用户没有有效的打卡 Token"
logger.error(f"{error_msg} - Task ID: {task.id}, User ID: {user.id if user else 'None'}") logger.error(
f"{error_msg} - Task ID: {task.id}, User ID: {user.id if user else 'None'}"
)
# 记录失败 # 记录失败
record = CheckInRecord( record = CheckInRecord(
@@ -238,20 +247,17 @@ class CheckInService:
response_text="", response_text="",
error_message=error_msg, error_message=error_msg,
location="{}", location="{}",
trigger_type=trigger_type trigger_type=trigger_type,
) )
db.add(record) db.add(record)
db.commit() db.commit()
db.refresh(record) db.refresh(record)
return { return {"success": False, "message": error_msg, "record_id": record.id}
"success": False,
"message": error_msg,
"record_id": record.id
}
# 使用统一的打卡 Token 验证方法 # 使用统一的打卡 Token 验证方法
from backend.services.auth_service import AuthService from backend.services.auth_service import AuthService
token_result = AuthService.verify_checkin_authorization(user) token_result = AuthService.verify_checkin_authorization(user)
if not token_result["is_valid"]: if not token_result["is_valid"]:
@@ -268,7 +274,7 @@ class CheckInService:
response_text="", response_text="",
error_message=error_msg, error_message=error_msg,
location="{}", location="{}",
trigger_type=trigger_type trigger_type=trigger_type,
) )
db.add(record) db.add(record)
db.commit() db.commit()
@@ -277,7 +283,7 @@ class CheckInService:
return { return {
"success": False, "success": False,
"message": f"{error_msg},请重新扫码登录", "message": f"{error_msg},请重新扫码登录",
"record_id": record.id "record_id": record.id,
} }
# 执行打卡(传递 task 对象和用户 token) # 执行打卡(传递 task 对象和用户 token)
@@ -295,7 +301,7 @@ class CheckInService:
response_text=result["response_text"], response_text=result["response_text"],
error_message=result["error_message"], error_message=result["error_message"],
location="{}", location="{}",
trigger_type=trigger_type trigger_type=trigger_type,
) )
db.add(record) db.add(record)
db.commit() db.commit()
@@ -309,7 +315,7 @@ class CheckInService:
return { return {
"success": result["success"], "success": result["success"],
"message": "打卡成功" if result["success"] else f"打卡失败: {result['error_message']}", "message": "打卡成功" if result["success"] else f"打卡失败: {result['error_message']}",
"record_id": record.id "record_id": record.id,
} }
@staticmethod @staticmethod
@@ -326,13 +332,7 @@ class CheckInService:
""" """
logger.info(f"🚀 开始批量打卡,任务数量: {len(task_ids)}") logger.info(f"🚀 开始批量打卡,任务数量: {len(task_ids)}")
results = { results = {"total": len(task_ids), "success": 0, "failure": 0, "skipped": 0, "details": []}
"total": len(task_ids),
"success": 0,
"failure": 0,
"skipped": 0,
"details": []
}
# 优化:一次性查询所有任务,避免 N+1 查询 # 优化:一次性查询所有任务,避免 N+1 查询
tasks = db.query(CheckInTask).filter(CheckInTask.id.in_(task_ids)).all() tasks = db.query(CheckInTask).filter(CheckInTask.id.in_(task_ids)).all()
@@ -344,11 +344,9 @@ class CheckInService:
if not task: if not task:
logger.warning(f"⚠️ 任务 ID {task_id} 不存在,跳过") logger.warning(f"⚠️ 任务 ID {task_id} 不存在,跳过")
results["skipped"] += 1 results["skipped"] += 1
results["details"].append({ results["details"].append(
"task_id": task_id, {"task_id": task_id, "success": False, "message": "任务不存在"}
"success": False, )
"message": "任务不存在"
})
continue continue
# 执行打卡(移除 is_active 检查,允许手动打卡) # 执行打卡(移除 is_active 检查,允许手动打卡)
@@ -361,24 +359,26 @@ class CheckInService:
results["failure"] += 1 results["failure"] += 1
logger.error(f"❌ 任务 {task_id} 批量打卡失败: {result['message']}") logger.error(f"❌ 任务 {task_id} 批量打卡失败: {result['message']}")
results["details"].append({ results["details"].append(
{
"task_id": task_id, "task_id": task_id,
"task_name": task.name or f'Task-{task.id}', "task_name": task.name or f"Task-{task.id}",
"success": result["success"], "success": result["success"],
"message": result["message"], "message": result["message"],
"record_id": result.get("record_id") "record_id": result.get("record_id"),
}) }
)
except Exception as e: except Exception as e:
logger.error(f"💥 任务 {task_id} 处理异常: {str(e)}") logger.error(f"💥 任务 {task_id} 处理异常: {str(e)}")
results["failure"] += 1 results["failure"] += 1
results["details"].append({ results["details"].append(
"task_id": task_id, {"task_id": task_id, "success": False, "message": f"异常: {str(e)}"}
"success": False, )
"message": f"异常: {str(e)}"
})
logger.info(f"📊 批量打卡完成 - 成功: {results['success']}, 失败: {results['failure']}, 跳过: {results['skipped']}") logger.info(
f"📊 批量打卡完成 - 成功: {results['success']}, 失败: {results['failure']}, 跳过: {results['skipped']}"
)
return results return results
@staticmethod @staticmethod
@@ -388,7 +388,7 @@ class CheckInService:
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
status: Optional[str] = None, status: Optional[str] = None,
trigger_type: Optional[str] = None trigger_type: Optional[str] = None,
) -> tuple[List[CheckInRecord], int]: ) -> tuple[List[CheckInRecord], int]:
""" """
获取任务的打卡记录 获取任务的打卡记录
@@ -416,9 +416,7 @@ class CheckInService:
total = query.count() total = query.count()
# 获取分页数据 # 获取分页数据
records = query.order_by( records = query.order_by(CheckInRecord.check_in_time.desc()).offset(skip).limit(limit).all()
CheckInRecord.check_in_time.desc()
).offset(skip).limit(limit).all()
return records, total return records, total
@@ -429,7 +427,7 @@ class CheckInService:
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
status: Optional[str] = None, status: Optional[str] = None,
trigger_type: Optional[str] = None trigger_type: Optional[str] = None,
) -> tuple[List[CheckInRecord], int]: ) -> tuple[List[CheckInRecord], int]:
""" """
获取用户的所有打卡记录 获取用户的所有打卡记录
@@ -462,9 +460,7 @@ class CheckInService:
total = query.count() total = query.count()
# 获取分页数据 # 获取分页数据
records = query.order_by( records = query.order_by(CheckInRecord.check_in_time.desc()).offset(skip).limit(limit).all()
CheckInRecord.check_in_time.desc()
).offset(skip).limit(limit).all()
return records, total return records, total
@@ -474,7 +470,7 @@ class CheckInService:
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
task_id: Optional[int] = None, task_id: Optional[int] = None,
status: Optional[str] = None status: Optional[str] = None,
) -> tuple[List[CheckInRecord], int]: ) -> tuple[List[CheckInRecord], int]:
""" """
获取所有打卡记录(管理员)- 使用联表查询优化性能 获取所有打卡记录(管理员)- 使用联表查询优化性能
@@ -506,9 +502,7 @@ class CheckInService:
total = query.count() total = query.count()
# 获取分页数据 # 获取分页数据
records = query.order_by( records = query.order_by(CheckInRecord.check_in_time.desc()).offset(skip).limit(limit).all()
CheckInRecord.check_in_time.desc()
).offset(skip).limit(limit).all()
return records, total return records, total
@@ -527,8 +521,11 @@ class CheckInService:
包含额外信息的记录字典 包含额外信息的记录字典
""" """
# 尝试使用已加载的关联对象,如果没有则查询 # 尝试使用已加载的关联对象,如果没有则查询
task = record.task if hasattr(record, 'task') and record.task else \ task = (
db.query(CheckInTask).filter(CheckInTask.id == record.task_id).first() record.task
if hasattr(record, "task") and record.task
else db.query(CheckInTask).filter(CheckInTask.id == record.task_id).first()
)
# 获取用户信息 # 获取用户信息
user = None user = None
@@ -537,28 +534,32 @@ class CheckInService:
if task: if task:
# 尝试使用已加载的 user,否则查询 # 尝试使用已加载的 user,否则查询
user = task.user if hasattr(task, 'user') and task.user else \ user = (
db.query(User).filter(User.id == task.user_id).first() task.user
if hasattr(task, "user") and task.user
else db.query(User).filter(User.id == task.user_id).first()
)
task_name = task.name task_name = task.name
# 从 payload_config 提取 ThreadId # 从 payload_config 提取 ThreadId
from backend.utils.json_helpers import extract_thread_id from backend.utils.json_helpers import extract_thread_id
thread_id = extract_thread_id(task.payload_config) # type: ignore thread_id = extract_thread_id(task.payload_config) # type: ignore
# 转换为字典并添加额外字段 # 转换为字典并添加额外字段
record_dict = { record_dict = {
'id': record.id, "id": record.id,
'task_id': record.task_id, "task_id": record.task_id,
'status': record.status, "status": record.status,
'response_text': record.response_text, "response_text": record.response_text,
'error_message': record.error_message, "error_message": record.error_message,
'location': record.location, "location": record.location,
'trigger_type': record.trigger_type, "trigger_type": record.trigger_type,
'check_in_time': record.check_in_time, "check_in_time": record.check_in_time,
'user_id': user.id if user else None, "user_id": user.id if user else None,
'user_email': user.email if user else None, "user_email": user.email if user else None,
'task_name': task_name, "task_name": task_name,
'thread_id': thread_id, "thread_id": thread_id,
} }
return record_dict return record_dict
+27 -13
View File
@@ -69,7 +69,11 @@ class EmailService:
# 安全获取创建时间 # 安全获取创建时间
created_at_value = user.created_at created_at_value = user.created_at
created_time = created_at_value.strftime('%Y-%m-%d %H:%M:%S') if created_at_value is not None else '未知' created_time = (
created_at_value.strftime("%Y-%m-%d %H:%M:%S")
if created_at_value is not None
else "未知"
)
body_html = f""" body_html = f"""
<!DOCTYPE html> <!DOCTYPE html>
@@ -191,7 +195,9 @@ class EmailService:
# 安全获取创建时间 # 安全获取创建时间
user_created_at = user.created_at user_created_at = user.created_at
created_time = user_created_at.strftime('%Y-%m-%d %H:%M:%S') if user_created_at is not None else '未知' created_time = (
user_created_at.strftime("%Y-%m-%d %H:%M:%S") if user_created_at is not None else "未知"
)
body_html = f""" body_html = f"""
<!DOCTYPE html> <!DOCTYPE html>
@@ -270,7 +276,7 @@ class EmailService:
<div class="success-box"> <div class="success-box">
<strong>✅ 审批结果:</strong> 已通过 <strong>✅ 审批结果:</strong> 已通过
<br> <br>
<strong>审批时间:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} <strong>审批时间:</strong> {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
</div> </div>
<table class="info-table"> <table class="info-table">
@@ -391,7 +397,7 @@ class EmailService:
<div class="error-box"> <div class="error-box">
<strong>❌ 审批结果:</strong> 未通过 <strong>❌ 审批结果:</strong> 未通过
<br> <br>
<strong>处理时间:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} <strong>处理时间:</strong> {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
</div> </div>
{reason_html} {reason_html}
@@ -409,7 +415,6 @@ class EmailService:
return EmailService.send_email([str(user_email)], subject, body_html) return EmailService.send_email([str(user_email)], subject, body_html)
@staticmethod @staticmethod
def notify_token_expiring(user: User, jwt_exp: str) -> bool: def notify_token_expiring(user: User, jwt_exp: str) -> bool:
""" """
@@ -640,7 +645,9 @@ class EmailService:
return EmailService.send_email([str(user_email)], subject, body_html) return EmailService.send_email([str(user_email)], subject, body_html)
@staticmethod @staticmethod
def notify_check_in_result(user: User, task_info: dict, success: bool, message: str = "") -> bool: def notify_check_in_result(
user: User, task_info: dict, success: bool, message: str = ""
) -> bool:
""" """
通知用户打卡结果 通知用户打卡结果
@@ -665,9 +672,16 @@ class EmailService:
subject = f"【接龙自动打卡】打卡{status_text} - {user.alias}" subject = f"【接龙自动打卡】打卡{status_text} - {user.alias}"
# 判断是否是 Token 失效导致的失败 # 判断是否是 Token 失效导致的失败
is_token_error = not success and message and ( is_token_error = (
"Token" in message or "token" in message or not success
"失效" in message or "授权" in message or "登录" in message and message
and (
"Token" in message
or "token" in message
or "失效" in message
or "授权" in message
or "登录" in message
)
) )
# Token 失效时的额外提示内容 # Token 失效时的额外提示内容
@@ -768,20 +782,20 @@ class EmailService:
<table class="info-table"> <table class="info-table">
<tr> <tr>
<td>执行时间</td> <td>执行时间</td>
<td>{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</td> <td>{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}</td>
</tr> </tr>
<tr> <tr>
<td>任务 ID</td> <td>任务 ID</td>
<td>{task_info.get('thread_id', '未知')}</td> <td>{task_info.get("thread_id", "未知")}</td>
</tr> </tr>
<tr> <tr>
<td>打卡状态</td> <td>打卡状态</td>
<td><strong style="color: {status_color};">{status_text}</strong></td> <td><strong style="color: {status_color};">{status_text}</strong></td>
</tr> </tr>
{f'<tr><td>失败原因</td><td>{message}</td></tr>' if message else ''} {f"<tr><td>失败原因</td><td>{message}</td></tr>" if message else ""}
</table> </table>
{token_error_section if is_token_error else '<p>如有问题,请及时检查您的打卡配置。</p>'} {token_error_section if is_token_error else "<p>如有问题,请及时检查您的打卡配置。</p>"}
</div> </div>
<div class="footer"> <div class="footer">
<p>此邮件由系统自动发送,请勿直接回复。</p> <p>此邮件由系统自动发送,请勿直接回复。</p>
+25 -18
View File
@@ -1,6 +1,7 @@
""" """
用户名预占和注册限流管理器 用户名预占和注册限流管理器
""" """
import time import time
import threading import threading
import logging import logging
@@ -47,23 +48,22 @@ class RegistrationManager:
reservation = self._reserved_aliases[alias] reservation = self._reserved_aliases[alias]
# 检查是否过期 # 检查是否过期
if reservation['expire_time'] > current_time: if reservation["expire_time"] > current_time:
# 未过期,检查是否是同一个 session # 未过期,检查是否是同一个 session
if reservation['session_id'] == session_id: if reservation["session_id"] == session_id:
# 同一个 session,更新过期时间 # 同一个 session,更新过期时间
reservation['expire_time'] = expire_time reservation["expire_time"] = expire_time
logger.info(f"用户名 {alias} 预占时间已更新(session: {session_id}") logger.info(f"用户名 {alias} 预占时间已更新(session: {session_id}")
return True return True
else: else:
# 不同 session,预占失败 # 不同 session,预占失败
logger.warning(f"用户名 {alias} 已被占用(session: {reservation['session_id']}") logger.warning(
f"用户名 {alias} 已被占用(session: {reservation['session_id']}"
)
return False return False
# 预占用户名 # 预占用户名
self._reserved_aliases[alias] = { self._reserved_aliases[alias] = {"session_id": session_id, "expire_time": expire_time}
'session_id': session_id,
'expire_time': expire_time
}
logger.info(f"用户名 {alias} 已预占(session: {session_id}, 超时: {timeout_seconds}s") logger.info(f"用户名 {alias} 已预占(session: {session_id}, 超时: {timeout_seconds}s")
return True return True
@@ -85,7 +85,7 @@ class RegistrationManager:
reservation = self._reserved_aliases[alias] reservation = self._reserved_aliases[alias]
# 如果指定了 session_id,则只释放匹配的 # 如果指定了 session_id,则只释放匹配的
if session_id and reservation['session_id'] != session_id: if session_id and reservation["session_id"] != session_id:
logger.warning(f"尝试释放用户名 {alias},但 session 不匹配") logger.warning(f"尝试释放用户名 {alias},但 session 不匹配")
return False return False
@@ -111,7 +111,7 @@ class RegistrationManager:
current_time = time.time() current_time = time.time()
# 检查是否过期 # 检查是否过期
if reservation['expire_time'] <= current_time: if reservation["expire_time"] <= current_time:
# 已过期,自动释放 # 已过期,自动释放
del self._reserved_aliases[alias] del self._reserved_aliases[alias]
return False return False
@@ -138,7 +138,9 @@ class RegistrationManager:
# 检查是否过期 # 检查是否过期
if expire_time > current_time: if expire_time > current_time:
remaining = int(expire_time - current_time) remaining = int(expire_time - current_time)
logger.warning(f"Cookie {cookie_value[:8]}... 在限流期内(剩余 {remaining} 秒)") logger.warning(
f"Cookie {cookie_value[:8]}... 在限流期内(剩余 {remaining} 秒)"
)
return False return False
else: else:
# 已过期,移除记录 # 已过期,移除记录
@@ -168,8 +170,9 @@ class RegistrationManager:
# 清理过期的用户名预占 # 清理过期的用户名预占
expired_aliases = [ expired_aliases = [
alias for alias, reservation in self._reserved_aliases.items() alias
if reservation['expire_time'] <= current_time for alias, reservation in self._reserved_aliases.items()
if reservation["expire_time"] <= current_time
] ]
for alias in expired_aliases: for alias in expired_aliases:
@@ -178,7 +181,8 @@ class RegistrationManager:
# 清理过期的注册限流记录 # 清理过期的注册限流记录
expired_cookies = [ expired_cookies = [
cookie for cookie, expire_time in self._registration_cookies.items() cookie
for cookie, expire_time in self._registration_cookies.items()
if expire_time <= current_time if expire_time <= current_time
] ]
@@ -187,10 +191,13 @@ class RegistrationManager:
logger.debug(f"Cookie {cookie[:8]}... 限流记录已过期,自动清理") logger.debug(f"Cookie {cookie[:8]}... 限流记录已过期,自动清理")
if expired_aliases or expired_cookies: if expired_aliases or expired_cookies:
logger.info(f"清理完成:{len(expired_aliases)} 个用户名,{len(expired_cookies)} 个 Cookie") logger.info(
f"清理完成:{len(expired_aliases)} 个用户名,{len(expired_cookies)} 个 Cookie"
)
def _start_cleanup_thread(self) -> None: def _start_cleanup_thread(self) -> None:
"""启动定期清理线程""" """启动定期清理线程"""
def cleanup_loop(): def cleanup_loop():
while True: while True:
try: try:
@@ -207,9 +214,9 @@ class RegistrationManager:
"""获取当前状态统计""" """获取当前状态统计"""
with self._lock: with self._lock:
return { return {
'reserved_aliases_count': len(self._reserved_aliases), "reserved_aliases_count": len(self._reserved_aliases),
'rate_limited_cookies_count': len(self._registration_cookies), "rate_limited_cookies_count": len(self._registration_cookies),
'reserved_aliases': list(self._reserved_aliases.keys()), "reserved_aliases": list(self._reserved_aliases.keys()),
} }
+25 -20
View File
@@ -39,14 +39,15 @@ def load_scheduled_tasks(db: Session, scheduler_instance):
# 移除所有现有的动态任务(保留系统任务) # 移除所有现有的动态任务(保留系统任务)
for job in scheduler_instance.get_jobs(): for job in scheduler_instance.get_jobs():
if job.id.startswith('task_'): if job.id.startswith("task_"):
scheduler_instance.remove_job(job.id) scheduler_instance.remove_job(job.id)
# 查询所有启用且有 cron 表达式的任务 # 查询所有启用且有 cron 表达式的任务
tasks = db.query(CheckInTask).filter( tasks = (
CheckInTask.is_active == True, db.query(CheckInTask)
CheckInTask.cron_expression.isnot(None) .filter(CheckInTask.is_active == True, CheckInTask.cron_expression.isnot(None))
).all() .all()
)
loaded_count = 0 loaded_count = 0
skipped_count = 0 skipped_count = 0
@@ -76,7 +77,7 @@ def load_scheduled_tasks(db: Session, scheduler_instance):
id=job_id, id=job_id,
name=f"CheckIn-Task-{task.id}", name=f"CheckIn-Task-{task.id}",
args=[task.id], args=[task.id],
replace_existing=True replace_existing=True,
) )
logger.info(f"✅ 加载任务 {task.id}: {task.name} (Cron: {task.cron_expression})") logger.info(f"✅ 加载任务 {task.id}: {task.name} (Cron: {task.cron_expression})")
@@ -90,7 +91,7 @@ def load_scheduled_tasks(db: Session, scheduler_instance):
"loaded": loaded_count, "loaded": loaded_count,
"skipped": skipped_count, "skipped": skipped_count,
"errors": error_count, "errors": error_count,
"total": len(tasks) "total": len(tasks),
} }
logger.info(f"任务加载完成: {result}") logger.info(f"任务加载完成: {result}")
@@ -114,7 +115,9 @@ def scheduled_check_in_task(task_id: int):
return return
if not task.is_scheduled_enabled: if not task.is_scheduled_enabled:
logger.info(f"任务 {task_id} 未启用定时打卡 (is_active={task.is_active}, cron={task.cron_expression})") logger.info(
f"任务 {task_id} 未启用定时打卡 (is_active={task.is_active}, cron={task.cron_expression})"
)
return return
logger.info(f"🤖 执行定时打卡任务 {task_id}") logger.info(f"🤖 执行定时打卡任务 {task_id}")
@@ -184,14 +187,18 @@ def check_token_expiration():
# 计算剩余时间 # 计算剩余时间
time_until_expiry = seconds_until_expiry(exp_timestamp) time_until_expiry = seconds_until_expiry(exp_timestamp)
logger.debug(f"用户 {user.alias}: 剩余 {time_until_expiry} 秒 (即将过期标志={user.token_expiring_notified}, 已过期标志={user.token_expired_notified})") logger.debug(
f"用户 {user.alias}: 剩余 {time_until_expiry} 秒 (即将过期标志={user.token_expiring_notified}, 已过期标志={user.token_expired_notified})"
)
# 情况1:Token 即将过期(过期前 30 分钟内,且还未过期) # 情况1:Token 即将过期(过期前 30 分钟内,且还未过期)
if 0 < time_until_expiry < 1800: # 30分钟 = 1800秒 if 0 < time_until_expiry < 1800: # 30分钟 = 1800秒
# 检查是否已发送过提醒 # 检查是否已发送过提醒
expiring_notified = bool(user.token_expiring_notified) expiring_notified = bool(user.token_expiring_notified)
if not expiring_notified: if not expiring_notified:
logger.info(f"用户 {user.alias} 的打卡 Token 即将过期,发送邮件提醒到 {user_email}...") logger.info(
f"用户 {user.alias} 的打卡 Token 即将过期,发送邮件提醒到 {user_email}..."
)
from backend.services.email_service import EmailService from backend.services.email_service import EmailService
# 发送"即将过期"邮件 # 发送"即将过期"邮件
@@ -212,7 +219,9 @@ def check_token_expiration():
# 检查是否已发送过提醒 # 检查是否已发送过提醒
expired_notified = bool(user.token_expired_notified) expired_notified = bool(user.token_expired_notified)
if not expired_notified: if not expired_notified:
logger.info(f"用户 {user.alias} 的打卡 Token 已过期,发送邮件提醒到 {user_email}...") logger.info(
f"用户 {user.alias} 的打卡 Token 已过期,发送邮件提醒到 {user_email}..."
)
from backend.services.email_service import EmailService from backend.services.email_service import EmailService
# 发送"已过期"邮件 # 发送"已过期"邮件
@@ -320,11 +329,9 @@ def start_scheduler():
minutes=settings.TOKEN_CHECK_INTERVAL_MINUTES, minutes=settings.TOKEN_CHECK_INTERVAL_MINUTES,
id="check_token_expiration", id="check_token_expiration",
name="Token 过期检查任务", name="Token 过期检查任务",
replace_existing=True replace_existing=True,
)
logger.info(
f"已添加 Token 过期检查任务: 每 {settings.TOKEN_CHECK_INTERVAL_MINUTES} 分钟"
) )
logger.info(f"已添加 Token 过期检查任务: 每 {settings.TOKEN_CHECK_INTERVAL_MINUTES} 分钟")
# 添加会话文件清理任务(每隔指定小时) # 添加会话文件清理任务(每隔指定小时)
scheduler.add_job( scheduler.add_job(
@@ -333,11 +340,9 @@ def start_scheduler():
hours=settings.SESSION_CLEANUP_INTERVAL_HOURS, hours=settings.SESSION_CLEANUP_INTERVAL_HOURS,
id="cleanup_old_sessions", id="cleanup_old_sessions",
name="清理旧会话文件任务", name="清理旧会话文件任务",
replace_existing=True replace_existing=True,
)
logger.info(
f"已添加会话清理任务: 每 {settings.SESSION_CLEANUP_INTERVAL_HOURS} 小时"
) )
logger.info(f"已添加会话清理任务: 每 {settings.SESSION_CLEANUP_INTERVAL_HOURS} 小时")
# 添加清理过期未审批用户任务(每小时执行一次) # 添加清理过期未审批用户任务(每小时执行一次)
scheduler.add_job( scheduler.add_job(
@@ -346,7 +351,7 @@ def start_scheduler():
hours=1, hours=1,
id="cleanup_expired_pending_users", id="cleanup_expired_pending_users",
name="清理过期未审批用户任务", name="清理过期未审批用户任务",
replace_existing=True replace_existing=True,
) )
logger.info("已添加清理过期未审批用户任务: 每 1 小时") logger.info("已添加清理过期未审批用户任务: 每 1 小时")
+40 -32
View File
@@ -36,22 +36,22 @@ class TaskService:
from backend.utils.json_helpers import safe_parse_payload, extract_thread_id from backend.utils.json_helpers import safe_parse_payload, extract_thread_id
payload = safe_parse_payload(task_data.payload_config) payload = safe_parse_payload(task_data.payload_config)
thread_id = payload.get('ThreadId') thread_id = payload.get("ThreadId")
if not thread_id: if not thread_id:
raise ValueError("payload_config 中缺少 ThreadId") raise ValueError("payload_config 中缺少 ThreadId")
# 3. 验证唯一性:同一用户在同一个接龙中不能有重复的任务 # 3. 验证唯一性:同一用户在同一个接龙中不能有重复的任务
existing_tasks = db.query( existing_tasks = (
CheckInTask.payload_config db.query(CheckInTask.payload_config).filter(CheckInTask.user_id == user_id).all()
).filter( )
CheckInTask.user_id == user_id
).all()
for (payload_config,) in existing_tasks: for (payload_config,) in existing_tasks:
existing_thread_id = extract_thread_id(payload_config) existing_thread_id = extract_thread_id(payload_config)
# extract_thread_id 已处理异常,失败时返回 None # extract_thread_id 已处理异常,失败时返回 None
if existing_thread_id and existing_thread_id == thread_id: if existing_thread_id and existing_thread_id == thread_id:
logger.warning(f"⚠️ 任务创建冲突 - User: {user.alias}({user_id}), ThreadId: {thread_id}") logger.warning(
f"⚠️ 任务创建冲突 - User: {user.alias}({user_id}), ThreadId: {thread_id}"
)
raise ValueError(f"该接龙中已存在任务。ThreadId: {thread_id}") raise ValueError(f"该接龙中已存在任务。ThreadId: {thread_id}")
# 4. 记录日志 # 4. 记录日志
@@ -63,14 +63,16 @@ class TaskService:
user_id=user_id, user_id=user_id,
payload_config=task_data.payload_config, payload_config=task_data.payload_config,
name=task_data.name or task_name, name=task_data.name or task_name,
is_active=task_data.is_active if task_data.is_active is not None else True is_active=task_data.is_active if task_data.is_active is not None else True,
) )
try: try:
db.add(task) db.add(task)
db.commit() db.commit()
db.refresh(task) db.refresh(task)
logger.info(f"✅ 任务创建成功 - ID: {task.id}, Name: {task.name}, ThreadId: {thread_id}") logger.info(
f"✅ 任务创建成功 - ID: {task.id}, Name: {task.name}, ThreadId: {thread_id}"
)
# 如果任务启用且包含 cron_expression,立即添加到调度器 # 如果任务启用且包含 cron_expression,立即添加到调度器
if task.is_scheduled_enabled: if task.is_scheduled_enabled:
@@ -111,33 +113,38 @@ class TaskService:
from backend.utils.json_helpers import extract_thread_id from backend.utils.json_helpers import extract_thread_id
# 获取最后一次打卡记录 # 获取最后一次打卡记录
last_record = db.query(CheckInRecord).filter( last_record = (
CheckInRecord.task_id == task.id db.query(CheckInRecord)
).order_by(desc(CheckInRecord.check_in_time)).first() .filter(CheckInRecord.task_id == task.id)
.order_by(desc(CheckInRecord.check_in_time))
.first()
)
# 从 payload_config 提取 ThreadId # 从 payload_config 提取 ThreadId
thread_id = extract_thread_id(task.payload_config) # type: ignore thread_id = extract_thread_id(task.payload_config) # type: ignore
# 转换为字典并添加额外字段 # 转换为字典并添加额外字段
task_dict = { task_dict = {
'id': task.id, "id": task.id,
'user_id': task.user_id, "user_id": task.user_id,
'payload_config': task.payload_config, "payload_config": task.payload_config,
'name': task.name, "name": task.name,
'is_active': task.is_active, "is_active": task.is_active,
'cron_expression': task.cron_expression, "cron_expression": task.cron_expression,
'is_scheduled_enabled': task.is_scheduled_enabled, "is_scheduled_enabled": task.is_scheduled_enabled,
'created_at': task.created_at, "created_at": task.created_at,
'updated_at': task.updated_at, "updated_at": task.updated_at,
'thread_id': thread_id, "thread_id": thread_id,
'last_check_in_time': last_record.check_in_time if last_record else None, "last_check_in_time": last_record.check_in_time if last_record else None,
'last_check_in_status': last_record.status if last_record else None, "last_check_in_status": last_record.status if last_record else None,
} }
return task_dict return task_dict
@staticmethod @staticmethod
def get_user_tasks(user_id: int, db: Session, include_inactive: bool = True) -> List[CheckInTask]: def get_user_tasks(
user_id: int, db: Session, include_inactive: bool = True
) -> List[CheckInTask]:
""" """
获取用户的所有任务 获取用户的所有任务
@@ -191,8 +198,8 @@ class TaskService:
update_data = task_data.model_dump(exclude_unset=True) update_data = task_data.model_dump(exclude_unset=True)
# 检查是否更新了 cron_expression 或 is_active # 检查是否更新了 cron_expression 或 is_active
cron_changed = 'cron_expression' in update_data cron_changed = "cron_expression" in update_data
active_changed = 'is_active' in update_data active_changed = "is_active" in update_data
for field, value in update_data.items(): for field, value in update_data.items():
setattr(task, field, value) setattr(task, field, value)
@@ -297,10 +304,11 @@ class TaskService:
Returns: Returns:
是否属于该用户 是否属于该用户
""" """
task = db.query(CheckInTask).filter( task = (
CheckInTask.id == task_id, db.query(CheckInTask)
CheckInTask.user_id == user_id .filter(CheckInTask.id == task_id, CheckInTask.user_id == user_id)
).first() .first()
)
return task is not None return task is not None
@@ -342,7 +350,7 @@ class TaskService:
id=job_id, id=job_id,
name=f"CheckIn-Task-{task.id}", name=f"CheckIn-Task-{task.id}",
args=[task.id], args=[task.id],
replace_existing=True replace_existing=True,
) )
logger.info(f"✅ 任务 {task.id} 已重新加载到调度器: {cron_str}") logger.info(f"✅ 任务 {task.id} 已重新加载到调度器: {cron_str}")
else: else:
+52 -78
View File
@@ -132,15 +132,13 @@ class TemplateService:
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"模板字段配置 JSON 格式错误: {str(e)}") logger.error(f"模板字段配置 JSON 格式错误: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
detail=f"字段配置 JSON 格式错误: {str(e)}"
) )
except Exception as e: except Exception as e:
logger.error(f"创建模板失败: {str(e)}") logger.error(f"创建模板失败: {str(e)}")
db.rollback() db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建模板失败: {str(e)}"
detail=f"创建模板失败: {str(e)}"
) )
@staticmethod @staticmethod
@@ -159,10 +157,7 @@ class TemplateService:
@staticmethod @staticmethod
def get_all_templates( def get_all_templates(
db: Session, db: Session, skip: int = 0, limit: int = 100, is_active: Optional[bool] = None
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None
) -> List[TaskTemplate]: ) -> List[TaskTemplate]:
""" """
获取所有模板列表 获取所有模板列表
@@ -185,9 +180,7 @@ class TemplateService:
@staticmethod @staticmethod
def update_template( def update_template(
template_id: int, template_id: int, template_data: TemplateUpdate, db: Session
template_data: TemplateUpdate,
db: Session
) -> TaskTemplate: ) -> TaskTemplate:
""" """
更新模板 更新模板
@@ -202,18 +195,15 @@ class TemplateService:
""" """
template = TemplateService.get_template(template_id, db) template = TemplateService.get_template(template_id, db)
if not template: if not template:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模板不存在")
status_code=status.HTTP_404_NOT_FOUND,
detail="模板不存在"
)
try: try:
# 更新字段 # 更新字段
update_data = template_data.model_dump(exclude_unset=True) update_data = template_data.model_dump(exclude_unset=True)
# 验证 field_config 如果有更新 # 验证 field_config 如果有更新
if 'field_config' in update_data and update_data['field_config']: if "field_config" in update_data and update_data["field_config"]:
json.loads(update_data['field_config']) json.loads(update_data["field_config"])
for field, value in update_data.items(): for field, value in update_data.items():
setattr(template, field, value) setattr(template, field, value)
@@ -227,15 +217,13 @@ class TemplateService:
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"模板字段配置 JSON 格式错误: {str(e)}") logger.error(f"模板字段配置 JSON 格式错误: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
detail=f"字段配置 JSON 格式错误: {str(e)}"
) )
except Exception as e: except Exception as e:
logger.error(f"更新模板失败: {str(e)}") logger.error(f"更新模板失败: {str(e)}")
db.rollback() db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新模板失败: {str(e)}"
detail=f"更新模板失败: {str(e)}"
) )
@staticmethod @staticmethod
@@ -252,10 +240,7 @@ class TemplateService:
""" """
template = TemplateService.get_template(template_id, db) template = TemplateService.get_template(template_id, db)
if not template: if not template:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模板不存在")
status_code=status.HTTP_404_NOT_FOUND,
detail="模板不存在"
)
try: try:
db.delete(template) db.delete(template)
@@ -266,28 +251,26 @@ class TemplateService:
logger.error(f"删除模板失败: {str(e)}") logger.error(f"删除模板失败: {str(e)}")
db.rollback() db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"删除模板失败: {str(e)}"
detail=f"删除模板失败: {str(e)}"
) )
@staticmethod @staticmethod
def _is_field_config(obj: Any) -> bool: def _is_field_config(obj: Any) -> bool:
"""判断是否为字段配置对象""" """判断是否为字段配置对象"""
return isinstance(obj, dict) and 'display_name' in obj return isinstance(obj, dict) and "display_name" in obj
@staticmethod @staticmethod
def _is_object_field(obj: Any) -> bool: def _is_object_field(obj: Any) -> bool:
"""判断是否为对象字段(包含多个子字段配置)""" """判断是否为对象字段(包含多个子字段配置)"""
if not isinstance(obj, dict): if not isinstance(obj, dict):
return False return False
if 'display_name' in obj: if "display_name" in obj:
return False return False
# 检查所有值是否都是字段配置对象 # 检查所有值是否都是字段配置对象
return all( return (
TemplateService._is_field_config(v) all(TemplateService._is_field_config(v) for v in obj.values() if isinstance(v, dict))
for v in obj.values() and len(obj) > 0
if isinstance(v, dict) )
) and len(obj) > 0
@staticmethod @staticmethod
def _process_field_value(key: str, config: Any, field_values: Dict[str, Any]) -> Any: def _process_field_value(key: str, config: Any, field_values: Dict[str, Any]) -> Any:
@@ -304,12 +287,12 @@ class TemplateService:
""" """
# 1. 普通字段配置 # 1. 普通字段配置
if TemplateService._is_field_config(config): if TemplateService._is_field_config(config):
if config.get('hidden', False): if config.get("hidden", False):
value = config.get('default_value', '') value = config.get("default_value", "")
else: else:
value = field_values.get(key, config.get('default_value', '')) value = field_values.get(key, config.get("default_value", ""))
value_type = config.get('value_type', 'string') value_type = config.get("value_type", "string")
return TemplateService._validate_and_convert_value(value, value_type, key) return TemplateService._validate_and_convert_value(value, value_type, key)
# 2. 数组字段 # 2. 数组字段
@@ -319,10 +302,10 @@ class TemplateService:
# 检查数组元素是否是字段配置对象 # 检查数组元素是否是字段配置对象
if TemplateService._is_field_config(item_config): if TemplateService._is_field_config(item_config):
# 数组元素是字段配置对象,需要序列化为 JSON 字符串 # 数组元素是字段配置对象,需要序列化为 JSON 字符串
value = item_config.get('default_value', '') value = item_config.get("default_value", "")
value_type = item_config.get('value_type', 'string') value_type = item_config.get("value_type", "string")
# 将对象序列化为 JSON 字符串 # 将对象序列化为 JSON 字符串
if value_type == 'json': if value_type == "json":
if isinstance(value, str): if isinstance(value, str):
# 如果是字符串,验证 JSON 格式 # 如果是字符串,验证 JSON 格式
try: try:
@@ -333,15 +316,16 @@ class TemplateService:
error_detail += f"JSON 解析错误: {str(e)}\n" error_detail += f"JSON 解析错误: {str(e)}\n"
error_detail += "常见问题: 数字不能有前导零(如 00.00 应改为 0.0)" error_detail += "常见问题: 数字不能有前导零(如 00.00 应改为 0.0)"
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST, detail=error_detail
detail=error_detail
) )
result.append(value) result.append(value)
else: else:
# 如果是对象,序列化为 JSON 字符串 # 如果是对象,序列化为 JSON 字符串
result.append(json.dumps(value, ensure_ascii=False)) result.append(json.dumps(value, ensure_ascii=False))
else: else:
result.append(TemplateService._validate_and_convert_value(value, value_type, key)) result.append(
TemplateService._validate_and_convert_value(value, value_type, key)
)
elif isinstance(item_config, dict): elif isinstance(item_config, dict):
# 数组元素是普通对象,递归处理 # 数组元素是普通对象,递归处理
item = {} item = {}
@@ -388,9 +372,7 @@ class TemplateService:
field_config = TemplateService.merge_parent_config(template, db) field_config = TemplateService.merge_parent_config(template, db)
# 初始化 payload,只包含 ThreadId(唯一必需,不在模板中配置) # 初始化 payload,只包含 ThreadId(唯一必需,不在模板中配置)
payload = { payload = {"ThreadId": "<接龙项目ID>"}
"ThreadId": "<接龙项目ID>"
}
# 递归处理所有字段,保持键名原样 # 递归处理所有字段,保持键名原样
for key, config in field_config.items(): for key, config in field_config.items():
@@ -402,15 +384,12 @@ class TemplateService:
logger.error(f"解析模板配置失败: {str(e)}") logger.error(f"解析模板配置失败: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"解析模板配置失败: {str(e)}" detail=f"解析模板配置失败: {str(e)}",
) )
@staticmethod @staticmethod
def assemble_payload_from_template( def assemble_payload_from_template(
template: TaskTemplate, template: TaskTemplate, thread_id: str, field_values: Dict[str, Any], db: Session
thread_id: str,
field_values: Dict[str, Any],
db: Session
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
根据模板和用户输入组装完整的 payload 根据模板和用户输入组装完整的 payload
@@ -432,9 +411,7 @@ class TemplateService:
field_config = TemplateService.merge_parent_config(template, db) field_config = TemplateService.merge_parent_config(template, db)
# 初始化 payload,只包含 ThreadId(唯一必需) # 初始化 payload,只包含 ThreadId(唯一必需)
payload = { payload = {"ThreadId": thread_id}
"ThreadId": thread_id
}
# 递归处理所有字段,保持键名原样 # 递归处理所有字段,保持键名原样
for key, config in field_config.items(): for key, config in field_config.items():
@@ -445,14 +422,13 @@ class TemplateService:
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.error(f"解析模板配置失败: {str(e)}") logger.error(f"解析模板配置失败: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"解析模板配置失败"
detail=f"解析模板配置失败"
) )
except Exception as e: except Exception as e:
logger.error(f"组装 payload 失败: {str(e)}") logger.error(f"组装 payload 失败: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"组装 payload 失败: {str(e)}" detail=f"组装 payload 失败: {str(e)}",
) )
@staticmethod @staticmethod
@@ -469,17 +445,17 @@ class TemplateService:
转换后的值 转换后的值
""" """
try: try:
if value_type == 'int': if value_type == "int":
return int(value) if value != '' else 0 return int(value) if value != "" else 0
elif value_type == 'double': elif value_type == "double":
return float(value) if value != '' else 0.0 return float(value) if value != "" else 0.0
elif value_type == 'bool': elif value_type == "bool":
if isinstance(value, bool): if isinstance(value, bool):
return value return value
if isinstance(value, str): if isinstance(value, str):
return value.lower() in ('true', '1', 'yes') return value.lower() in ("true", "1", "yes")
return bool(value) return bool(value)
elif value_type == 'json': elif value_type == "json":
# JSON 类型:如果是字符串,尝试解析后再序列化;如果是对象,直接序列化 # JSON 类型:如果是字符串,尝试解析后再序列化;如果是对象,直接序列化
if isinstance(value, str): if isinstance(value, str):
# 验证是否为有效 JSON # 验证是否为有效 JSON
@@ -493,7 +469,7 @@ class TemplateService:
except (ValueError, TypeError, json.JSONDecodeError) as e: except (ValueError, TypeError, json.JSONDecodeError) as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=f"字段 '{field_name}' 类型错误:期望 {value_type},实际值为 '{value}',错误: {str(e)}" detail=f"字段 '{field_name}' 类型错误:期望 {value_type},实际值为 '{value}',错误: {str(e)}",
) )
@staticmethod @staticmethod
@@ -504,7 +480,7 @@ class TemplateService:
user_id: int, user_id: int,
task_name: Optional[str], task_name: Optional[str],
db: Session, db: Session,
cron_expression: Optional[str] = "0 20 * * *" cron_expression: Optional[str] = "0 20 * * *",
) -> CheckInTask: ) -> CheckInTask:
""" """
从模板创建打卡任务 从模板创建打卡任务
@@ -524,16 +500,12 @@ class TemplateService:
# 获取模板 # 获取模板
template = TemplateService.get_template(template_id, db) template = TemplateService.get_template(template_id, db)
if not template: if not template:
raise HTTPException( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模板不存在")
status_code=status.HTTP_404_NOT_FOUND,
detail="模板不存在"
)
# 检查模板是否启用 # 检查模板是否启用
if template.is_active is not True: if template.is_active is not True:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN, detail="该模板未启用,无法创建任务"
detail="该模板未启用,无法创建任务"
) )
# 组装 payload # 组装 payload
@@ -543,7 +515,7 @@ class TemplateService:
# 生成任务名称 # 生成任务名称
if not task_name: if not task_name:
signature = payload.get('Signature', 'Unknown') signature = payload.get("Signature", "Unknown")
task_name = f"{template.name} - {signature}" task_name = f"{template.name} - {signature}"
# 创建任务(包含 cron_expression # 创建任务(包含 cron_expression
@@ -553,17 +525,20 @@ class TemplateService:
payload_config=json.dumps(payload, ensure_ascii=False), payload_config=json.dumps(payload, ensure_ascii=False),
name=task_name, name=task_name,
is_active=True, is_active=True,
cron_expression=cron_expression or "0 20 * * *" cron_expression=cron_expression or "0 20 * * *",
) )
db.add(task) db.add(task)
db.commit() db.commit()
db.refresh(task) db.refresh(task)
logger.info(f"从模板创建任务成功: {task.name} (ID: {task.id}, 模板: {template.name}, ThreadId: {thread_id})") logger.info(
f"从模板创建任务成功: {task.name} (ID: {task.id}, 模板: {template.name}, ThreadId: {thread_id})"
)
# 如果任务启用且包含 cron_expression,立即添加到调度器 # 如果任务启用且包含 cron_expression,立即添加到调度器
if task.is_scheduled_enabled: if task.is_scheduled_enabled:
from backend.services.task_service import TaskService from backend.services.task_service import TaskService
TaskService._reload_scheduler_for_task(task, db) TaskService._reload_scheduler_for_task(task, db)
return task return task
@@ -572,6 +547,5 @@ class TemplateService:
logger.error(f"从模板创建任务失败: {str(e)}") logger.error(f"从模板创建任务失败: {str(e)}")
db.rollback() db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建任务失败: {str(e)}"
detail=f"创建任务失败: {str(e)}"
) )
+14 -7
View File
@@ -20,7 +20,7 @@ def escape_like_pattern(text: str) -> str:
Returns: Returns:
转义后的文本 转义后的文本
""" """
return text.replace('%', r'\%').replace('_', r'\_') return text.replace("%", r"\%").replace("_", r"\_")
class UserService: class UserService:
@@ -49,7 +49,9 @@ class UserService:
alias=user_data.alias, alias=user_data.alias,
email=user_data.email, email=user_data.email,
role=user_data.role or "user", role=user_data.role or "user",
is_approved=user_data.is_approved if user_data.is_approved is not None else True, # 使用请求中的值,默认已审批 is_approved=user_data.is_approved
if user_data.is_approved is not None
else True, # 使用请求中的值,默认已审批
jwt_exp="0", jwt_exp="0",
authorization=None, authorization=None,
) )
@@ -57,14 +59,17 @@ class UserService:
# 如果提供了密码,则设置密码 # 如果提供了密码,则设置密码
if user_data.password: if user_data.password:
import bcrypt import bcrypt
password_hash = bcrypt.hashpw(user_data.password.encode('utf-8'), bcrypt.gensalt())
setattr(user, 'password_hash', password_hash.decode('utf-8')) password_hash = bcrypt.hashpw(user_data.password.encode("utf-8"), bcrypt.gensalt())
setattr(user, "password_hash", password_hash.decode("utf-8"))
db.add(user) db.add(user)
db.commit() db.commit()
db.refresh(user) db.refresh(user)
logger.info(f"管理员创建用户成功: {user.alias} (ID: {user.id}, 角色: {user.role}, 密码: {'已设置' if user_data.password else '未设置'})") logger.info(
f"管理员创建用户成功: {user.alias} (ID: {user.id}, 角色: {user.role}, 密码: {'已设置' if user_data.password else '未设置'})"
)
return user return user
@staticmethod @staticmethod
@@ -115,7 +120,7 @@ class UserService:
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
search: Optional[str] = None, search: Optional[str] = None,
role: Optional[str] = None role: Optional[str] = None,
) -> List[User]: ) -> List[User]:
""" """
获取所有用户 获取所有用户
@@ -241,7 +246,9 @@ class UserService:
raise ValueError("修改密码时必须提供当前密码") raise ValueError("修改密码时必须提供当前密码")
# 验证当前密码 # 验证当前密码
if not AuthService.verify_password(update_data["current_password"], user.password_hash): if not AuthService.verify_password(
update_data["current_password"], user.password_hash
):
raise ValueError("当前密码错误") raise ValueError("当前密码错误")
# 设置新密码 # 设置新密码
+11 -26
View File
@@ -3,18 +3,16 @@
提供统一的资源查询、权限验证等通用功能 提供统一的资源查询、权限验证等通用功能
""" """
from typing import TypeVar, Type, Optional, Any from typing import TypeVar, Type, Optional, Any
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi import HTTPException, status from fastapi import HTTPException, status
T = TypeVar('T') T = TypeVar("T")
def get_or_404( def get_or_404(
model: Type[T], model: Type[T], model_id: int, db: Session, error_message: Optional[str] = None
model_id: int,
db: Session,
error_message: Optional[str] = None
) -> T: ) -> T:
""" """
查询资源,不存在则抛出 404 查询资源,不存在则抛出 404
@@ -35,18 +33,13 @@ def get_or_404(
if not obj: if not obj:
default_message = f"{model.__name__}不存在" default_message = f"{model.__name__}不存在"
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail=error_message or default_message
detail=error_message or default_message
) )
return obj return obj
def get_owned_or_403( def get_owned_or_403(
model: Type[T], model: Type[T], model_id: int, user_id: int, db: Session, error_message: Optional[str] = None
model_id: int,
user_id: int,
db: Session,
error_message: Optional[str] = None
) -> T: ) -> T:
""" """
查询资源并验证归属,否则抛出 403 查询资源并验证归属,否则抛出 403
@@ -64,24 +57,19 @@ def get_owned_or_403(
Raises: Raises:
HTTPException: 403 无权访问此资源 HTTPException: 403 无权访问此资源
""" """
obj = db.query(model).filter( obj = db.query(model).filter(model.id == model_id, model.user_id == user_id).first()
model.id == model_id,
model.user_id == user_id
).first()
if not obj: if not obj:
# 先检查资源是否存在 # 先检查资源是否存在
exists = db.query(model).filter(model.id == model_id).first() exists = db.query(model).filter(model.id == model_id).first()
if not exists: if not exists:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail=f"{model.__name__}不存在"
detail=f"{model.__name__}不存在"
) )
# 资源存在但不属于当前用户 # 资源存在但不属于当前用户
default_message = f"无权访问此{model.__name__}" default_message = f"无权访问此{model.__name__}"
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN, detail=error_message or default_message
detail=error_message or default_message
) )
return obj return obj
@@ -92,7 +80,7 @@ def get_by_field_or_404(
field_name: str, field_name: str,
field_value: Any, field_value: Any,
db: Session, db: Session,
error_message: Optional[str] = None error_message: Optional[str] = None,
) -> T: ) -> T:
""" """
根据字段查询资源,不存在则抛出 404 根据字段查询资源,不存在则抛出 404
@@ -110,14 +98,11 @@ def get_by_field_or_404(
Raises: Raises:
HTTPException: 404 资源不存在 HTTPException: 404 资源不存在
""" """
obj = db.query(model).filter( obj = db.query(model).filter(getattr(model, field_name) == field_value).first()
getattr(model, field_name) == field_value
).first()
if not obj: if not obj:
default_message = f"{model.__name__}不存在" default_message = f"{model.__name__}不存在"
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, detail=error_message or default_message
detail=error_message or default_message
) )
return obj return obj
+7 -13
View File
@@ -3,6 +3,7 @@ JSON 处理辅助函数
提供安全的 JSON 解析和数据提取功能 提供安全的 JSON 解析和数据提取功能
""" """
import json import json
import logging import logging
from typing import Optional, Any, Dict from typing import Optional, Any, Dict
@@ -10,11 +11,7 @@ from typing import Optional, Any, Dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def safe_parse_json( def safe_parse_json(json_str: Optional[str], default: Any = None, log_error: bool = True) -> Any:
json_str: Optional[str],
default: Any = None,
log_error: bool = True
) -> Any:
""" """
安全解析 JSON 字符串,失败时返回默认值 安全解析 JSON 字符串,失败时返回默认值
@@ -37,10 +34,7 @@ def safe_parse_json(
return default return default
def safe_parse_payload( def safe_parse_payload(payload_config: Optional[str], default: Optional[Dict] = None) -> Dict:
payload_config: Optional[str],
default: Optional[Dict] = None
) -> Dict:
""" """
安全解析 payload_config,失败时返回默认字典 安全解析 payload_config,失败时返回默认字典
@@ -70,7 +64,7 @@ def extract_thread_id(payload_config: Optional[str]) -> Optional[str]:
ThreadId 或 None ThreadId 或 None
""" """
payload = safe_parse_payload(payload_config) payload = safe_parse_payload(payload_config)
return payload.get('ThreadId') return payload.get("ThreadId")
def extract_signature(payload_config: Optional[str]) -> Optional[str]: def extract_signature(payload_config: Optional[str]) -> Optional[str]:
@@ -84,7 +78,7 @@ def extract_signature(payload_config: Optional[str]) -> Optional[str]:
Signature 或 None Signature 或 None
""" """
payload = safe_parse_payload(payload_config) payload = safe_parse_payload(payload_config)
return payload.get('Signature') return payload.get("Signature")
def build_task_info(task) -> Dict[str, str]: def build_task_info(task) -> Dict[str, str]:
@@ -98,6 +92,6 @@ def build_task_info(task) -> Dict[str, str]:
包含 thread_id 和 name 的字典 包含 thread_id 和 name 的字典
""" """
return { return {
'thread_id': extract_thread_id(getattr(task, 'payload_config', None)) or '未知', "thread_id": extract_thread_id(getattr(task, "payload_config", None)) or "未知",
'name': getattr(task, 'name', None) or f'Task-{getattr(task, "id", "Unknown")}' "name": getattr(task, "name", None) or f"Task-{getattr(task, 'id', 'Unknown')}",
} }
+2 -5
View File
@@ -42,7 +42,7 @@ class JWTManager:
"alias": user_alias, "alias": user_alias,
"iat": now, # Issued At - 签发时间 "iat": now, # Issued At - 签发时间
"exp": exp, # Expiration Time - 过期时间 "exp": exp, # Expiration Time - 过期时间
"type": "access" # Token 类型 "type": "access", # Token 类型
} }
token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) token = jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
@@ -97,10 +97,7 @@ class JWTManager:
try: try:
# decode 时设置 verify=False 跳过过期验证 # decode 时设置 verify=False 跳过过期验证
payload = jwt.decode( payload = jwt.decode(
token, token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM], options={"verify_exp": False}
JWT_SECRET_KEY,
algorithms=[JWT_ALGORITHM],
options={"verify_exp": False}
) )
return payload.get("user_id") return payload.get("user_id")
except Exception as e: except Exception as e:
+2 -1
View File
@@ -3,6 +3,7 @@
提供统一的时间戳处理和格式化功能 提供统一的时间戳处理和格式化功能
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional
@@ -85,7 +86,7 @@ def minutes_until_expiry(timestamp: int) -> int:
return seconds // 60 return seconds // 60
def format_timestamp(timestamp: int, format_str: str = '%Y-%m-%d %H:%M:%S') -> str: def format_timestamp(timestamp: int, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
""" """
格式化时间戳为人类可读格式 格式化时间戳为人类可读格式
+58 -53
View File
@@ -42,17 +42,17 @@ def get_live_x_api_payload(auth_token: str) -> str:
chrome_options.binary_location = CHROME_BINARY_PATH chrome_options.binary_location = CHROME_BINARY_PATH
# 开启性能日志记录功能 # 开启性能日志记录功能
logging_prefs = {'performance': 'ALL'} logging_prefs = {"performance": "ALL"}
chrome_options.set_capability('goog:loggingPrefs', logging_prefs) chrome_options.set_capability("goog:loggingPrefs", logging_prefs)
# Headless 模式配置 # Headless 模式配置
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36" user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36"
chrome_options.add_argument(f'user-agent={user_agent}') chrome_options.add_argument(f"user-agent={user_agent}")
chrome_options.add_argument("--headless") chrome_options.add_argument("--headless")
chrome_options.add_argument("--no-sandbox") chrome_options.add_argument("--no-sandbox")
chrome_options.add_argument("--disable-dev-shm-usage") chrome_options.add_argument("--disable-dev-shm-usage")
chrome_options.add_argument("--window-size=1920,1080") chrome_options.add_argument("--window-size=1920,1080")
chrome_options.add_argument('--ignore-certificate-errors') chrome_options.add_argument("--ignore-certificate-errors")
chrome_options.add_experimental_option("excludeSwitches", ["enable-automation"]) chrome_options.add_experimental_option("excludeSwitches", ["enable-automation"])
driver = webdriver.Chrome(service=service, options=chrome_options) driver = webdriver.Chrome(service=service, options=chrome_options)
@@ -63,11 +63,7 @@ def get_live_x_api_payload(auth_token: str) -> str:
driver.get("https://i.jielong.com/my-class") driver.get("https://i.jielong.com/my-class")
# 注入长期 Token # 注入长期 Token
driver.add_cookie({ driver.add_cookie({"name": "token", "value": auth_token, "domain": ".jielong.com"})
'name': 'token',
'value': auth_token,
'domain': '.jielong.com'
})
# 导航到触发 API 的页面 # 导航到触发 API 的页面
driver.get("https://i.jielong.com/my-form") driver.get("https://i.jielong.com/my-form")
@@ -78,14 +74,14 @@ def get_live_x_api_payload(auth_token: str) -> str:
found = False found = False
while time.time() - start_time < max_wait_time: while time.time() - start_time < max_wait_time:
logs = driver.get_log('performance') logs = driver.get_log("performance")
for entry in logs: for entry in logs:
log = json.loads(entry['message'])['message'] log = json.loads(entry["message"])["message"]
if log['method'] == 'Network.requestWillBeSent': if log["method"] == "Network.requestWillBeSent":
headers = log.get('params', {}).get('request', {}).get('headers', {}) headers = log.get("params", {}).get("request", {}).get("headers", {})
headers_lower = {k.lower(): v for k, v in headers.items()} headers_lower = {k.lower(): v for k, v in headers.items()}
if 'x-api-request-payload' in headers_lower: if "x-api-request-payload" in headers_lower:
payload_signature = headers_lower['x-api-request-payload'] payload_signature = headers_lower["x-api-request-payload"]
logger.info("成功通过网络日志捕获到现场的 x-api-request-payload") logger.info("成功通过网络日志捕获到现场的 x-api-request-payload")
found = True found = True
break break
@@ -94,12 +90,14 @@ def get_live_x_api_payload(auth_token: str) -> str:
time.sleep(1) time.sleep(1)
if not payload_signature: if not payload_signature:
raise Exception(f"{max_wait_time} 秒内未能通过网络日志捕获到 x-api-request-payload。") raise Exception(
f"{max_wait_time} 秒内未能通过网络日志捕获到 x-api-request-payload。"
)
except Exception as e: except Exception as e:
logger.error(f"获取现场 x-api-request-payload 时失败: {e}") logger.error(f"获取现场 x-api-request-payload 时失败: {e}")
try: try:
debug_screenshot = os.path.join(settings.BASE_DIR, 'payload_debug.png') debug_screenshot = os.path.join(settings.BASE_DIR, "payload_debug.png")
driver.save_screenshot(debug_screenshot) driver.save_screenshot(debug_screenshot)
except Exception as screenshot_error: except Exception as screenshot_error:
logger.warning(f"保存调试截图失败: {screenshot_error}") logger.warning(f"保存调试截图失败: {screenshot_error}")
@@ -135,7 +133,7 @@ def perform_check_in(task, user_token: str) -> Dict[str, Any]:
from backend.utils.json_helpers import safe_parse_payload, extract_signature from backend.utils.json_helpers import safe_parse_payload, extract_signature
payload_dict = safe_parse_payload(task.payload_config) payload_dict = safe_parse_payload(task.payload_config)
signature = extract_signature(task.payload_config) or 'Unknown' signature = extract_signature(task.payload_config) or "Unknown"
logger.info(f"Selenium打卡: 正在为任务 ID: {task.id} (Signature: {signature}) 执行打卡...") logger.info(f"Selenium打卡: 正在为任务 ID: {task.id} (Signature: {signature}) 执行打卡...")
@@ -146,7 +144,7 @@ def perform_check_in(task, user_token: str) -> Dict[str, Any]:
"success": False, "success": False,
"status": "failure", "status": "failure",
"response_text": "", "response_text": "",
"error_message": error_msg "error_message": error_msg,
} }
# 获取 x-api-request-payload # 获取 x-api-request-payload
@@ -158,7 +156,7 @@ def perform_check_in(task, user_token: str) -> Dict[str, Any]:
"success": False, "success": False,
"status": "failure", "status": "failure",
"response_text": "", "response_text": "",
"error_message": error_msg "error_message": error_msg,
} }
try: try:
@@ -175,19 +173,19 @@ def perform_check_in(task, user_token: str) -> Dict[str, Any]:
"success": False, "success": False,
"status": "failure", "status": "failure",
"response_text": "", "response_text": "",
"error_message": error_msg "error_message": error_msg,
} }
headers = { headers = {
'User-Agent': "Mozilla%2f5.0+(Linux%3b+Android+16%3b+wv)+AppleWebKit%2f537.36+(KHTML%2c+like+Gecko)+Chrome%2f142.0.0.0+Safari%2f537.36+QQ%2f9.2.30.31620+QQ%2fMiniApp", "User-Agent": "Mozilla%2f5.0+(Linux%3b+Android+16%3b+wv)+AppleWebKit%2f537.36+(KHTML%2c+like+Gecko)+Chrome%2f142.0.0.0+Safari%2f537.36+QQ%2f9.2.30.31620+QQ%2fMiniApp",
'Accept-Encoding': "gzip", "Accept-Encoding": "gzip",
'Content-Type': "application/json", "Content-Type": "application/json",
'authorization': f"Bearer {user_token}", "authorization": f"Bearer {user_token}",
'x-api-request-referer': "https://appservice.qq.com/1110276759", "x-api-request-referer": "https://appservice.qq.com/1110276759",
'x-api-request-payload': payload_signature, "x-api-request-payload": payload_signature,
'referer': "https://appservice.qq.com/1110276759/8.10.1.7/page-frame.html", "referer": "https://appservice.qq.com/1110276759/8.10.1.7/page-frame.html",
'platform': "qq", "platform": "qq",
'x-api-request-mode': "cors", "x-api-request-mode": "cors",
} }
url = "https://api.jielong.com/api/CheckIn/EditRecord" url = "https://api.jielong.com/api/CheckIn/EditRecord"
@@ -203,7 +201,9 @@ def perform_check_in(task, user_token: str) -> Dict[str, Any]:
response.raise_for_status() response.raise_for_status()
response_text = response.text response_text = response.text
logger.info(f"✉️ 任务 ID: {task.id} (Signature: {signature}) 打卡请求完成!响应: {response_text}") logger.info(
f"✉️ 任务 ID: {task.id} (Signature: {signature}) 打卡请求完成!响应: {response_text}"
)
# 判断响应内容(参考 V1 实现逻辑) # 判断响应内容(参考 V1 实现逻辑)
# 情况1: 明确包含"打卡成功" → 成功 # 情况1: 明确包含"打卡成功" → 成功
@@ -213,9 +213,10 @@ def perform_check_in(task, user_token: str) -> Dict[str, Any]:
if task.user and task.user.email: if task.user and task.user.email:
try: try:
from backend.services.email_service import EmailService from backend.services.email_service import EmailService
task_info = { task_info = {
'thread_id': payload.get('ThreadId', '未知'), "thread_id": payload.get("ThreadId", "未知"),
'name': getattr(task, 'name', '打卡任务') "name": getattr(task, "name", "打卡任务"),
} }
EmailService.notify_check_in_result(task.user, task_info, True, "打卡成功") EmailService.notify_check_in_result(task.user, task_info, True, "打卡成功")
except Exception as e: except Exception as e:
@@ -225,38 +226,45 @@ def perform_check_in(task, user_token: str) -> Dict[str, Any]:
"success": True, "success": True,
"status": "success", "status": "success",
"response_text": response_text, "response_text": response_text,
"error_message": "" "error_message": "",
} }
# 情况2: 已经提交过了(重复提交)→ 视为成功,但不发送邮件 # 情况2: 已经提交过了(重复提交)→ 视为成功,但不发送邮件
# 匹配 "已被提交" 或 "已经打卡" # 匹配 "已被提交" 或 "已经打卡"
elif ("已被提交" in response_text or "已经打卡" in response_text or elif (
"重复提交" in response_text): "已被提交" in response_text
or "已经打卡" in response_text
or "重复提交" in response_text
):
logger.info(f"✅ 检测到'已被提交',本次打卡已完成(重复提交,不发送邮件)") logger.info(f"✅ 检测到'已被提交',本次打卡已完成(重复提交,不发送邮件)")
return { return {
"success": True, "success": True,
"status": "success", "status": "success",
"response_text": response_text, "response_text": response_text,
"error_message": "" "error_message": "",
} }
# 情况3: 不在打卡时间范围 → 标记为时间范围外 # 情况3: 不在打卡时间范围 → 标记为时间范围外
# 匹配 Data 或 Description 中的内容 # 匹配 Data 或 Description 中的内容
elif ("不在打卡时间范围" in response_text or elif "不在打卡时间范围" in response_text or "不在打卡时间" in response_text:
"不在打卡时间" in response_text):
logger.warning(f"⏰ 检测到'不在打卡时间范围',打卡时间不符") logger.warning(f"⏰ 检测到'不在打卡时间范围',打卡时间不符")
return { return {
"success": False, "success": False,
"status": "out_of_time", "status": "out_of_time",
"response_text": response_text, "response_text": response_text,
"error_message": "不在打卡时间范围内" "error_message": "不在打卡时间范围内",
} }
# 情况4: Token 失效的特征标识 → 失败 # 情况4: Token 失效的特征标识 → 失败
# 扩展检测条件:检测多种 Token 失效的响应特征 # 扩展检测条件:检测多种 Token 失效的响应特征
elif ("登录" in response_text or "授权" in response_text or elif (
"登录" in response_text or "token" in response_text.lower() or "登录" in response_text
"Unauthorized" in response_text or response.status_code == 401): or "授权" in response_text
or "未登录" in response_text
or "token" in response_text.lower()
or "Unauthorized" in response_text
or response.status_code == 401
):
logger.warning(f"⚠️ 检测到Token失效特征,Token 可能已失效") logger.warning(f"⚠️ 检测到Token失效特征,Token 可能已失效")
# 发送打卡失败邮件通知(邮件内容已包含Token失效提醒和刷新指引) # 发送打卡失败邮件通知(邮件内容已包含Token失效提醒和刷新指引)
if task.user and task.user.email: if task.user and task.user.email:
@@ -268,7 +276,9 @@ def perform_check_in(task, user_token: str) -> Dict[str, Any]:
task_info = build_task_info(task) task_info = build_task_info(task)
# 只发送打卡失败通知(内容已说明Token失效) # 只发送打卡失败通知(内容已说明Token失效)
EmailService.notify_check_in_result(task.user, task_info, False, "Token 已失效,需要重新授权") EmailService.notify_check_in_result(
task.user, task_info, False, "Token 已失效,需要重新授权"
)
except Exception as e: except Exception as e:
logger.error(f"发送打卡失败邮件失败: {e}") logger.error(f"发送打卡失败邮件失败: {e}")
@@ -276,7 +286,7 @@ def perform_check_in(task, user_token: str) -> Dict[str, Any]:
"success": False, "success": False,
"status": "token_expired", # 特殊状态,用于标识 Token 过期 "status": "token_expired", # 特殊状态,用于标识 Token 过期
"response_text": response_text, "response_text": response_text,
"error_message": "Token 已失效,需要重新授权" "error_message": "Token 已失效,需要重新授权",
} }
# 情况5: 其他响应 → 需要人工确认(标记为异常) # 情况5: 其他响应 → 需要人工确认(标记为异常)
@@ -287,7 +297,7 @@ def perform_check_in(task, user_token: str) -> Dict[str, Any]:
"success": False, "success": False,
"status": "unknown", "status": "unknown",
"response_text": response_text, "response_text": response_text,
"error_message": "未识别的响应,请人工确认" "error_message": "未识别的响应,请人工确认",
} }
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
@@ -303,15 +313,10 @@ def perform_check_in(task, user_token: str) -> Dict[str, Any]:
"success": False, "success": False,
"status": "failure", "status": "failure",
"response_text": response_text, "response_text": response_text,
"error_message": str(e) "error_message": str(e),
} }
except Exception as e: except Exception as e:
error_msg = f"为任务 ID: {task.id} (Signature: {signature}) 打卡时发生未知错误: {e}" error_msg = f"为任务 ID: {task.id} (Signature: {signature}) 打卡时发生未知错误: {e}"
logger.error(error_msg) logger.error(error_msg)
return { return {"success": False, "status": "failure", "response_text": "", "error_message": str(e)}
"success": False,
"status": "failure",
"response_text": "",
"error_message": str(e)
}
+19 -25
View File
@@ -32,7 +32,9 @@ class EmailNotifier:
""" """
# 检查必要的邮件配置是否存在 # 检查必要的邮件配置是否存在
if not settings.SMTP_SERVER or not settings.SMTP_SENDER_EMAIL: if not settings.SMTP_SERVER or not settings.SMTP_SENDER_EMAIL:
logger.debug("邮件配置未完成(SMTP_SERVER 或 SMTP_SENDER_EMAIL 为空),邮件发送功能已禁用") logger.debug(
"邮件配置未完成(SMTP_SERVER 或 SMTP_SENDER_EMAIL 为空),邮件发送功能已禁用"
)
return None return None
if not settings.SMTP_PORT: if not settings.SMTP_PORT:
@@ -41,19 +43,16 @@ class EmailNotifier:
# 返回配置字典 # 返回配置字典
return { return {
'smtp_server': settings.SMTP_SERVER, "smtp_server": settings.SMTP_SERVER,
'smtp_port': settings.SMTP_PORT, "smtp_port": settings.SMTP_PORT,
'sender_email': settings.SMTP_SENDER_EMAIL, "sender_email": settings.SMTP_SENDER_EMAIL,
'sender_password': settings.SMTP_SENDER_PASSWORD, "sender_password": settings.SMTP_SENDER_PASSWORD,
'use_ssl': settings.SMTP_USE_SSL "use_ssl": settings.SMTP_USE_SSL,
} }
@staticmethod @staticmethod
def send_email( def send_email(
to_emails: List[str], to_emails: List[str], subject: str, html_content: str, from_email: Optional[str] = None
subject: str,
html_content: str,
from_email: Optional[str] = None
) -> bool: ) -> bool:
""" """
发送邮件(底层方法) 发送邮件(底层方法)
@@ -74,30 +73,26 @@ class EmailNotifier:
try: try:
# 创建邮件 # 创建邮件
msg = MIMEMultipart('alternative') msg = MIMEMultipart("alternative")
msg['From'] = from_email or email_config['sender_email'] msg["From"] = from_email or email_config["sender_email"]
msg['To'] = ', '.join(to_emails) msg["To"] = ", ".join(to_emails)
msg['Subject'] = subject msg["Subject"] = subject
# 添加 HTML 正文 # 添加 HTML 正文
html_part = MIMEText(html_content, 'html', 'utf-8') html_part = MIMEText(html_content, "html", "utf-8")
msg.attach(html_part) msg.attach(html_part)
# 连接 SMTP 服务器并发送 # 连接 SMTP 服务器并发送
if email_config.get('use_ssl', True): if email_config.get("use_ssl", True):
server = smtplib.SMTP_SSL( server = smtplib.SMTP_SSL(
email_config['smtp_server'], email_config["smtp_server"], int(email_config["smtp_port"])
int(email_config['smtp_port'])
) )
else: else:
server = smtplib.SMTP( server = smtplib.SMTP(email_config["smtp_server"], int(email_config["smtp_port"]))
email_config['smtp_server'],
int(email_config['smtp_port'])
)
server.starttls() server.starttls()
server.login(email_config['sender_email'], email_config['sender_password']) server.login(email_config["sender_email"], email_config["sender_password"])
server.sendmail(msg['From'], to_emails, msg.as_string()) server.sendmail(msg["From"], to_emails, msg.as_string())
server.quit() server.quit()
logger.info(f"邮件发送成功: {subject} -> {', '.join(to_emails)}") logger.info(f"邮件发送成功: {subject} -> {', '.join(to_emails)}")
@@ -116,4 +111,3 @@ class EmailNotifier:
邮件功能是否可用 邮件功能是否可用
""" """
return EmailNotifier.get_email_config() is not None return EmailNotifier.get_email_config() is not None
+68 -47
View File
@@ -27,11 +27,10 @@ def get_chrome_config():
"""获取 Chrome 配置(从 settings 读取)""" """获取 Chrome 配置(从 settings 读取)"""
return { return {
"chrome_binary": settings.CHROME_BINARY_PATH, "chrome_binary": settings.CHROME_BINARY_PATH,
"chromedriver": settings.CHROMEDRIVER_PATH "chromedriver": settings.CHROMEDRIVER_PATH,
} }
def update_session_file(session_id: str, data: dict) -> None: def update_session_file(session_id: str, data: dict) -> None:
"""线程安全地写入会话文件""" """线程安全地写入会话文件"""
filepath = settings.SESSION_DIR / f"{session_id}.json" filepath = settings.SESSION_DIR / f"{session_id}.json"
@@ -39,7 +38,7 @@ def update_session_file(session_id: str, data: dict) -> None:
try: try:
with FileLock(lock_path, timeout=5): with FileLock(lock_path, timeout=5):
with open(filepath, 'w', encoding='utf-8') as f: with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2) json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e: except Exception as e:
logger.error(f"写入会话文件 {filepath} 失败: {e}") logger.error(f"写入会话文件 {filepath} 失败: {e}")
@@ -55,13 +54,14 @@ def get_session_status(session_id: str) -> str:
try: try:
with FileLock(lock_path, timeout=5): with FileLock(lock_path, timeout=5):
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, "r", encoding="utf-8") as f:
content = f.read() content = f.read()
if not content: if not content:
return None return None
from backend.utils.json_helpers import safe_parse_json from backend.utils.json_helpers import safe_parse_json
data = safe_parse_json(content, {}) data = safe_parse_json(content, {})
return data.get('status') return data.get("status")
except IOError as e: except IOError as e:
logger.error(f"读取会话文件 {filepath} 失败: {e}") logger.error(f"读取会话文件 {filepath} 失败: {e}")
return None return None
@@ -77,11 +77,12 @@ def get_session_data(session_id: str) -> dict:
try: try:
with FileLock(lock_path, timeout=5): with FileLock(lock_path, timeout=5):
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, "r", encoding="utf-8") as f:
content = f.read() content = f.read()
if not content: if not content:
return None return None
from backend.utils.json_helpers import safe_parse_json from backend.utils.json_helpers import safe_parse_json
return safe_parse_json(content, {}) return safe_parse_json(content, {})
except IOError as e: except IOError as e:
logger.error(f"读取会话文件 {filepath} 失败: {e}") logger.error(f"读取会话文件 {filepath} 失败: {e}")
@@ -110,23 +111,23 @@ def cancel_session(session_id: str) -> bool:
# 读取当前会话数据 # 读取当前会话数据
from backend.utils.json_helpers import safe_parse_json from backend.utils.json_helpers import safe_parse_json
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, "r", encoding="utf-8") as f:
content = f.read() content = f.read()
if not content: if not content:
return False return False
data = safe_parse_json(content, {}) data = safe_parse_json(content, {})
# 如果已经成功,不允许取消 # 如果已经成功,不允许取消
if data.get('status') == 'success': if data.get("status") == "success":
logger.info(f"会话 {session_id} 已成功,无法取消") logger.info(f"会话 {session_id} 已成功,无法取消")
return False return False
# 标记为已取消 # 标记为已取消
data['status'] = 'cancelled' data["status"] = "cancelled"
data['message'] = '用户取消登录' data["message"] = "用户取消登录"
# 写回文件 # 写回文件
with open(filepath, 'w', encoding='utf-8') as f: with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2) json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"✅ 会话 {session_id} 已取消") logger.info(f"✅ 会话 {session_id} 已取消")
@@ -137,7 +138,9 @@ def cancel_session(session_id: str) -> bool:
return False return False
def get_token_headless(session_id: str, jwt_sub: str = None, alias: str = None, client_ip: str = "") -> None: def get_token_headless(
session_id: str, jwt_sub: str = None, alias: str = None, client_ip: str = ""
) -> None:
""" """
使用 Selenium 获取 QQ 扫码登录的 Token 使用 Selenium 获取 QQ 扫码登录的 Token
@@ -171,12 +174,12 @@ def get_token_headless(session_id: str, jwt_sub: str = None, alias: str = None,
# Headless 模式配置 # Headless 模式配置
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36" user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36"
chrome_options.add_argument(f'user-agent={user_agent}') chrome_options.add_argument(f"user-agent={user_agent}")
chrome_options.add_argument("--headless") chrome_options.add_argument("--headless")
chrome_options.add_argument("--no-sandbox") chrome_options.add_argument("--no-sandbox")
chrome_options.add_argument("--disable-dev-shm-usage") chrome_options.add_argument("--disable-dev-shm-usage")
chrome_options.add_argument("--window-size=1920,1080") chrome_options.add_argument("--window-size=1920,1080")
chrome_options.add_argument('--ignore-certificate-errors') chrome_options.add_argument("--ignore-certificate-errors")
chrome_options.add_experimental_option("excludeSwitches", ["enable-automation"]) chrome_options.add_experimental_option("excludeSwitches", ["enable-automation"])
# 启动浏览器 # 启动浏览器
@@ -203,7 +206,9 @@ def get_token_headless(session_id: str, jwt_sub: str = None, alias: str = None,
current_step = "查找并点击切换按钮" current_step = "查找并点击切换按钮"
toggle_button_selector = "div.login-wrap .toggle" toggle_button_selector = "div.login-wrap .toggle"
logger.info(f"Selenium ({session_id}): {current_step} ({toggle_button_selector})...") logger.info(f"Selenium ({session_id}): {current_step} ({toggle_button_selector})...")
toggle_button = wait.until(EC.element_to_be_clickable((By.CSS_SELECTOR, toggle_button_selector))) toggle_button = wait.until(
EC.element_to_be_clickable((By.CSS_SELECTOR, toggle_button_selector))
)
toggle_button.click() toggle_button.click()
# --- 步骤 2: 勾选同意服务协议 --- # --- 步骤 2: 勾选同意服务协议 ---
@@ -219,27 +224,35 @@ def get_token_headless(session_id: str, jwt_sub: str = None, alias: str = None,
current_step = "点击立即登录按钮" current_step = "点击立即登录按钮"
login_button_selector = "button.css-1wli0ry.ant-btn.ant-btn-default.login-btn" login_button_selector = "button.css-1wli0ry.ant-btn.ant-btn-default.login-btn"
logger.info(f"Selenium ({session_id}): {current_step} ({login_button_selector})...") logger.info(f"Selenium ({session_id}): {current_step} ({login_button_selector})...")
login_button = wait.until(EC.element_to_be_clickable((By.CSS_SELECTOR, login_button_selector))) login_button = wait.until(
EC.element_to_be_clickable((By.CSS_SELECTOR, login_button_selector))
)
login_button.click() login_button.click()
# --- 步骤 4: 等待二维码加载 --- # --- 步骤 4: 等待二维码加载 ---
import time import time
time.sleep(3) # 等待几秒让二维码刷新出来 time.sleep(3) # 等待几秒让二维码刷新出来
current_step = "等待QQ二维码图片加载" current_step = "等待QQ二维码图片加载"
qq_qr_image_selector = "#login_container img" qq_qr_image_selector = "#login_container img"
logger.info(f"Selenium ({session_id}): {current_step} ({qq_qr_image_selector})...") logger.info(f"Selenium ({session_id}): {current_step} ({qq_qr_image_selector})...")
qr_element = wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, qq_qr_image_selector))) qr_element = wait.until(
EC.visibility_of_element_located((By.CSS_SELECTOR, qq_qr_image_selector))
)
logger.info(f"Selenium ({session_id}): 成功找到QQ二维码元素,正在截图...") logger.info(f"Selenium ({session_id}): 成功找到QQ二维码元素,正在截图...")
qr_base64 = qr_element.screenshot_as_base64 qr_base64 = qr_element.screenshot_as_base64
update_session_file(session_id, { update_session_file(
'status': 'waiting_scan', session_id,
'qr_image_data': qr_base64, {
'jwt_sub': jwt_sub, "status": "waiting_scan",
'alias': alias, # 新增:保存 alias "qr_image_data": qr_base64,
'client_ip': client_ip # 新增:保存 IP "jwt_sub": jwt_sub,
}) "alias": alias, # 新增:保存 alias
"client_ip": client_ip, # 新增:保存 IP
},
)
current_step = "等待用户扫描登录 (Cookie 'token' 出现)" current_step = "等待用户扫描登录 (Cookie 'token' 出现)"
cookie_name_to_find = "token" cookie_name_to_find = "token"
@@ -248,10 +261,11 @@ def get_token_headless(session_id: str, jwt_sub: str = None, alias: str = None,
# 自定义等待逻辑:每秒检查cookie和session状态 # 自定义等待逻辑:每秒检查cookie和session状态
max_wait_seconds = 120 max_wait_seconds = 120
import time import time
for i in range(max_wait_seconds): for i in range(max_wait_seconds):
# 检查session是否被取消 # 检查session是否被取消
status = get_session_status(session_id) status = get_session_status(session_id)
if status == 'cancelled': if status == "cancelled":
logger.info(f"Selenium ({session_id}): 用户取消了登录,终止会话") logger.info(f"Selenium ({session_id}): 用户取消了登录,终止会话")
raise Exception("用户取消登录") raise Exception("用户取消登录")
@@ -268,22 +282,28 @@ def get_token_headless(session_id: str, jwt_sub: str = None, alias: str = None,
cookie = driver.get_cookie(cookie_name_to_find) cookie = driver.get_cookie(cookie_name_to_find)
if cookie: if cookie:
logger.info(f"Selenium ({session_id}): 成功在Cookie中捕获到Token") logger.info(f"Selenium ({session_id}): 成功在Cookie中捕获到Token")
update_session_file(session_id, { update_session_file(
'status': 'success', session_id,
'token': cookie['value'], {
'alias': alias, # 保存 alias "status": "success",
'client_ip': client_ip # 保存 IP "token": cookie["value"],
}) "alias": alias, # 保存 alias
"client_ip": client_ip, # 保存 IP
},
)
else: else:
raise Exception("等待Cookie成功但获取失败") raise Exception("等待Cookie成功但获取失败")
except TimeoutException: except TimeoutException:
if get_session_status(session_id) == 'success': if get_session_status(session_id) == "success":
logger.warning(f"Selenium ({session_id}): 一个并发线程超时,但会话已成功,将忽略此超时。") logger.warning(
f"Selenium ({session_id}): 一个并发线程超时,但会话已成功,将忽略此超时。"
)
else: else:
# 释放预占的用户名 # 释放预占的用户名
if alias: if alias:
from backend.services.registration_manager import registration_manager from backend.services.registration_manager import registration_manager
registration_manager.release_alias(alias, session_id) registration_manager.release_alias(alias, session_id)
logger.info(f"超时释放用户名预占: {alias}") logger.info(f"超时释放用户名预占: {alias}")
@@ -294,34 +314,35 @@ def get_token_headless(session_id: str, jwt_sub: str = None, alias: str = None,
if driver: if driver:
try: try:
driver.save_screenshot(DEBUG_SCREENSHOT_PATH) driver.save_screenshot(DEBUG_SCREENSHOT_PATH)
with open(DEBUG_PAGE_SOURCE_PATH, 'w', encoding='utf-8') as f: with open(DEBUG_PAGE_SOURCE_PATH, "w", encoding="utf-8") as f:
f.write(driver.page_source) f.write(driver.page_source)
logger.error(f"Selenium ({session_id}): 调试截图和源码已保存。当前URL: {driver.current_url}") logger.error(
f"Selenium ({session_id}): 调试截图和源码已保存。当前URL: {driver.current_url}"
)
except Exception as debug_error: except Exception as debug_error:
logger.error(f"Selenium ({session_id}): 保存调试信息失败: {debug_error}") logger.error(f"Selenium ({session_id}): 保存调试信息失败: {debug_error}")
update_session_file(session_id, { update_session_file(
'status': 'error', session_id, {"status": "error", "message": error_message, "jwt_sub": jwt_sub}
'message': error_message, )
'jwt_sub': jwt_sub
})
except Exception as e: except Exception as e:
if get_session_status(session_id) == 'success': if get_session_status(session_id) == "success":
logger.warning(f"Selenium ({session_id}): 一个并发线程出错 ({e}),但会话已成功,将忽略此错误。") logger.warning(
f"Selenium ({session_id}): 一个并发线程出错 ({e}),但会话已成功,将忽略此错误。"
)
else: else:
# 释放预占的用户名 # 释放预占的用户名
if alias: if alias:
from backend.services.registration_manager import registration_manager from backend.services.registration_manager import registration_manager
registration_manager.release_alias(alias, session_id) registration_manager.release_alias(alias, session_id)
logger.info(f"异常释放用户名预占: {alias}") logger.info(f"异常释放用户名预占: {alias}")
logger.error(f"Selenium ({session_id}): 发生未知错误: {e}", exc_info=True) logger.error(f"Selenium ({session_id}): 发生未知错误: {e}", exc_info=True)
update_session_file(session_id, { update_session_file(
'status': 'error', session_id, {"status": "error", "message": str(e), "jwt_sub": jwt_sub}
'message': str(e), )
'jwt_sub': jwt_sub
})
finally: finally:
if driver: if driver:
+18 -4
View File
@@ -173,7 +173,15 @@ def start_frontend_daemon(args: argparse.Namespace) -> int:
LOGS_DIR.mkdir(parents=True, exist_ok=True) LOGS_DIR.mkdir(parents=True, exist_ok=True)
log_file = FRONTEND_LOG.open("a", encoding="utf-8") log_file = FRONTEND_LOG.open("a", encoding="utf-8")
cmd = [get_python(), str(REPO_ROOT / "main.py"), "frontend", "--host", args.host, "--port", str(args.port)] cmd = [
get_python(),
str(REPO_ROOT / "main.py"),
"frontend",
"--host",
args.host,
"--port",
str(args.port),
]
proc = subprocess.Popen( proc = subprocess.Popen(
cmd, cmd,
cwd=REPO_ROOT, cwd=REPO_ROOT,
@@ -263,13 +271,17 @@ def build_parser() -> argparse.ArgumentParser:
frontend.add_argument("--port", type=int, default=FRONTEND_PORT) frontend.add_argument("--port", type=int, default=FRONTEND_PORT)
frontend.set_defaults(func=run_frontend) frontend.set_defaults(func=run_frontend)
frontend_daemon = sub.add_parser("frontend-daemon", help="start frontend dev server in the background") frontend_daemon = sub.add_parser(
"frontend-daemon", help="start frontend dev server in the background"
)
frontend_daemon.add_argument("--host", default="0.0.0.0") frontend_daemon.add_argument("--host", default="0.0.0.0")
frontend_daemon.add_argument("--port", type=int, default=FRONTEND_PORT) frontend_daemon.add_argument("--port", type=int, default=FRONTEND_PORT)
frontend_daemon.set_defaults(func=start_frontend_daemon) frontend_daemon.set_defaults(func=start_frontend_daemon)
frontend_build = sub.add_parser("frontend-build", help="build current frontend") frontend_build = sub.add_parser("frontend-build", help="build current frontend")
frontend_build.add_argument("--install", action="store_true", help="run npm install if node_modules is missing") frontend_build.add_argument(
"--install", action="store_true", help="run npm install if node_modules is missing"
)
frontend_build.set_defaults(func=build_frontend) frontend_build.set_defaults(func=build_frontend)
deploy = sub.add_parser("frontend-deploy", help="show frontend deployment output path") deploy = sub.add_parser("frontend-deploy", help="show frontend deployment output path")
@@ -279,7 +291,9 @@ def build_parser() -> argparse.ArgumentParser:
service_status.set_defaults(func=status) service_status.set_defaults(func=status)
service_stop = sub.add_parser("stop", help="stop managed daemon processes") service_stop = sub.add_parser("stop", help="stop managed daemon processes")
service_stop.add_argument("target", choices=["backend", "frontend", "all"], nargs="?", default="all") service_stop.add_argument(
"target", choices=["backend", "frontend", "all"], nargs="?", default="all"
)
service_stop.set_defaults(func=stop) service_stop.set_defaults(func=stop)
return parser return parser