feat(backend): harden task boundaries

This commit is contained in:
2026-05-05 00:55:29 +08:00
parent 817540f8a0
commit e243dccfd7
15 changed files with 694 additions and 147 deletions
+3 -3
View File
@@ -541,10 +541,10 @@ class CheckInService:
)
task_name = task.name
# payload_config 提取 ThreadId
from backend.utils.json_helpers import extract_thread_id
# 优先从显式任务身份读取 ThreadId,兼容旧数据时回退到 payload_config
from backend.utils.json_helpers import resolve_task_thread_id
thread_id = extract_thread_id(task.payload_config) # type: ignore
thread_id = resolve_task_thread_id(task)
# 转换为字典并添加额外字段
record_dict = {
+67 -26
View File
@@ -20,6 +20,70 @@ scheduler = None
scheduler_lock = None
def _task_job_id(task_id: int) -> str:
return f"task_{task_id}"
def _remove_task_job(job_id: str, scheduler_instance=None) -> bool:
active_scheduler = scheduler_instance or scheduler
if not active_scheduler:
return False
existing_job = active_scheduler.get_job(job_id)
if not existing_job:
return False
active_scheduler.remove_job(job_id)
logger.info(f"✅ 已移除调度任务: {job_id}")
return True
def remove_scheduled_task(task_id: int, scheduler_instance=None) -> bool:
"""从调度器移除指定任务。"""
active_scheduler = scheduler_instance or scheduler
if not active_scheduler:
logger.warning(f"调度器未启动,无法移除任务 {task_id}")
return False
return _remove_task_job(_task_job_id(task_id), active_scheduler)
def sync_scheduled_task(task: CheckInTask, scheduler_instance=None) -> bool:
"""
根据任务当前状态同步调度器中的对应 job。
返回 True 表示任务已成功安排到调度器,False 表示未安排或调度器不可用。
"""
active_scheduler = scheduler_instance or scheduler
if not active_scheduler:
logger.warning(f"调度器未启动,无法同步任务 {task.id}")
return False
job_id = _task_job_id(task.id)
_remove_task_job(job_id, active_scheduler)
if not task.is_scheduled_enabled:
logger.info(f"任务 {task.id} 未启用或无 cron 表达式,已从调度器移除")
return False
cron_str = str(task.cron_expression or "").strip()
if not cron_str or not croniter.is_valid(cron_str):
logger.warning(f"任务 {task.id} 的 cron 表达式无效: {cron_str}")
return False
active_scheduler.add_job(
func=scheduled_check_in_task,
trigger=CronTrigger.from_crontab(cron_str),
id=job_id,
name=f"CheckIn-Task-{task.id}",
args=[task.id],
replace_existing=True,
)
logger.info(f"✅ 任务 {task.id} 已同步到调度器: {cron_str}")
return True
def load_scheduled_tasks(db: Session, scheduler_instance):
"""
从数据库加载所有启用的定时任务并添加到 APScheduler
@@ -55,33 +119,10 @@ def load_scheduled_tasks(db: Session, scheduler_instance):
for task in tasks:
try:
# 验证 cron 表达式
cron_str = str(task.cron_expression) if task.cron_expression else None
if not cron_str or not croniter.is_valid(cron_str):
logger.warning(f"跳过任务 {task.id}: 无效的 cron 表达式 '{task.cron_expression}'")
if sync_scheduled_task(task, scheduler_instance):
loaded_count += 1
else:
skipped_count += 1
continue
# 创建任务 ID
job_id = f"task_{task.id}"
# 检查任务是否已存在
if scheduler_instance.get_job(job_id):
logger.debug(f"任务 {task.id} 已存在,跳过")
continue
# 添加任务到调度器
scheduler_instance.add_job(
func=scheduled_check_in_task,
trigger=CronTrigger.from_crontab(cron_str),
id=job_id,
name=f"CheckIn-Task-{task.id}",
args=[task.id],
replace_existing=True,
)
logger.info(f"✅ 加载任务 {task.id}: {task.name} (Cron: {task.cron_expression})")
loaded_count += 1
except Exception as e:
logger.error(f"❌ 加载任务 {task.id} 时出错: {str(e)}")
+124 -86
View File
@@ -3,6 +3,12 @@ from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import desc
from backend.exceptions import (
InternalServerError,
ResourceConflictError,
ResourceNotFoundError,
ValidationError as APIValidationError,
)
from backend.models import User, CheckInTask, CheckInRecord
from backend.schemas.task import TaskCreate, TaskUpdate
@@ -12,6 +18,68 @@ logger = logging.getLogger(__name__)
class TaskService:
"""打卡任务服务"""
@staticmethod
def _normalize_thread_id(thread_id: Any) -> str:
value = str(thread_id).strip() if thread_id is not None else ""
if not value:
raise APIValidationError(
"payload_config 必须包含有效的 ThreadId 字段",
error_code="TASK_IDENTITY_INVALID",
)
return value
@staticmethod
def _extract_thread_id_from_payload(payload_config: str) -> str:
from backend.utils.json_helpers import safe_parse_payload
payload = safe_parse_payload(payload_config)
return TaskService._normalize_thread_id(payload.get("ThreadId"))
@staticmethod
def _resolve_thread_id_for_task(task: CheckInTask) -> Optional[str]:
thread_id = getattr(task, "thread_id", None)
if thread_id is not None:
value = str(thread_id).strip()
if value:
return value
from backend.utils.json_helpers import extract_thread_id
legacy_thread_id = extract_thread_id(task.payload_config)
if legacy_thread_id is None:
return None
value = str(legacy_thread_id).strip()
return value or None
@staticmethod
def _ensure_unique_thread_id(
db: Session, user_id: int, thread_id: str, exclude_task_id: int | None = None
) -> None:
query = db.query(CheckInTask.id).filter(
CheckInTask.user_id == user_id, CheckInTask.thread_id == thread_id
)
if exclude_task_id is not None:
query = query.filter(CheckInTask.id != exclude_task_id)
conflict = query.first()
if conflict:
raise ResourceConflictError(
f"该接龙中已存在任务。ThreadId: {thread_id}",
error_code="TASK_IDENTITY_CONFLICT",
)
@staticmethod
def _sync_scheduler_for_task(task: CheckInTask) -> None:
from backend.services.scheduler_service import sync_scheduled_task
sync_scheduled_task(task)
@staticmethod
def _remove_scheduler_for_task(task_id: int) -> None:
from backend.services.scheduler_service import remove_scheduled_task
remove_scheduled_task(task_id)
@staticmethod
def create_task(user_id: int, task_data: TaskCreate, db: Session) -> CheckInTask:
"""
@@ -25,34 +93,16 @@ class TaskService:
Returns:
创建的任务对象
"""
import json
# 1. 检查用户是否存在
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ValueError(f"用户 ID {user_id} 不存在")
raise ResourceNotFoundError(f"用户 ID {user_id} 不存在", error_code="USER_NOT_FOUND")
# 2. 从 payload_config 中提取 ThreadId 用于唯一性校验
from backend.utils.json_helpers import safe_parse_payload, extract_thread_id
payload = safe_parse_payload(task_data.payload_config)
thread_id = payload.get("ThreadId")
if not thread_id:
raise ValueError("payload_config 中缺少 ThreadId")
thread_id = TaskService._extract_thread_id_from_payload(task_data.payload_config)
# 3. 验证唯一性:同一用户在同一个接龙中不能有重复的任务
existing_tasks = (
db.query(CheckInTask.payload_config).filter(CheckInTask.user_id == user_id).all()
)
for (payload_config,) in existing_tasks:
existing_thread_id = extract_thread_id(payload_config)
# extract_thread_id 已处理异常,失败时返回 None
if existing_thread_id and existing_thread_id == thread_id:
logger.warning(
f"⚠️ 任务创建冲突 - User: {user.alias}({user_id}), ThreadId: {thread_id}"
)
raise ValueError(f"该接龙中已存在任务。ThreadId: {thread_id}")
TaskService._ensure_unique_thread_id(db, user_id, thread_id)
# 4. 记录日志
task_name = task_data.name or f"接龙任务 {thread_id}"
@@ -61,9 +111,11 @@ class TaskService:
# 5. 创建任务
task = CheckInTask(
user_id=user_id,
thread_id=thread_id,
payload_config=task_data.payload_config,
name=task_data.name or task_name,
is_active=task_data.is_active if task_data.is_active is not None else True,
cron_expression=task_data.cron_expression or "0 20 * * *",
)
try:
@@ -75,14 +127,19 @@ class TaskService:
)
# 如果任务启用且包含 cron_expression,立即添加到调度器
if task.is_scheduled_enabled:
TaskService._reload_scheduler_for_task(task, db)
TaskService._sync_scheduler_for_task(task)
return task
except APIValidationError:
db.rollback()
raise
except ResourceConflictError:
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"❌ 任务创建失败: {str(e)}")
raise ValueError(f"任务创建失败: {str(e)}")
raise InternalServerError(f"任务创建失败: {str(e)}")
@staticmethod
def get_task(task_id: int, db: Session) -> Optional[CheckInTask]:
@@ -110,8 +167,6 @@ class TaskService:
Returns:
包含额外信息的任务字典
"""
from backend.utils.json_helpers import extract_thread_id
# 获取最后一次打卡记录
last_record = (
db.query(CheckInRecord)
@@ -120,8 +175,8 @@ class TaskService:
.first()
)
# payload_config 提取 ThreadId
thread_id = extract_thread_id(task.payload_config) # type: ignore
# 优先使用持久化的 ThreadId,兼容旧数据时回退到 payload_config
thread_id = TaskService._resolve_thread_id_for_task(task)
# 转换为字典并添加额外字段
task_dict = {
@@ -192,7 +247,7 @@ class TaskService:
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
if not task:
return None
raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
# 更新字段
update_data = task_data.model_dump(exclude_unset=True)
@@ -201,17 +256,34 @@ class TaskService:
cron_changed = "cron_expression" in update_data
active_changed = "is_active" in update_data
if "payload_config" in update_data:
new_thread_id = TaskService._extract_thread_id_from_payload(
update_data["payload_config"]
)
TaskService._ensure_unique_thread_id(
db, task.user_id, new_thread_id, exclude_task_id=task.id
)
task.thread_id = new_thread_id
for field, value in update_data.items():
setattr(task, field, value)
db.commit()
db.refresh(task)
try:
db.commit()
db.refresh(task)
except ResourceConflictError:
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"任务 {task_id} 更新失败: {str(e)}")
raise InternalServerError(f"任务更新失败: {str(e)}")
logger.info(f"任务 {task_id} 已更新")
# 如果 cron_expression 或 is_active 发生变化,重新加载调度器
if cron_changed or active_changed:
TaskService._reload_scheduler_for_task(task, db)
TaskService._sync_scheduler_for_task(task)
return task
@@ -230,15 +302,20 @@ class TaskService:
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
if not task:
return False
raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
db.delete(task)
db.commit()
try:
db.delete(task)
db.commit()
except Exception as e:
db.rollback()
logger.error(f"任务 {task_id} 删除失败: {str(e)}")
raise InternalServerError(f"任务删除失败: {str(e)}")
logger.info(f"任务 {task_id} 已删除")
# 从调度器中移除该任务
TaskService._remove_task_from_scheduler(task_id)
TaskService._remove_scheduler_for_task(task_id)
return True
@@ -257,16 +334,21 @@ class TaskService:
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
if not task:
return None
raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
task.is_active = not task.is_active
db.commit()
db.refresh(task)
try:
db.commit()
db.refresh(task)
except Exception as e:
db.rollback()
logger.error(f"任务 {task_id} 状态切换失败: {str(e)}")
raise InternalServerError(f"任务状态切换失败: {str(e)}")
logger.info(f"任务 {task_id} 状态已切换为: {'启用' if task.is_active else '禁用'}")
# 重新加载调度器
TaskService._reload_scheduler_for_task(task, db)
TaskService._sync_scheduler_for_task(task)
return task
@@ -322,42 +404,7 @@ class TaskService:
db: 数据库会话
"""
try:
from backend.services.scheduler_service import scheduler
from apscheduler.triggers.cron import CronTrigger
from croniter import croniter
if not scheduler:
logger.warning(f"调度器未启动,无法加载任务 {task.id}")
return
job_id = f"task_{task.id}"
# 先移除旧的任务(如果存在)
existing_job = scheduler.get_job(job_id)
if existing_job:
scheduler.remove_job(job_id)
logger.info(f"从调度器移除旧任务: {job_id}")
# 如果任务启用且有有效的 cron 表达式,添加新任务
if task.is_scheduled_enabled:
cron_str = str(task.cron_expression)
if croniter.is_valid(cron_str):
from backend.services.scheduler_service import scheduled_check_in_task
scheduler.add_job(
func=scheduled_check_in_task,
trigger=CronTrigger.from_crontab(cron_str),
id=job_id,
name=f"CheckIn-Task-{task.id}",
args=[task.id],
replace_existing=True,
)
logger.info(f"✅ 任务 {task.id} 已重新加载到调度器: {cron_str}")
else:
logger.warning(f"任务 {task.id} 的 cron 表达式无效: {cron_str}")
else:
logger.info(f"任务 {task.id} 未启用或无 cron 表达式,已从调度器移除")
TaskService._sync_scheduler_for_task(task)
except Exception as e:
logger.error(f"重新加载任务 {task.id} 到调度器失败: {str(e)}")
@@ -370,15 +417,6 @@ class TaskService:
task_id: 任务 ID
"""
try:
from backend.services.scheduler_service import scheduler
if not scheduler:
return
job_id = f"task_{task_id}"
if scheduler.get_job(job_id):
scheduler.remove_job(job_id)
logger.info(f"✅ 任务 {task_id} 已从调度器移除")
TaskService._remove_scheduler_for_task(task_id)
except Exception as e:
logger.error(f"从调度器移除任务 {task_id} 失败: {str(e)}")
+32 -16
View File
@@ -4,7 +4,9 @@ from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from backend.models import TaskTemplate, CheckInTask
from backend.exceptions import BaseAPIException
from backend.models import CheckInTask, TaskTemplate
from backend.schemas.task import TaskCreate
from backend.schemas.template import TemplateCreate, TemplateUpdate
logger = logging.getLogger(__name__)
@@ -134,6 +136,9 @@ class TemplateService:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
)
except (BaseAPIException, HTTPException):
db.rollback()
raise
except Exception as e:
logger.error(f"创建模板失败: {str(e)}")
db.rollback()
@@ -219,6 +224,9 @@ class TemplateService:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
)
except (BaseAPIException, HTTPException):
db.rollback()
raise
except Exception as e:
logger.error(f"更新模板失败: {str(e)}")
db.rollback()
@@ -247,6 +255,9 @@ class TemplateService:
db.commit()
logger.info(f"删除模板成功: {template.name} (ID: {template_id})")
return True
except (BaseAPIException, HTTPException):
db.rollback()
raise
except Exception as e:
logger.error(f"删除模板失败: {str(e)}")
db.rollback()
@@ -424,6 +435,10 @@ class TemplateService:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"解析模板配置失败"
)
except BaseAPIException:
raise
except HTTPException:
raise
except Exception as e:
logger.error(f"组装 payload 失败: {str(e)}")
raise HTTPException(
@@ -518,31 +533,32 @@ class TemplateService:
signature = payload.get("Signature", "Unknown")
task_name = f"{template.name} - {signature}"
# 创建任务(包含 cron_expression
try:
task = CheckInTask(
from backend.services.task_service import TaskService
task = TaskService.create_task(
user_id=user_id,
payload_config=json.dumps(payload, ensure_ascii=False),
name=task_name,
is_active=True,
cron_expression=cron_expression or "0 20 * * *",
task_data=TaskCreate(
payload_config=json.dumps(payload, ensure_ascii=False),
name=task_name,
is_active=True,
cron_expression=cron_expression or "0 20 * * *",
),
db=db,
)
db.add(task)
db.commit()
db.refresh(task)
logger.info(
f"从模板创建任务成功: {task.name} (ID: {task.id}, 模板: {template.name}, ThreadId: {thread_id})"
)
# 如果任务启用且包含 cron_expression,立即添加到调度器
if task.is_scheduled_enabled:
from backend.services.task_service import TaskService
TaskService._reload_scheduler_for_task(task, db)
return task
except BaseAPIException:
db.rollback()
raise
except HTTPException:
db.rollback()
raise
except Exception as e:
logger.error(f"从模板创建任务失败: {str(e)}")
db.rollback()