diff --git a/apps/backend/api/admin.py b/apps/backend/api/admin.py index 72d6912..c291d8f 100644 --- a/apps/backend/api/admin.py +++ b/apps/backend/api/admin.py @@ -11,9 +11,11 @@ from backend.services.check_in_service import CheckInService from backend.services.admin_service import AdminService from backend.dependencies import get_current_admin_user from backend.config import settings +from backend.exceptions import BaseAPIException logger = logging.getLogger(__name__) router = APIRouter() +EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException) class BatchToggleTasksRequest(BaseModel): @@ -43,13 +45,21 @@ async def batch_toggle_tasks( task.is_active = request.is_active count += 1 + from backend.services.scheduler_service import sync_scheduled_task + db.commit() + for task_id in request.task_ids: + task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first() + if task: + sync_scheduled_task(task) return { "success": True, "message": f"已{'启用' if request.is_active else '禁用'} {count} 个任务", "count": count, } + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量操作失败: {str(e)}" @@ -72,6 +82,8 @@ async def batch_check_in( try: result = CheckInService.batch_check_in_tasks(request.task_ids, db) return result + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量打卡失败: {str(e)}" @@ -235,6 +247,8 @@ async def get_system_stats( "tokens": {"expiring_soon": expiring_users}, } + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}" @@ -251,6 +265,8 @@ async def get_pending_users( try: users = AdminService.get_pending_users(db) return users + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -274,7 +290,7 @@ async def approve_user( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"]) return result - except HTTPException: + except EXPECTED_API_EXCEPTIONS: raise except Exception as e: raise HTTPException( @@ -298,7 +314,7 @@ async def reject_user( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"]) return result - except HTTPException: + except EXPECTED_API_EXCEPTIONS: raise except Exception as e: raise HTTPException( diff --git a/apps/backend/api/auth.py b/apps/backend/api/auth.py index 6923c39..c19592e 100644 --- a/apps/backend/api/auth.py +++ b/apps/backend/api/auth.py @@ -12,10 +12,11 @@ from backend.schemas.auth import ( AliasLoginResponse, ) from backend.services.auth_service import AuthService -from backend.exceptions import BusinessLogicError +from backend.exceptions import BaseAPIException, BusinessLogicError from backend.limiter import limiter router = APIRouter() +EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException) @router.post("/request_qrcode", response_model=dict, summary="请求 QQ 扫码二维码") @@ -68,6 +69,8 @@ async def request_qrcode( ) return result + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建扫码会话失败: {str(e)}" @@ -95,6 +98,8 @@ async def get_qrcode_status(session_id: str, db: Session = Depends(get_db)): try: result = AuthService.get_qrcode_status(session_id, db) return result + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"查询扫码状态失败: {str(e)}" @@ -113,6 +118,8 @@ async def cancel_qrcode_session(session_id: str): try: result = AuthService.cancel_qrcode_session(session_id) return result + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"取消会话失败: {str(e)}" @@ -136,6 +143,8 @@ async def verify_token(request: TokenVerifyRequest, db: Session = Depends(get_db try: result = AuthService.verify_token(request.authorization, db) return result + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"验证 Token 失败: {str(e)}" @@ -170,6 +179,8 @@ async def alias_login( try: result = AuthService.alias_login(login_data.alias, login_data.password, db) return result + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"别名登录失败: {str(e)}" diff --git a/apps/backend/api/check_in.py b/apps/backend/api/check_in.py index 3c71843..1af13b4 100644 --- a/apps/backend/api/check_in.py +++ b/apps/backend/api/check_in.py @@ -12,8 +12,10 @@ from backend.schemas.check_in import ( from backend.services.check_in_service import CheckInService from backend.services.task_service import TaskService from backend.dependencies import get_current_user, get_current_admin_user +from backend.exceptions import BaseAPIException router = APIRouter() +EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException) @router.post("/manual/{task_id}", summary="手动触发打卡(异步)") @@ -38,6 +40,8 @@ async def manual_check_in( try: result = CheckInService.start_async_check_in(task, "manual", db) return result + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"启动打卡任务失败: {str(e)}" @@ -111,6 +115,8 @@ async def get_task_check_in_records( task_id, db, skip, limit, status_filter, trigger_type ) return PaginatedResponse(records=records, total=total, skip=skip, limit=limit) + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}" @@ -145,6 +151,8 @@ async def get_my_check_in_records( current_user.id, db, skip, limit, status_filter, trigger_type ) return PaginatedResponse(records=records, total=total, skip=skip, limit=limit) + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}" @@ -181,6 +189,8 @@ async def get_all_check_in_records( CheckInService.enrich_record_with_user_task_info(record, db) for record in records ] return PaginatedResponse(records=enriched_records, total=total, skip=skip, limit=limit) + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}" @@ -213,6 +223,8 @@ async def get_check_in_records_count( total = query.count() return {"total": total} + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}" diff --git a/apps/backend/api/tasks.py b/apps/backend/api/tasks.py index b1fb18b..3c6a56e 100644 --- a/apps/backend/api/tasks.py +++ b/apps/backend/api/tasks.py @@ -4,6 +4,7 @@ from typing import List from datetime import datetime, timedelta from pydantic import BaseModel, Field +from backend.exceptions import BaseAPIException from backend.models import get_db, User from backend.schemas.task import TaskUpdate, TaskResponse from backend.services.task_service import TaskService @@ -37,6 +38,8 @@ async def get_tasks( # 为每个任务添加额外信息 enriched_tasks = [TaskService.enrich_task_with_check_in_info(task, db) for task in tasks] return enriched_tasks + except (BaseAPIException, HTTPException): + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}" @@ -173,6 +176,8 @@ async def validate_cron_expression(request: CronValidateRequest): "next_times": next_times, "description": generate_cron_description(cron_expr), } + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的 Crontab 表达式: {str(e)}" diff --git a/apps/backend/api/templates.py b/apps/backend/api/templates.py index 4c19933..8d91322 100644 --- a/apps/backend/api/templates.py +++ b/apps/backend/api/templates.py @@ -4,6 +4,7 @@ from sqlalchemy.orm import Session from backend.models import User from backend.dependencies import get_db, get_current_user, get_current_admin_user +from backend.exceptions import BaseAPIException from backend.schemas.template import ( TemplateCreate, TemplateUpdate, @@ -17,6 +18,9 @@ from backend.services.template_service import TemplateService router = APIRouter() +EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException) + + @router.get("/", response_model=List[TemplateResponse], summary="获取所有模板列表") async def get_all_templates( skip: int = Query(0, ge=0, description="跳过记录数"), @@ -35,6 +39,8 @@ async def get_all_templates( try: templates = TemplateService.get_all_templates(db, skip, limit, is_active) return templates + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}" @@ -57,6 +63,8 @@ async def get_active_templates( try: templates = TemplateService.get_all_templates(db, skip, limit, is_active=True) return templates + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}" @@ -115,6 +123,8 @@ async def preview_template( "preview_payload": preview_payload, "field_config": merged_config, } + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"生成预览失败: {str(e)}" diff --git a/apps/backend/api/users.py b/apps/backend/api/users.py index 6590f07..b74c7ff 100644 --- a/apps/backend/api/users.py +++ b/apps/backend/api/users.py @@ -14,9 +14,15 @@ from backend.schemas.task import TaskResponse from backend.services.user_service import UserService from backend.services.task_service import TaskService from backend.dependencies import get_current_user, get_current_admin_user -from backend.exceptions import ValidationError, AuthorizationError, ResourceNotFoundError +from backend.exceptions import ( + AuthorizationError, + BaseAPIException, + ResourceNotFoundError, + ValidationError, +) router = APIRouter() +EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException) @router.post( @@ -42,6 +48,8 @@ async def create_user( return user except ValueError as e: raise ValidationError(str(e)) + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建用户失败: {str(e)}" @@ -103,6 +111,8 @@ async def update_current_user_profile( return user except ValueError as e: raise ValidationError(str(e)) + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新个人信息失败: {str(e)}" @@ -144,6 +154,8 @@ async def get_current_user_tasks( try: tasks = TaskService.get_user_tasks(current_user.id, db, include_inactive) return tasks + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}" @@ -170,6 +182,8 @@ async def get_all_users( try: users = UserService.get_all_users(db, skip, limit, search, role) return users + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取用户列表失败: {str(e)}" @@ -252,6 +266,8 @@ async def update_user( return user except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新用户失败: {str(e)}" @@ -272,6 +288,8 @@ async def delete_user( return None except ValueError as e: raise ResourceNotFoundError(str(e)) + except EXPECTED_API_EXCEPTIONS: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"删除用户失败: {str(e)}" diff --git a/apps/backend/models/check_in_task.py b/apps/backend/models/check_in_task.py index 1532dd6..4af3d9a 100644 --- a/apps/backend/models/check_in_task.py +++ b/apps/backend/models/check_in_task.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import TYPE_CHECKING -from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text +from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.sql import func @@ -23,6 +23,12 @@ class CheckInTask(Base): index=True, comment="用户 ID", ) + thread_id: Mapped[str | None] = mapped_column( + String(100), + index=True, + nullable=True, + comment="接龙项目 ID", + ) payload_config: Mapped[str] = mapped_column( Text, default="{}", @@ -57,6 +63,7 @@ class CheckInTask(Base): # 添加索引:加速查询 __table_args__ = ( Index("ix_task_user_active", "user_id", "is_active"), + UniqueConstraint("user_id", "thread_id", name="uq_task_user_thread_id"), Index("ix_task_cron", "cron_expression"), # 加速查询启用了定时打卡的任务 ) diff --git a/apps/backend/models/database.py b/apps/backend/models/database.py index 14dab14..ff1e36a 100644 --- a/apps/backend/models/database.py +++ b/apps/backend/models/database.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone -from sqlalchemy import create_engine, event +from sqlalchemy import DateTime, create_engine, event, inspect from sqlalchemy.orm import DeclarativeBase, sessionmaker from backend.config import settings @@ -24,21 +24,23 @@ class Base(DeclarativeBase): @event.listens_for(Base, "load", propagate=True) def receive_load(target, context): """在从数据库加载对象后,将所有 datetime 字段转换为 timezone-aware (UTC)""" - for attr_name in dir(target): - # 跳过私有属性和方法 - if attr_name.startswith("_"): + mapper = inspect(target).mapper + for attr in mapper.column_attrs: + column = attr.columns[0] + if not isinstance(column.type, DateTime): continue try: - attr_value = getattr(target, attr_name) - - # 如果是 naive datetime,添加 UTC timezone - if isinstance(attr_value, datetime) and attr_value.tzinfo is None: - setattr(target, attr_name, attr_value.replace(tzinfo=timezone.utc)) + attr_value = getattr(target, attr.key) except (AttributeError, TypeError): - # 某些属性可能无法访问或设置,跳过 continue + if isinstance(attr_value, datetime) and attr_value.tzinfo is None: + try: + setattr(target, attr.key, attr_value.replace(tzinfo=timezone.utc)) + except (AttributeError, TypeError): + continue + def get_db(): """依赖注入:获取数据库会话""" diff --git a/apps/backend/scripts/migrate_add_task_thread_id.py b/apps/backend/scripts/migrate_add_task_thread_id.py new file mode 100644 index 0000000..8763ce9 --- /dev/null +++ b/apps/backend/scripts/migrate_add_task_thread_id.py @@ -0,0 +1,113 @@ +""" +数据库迁移脚本:添加打卡任务 thread_id 字段并回填。 + +运行方式: + uv run python -m backend.scripts.migrate_add_task_thread_id +""" + +import json +import logging +import sys +from pathlib import Path + +APPS_DIR = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(APPS_DIR)) + +from sqlalchemy import text + +from backend.models.database import engine + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def _extract_thread_id(payload_config: str | None) -> str | None: + if not payload_config: + return None + try: + payload = json.loads(payload_config) + except json.JSONDecodeError: + return None + if not isinstance(payload, dict): + return None + thread_id = payload.get("ThreadId") + value = str(thread_id).strip() if thread_id is not None else "" + return value or None + + +def migrate() -> None: + """执行迁移。""" + logger.info("开始迁移:添加 check_in_tasks.thread_id 字段...") + + with engine.connect() as conn: + result = conn.execute(text("PRAGMA table_info(check_in_tasks)")) + columns = [row[1] for row in result] + + if "thread_id" not in columns: + logger.info("添加 thread_id 字段...") + conn.execute(text("ALTER TABLE check_in_tasks ADD COLUMN thread_id VARCHAR(100)")) + conn.commit() + logger.info("✓ thread_id 字段添加成功") + else: + logger.info("✓ thread_id 字段已存在,跳过") + + rows = conn.execute(text("SELECT id, payload_config FROM check_in_tasks")).fetchall() + invalid_ids: list[int] = [] + seen: dict[tuple[int, str], int] = {} + duplicate_ids: list[int] = [] + + full_rows = conn.execute( + text("SELECT id, user_id, payload_config FROM check_in_tasks") + ).fetchall() + for row in full_rows: + thread_id = _extract_thread_id(row.payload_config) + if not thread_id: + invalid_ids.append(row.id) + continue + key = (row.user_id, thread_id) + if key in seen: + duplicate_ids.append(row.id) + else: + seen[key] = row.id + + if invalid_ids or duplicate_ids: + messages = [] + if invalid_ids: + messages.append(f"payload_config 缺少有效 ThreadId 的任务: {invalid_ids}") + if duplicate_ids: + messages.append(f"同用户 ThreadId 重复的任务: {duplicate_ids}") + raise RuntimeError(";".join(messages)) + + for row in rows: + thread_id = _extract_thread_id(row.payload_config) + if thread_id: + conn.execute( + text("UPDATE check_in_tasks SET thread_id = :thread_id WHERE id = :id"), + {"thread_id": thread_id, "id": row.id}, + ) + conn.commit() + + indexes = conn.execute(text("PRAGMA index_list(check_in_tasks)")).fetchall() + index_names = [row[1] for row in indexes] + if "ix_task_user_thread_id_unique" not in index_names: + logger.info("添加用户级 thread_id 唯一索引...") + conn.execute( + text( + "CREATE UNIQUE INDEX ix_task_user_thread_id_unique " + "ON check_in_tasks (user_id, thread_id)" + ) + ) + conn.commit() + logger.info("✓ 用户级 thread_id 唯一索引添加成功") + else: + logger.info("✓ 用户级 thread_id 唯一索引已存在,跳过") + + logger.info("✅ 迁移完成!任务 thread_id 身份字段已启用") + + +if __name__ == "__main__": + try: + migrate() + except Exception as e: + logger.error(f"❌ 迁移失败: {e}") + sys.exit(1) diff --git a/apps/backend/services/check_in_service.py b/apps/backend/services/check_in_service.py index 4983e9d..ff39260 100644 --- a/apps/backend/services/check_in_service.py +++ b/apps/backend/services/check_in_service.py @@ -541,10 +541,10 @@ class CheckInService: ) task_name = task.name - # 从 payload_config 提取 ThreadId - from backend.utils.json_helpers import extract_thread_id + # 优先从显式任务身份读取 ThreadId,兼容旧数据时回退到 payload_config + from backend.utils.json_helpers import resolve_task_thread_id - thread_id = extract_thread_id(task.payload_config) # type: ignore + thread_id = resolve_task_thread_id(task) # 转换为字典并添加额外字段 record_dict = { diff --git a/apps/backend/services/scheduler_service.py b/apps/backend/services/scheduler_service.py index ec27697..e8a6c0c 100644 --- a/apps/backend/services/scheduler_service.py +++ b/apps/backend/services/scheduler_service.py @@ -20,6 +20,70 @@ scheduler = None scheduler_lock = None +def _task_job_id(task_id: int) -> str: + return f"task_{task_id}" + + +def _remove_task_job(job_id: str, scheduler_instance=None) -> bool: + active_scheduler = scheduler_instance or scheduler + if not active_scheduler: + return False + + existing_job = active_scheduler.get_job(job_id) + if not existing_job: + return False + + active_scheduler.remove_job(job_id) + logger.info(f"✅ 已移除调度任务: {job_id}") + return True + + +def remove_scheduled_task(task_id: int, scheduler_instance=None) -> bool: + """从调度器移除指定任务。""" + active_scheduler = scheduler_instance or scheduler + if not active_scheduler: + logger.warning(f"调度器未启动,无法移除任务 {task_id}") + return False + + return _remove_task_job(_task_job_id(task_id), active_scheduler) + + +def sync_scheduled_task(task: CheckInTask, scheduler_instance=None) -> bool: + """ + 根据任务当前状态同步调度器中的对应 job。 + + 返回 True 表示任务已成功安排到调度器,False 表示未安排或调度器不可用。 + """ + active_scheduler = scheduler_instance or scheduler + if not active_scheduler: + logger.warning(f"调度器未启动,无法同步任务 {task.id}") + return False + + job_id = _task_job_id(task.id) + _remove_task_job(job_id, active_scheduler) + + if not task.is_scheduled_enabled: + logger.info(f"任务 {task.id} 未启用或无 cron 表达式,已从调度器移除") + return False + + cron_str = str(task.cron_expression or "").strip() + if not cron_str or not croniter.is_valid(cron_str): + logger.warning(f"任务 {task.id} 的 cron 表达式无效: {cron_str}") + return False + + active_scheduler.add_job( + func=scheduled_check_in_task, + trigger=CronTrigger.from_crontab(cron_str), + id=job_id, + name=f"CheckIn-Task-{task.id}", + args=[task.id], + replace_existing=True, + ) + + logger.info(f"✅ 任务 {task.id} 已同步到调度器: {cron_str}") + return True + + def load_scheduled_tasks(db: Session, scheduler_instance): """ 从数据库加载所有启用的定时任务并添加到 APScheduler @@ -55,33 +119,10 @@ def load_scheduled_tasks(db: Session, scheduler_instance): for task in tasks: try: - # 验证 cron 表达式 - cron_str = str(task.cron_expression) if task.cron_expression else None - if not cron_str or not croniter.is_valid(cron_str): - logger.warning(f"跳过任务 {task.id}: 无效的 cron 表达式 '{task.cron_expression}'") + if sync_scheduled_task(task, scheduler_instance): + loaded_count += 1 + else: skipped_count += 1 - continue - - # 创建任务 ID - job_id = f"task_{task.id}" - - # 检查任务是否已存在 - if scheduler_instance.get_job(job_id): - logger.debug(f"任务 {task.id} 已存在,跳过") - continue - - # 添加任务到调度器 - scheduler_instance.add_job( - func=scheduled_check_in_task, - trigger=CronTrigger.from_crontab(cron_str), - id=job_id, - name=f"CheckIn-Task-{task.id}", - args=[task.id], - replace_existing=True, - ) - - logger.info(f"✅ 加载任务 {task.id}: {task.name} (Cron: {task.cron_expression})") - loaded_count += 1 except Exception as e: logger.error(f"❌ 加载任务 {task.id} 时出错: {str(e)}") diff --git a/apps/backend/services/task_service.py b/apps/backend/services/task_service.py index d181215..5b2bda7 100644 --- a/apps/backend/services/task_service.py +++ b/apps/backend/services/task_service.py @@ -3,6 +3,12 @@ from typing import List, Optional, Dict, Any from sqlalchemy.orm import Session from sqlalchemy import desc +from backend.exceptions import ( + InternalServerError, + ResourceConflictError, + ResourceNotFoundError, + ValidationError as APIValidationError, +) from backend.models import User, CheckInTask, CheckInRecord from backend.schemas.task import TaskCreate, TaskUpdate @@ -12,6 +18,68 @@ logger = logging.getLogger(__name__) class TaskService: """打卡任务服务""" + @staticmethod + def _normalize_thread_id(thread_id: Any) -> str: + value = str(thread_id).strip() if thread_id is not None else "" + if not value: + raise APIValidationError( + "payload_config 必须包含有效的 ThreadId 字段", + error_code="TASK_IDENTITY_INVALID", + ) + return value + + @staticmethod + def _extract_thread_id_from_payload(payload_config: str) -> str: + from backend.utils.json_helpers import safe_parse_payload + + payload = safe_parse_payload(payload_config) + return TaskService._normalize_thread_id(payload.get("ThreadId")) + + @staticmethod + def _resolve_thread_id_for_task(task: CheckInTask) -> Optional[str]: + thread_id = getattr(task, "thread_id", None) + if thread_id is not None: + value = str(thread_id).strip() + if value: + return value + + from backend.utils.json_helpers import extract_thread_id + + legacy_thread_id = extract_thread_id(task.payload_config) + if legacy_thread_id is None: + return None + value = str(legacy_thread_id).strip() + return value or None + + @staticmethod + def _ensure_unique_thread_id( + db: Session, user_id: int, thread_id: str, exclude_task_id: int | None = None + ) -> None: + query = db.query(CheckInTask.id).filter( + CheckInTask.user_id == user_id, CheckInTask.thread_id == thread_id + ) + if exclude_task_id is not None: + query = query.filter(CheckInTask.id != exclude_task_id) + + conflict = query.first() + if conflict: + raise ResourceConflictError( + f"该接龙中已存在任务。ThreadId: {thread_id}", + error_code="TASK_IDENTITY_CONFLICT", + ) + + @staticmethod + def _sync_scheduler_for_task(task: CheckInTask) -> None: + from backend.services.scheduler_service import sync_scheduled_task + + sync_scheduled_task(task) + + @staticmethod + def _remove_scheduler_for_task(task_id: int) -> None: + from backend.services.scheduler_service import remove_scheduled_task + + remove_scheduled_task(task_id) + @staticmethod def create_task(user_id: int, task_data: TaskCreate, db: Session) -> CheckInTask: """ @@ -25,34 +93,16 @@ class TaskService: Returns: 创建的任务对象 """ - import json - # 1. 检查用户是否存在 user = db.query(User).filter(User.id == user_id).first() if not user: - raise ValueError(f"用户 ID {user_id} 不存在") + raise ResourceNotFoundError(f"用户 ID {user_id} 不存在", error_code="USER_NOT_FOUND") # 2. 从 payload_config 中提取 ThreadId 用于唯一性校验 - from backend.utils.json_helpers import safe_parse_payload, extract_thread_id - - payload = safe_parse_payload(task_data.payload_config) - thread_id = payload.get("ThreadId") - if not thread_id: - raise ValueError("payload_config 中缺少 ThreadId") + thread_id = TaskService._extract_thread_id_from_payload(task_data.payload_config) # 3. 验证唯一性:同一用户在同一个接龙中不能有重复的任务 - existing_tasks = ( - db.query(CheckInTask.payload_config).filter(CheckInTask.user_id == user_id).all() - ) - - for (payload_config,) in existing_tasks: - existing_thread_id = extract_thread_id(payload_config) - # extract_thread_id 已处理异常,失败时返回 None - if existing_thread_id and existing_thread_id == thread_id: - logger.warning( - f"⚠️ 任务创建冲突 - User: {user.alias}({user_id}), ThreadId: {thread_id}" - ) - raise ValueError(f"该接龙中已存在任务。ThreadId: {thread_id}") + TaskService._ensure_unique_thread_id(db, user_id, thread_id) # 4. 记录日志 task_name = task_data.name or f"接龙任务 {thread_id}" @@ -61,9 +111,11 @@ class TaskService: # 5. 创建任务 task = CheckInTask( user_id=user_id, + thread_id=thread_id, payload_config=task_data.payload_config, name=task_data.name or task_name, is_active=task_data.is_active if task_data.is_active is not None else True, + cron_expression=task_data.cron_expression or "0 20 * * *", ) try: @@ -75,14 +127,19 @@ class TaskService: ) # 如果任务启用且包含 cron_expression,立即添加到调度器 - if task.is_scheduled_enabled: - TaskService._reload_scheduler_for_task(task, db) + TaskService._sync_scheduler_for_task(task) return task + except APIValidationError: + db.rollback() + raise + except ResourceConflictError: + db.rollback() + raise except Exception as e: db.rollback() logger.error(f"❌ 任务创建失败: {str(e)}") - raise ValueError(f"任务创建失败: {str(e)}") + raise InternalServerError(f"任务创建失败: {str(e)}") @staticmethod def get_task(task_id: int, db: Session) -> Optional[CheckInTask]: @@ -110,8 +167,6 @@ class TaskService: Returns: 包含额外信息的任务字典 """ - from backend.utils.json_helpers import extract_thread_id - # 获取最后一次打卡记录 last_record = ( db.query(CheckInRecord) @@ -120,8 +175,8 @@ class TaskService: .first() ) - # 从 payload_config 提取 ThreadId - thread_id = extract_thread_id(task.payload_config) # type: ignore + # 优先使用持久化的 ThreadId,兼容旧数据时回退到 payload_config + thread_id = TaskService._resolve_thread_id_for_task(task) # 转换为字典并添加额外字段 task_dict = { @@ -192,7 +247,7 @@ class TaskService: task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first() if not task: - return None + raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND") # 更新字段 update_data = task_data.model_dump(exclude_unset=True) @@ -201,17 +256,34 @@ class TaskService: cron_changed = "cron_expression" in update_data active_changed = "is_active" in update_data + if "payload_config" in update_data: + new_thread_id = TaskService._extract_thread_id_from_payload( + update_data["payload_config"] + ) + TaskService._ensure_unique_thread_id( + db, task.user_id, new_thread_id, exclude_task_id=task.id + ) + task.thread_id = new_thread_id + for field, value in update_data.items(): setattr(task, field, value) - db.commit() - db.refresh(task) + try: + db.commit() + db.refresh(task) + except ResourceConflictError: + db.rollback() + raise + except Exception as e: + db.rollback() + logger.error(f"任务 {task_id} 更新失败: {str(e)}") + raise InternalServerError(f"任务更新失败: {str(e)}") logger.info(f"任务 {task_id} 已更新") # 如果 cron_expression 或 is_active 发生变化,重新加载调度器 if cron_changed or active_changed: - TaskService._reload_scheduler_for_task(task, db) + TaskService._sync_scheduler_for_task(task) return task @@ -230,15 +302,20 @@ class TaskService: task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first() if not task: - return False + raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND") - db.delete(task) - db.commit() + try: + db.delete(task) + db.commit() + except Exception as e: + db.rollback() + logger.error(f"任务 {task_id} 删除失败: {str(e)}") + raise InternalServerError(f"任务删除失败: {str(e)}") logger.info(f"任务 {task_id} 已删除") # 从调度器中移除该任务 - TaskService._remove_task_from_scheduler(task_id) + TaskService._remove_scheduler_for_task(task_id) return True @@ -257,16 +334,21 @@ class TaskService: task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first() if not task: - return None + raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND") task.is_active = not task.is_active - db.commit() - db.refresh(task) + try: + db.commit() + db.refresh(task) + except Exception as e: + db.rollback() + logger.error(f"任务 {task_id} 状态切换失败: {str(e)}") + raise InternalServerError(f"任务状态切换失败: {str(e)}") logger.info(f"任务 {task_id} 状态已切换为: {'启用' if task.is_active else '禁用'}") # 重新加载调度器 - TaskService._reload_scheduler_for_task(task, db) + TaskService._sync_scheduler_for_task(task) return task @@ -322,42 +404,7 @@ class TaskService: db: 数据库会话 """ try: - from backend.services.scheduler_service import scheduler - from apscheduler.triggers.cron import CronTrigger - from croniter import croniter - - if not scheduler: - logger.warning(f"调度器未启动,无法加载任务 {task.id}") - return - - job_id = f"task_{task.id}" - - # 先移除旧的任务(如果存在) - existing_job = scheduler.get_job(job_id) - if existing_job: - scheduler.remove_job(job_id) - logger.info(f"从调度器移除旧任务: {job_id}") - - # 如果任务启用且有有效的 cron 表达式,添加新任务 - if task.is_scheduled_enabled: - cron_str = str(task.cron_expression) - if croniter.is_valid(cron_str): - from backend.services.scheduler_service import scheduled_check_in_task - - scheduler.add_job( - func=scheduled_check_in_task, - trigger=CronTrigger.from_crontab(cron_str), - id=job_id, - name=f"CheckIn-Task-{task.id}", - args=[task.id], - replace_existing=True, - ) - logger.info(f"✅ 任务 {task.id} 已重新加载到调度器: {cron_str}") - else: - logger.warning(f"任务 {task.id} 的 cron 表达式无效: {cron_str}") - else: - logger.info(f"任务 {task.id} 未启用或无 cron 表达式,已从调度器移除") - + TaskService._sync_scheduler_for_task(task) except Exception as e: logger.error(f"重新加载任务 {task.id} 到调度器失败: {str(e)}") @@ -370,15 +417,6 @@ class TaskService: task_id: 任务 ID """ try: - from backend.services.scheduler_service import scheduler - - if not scheduler: - return - - job_id = f"task_{task_id}" - if scheduler.get_job(job_id): - scheduler.remove_job(job_id) - logger.info(f"✅ 任务 {task_id} 已从调度器移除") - + TaskService._remove_scheduler_for_task(task_id) except Exception as e: logger.error(f"从调度器移除任务 {task_id} 失败: {str(e)}") diff --git a/apps/backend/services/template_service.py b/apps/backend/services/template_service.py index 2ff734d..c063894 100644 --- a/apps/backend/services/template_service.py +++ b/apps/backend/services/template_service.py @@ -4,7 +4,9 @@ from typing import List, Dict, Any, Optional from sqlalchemy.orm import Session from fastapi import HTTPException, status -from backend.models import TaskTemplate, CheckInTask +from backend.exceptions import BaseAPIException +from backend.models import CheckInTask, TaskTemplate +from backend.schemas.task import TaskCreate from backend.schemas.template import TemplateCreate, TemplateUpdate logger = logging.getLogger(__name__) @@ -134,6 +136,9 @@ class TemplateService: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}" ) + except (BaseAPIException, HTTPException): + db.rollback() + raise except Exception as e: logger.error(f"创建模板失败: {str(e)}") db.rollback() @@ -219,6 +224,9 @@ class TemplateService: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}" ) + except (BaseAPIException, HTTPException): + db.rollback() + raise except Exception as e: logger.error(f"更新模板失败: {str(e)}") db.rollback() @@ -247,6 +255,9 @@ class TemplateService: db.commit() logger.info(f"删除模板成功: {template.name} (ID: {template_id})") return True + except (BaseAPIException, HTTPException): + db.rollback() + raise except Exception as e: logger.error(f"删除模板失败: {str(e)}") db.rollback() @@ -424,6 +435,10 @@ class TemplateService: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"解析模板配置失败" ) + except BaseAPIException: + raise + except HTTPException: + raise except Exception as e: logger.error(f"组装 payload 失败: {str(e)}") raise HTTPException( @@ -518,31 +533,32 @@ class TemplateService: signature = payload.get("Signature", "Unknown") task_name = f"{template.name} - {signature}" - # 创建任务(包含 cron_expression) try: - task = CheckInTask( + from backend.services.task_service import TaskService + + task = TaskService.create_task( user_id=user_id, - payload_config=json.dumps(payload, ensure_ascii=False), - name=task_name, - is_active=True, - cron_expression=cron_expression or "0 20 * * *", + task_data=TaskCreate( + payload_config=json.dumps(payload, ensure_ascii=False), + name=task_name, + is_active=True, + cron_expression=cron_expression or "0 20 * * *", + ), + db=db, ) - db.add(task) - db.commit() - db.refresh(task) logger.info( f"从模板创建任务成功: {task.name} (ID: {task.id}, 模板: {template.name}, ThreadId: {thread_id})" ) - # 如果任务启用且包含 cron_expression,立即添加到调度器 - if task.is_scheduled_enabled: - from backend.services.task_service import TaskService - - TaskService._reload_scheduler_for_task(task, db) - return task + except BaseAPIException: + db.rollback() + raise + except HTTPException: + db.rollback() + raise except Exception as e: logger.error(f"从模板创建任务失败: {str(e)}") db.rollback() diff --git a/apps/backend/utils/json_helpers.py b/apps/backend/utils/json_helpers.py index 117d94c..cd40a02 100644 --- a/apps/backend/utils/json_helpers.py +++ b/apps/backend/utils/json_helpers.py @@ -92,6 +92,29 @@ def build_task_info(task) -> Dict[str, str]: 包含 thread_id 和 name 的字典 """ return { - "thread_id": extract_thread_id(getattr(task, "payload_config", None)) or "未知", + "thread_id": resolve_task_thread_id(task) or "未知", "name": getattr(task, "name", None) or f"Task-{getattr(task, 'id', 'Unknown')}", } + + +def resolve_task_thread_id(task) -> Optional[str]: + """ + 优先从显式字段读取任务 ThreadId,兼容旧数据时回退到 payload_config。 + + Args: + task: CheckInTask 或相似对象 + + Returns: + ThreadId 或 None + """ + thread_id = getattr(task, "thread_id", None) + if thread_id is not None: + value = str(thread_id).strip() + if value: + return value + + legacy_thread_id = extract_thread_id(getattr(task, "payload_config", None)) + if legacy_thread_id is None: + return None + value = str(legacy_thread_id).strip() + return value or None diff --git a/tests/test_backend_structure_boundaries.py b/tests/test_backend_structure_boundaries.py new file mode 100644 index 0000000..c5615cb --- /dev/null +++ b/tests/test_backend_structure_boundaries.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import json +from datetime import datetime +from typing import Any + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from backend.exceptions import ResourceConflictError +from backend.models import Base, CheckInTask, User +from backend.schemas.task import TaskCreate, TaskUpdate +from backend.services import scheduler_service +from backend.services.task_service import TaskService + + +@pytest.fixture() +def db_session(): + engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False}) + Base.metadata.create_all(bind=engine) + Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) + session = Session() + try: + yield session + finally: + session.close() + engine.dispose() + + +def add_user(db_session, alias: str) -> User: + user = User(alias=alias) + db_session.add(user) + db_session.commit() + db_session.refresh(user) + return user + + +def payload(thread_id: str, **extra: Any) -> str: + return json.dumps({"ThreadId": thread_id, **extra}, ensure_ascii=False) + + +def test_create_task_persists_thread_id_identity(db_session) -> None: + user = add_user(db_session, "alice") + + task = TaskService.create_task( + user.id, + TaskCreate( + payload_config=payload("thread-1", Signature="sig"), + name="Morning", + cron_expression="0 8 * * *", + ), + db_session, + ) + + assert task.thread_id == "thread-1" + assert task.payload_config == payload("thread-1", Signature="sig") + assert task.cron_expression == "0 8 * * *" + + +def test_same_user_duplicate_thread_id_is_structured_conflict(db_session) -> None: + user = add_user(db_session, "alice") + first = TaskCreate(payload_config=payload("thread-1"), name="First") + duplicate = TaskCreate(payload_config=payload("thread-1"), name="Duplicate") + + TaskService.create_task(user.id, first, db_session) + + with pytest.raises(ResourceConflictError) as exc_info: + TaskService.create_task(user.id, duplicate, db_session) + + assert exc_info.value.status_code == 409 + assert exc_info.value.error_code == "TASK_IDENTITY_CONFLICT" + + +def test_different_users_can_share_thread_id(db_session) -> None: + alice = add_user(db_session, "alice") + bob = add_user(db_session, "bob") + + alice_task = TaskService.create_task( + alice.id, TaskCreate(payload_config=payload("shared-thread")), db_session + ) + bob_task = TaskService.create_task( + bob.id, TaskCreate(payload_config=payload("shared-thread")), db_session + ) + + assert alice_task.thread_id == "shared-thread" + assert bob_task.thread_id == "shared-thread" + assert alice_task.user_id != bob_task.user_id + + +def test_update_task_rejects_duplicate_thread_id(db_session) -> None: + user = add_user(db_session, "alice") + first = TaskService.create_task( + user.id, TaskCreate(payload_config=payload("thread-1")), db_session + ) + second = TaskService.create_task( + user.id, TaskCreate(payload_config=payload("thread-2")), db_session + ) + + with pytest.raises(ResourceConflictError): + TaskService.update_task( + second.id, + TaskUpdate(payload_config=payload("thread-1")), + db_session, + ) + + db_session.refresh(first) + db_session.refresh(second) + assert first.thread_id == "thread-1" + assert second.thread_id == "thread-2" + + +def test_task_enrichment_uses_stored_thread_id(db_session) -> None: + user = add_user(db_session, "alice") + task = CheckInTask( + user_id=user.id, + thread_id="stored-thread", + payload_config=payload("payload-thread"), + name="Task", + ) + db_session.add(task) + db_session.commit() + db_session.refresh(task) + + enriched = TaskService.enrich_task_with_check_in_info(task, db_session) + + assert enriched["thread_id"] == "stored-thread" + + +def test_template_created_duplicate_task_preserves_structured_conflict(db_session) -> None: + from backend.models import TaskTemplate + from backend.services.template_service import TemplateService + + user = add_user(db_session, "alice") + template = TaskTemplate(name="Daily", field_config="{}", is_active=True) + db_session.add(template) + db_session.commit() + db_session.refresh(template) + + TemplateService.create_task_from_template( + template_id=template.id, + thread_id="thread-1", + field_values={}, + user_id=user.id, + task_name="First", + db=db_session, + ) + + with pytest.raises(ResourceConflictError) as exc_info: + TemplateService.create_task_from_template( + template_id=template.id, + thread_id="thread-1", + field_values={}, + user_id=user.id, + task_name="Duplicate", + db=db_session, + ) + + assert exc_info.value.status_code == 409 + assert exc_info.value.error_code == "TASK_IDENTITY_CONFLICT" + + +class FakeScheduler: + def __init__(self) -> None: + self.jobs: dict[str, object] = {} + self.added: list[str] = [] + self.removed: list[str] = [] + + def get_job(self, job_id: str): + return self.jobs.get(job_id) + + def remove_job(self, job_id: str) -> None: + self.jobs.pop(job_id, None) + self.removed.append(job_id) + + def add_job(self, **kwargs) -> None: + job_id = kwargs["id"] + self.jobs[job_id] = kwargs + self.added.append(job_id) + + +def test_scheduler_sync_skips_invalid_cron_and_removes_existing_job(monkeypatch) -> None: + fake_scheduler = FakeScheduler() + fake_scheduler.jobs["task_12"] = object() + monkeypatch.setattr(scheduler_service, "scheduler", fake_scheduler) + task = CheckInTask(id=12, thread_id="thread-1", payload_config=payload("thread-1")) + task.is_active = True + task.cron_expression = "invalid cron" + + scheduled = scheduler_service.sync_scheduled_task(task) + + assert scheduled is False + assert fake_scheduler.added == [] + assert fake_scheduler.removed == ["task_12"] + + +def test_scheduler_sync_is_noop_when_scheduler_unavailable(monkeypatch) -> None: + monkeypatch.setattr(scheduler_service, "scheduler", None) + task = CheckInTask(id=12, thread_id="thread-1", payload_config=payload("thread-1")) + task.is_active = True + task.cron_expression = "0 8 * * *" + + assert scheduler_service.sync_scheduled_task(task) is False + + +def test_datetime_normalization_does_not_access_unmapped_properties( + db_session, monkeypatch +) -> None: + user = add_user(db_session, "alice") + user_id = user.id + db_session.expunge_all() + + def explode(self): + raise RuntimeError("unmapped property accessed") + + monkeypatch.setattr(User, "explosive_property", property(explode), raising=False) + + loaded = db_session.query(User).filter(User.id == user_id).one() + + assert loaded.created_at.tzinfo is not None + + +@pytest.mark.asyncio +async def test_task_route_preserves_structured_service_errors(monkeypatch, db_session) -> None: + from backend.api.tasks import get_tasks + + user = User(id=1, alias="alice") + + def raise_conflict(*args, **kwargs): + raise ResourceConflictError("duplicate", error_code="TASK_IDENTITY_CONFLICT") + + monkeypatch.setattr(TaskService, "get_user_tasks", raise_conflict) + + with pytest.raises(ResourceConflictError): + await get_tasks(current_user=user, db=db_session)