feat(backend): harden task boundaries

This commit is contained in:
2026-05-05 00:55:29 +08:00
parent 817540f8a0
commit e243dccfd7
15 changed files with 694 additions and 147 deletions
+18 -2
View File
@@ -11,9 +11,11 @@ from backend.services.check_in_service import CheckInService
from backend.services.admin_service import AdminService
from backend.dependencies import get_current_admin_user
from backend.config import settings
from backend.exceptions import BaseAPIException
logger = logging.getLogger(__name__)
router = APIRouter()
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
class BatchToggleTasksRequest(BaseModel):
@@ -43,13 +45,21 @@ async def batch_toggle_tasks(
task.is_active = request.is_active
count += 1
from backend.services.scheduler_service import sync_scheduled_task
db.commit()
for task_id in request.task_ids:
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
if task:
sync_scheduled_task(task)
return {
"success": True,
"message": f"{'启用' if request.is_active else '禁用'} {count} 个任务",
"count": count,
}
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量操作失败: {str(e)}"
@@ -72,6 +82,8 @@ async def batch_check_in(
try:
result = CheckInService.batch_check_in_tasks(request.task_ids, db)
return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量打卡失败: {str(e)}"
@@ -235,6 +247,8 @@ async def get_system_stats(
"tokens": {"expiring_soon": expiring_users},
}
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}"
@@ -251,6 +265,8 @@ async def get_pending_users(
try:
users = AdminService.get_pending_users(db)
return users
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -274,7 +290,7 @@ async def approve_user(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"])
return result
except HTTPException:
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
@@ -298,7 +314,7 @@ async def reject_user(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"])
return result
except HTTPException:
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
+12 -1
View File
@@ -12,10 +12,11 @@ from backend.schemas.auth import (
AliasLoginResponse,
)
from backend.services.auth_service import AuthService
from backend.exceptions import BusinessLogicError
from backend.exceptions import BaseAPIException, BusinessLogicError
from backend.limiter import limiter
router = APIRouter()
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
@router.post("/request_qrcode", response_model=dict, summary="请求 QQ 扫码二维码")
@@ -68,6 +69,8 @@ async def request_qrcode(
)
return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建扫码会话失败: {str(e)}"
@@ -95,6 +98,8 @@ async def get_qrcode_status(session_id: str, db: Session = Depends(get_db)):
try:
result = AuthService.get_qrcode_status(session_id, db)
return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"查询扫码状态失败: {str(e)}"
@@ -113,6 +118,8 @@ async def cancel_qrcode_session(session_id: str):
try:
result = AuthService.cancel_qrcode_session(session_id)
return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"取消会话失败: {str(e)}"
@@ -136,6 +143,8 @@ async def verify_token(request: TokenVerifyRequest, db: Session = Depends(get_db
try:
result = AuthService.verify_token(request.authorization, db)
return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"验证 Token 失败: {str(e)}"
@@ -170,6 +179,8 @@ async def alias_login(
try:
result = AuthService.alias_login(login_data.alias, login_data.password, db)
return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"别名登录失败: {str(e)}"
+12
View File
@@ -12,8 +12,10 @@ from backend.schemas.check_in import (
from backend.services.check_in_service import CheckInService
from backend.services.task_service import TaskService
from backend.dependencies import get_current_user, get_current_admin_user
from backend.exceptions import BaseAPIException
router = APIRouter()
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
@router.post("/manual/{task_id}", summary="手动触发打卡(异步)")
@@ -38,6 +40,8 @@ async def manual_check_in(
try:
result = CheckInService.start_async_check_in(task, "manual", db)
return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"启动打卡任务失败: {str(e)}"
@@ -111,6 +115,8 @@ async def get_task_check_in_records(
task_id, db, skip, limit, status_filter, trigger_type
)
return PaginatedResponse(records=records, total=total, skip=skip, limit=limit)
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
@@ -145,6 +151,8 @@ async def get_my_check_in_records(
current_user.id, db, skip, limit, status_filter, trigger_type
)
return PaginatedResponse(records=records, total=total, skip=skip, limit=limit)
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
@@ -181,6 +189,8 @@ async def get_all_check_in_records(
CheckInService.enrich_record_with_user_task_info(record, db) for record in records
]
return PaginatedResponse(records=enriched_records, total=total, skip=skip, limit=limit)
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
@@ -213,6 +223,8 @@ async def get_check_in_records_count(
total = query.count()
return {"total": total}
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}"
+5
View File
@@ -4,6 +4,7 @@ from typing import List
from datetime import datetime, timedelta
from pydantic import BaseModel, Field
from backend.exceptions import BaseAPIException
from backend.models import get_db, User
from backend.schemas.task import TaskUpdate, TaskResponse
from backend.services.task_service import TaskService
@@ -37,6 +38,8 @@ async def get_tasks(
# 为每个任务添加额外信息
enriched_tasks = [TaskService.enrich_task_with_check_in_info(task, db) for task in tasks]
return enriched_tasks
except (BaseAPIException, HTTPException):
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}"
@@ -173,6 +176,8 @@ async def validate_cron_expression(request: CronValidateRequest):
"next_times": next_times,
"description": generate_cron_description(cron_expr),
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的 Crontab 表达式: {str(e)}"
+10
View File
@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
from backend.models import User
from backend.dependencies import get_db, get_current_user, get_current_admin_user
from backend.exceptions import BaseAPIException
from backend.schemas.template import (
TemplateCreate,
TemplateUpdate,
@@ -17,6 +18,9 @@ from backend.services.template_service import TemplateService
router = APIRouter()
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
@router.get("/", response_model=List[TemplateResponse], summary="获取所有模板列表")
async def get_all_templates(
skip: int = Query(0, ge=0, description="跳过记录数"),
@@ -35,6 +39,8 @@ async def get_all_templates(
try:
templates = TemplateService.get_all_templates(db, skip, limit, is_active)
return templates
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}"
@@ -57,6 +63,8 @@ async def get_active_templates(
try:
templates = TemplateService.get_all_templates(db, skip, limit, is_active=True)
return templates
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}"
@@ -115,6 +123,8 @@ async def preview_template(
"preview_payload": preview_payload,
"field_config": merged_config,
}
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"生成预览失败: {str(e)}"
+19 -1
View File
@@ -14,9 +14,15 @@ from backend.schemas.task import TaskResponse
from backend.services.user_service import UserService
from backend.services.task_service import TaskService
from backend.dependencies import get_current_user, get_current_admin_user
from backend.exceptions import ValidationError, AuthorizationError, ResourceNotFoundError
from backend.exceptions import (
AuthorizationError,
BaseAPIException,
ResourceNotFoundError,
ValidationError,
)
router = APIRouter()
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
@router.post(
@@ -42,6 +48,8 @@ async def create_user(
return user
except ValueError as e:
raise ValidationError(str(e))
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建用户失败: {str(e)}"
@@ -103,6 +111,8 @@ async def update_current_user_profile(
return user
except ValueError as e:
raise ValidationError(str(e))
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新个人信息失败: {str(e)}"
@@ -144,6 +154,8 @@ async def get_current_user_tasks(
try:
tasks = TaskService.get_user_tasks(current_user.id, db, include_inactive)
return tasks
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}"
@@ -170,6 +182,8 @@ async def get_all_users(
try:
users = UserService.get_all_users(db, skip, limit, search, role)
return users
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取用户列表失败: {str(e)}"
@@ -252,6 +266,8 @@ async def update_user(
return user
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新用户失败: {str(e)}"
@@ -272,6 +288,8 @@ async def delete_user(
return None
except ValueError as e:
raise ResourceNotFoundError(str(e))
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"删除用户失败: {str(e)}"
+8 -1
View File
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func
@@ -23,6 +23,12 @@ class CheckInTask(Base):
index=True,
comment="用户 ID",
)
thread_id: Mapped[str | None] = mapped_column(
String(100),
index=True,
nullable=True,
comment="接龙项目 ID",
)
payload_config: Mapped[str] = mapped_column(
Text,
default="{}",
@@ -57,6 +63,7 @@ class CheckInTask(Base):
# 添加索引:加速查询
__table_args__ = (
Index("ix_task_user_active", "user_id", "is_active"),
UniqueConstraint("user_id", "thread_id", name="uq_task_user_thread_id"),
Index("ix_task_cron", "cron_expression"), # 加速查询启用了定时打卡的任务
)
+12 -10
View File
@@ -1,6 +1,6 @@
from datetime import datetime, timezone
from sqlalchemy import create_engine, event
from sqlalchemy import DateTime, create_engine, event, inspect
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from backend.config import settings
@@ -24,21 +24,23 @@ class Base(DeclarativeBase):
@event.listens_for(Base, "load", propagate=True)
def receive_load(target, context):
"""在从数据库加载对象后,将所有 datetime 字段转换为 timezone-aware (UTC)"""
for attr_name in dir(target):
# 跳过私有属性和方法
if attr_name.startswith("_"):
mapper = inspect(target).mapper
for attr in mapper.column_attrs:
column = attr.columns[0]
if not isinstance(column.type, DateTime):
continue
try:
attr_value = getattr(target, attr_name)
# 如果是 naive datetime,添加 UTC timezone
if isinstance(attr_value, datetime) and attr_value.tzinfo is None:
setattr(target, attr_name, attr_value.replace(tzinfo=timezone.utc))
attr_value = getattr(target, attr.key)
except (AttributeError, TypeError):
# 某些属性可能无法访问或设置,跳过
continue
if isinstance(attr_value, datetime) and attr_value.tzinfo is None:
try:
setattr(target, attr.key, attr_value.replace(tzinfo=timezone.utc))
except (AttributeError, TypeError):
continue
def get_db():
"""依赖注入:获取数据库会话"""
@@ -0,0 +1,113 @@
"""
数据库迁移脚本:添加打卡任务 thread_id 字段并回填。
运行方式:
uv run python -m backend.scripts.migrate_add_task_thread_id
"""
import json
import logging
import sys
from pathlib import Path
APPS_DIR = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(APPS_DIR))
from sqlalchemy import text
from backend.models.database import engine
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def _extract_thread_id(payload_config: str | None) -> str | None:
if not payload_config:
return None
try:
payload = json.loads(payload_config)
except json.JSONDecodeError:
return None
if not isinstance(payload, dict):
return None
thread_id = payload.get("ThreadId")
value = str(thread_id).strip() if thread_id is not None else ""
return value or None
def migrate() -> None:
"""执行迁移。"""
logger.info("开始迁移:添加 check_in_tasks.thread_id 字段...")
with engine.connect() as conn:
result = conn.execute(text("PRAGMA table_info(check_in_tasks)"))
columns = [row[1] for row in result]
if "thread_id" not in columns:
logger.info("添加 thread_id 字段...")
conn.execute(text("ALTER TABLE check_in_tasks ADD COLUMN thread_id VARCHAR(100)"))
conn.commit()
logger.info("✓ thread_id 字段添加成功")
else:
logger.info("✓ thread_id 字段已存在,跳过")
rows = conn.execute(text("SELECT id, payload_config FROM check_in_tasks")).fetchall()
invalid_ids: list[int] = []
seen: dict[tuple[int, str], int] = {}
duplicate_ids: list[int] = []
full_rows = conn.execute(
text("SELECT id, user_id, payload_config FROM check_in_tasks")
).fetchall()
for row in full_rows:
thread_id = _extract_thread_id(row.payload_config)
if not thread_id:
invalid_ids.append(row.id)
continue
key = (row.user_id, thread_id)
if key in seen:
duplicate_ids.append(row.id)
else:
seen[key] = row.id
if invalid_ids or duplicate_ids:
messages = []
if invalid_ids:
messages.append(f"payload_config 缺少有效 ThreadId 的任务: {invalid_ids}")
if duplicate_ids:
messages.append(f"同用户 ThreadId 重复的任务: {duplicate_ids}")
raise RuntimeError("".join(messages))
for row in rows:
thread_id = _extract_thread_id(row.payload_config)
if thread_id:
conn.execute(
text("UPDATE check_in_tasks SET thread_id = :thread_id WHERE id = :id"),
{"thread_id": thread_id, "id": row.id},
)
conn.commit()
indexes = conn.execute(text("PRAGMA index_list(check_in_tasks)")).fetchall()
index_names = [row[1] for row in indexes]
if "ix_task_user_thread_id_unique" not in index_names:
logger.info("添加用户级 thread_id 唯一索引...")
conn.execute(
text(
"CREATE UNIQUE INDEX ix_task_user_thread_id_unique "
"ON check_in_tasks (user_id, thread_id)"
)
)
conn.commit()
logger.info("✓ 用户级 thread_id 唯一索引添加成功")
else:
logger.info("✓ 用户级 thread_id 唯一索引已存在,跳过")
logger.info("✅ 迁移完成!任务 thread_id 身份字段已启用")
if __name__ == "__main__":
try:
migrate()
except Exception as e:
logger.error(f"❌ 迁移失败: {e}")
sys.exit(1)
+3 -3
View File
@@ -541,10 +541,10 @@ class CheckInService:
)
task_name = task.name
# payload_config 提取 ThreadId
from backend.utils.json_helpers import extract_thread_id
# 优先从显式任务身份读取 ThreadId,兼容旧数据时回退到 payload_config
from backend.utils.json_helpers import resolve_task_thread_id
thread_id = extract_thread_id(task.payload_config) # type: ignore
thread_id = resolve_task_thread_id(task)
# 转换为字典并添加额外字段
record_dict = {
+67 -26
View File
@@ -20,6 +20,70 @@ scheduler = None
scheduler_lock = None
def _task_job_id(task_id: int) -> str:
return f"task_{task_id}"
def _remove_task_job(job_id: str, scheduler_instance=None) -> bool:
active_scheduler = scheduler_instance or scheduler
if not active_scheduler:
return False
existing_job = active_scheduler.get_job(job_id)
if not existing_job:
return False
active_scheduler.remove_job(job_id)
logger.info(f"✅ 已移除调度任务: {job_id}")
return True
def remove_scheduled_task(task_id: int, scheduler_instance=None) -> bool:
"""从调度器移除指定任务。"""
active_scheduler = scheduler_instance or scheduler
if not active_scheduler:
logger.warning(f"调度器未启动,无法移除任务 {task_id}")
return False
return _remove_task_job(_task_job_id(task_id), active_scheduler)
def sync_scheduled_task(task: CheckInTask, scheduler_instance=None) -> bool:
"""
根据任务当前状态同步调度器中的对应 job。
返回 True 表示任务已成功安排到调度器,False 表示未安排或调度器不可用。
"""
active_scheduler = scheduler_instance or scheduler
if not active_scheduler:
logger.warning(f"调度器未启动,无法同步任务 {task.id}")
return False
job_id = _task_job_id(task.id)
_remove_task_job(job_id, active_scheduler)
if not task.is_scheduled_enabled:
logger.info(f"任务 {task.id} 未启用或无 cron 表达式,已从调度器移除")
return False
cron_str = str(task.cron_expression or "").strip()
if not cron_str or not croniter.is_valid(cron_str):
logger.warning(f"任务 {task.id} 的 cron 表达式无效: {cron_str}")
return False
active_scheduler.add_job(
func=scheduled_check_in_task,
trigger=CronTrigger.from_crontab(cron_str),
id=job_id,
name=f"CheckIn-Task-{task.id}",
args=[task.id],
replace_existing=True,
)
logger.info(f"✅ 任务 {task.id} 已同步到调度器: {cron_str}")
return True
def load_scheduled_tasks(db: Session, scheduler_instance):
"""
从数据库加载所有启用的定时任务并添加到 APScheduler
@@ -55,33 +119,10 @@ def load_scheduled_tasks(db: Session, scheduler_instance):
for task in tasks:
try:
# 验证 cron 表达式
cron_str = str(task.cron_expression) if task.cron_expression else None
if not cron_str or not croniter.is_valid(cron_str):
logger.warning(f"跳过任务 {task.id}: 无效的 cron 表达式 '{task.cron_expression}'")
if sync_scheduled_task(task, scheduler_instance):
loaded_count += 1
else:
skipped_count += 1
continue
# 创建任务 ID
job_id = f"task_{task.id}"
# 检查任务是否已存在
if scheduler_instance.get_job(job_id):
logger.debug(f"任务 {task.id} 已存在,跳过")
continue
# 添加任务到调度器
scheduler_instance.add_job(
func=scheduled_check_in_task,
trigger=CronTrigger.from_crontab(cron_str),
id=job_id,
name=f"CheckIn-Task-{task.id}",
args=[task.id],
replace_existing=True,
)
logger.info(f"✅ 加载任务 {task.id}: {task.name} (Cron: {task.cron_expression})")
loaded_count += 1
except Exception as e:
logger.error(f"❌ 加载任务 {task.id} 时出错: {str(e)}")
+124 -86
View File
@@ -3,6 +3,12 @@ from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import desc
from backend.exceptions import (
InternalServerError,
ResourceConflictError,
ResourceNotFoundError,
ValidationError as APIValidationError,
)
from backend.models import User, CheckInTask, CheckInRecord
from backend.schemas.task import TaskCreate, TaskUpdate
@@ -12,6 +18,68 @@ logger = logging.getLogger(__name__)
class TaskService:
"""打卡任务服务"""
@staticmethod
def _normalize_thread_id(thread_id: Any) -> str:
value = str(thread_id).strip() if thread_id is not None else ""
if not value:
raise APIValidationError(
"payload_config 必须包含有效的 ThreadId 字段",
error_code="TASK_IDENTITY_INVALID",
)
return value
@staticmethod
def _extract_thread_id_from_payload(payload_config: str) -> str:
from backend.utils.json_helpers import safe_parse_payload
payload = safe_parse_payload(payload_config)
return TaskService._normalize_thread_id(payload.get("ThreadId"))
@staticmethod
def _resolve_thread_id_for_task(task: CheckInTask) -> Optional[str]:
thread_id = getattr(task, "thread_id", None)
if thread_id is not None:
value = str(thread_id).strip()
if value:
return value
from backend.utils.json_helpers import extract_thread_id
legacy_thread_id = extract_thread_id(task.payload_config)
if legacy_thread_id is None:
return None
value = str(legacy_thread_id).strip()
return value or None
@staticmethod
def _ensure_unique_thread_id(
db: Session, user_id: int, thread_id: str, exclude_task_id: int | None = None
) -> None:
query = db.query(CheckInTask.id).filter(
CheckInTask.user_id == user_id, CheckInTask.thread_id == thread_id
)
if exclude_task_id is not None:
query = query.filter(CheckInTask.id != exclude_task_id)
conflict = query.first()
if conflict:
raise ResourceConflictError(
f"该接龙中已存在任务。ThreadId: {thread_id}",
error_code="TASK_IDENTITY_CONFLICT",
)
@staticmethod
def _sync_scheduler_for_task(task: CheckInTask) -> None:
from backend.services.scheduler_service import sync_scheduled_task
sync_scheduled_task(task)
@staticmethod
def _remove_scheduler_for_task(task_id: int) -> None:
from backend.services.scheduler_service import remove_scheduled_task
remove_scheduled_task(task_id)
@staticmethod
def create_task(user_id: int, task_data: TaskCreate, db: Session) -> CheckInTask:
"""
@@ -25,34 +93,16 @@ class TaskService:
Returns:
创建的任务对象
"""
import json
# 1. 检查用户是否存在
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError(f"用户 ID {user_id} 不存在")
raise ResourceNotFoundError(f"用户 ID {user_id} 不存在", error_code="USER_NOT_FOUND")
# 2. 从 payload_config 中提取 ThreadId 用于唯一性校验
from backend.utils.json_helpers import safe_parse_payload, extract_thread_id
payload = safe_parse_payload(task_data.payload_config)
thread_id = payload.get("ThreadId")
if not thread_id:
raise ValueError("payload_config 中缺少 ThreadId")
thread_id = TaskService._extract_thread_id_from_payload(task_data.payload_config)
# 3. 验证唯一性:同一用户在同一个接龙中不能有重复的任务
existing_tasks = (
db.query(CheckInTask.payload_config).filter(CheckInTask.user_id == user_id).all()
)
for (payload_config,) in existing_tasks:
existing_thread_id = extract_thread_id(payload_config)
# extract_thread_id 已处理异常,失败时返回 None
if existing_thread_id and existing_thread_id == thread_id:
logger.warning(
f"⚠️ 任务创建冲突 - User: {user.alias}({user_id}), ThreadId: {thread_id}"
)
raise ValueError(f"该接龙中已存在任务。ThreadId: {thread_id}")
TaskService._ensure_unique_thread_id(db, user_id, thread_id)
# 4. 记录日志
task_name = task_data.name or f"接龙任务 {thread_id}"
@@ -61,9 +111,11 @@ class TaskService:
# 5. 创建任务
task = CheckInTask(
user_id=user_id,
thread_id=thread_id,
payload_config=task_data.payload_config,
name=task_data.name or task_name,
is_active=task_data.is_active if task_data.is_active is not None else True,
cron_expression=task_data.cron_expression or "0 20 * * *",
)
try:
@@ -75,14 +127,19 @@ class TaskService:
)
# 如果任务启用且包含 cron_expression,立即添加到调度器
if task.is_scheduled_enabled:
TaskService._reload_scheduler_for_task(task, db)
TaskService._sync_scheduler_for_task(task)
return task
except APIValidationError:
db.rollback()
raise
except ResourceConflictError:
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"❌ 任务创建失败: {str(e)}")
raise ValueError(f"任务创建失败: {str(e)}")
raise InternalServerError(f"任务创建失败: {str(e)}")
@staticmethod
def get_task(task_id: int, db: Session) -> Optional[CheckInTask]:
@@ -110,8 +167,6 @@ class TaskService:
Returns:
包含额外信息的任务字典
"""
from backend.utils.json_helpers import extract_thread_id
# 获取最后一次打卡记录
last_record = (
db.query(CheckInRecord)
@@ -120,8 +175,8 @@ class TaskService:
.first()
)
# payload_config 提取 ThreadId
thread_id = extract_thread_id(task.payload_config) # type: ignore
# 优先使用持久化的 ThreadId,兼容旧数据时回退到 payload_config
thread_id = TaskService._resolve_thread_id_for_task(task)
# 转换为字典并添加额外字段
task_dict = {
@@ -192,7 +247,7 @@ class TaskService:
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
if not task:
return None
raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
# 更新字段
update_data = task_data.model_dump(exclude_unset=True)
@@ -201,17 +256,34 @@ class TaskService:
cron_changed = "cron_expression" in update_data
active_changed = "is_active" in update_data
if "payload_config" in update_data:
new_thread_id = TaskService._extract_thread_id_from_payload(
update_data["payload_config"]
)
TaskService._ensure_unique_thread_id(
db, task.user_id, new_thread_id, exclude_task_id=task.id
)
task.thread_id = new_thread_id
for field, value in update_data.items():
setattr(task, field, value)
db.commit()
db.refresh(task)
try:
db.commit()
db.refresh(task)
except ResourceConflictError:
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"任务 {task_id} 更新失败: {str(e)}")
raise InternalServerError(f"任务更新失败: {str(e)}")
logger.info(f"任务 {task_id} 已更新")
# 如果 cron_expression 或 is_active 发生变化,重新加载调度器
if cron_changed or active_changed:
TaskService._reload_scheduler_for_task(task, db)
TaskService._sync_scheduler_for_task(task)
return task
@@ -230,15 +302,20 @@ class TaskService:
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
if not task:
return False
raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
db.delete(task)
db.commit()
try:
db.delete(task)
db.commit()
except Exception as e:
db.rollback()
logger.error(f"任务 {task_id} 删除失败: {str(e)}")
raise InternalServerError(f"任务删除失败: {str(e)}")
logger.info(f"任务 {task_id} 已删除")
# 从调度器中移除该任务
TaskService._remove_task_from_scheduler(task_id)
TaskService._remove_scheduler_for_task(task_id)
return True
@@ -257,16 +334,21 @@ class TaskService:
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
if not task:
return None
raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
task.is_active = not task.is_active
db.commit()
db.refresh(task)
try:
db.commit()
db.refresh(task)
except Exception as e:
db.rollback()
logger.error(f"任务 {task_id} 状态切换失败: {str(e)}")
raise InternalServerError(f"任务状态切换失败: {str(e)}")
logger.info(f"任务 {task_id} 状态已切换为: {'启用' if task.is_active else '禁用'}")
# 重新加载调度器
TaskService._reload_scheduler_for_task(task, db)
TaskService._sync_scheduler_for_task(task)
return task
@@ -322,42 +404,7 @@ class TaskService:
db: 数据库会话
"""
try:
from backend.services.scheduler_service import scheduler
from apscheduler.triggers.cron import CronTrigger
from croniter import croniter
if not scheduler:
logger.warning(f"调度器未启动,无法加载任务 {task.id}")
return
job_id = f"task_{task.id}"
# 先移除旧的任务(如果存在)
existing_job = scheduler.get_job(job_id)
if existing_job:
scheduler.remove_job(job_id)
logger.info(f"从调度器移除旧任务: {job_id}")
# 如果任务启用且有有效的 cron 表达式,添加新任务
if task.is_scheduled_enabled:
cron_str = str(task.cron_expression)
if croniter.is_valid(cron_str):
from backend.services.scheduler_service import scheduled_check_in_task
scheduler.add_job(
func=scheduled_check_in_task,
trigger=CronTrigger.from_crontab(cron_str),
id=job_id,
name=f"CheckIn-Task-{task.id}",
args=[task.id],
replace_existing=True,
)
logger.info(f"✅ 任务 {task.id} 已重新加载到调度器: {cron_str}")
else:
logger.warning(f"任务 {task.id} 的 cron 表达式无效: {cron_str}")
else:
logger.info(f"任务 {task.id} 未启用或无 cron 表达式,已从调度器移除")
TaskService._sync_scheduler_for_task(task)
except Exception as e:
logger.error(f"重新加载任务 {task.id} 到调度器失败: {str(e)}")
@@ -370,15 +417,6 @@ class TaskService:
task_id: 任务 ID
"""
try:
from backend.services.scheduler_service import scheduler
if not scheduler:
return
job_id = f"task_{task_id}"
if scheduler.get_job(job_id):
scheduler.remove_job(job_id)
logger.info(f"✅ 任务 {task_id} 已从调度器移除")
TaskService._remove_scheduler_for_task(task_id)
except Exception as e:
logger.error(f"从调度器移除任务 {task_id} 失败: {str(e)}")
+32 -16
View File
@@ -4,7 +4,9 @@ from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from backend.models import TaskTemplate, CheckInTask
from backend.exceptions import BaseAPIException
from backend.models import CheckInTask, TaskTemplate
from backend.schemas.task import TaskCreate
from backend.schemas.template import TemplateCreate, TemplateUpdate
logger = logging.getLogger(__name__)
@@ -134,6 +136,9 @@ class TemplateService:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
)
except (BaseAPIException, HTTPException):
db.rollback()
raise
except Exception as e:
logger.error(f"创建模板失败: {str(e)}")
db.rollback()
@@ -219,6 +224,9 @@ class TemplateService:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
)
except (BaseAPIException, HTTPException):
db.rollback()
raise
except Exception as e:
logger.error(f"更新模板失败: {str(e)}")
db.rollback()
@@ -247,6 +255,9 @@ class TemplateService:
db.commit()
logger.info(f"删除模板成功: {template.name} (ID: {template_id})")
return True
except (BaseAPIException, HTTPException):
db.rollback()
raise
except Exception as e:
logger.error(f"删除模板失败: {str(e)}")
db.rollback()
@@ -424,6 +435,10 @@ class TemplateService:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"解析模板配置失败"
)
except BaseAPIException:
raise
except HTTPException:
raise
except Exception as e:
logger.error(f"组装 payload 失败: {str(e)}")
raise HTTPException(
@@ -518,31 +533,32 @@ class TemplateService:
signature = payload.get("Signature", "Unknown")
task_name = f"{template.name} - {signature}"
# 创建任务(包含 cron_expression
try:
task = CheckInTask(
from backend.services.task_service import TaskService
task = TaskService.create_task(
user_id=user_id,
payload_config=json.dumps(payload, ensure_ascii=False),
name=task_name,
is_active=True,
cron_expression=cron_expression or "0 20 * * *",
task_data=TaskCreate(
payload_config=json.dumps(payload, ensure_ascii=False),
name=task_name,
is_active=True,
cron_expression=cron_expression or "0 20 * * *",
),
db=db,
)
db.add(task)
db.commit()
db.refresh(task)
logger.info(
f"从模板创建任务成功: {task.name} (ID: {task.id}, 模板: {template.name}, ThreadId: {thread_id})"
)
# 如果任务启用且包含 cron_expression,立即添加到调度器
if task.is_scheduled_enabled:
from backend.services.task_service import TaskService
TaskService._reload_scheduler_for_task(task, db)
return task
except BaseAPIException:
db.rollback()
raise
except HTTPException:
db.rollback()
raise
except Exception as e:
logger.error(f"从模板创建任务失败: {str(e)}")
db.rollback()
+24 -1
View File
@@ -92,6 +92,29 @@ def build_task_info(task) -> Dict[str, str]:
包含 thread_id 和 name 的字典
"""
return {
"thread_id": extract_thread_id(getattr(task, "payload_config", None)) or "未知",
"thread_id": resolve_task_thread_id(task) or "未知",
"name": getattr(task, "name", None) or f"Task-{getattr(task, 'id', 'Unknown')}",
}
def resolve_task_thread_id(task) -> Optional[str]:
"""
优先从显式字段读取任务 ThreadId,兼容旧数据时回退到 payload_config。
Args:
task: CheckInTask 或相似对象
Returns:
ThreadId 或 None
"""
thread_id = getattr(task, "thread_id", None)
if thread_id is not None:
value = str(thread_id).strip()
if value:
return value
legacy_thread_id = extract_thread_id(getattr(task, "payload_config", None))
if legacy_thread_id is None:
return None
value = str(legacy_thread_id).strip()
return value or None