mirror of
https://github.com/Cccc-owo/CheckInApp.git
synced 2026-06-17 05:56:29 +00:00
feat(backend): harden task boundaries
This commit is contained in:
@@ -541,10 +541,10 @@ class CheckInService:
|
||||
)
|
||||
task_name = task.name
|
||||
|
||||
# 从 payload_config 提取 ThreadId
|
||||
from backend.utils.json_helpers import extract_thread_id
|
||||
# 优先从显式任务身份读取 ThreadId,兼容旧数据时回退到 payload_config
|
||||
from backend.utils.json_helpers import resolve_task_thread_id
|
||||
|
||||
thread_id = extract_thread_id(task.payload_config) # type: ignore
|
||||
thread_id = resolve_task_thread_id(task)
|
||||
|
||||
# 转换为字典并添加额外字段
|
||||
record_dict = {
|
||||
|
||||
@@ -20,6 +20,70 @@ scheduler = None
|
||||
scheduler_lock = None
|
||||
|
||||
|
||||
def _task_job_id(task_id: int) -> str:
|
||||
return f"task_{task_id}"
|
||||
|
||||
|
||||
def _remove_task_job(job_id: str, scheduler_instance=None) -> bool:
|
||||
active_scheduler = scheduler_instance or scheduler
|
||||
if not active_scheduler:
|
||||
return False
|
||||
|
||||
existing_job = active_scheduler.get_job(job_id)
|
||||
if not existing_job:
|
||||
return False
|
||||
|
||||
active_scheduler.remove_job(job_id)
|
||||
logger.info(f"✅ 已移除调度任务: {job_id}")
|
||||
return True
|
||||
|
||||
|
||||
def remove_scheduled_task(task_id: int, scheduler_instance=None) -> bool:
|
||||
"""从调度器移除指定任务。"""
|
||||
active_scheduler = scheduler_instance or scheduler
|
||||
if not active_scheduler:
|
||||
logger.warning(f"调度器未启动,无法移除任务 {task_id}")
|
||||
return False
|
||||
|
||||
return _remove_task_job(_task_job_id(task_id), active_scheduler)
|
||||
|
||||
|
||||
def sync_scheduled_task(task: CheckInTask, scheduler_instance=None) -> bool:
|
||||
"""
|
||||
根据任务当前状态同步调度器中的对应 job。
|
||||
|
||||
返回 True 表示任务已成功安排到调度器,False 表示未安排或调度器不可用。
|
||||
"""
|
||||
active_scheduler = scheduler_instance or scheduler
|
||||
if not active_scheduler:
|
||||
logger.warning(f"调度器未启动,无法同步任务 {task.id}")
|
||||
return False
|
||||
|
||||
job_id = _task_job_id(task.id)
|
||||
_remove_task_job(job_id, active_scheduler)
|
||||
|
||||
if not task.is_scheduled_enabled:
|
||||
logger.info(f"任务 {task.id} 未启用或无 cron 表达式,已从调度器移除")
|
||||
return False
|
||||
|
||||
cron_str = str(task.cron_expression or "").strip()
|
||||
if not cron_str or not croniter.is_valid(cron_str):
|
||||
logger.warning(f"任务 {task.id} 的 cron 表达式无效: {cron_str}")
|
||||
return False
|
||||
|
||||
active_scheduler.add_job(
|
||||
func=scheduled_check_in_task,
|
||||
trigger=CronTrigger.from_crontab(cron_str),
|
||||
id=job_id,
|
||||
name=f"CheckIn-Task-{task.id}",
|
||||
args=[task.id],
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
logger.info(f"✅ 任务 {task.id} 已同步到调度器: {cron_str}")
|
||||
return True
|
||||
|
||||
|
||||
def load_scheduled_tasks(db: Session, scheduler_instance):
|
||||
"""
|
||||
从数据库加载所有启用的定时任务并添加到 APScheduler
|
||||
@@ -55,33 +119,10 @@ def load_scheduled_tasks(db: Session, scheduler_instance):
|
||||
|
||||
for task in tasks:
|
||||
try:
|
||||
# 验证 cron 表达式
|
||||
cron_str = str(task.cron_expression) if task.cron_expression else None
|
||||
if not cron_str or not croniter.is_valid(cron_str):
|
||||
logger.warning(f"跳过任务 {task.id}: 无效的 cron 表达式 '{task.cron_expression}'")
|
||||
if sync_scheduled_task(task, scheduler_instance):
|
||||
loaded_count += 1
|
||||
else:
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
# 创建任务 ID
|
||||
job_id = f"task_{task.id}"
|
||||
|
||||
# 检查任务是否已存在
|
||||
if scheduler_instance.get_job(job_id):
|
||||
logger.debug(f"任务 {task.id} 已存在,跳过")
|
||||
continue
|
||||
|
||||
# 添加任务到调度器
|
||||
scheduler_instance.add_job(
|
||||
func=scheduled_check_in_task,
|
||||
trigger=CronTrigger.from_crontab(cron_str),
|
||||
id=job_id,
|
||||
name=f"CheckIn-Task-{task.id}",
|
||||
args=[task.id],
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
logger.info(f"✅ 加载任务 {task.id}: {task.name} (Cron: {task.cron_expression})")
|
||||
loaded_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 加载任务 {task.id} 时出错: {str(e)}")
|
||||
|
||||
@@ -3,6 +3,12 @@ from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
|
||||
from backend.exceptions import (
|
||||
InternalServerError,
|
||||
ResourceConflictError,
|
||||
ResourceNotFoundError,
|
||||
ValidationError as APIValidationError,
|
||||
)
|
||||
from backend.models import User, CheckInTask, CheckInRecord
|
||||
from backend.schemas.task import TaskCreate, TaskUpdate
|
||||
|
||||
@@ -12,6 +18,68 @@ logger = logging.getLogger(__name__)
|
||||
class TaskService:
|
||||
"""打卡任务服务"""
|
||||
|
||||
@staticmethod
|
||||
def _normalize_thread_id(thread_id: Any) -> str:
|
||||
value = str(thread_id).strip() if thread_id is not None else ""
|
||||
if not value:
|
||||
raise APIValidationError(
|
||||
"payload_config 必须包含有效的 ThreadId 字段",
|
||||
error_code="TASK_IDENTITY_INVALID",
|
||||
)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _extract_thread_id_from_payload(payload_config: str) -> str:
|
||||
from backend.utils.json_helpers import safe_parse_payload
|
||||
|
||||
payload = safe_parse_payload(payload_config)
|
||||
return TaskService._normalize_thread_id(payload.get("ThreadId"))
|
||||
|
||||
@staticmethod
|
||||
def _resolve_thread_id_for_task(task: CheckInTask) -> Optional[str]:
|
||||
thread_id = getattr(task, "thread_id", None)
|
||||
if thread_id is not None:
|
||||
value = str(thread_id).strip()
|
||||
if value:
|
||||
return value
|
||||
|
||||
from backend.utils.json_helpers import extract_thread_id
|
||||
|
||||
legacy_thread_id = extract_thread_id(task.payload_config)
|
||||
if legacy_thread_id is None:
|
||||
return None
|
||||
value = str(legacy_thread_id).strip()
|
||||
return value or None
|
||||
|
||||
@staticmethod
|
||||
def _ensure_unique_thread_id(
|
||||
db: Session, user_id: int, thread_id: str, exclude_task_id: int | None = None
|
||||
) -> None:
|
||||
query = db.query(CheckInTask.id).filter(
|
||||
CheckInTask.user_id == user_id, CheckInTask.thread_id == thread_id
|
||||
)
|
||||
if exclude_task_id is not None:
|
||||
query = query.filter(CheckInTask.id != exclude_task_id)
|
||||
|
||||
conflict = query.first()
|
||||
if conflict:
|
||||
raise ResourceConflictError(
|
||||
f"该接龙中已存在任务。ThreadId: {thread_id}",
|
||||
error_code="TASK_IDENTITY_CONFLICT",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _sync_scheduler_for_task(task: CheckInTask) -> None:
|
||||
from backend.services.scheduler_service import sync_scheduled_task
|
||||
|
||||
sync_scheduled_task(task)
|
||||
|
||||
@staticmethod
|
||||
def _remove_scheduler_for_task(task_id: int) -> None:
|
||||
from backend.services.scheduler_service import remove_scheduled_task
|
||||
|
||||
remove_scheduled_task(task_id)
|
||||
|
||||
@staticmethod
|
||||
def create_task(user_id: int, task_data: TaskCreate, db: Session) -> CheckInTask:
|
||||
"""
|
||||
@@ -25,34 +93,16 @@ class TaskService:
|
||||
Returns:
|
||||
创建的任务对象
|
||||
"""
|
||||
import json
|
||||
|
||||
# 1. 检查用户是否存在
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ValueError(f"用户 ID {user_id} 不存在")
|
||||
raise ResourceNotFoundError(f"用户 ID {user_id} 不存在", error_code="USER_NOT_FOUND")
|
||||
|
||||
# 2. 从 payload_config 中提取 ThreadId 用于唯一性校验
|
||||
from backend.utils.json_helpers import safe_parse_payload, extract_thread_id
|
||||
|
||||
payload = safe_parse_payload(task_data.payload_config)
|
||||
thread_id = payload.get("ThreadId")
|
||||
if not thread_id:
|
||||
raise ValueError("payload_config 中缺少 ThreadId")
|
||||
thread_id = TaskService._extract_thread_id_from_payload(task_data.payload_config)
|
||||
|
||||
# 3. 验证唯一性:同一用户在同一个接龙中不能有重复的任务
|
||||
existing_tasks = (
|
||||
db.query(CheckInTask.payload_config).filter(CheckInTask.user_id == user_id).all()
|
||||
)
|
||||
|
||||
for (payload_config,) in existing_tasks:
|
||||
existing_thread_id = extract_thread_id(payload_config)
|
||||
# extract_thread_id 已处理异常,失败时返回 None
|
||||
if existing_thread_id and existing_thread_id == thread_id:
|
||||
logger.warning(
|
||||
f"⚠️ 任务创建冲突 - User: {user.alias}({user_id}), ThreadId: {thread_id}"
|
||||
)
|
||||
raise ValueError(f"该接龙中已存在任务。ThreadId: {thread_id}")
|
||||
TaskService._ensure_unique_thread_id(db, user_id, thread_id)
|
||||
|
||||
# 4. 记录日志
|
||||
task_name = task_data.name or f"接龙任务 {thread_id}"
|
||||
@@ -61,9 +111,11 @@ class TaskService:
|
||||
# 5. 创建任务
|
||||
task = CheckInTask(
|
||||
user_id=user_id,
|
||||
thread_id=thread_id,
|
||||
payload_config=task_data.payload_config,
|
||||
name=task_data.name or task_name,
|
||||
is_active=task_data.is_active if task_data.is_active is not None else True,
|
||||
cron_expression=task_data.cron_expression or "0 20 * * *",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -75,14 +127,19 @@ class TaskService:
|
||||
)
|
||||
|
||||
# 如果任务启用且包含 cron_expression,立即添加到调度器
|
||||
if task.is_scheduled_enabled:
|
||||
TaskService._reload_scheduler_for_task(task, db)
|
||||
TaskService._sync_scheduler_for_task(task)
|
||||
|
||||
return task
|
||||
except APIValidationError:
|
||||
db.rollback()
|
||||
raise
|
||||
except ResourceConflictError:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"❌ 任务创建失败: {str(e)}")
|
||||
raise ValueError(f"任务创建失败: {str(e)}")
|
||||
raise InternalServerError(f"任务创建失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def get_task(task_id: int, db: Session) -> Optional[CheckInTask]:
|
||||
@@ -110,8 +167,6 @@ class TaskService:
|
||||
Returns:
|
||||
包含额外信息的任务字典
|
||||
"""
|
||||
from backend.utils.json_helpers import extract_thread_id
|
||||
|
||||
# 获取最后一次打卡记录
|
||||
last_record = (
|
||||
db.query(CheckInRecord)
|
||||
@@ -120,8 +175,8 @@ class TaskService:
|
||||
.first()
|
||||
)
|
||||
|
||||
# 从 payload_config 提取 ThreadId
|
||||
thread_id = extract_thread_id(task.payload_config) # type: ignore
|
||||
# 优先使用持久化的 ThreadId,兼容旧数据时回退到 payload_config
|
||||
thread_id = TaskService._resolve_thread_id_for_task(task)
|
||||
|
||||
# 转换为字典并添加额外字段
|
||||
task_dict = {
|
||||
@@ -192,7 +247,7 @@ class TaskService:
|
||||
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
|
||||
|
||||
if not task:
|
||||
return None
|
||||
raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
|
||||
|
||||
# 更新字段
|
||||
update_data = task_data.model_dump(exclude_unset=True)
|
||||
@@ -201,17 +256,34 @@ class TaskService:
|
||||
cron_changed = "cron_expression" in update_data
|
||||
active_changed = "is_active" in update_data
|
||||
|
||||
if "payload_config" in update_data:
|
||||
new_thread_id = TaskService._extract_thread_id_from_payload(
|
||||
update_data["payload_config"]
|
||||
)
|
||||
TaskService._ensure_unique_thread_id(
|
||||
db, task.user_id, new_thread_id, exclude_task_id=task.id
|
||||
)
|
||||
task.thread_id = new_thread_id
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(task, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
except ResourceConflictError:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"任务 {task_id} 更新失败: {str(e)}")
|
||||
raise InternalServerError(f"任务更新失败: {str(e)}")
|
||||
|
||||
logger.info(f"任务 {task_id} 已更新")
|
||||
|
||||
# 如果 cron_expression 或 is_active 发生变化,重新加载调度器
|
||||
if cron_changed or active_changed:
|
||||
TaskService._reload_scheduler_for_task(task, db)
|
||||
TaskService._sync_scheduler_for_task(task)
|
||||
|
||||
return task
|
||||
|
||||
@@ -230,15 +302,20 @@ class TaskService:
|
||||
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
|
||||
|
||||
if not task:
|
||||
return False
|
||||
raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
|
||||
|
||||
db.delete(task)
|
||||
db.commit()
|
||||
try:
|
||||
db.delete(task)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"任务 {task_id} 删除失败: {str(e)}")
|
||||
raise InternalServerError(f"任务删除失败: {str(e)}")
|
||||
|
||||
logger.info(f"任务 {task_id} 已删除")
|
||||
|
||||
# 从调度器中移除该任务
|
||||
TaskService._remove_task_from_scheduler(task_id)
|
||||
TaskService._remove_scheduler_for_task(task_id)
|
||||
|
||||
return True
|
||||
|
||||
@@ -257,16 +334,21 @@ class TaskService:
|
||||
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
|
||||
|
||||
if not task:
|
||||
return None
|
||||
raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
|
||||
|
||||
task.is_active = not task.is_active
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"任务 {task_id} 状态切换失败: {str(e)}")
|
||||
raise InternalServerError(f"任务状态切换失败: {str(e)}")
|
||||
|
||||
logger.info(f"任务 {task_id} 状态已切换为: {'启用' if task.is_active else '禁用'}")
|
||||
|
||||
# 重新加载调度器
|
||||
TaskService._reload_scheduler_for_task(task, db)
|
||||
TaskService._sync_scheduler_for_task(task)
|
||||
|
||||
return task
|
||||
|
||||
@@ -322,42 +404,7 @@ class TaskService:
|
||||
db: 数据库会话
|
||||
"""
|
||||
try:
|
||||
from backend.services.scheduler_service import scheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from croniter import croniter
|
||||
|
||||
if not scheduler:
|
||||
logger.warning(f"调度器未启动,无法加载任务 {task.id}")
|
||||
return
|
||||
|
||||
job_id = f"task_{task.id}"
|
||||
|
||||
# 先移除旧的任务(如果存在)
|
||||
existing_job = scheduler.get_job(job_id)
|
||||
if existing_job:
|
||||
scheduler.remove_job(job_id)
|
||||
logger.info(f"从调度器移除旧任务: {job_id}")
|
||||
|
||||
# 如果任务启用且有有效的 cron 表达式,添加新任务
|
||||
if task.is_scheduled_enabled:
|
||||
cron_str = str(task.cron_expression)
|
||||
if croniter.is_valid(cron_str):
|
||||
from backend.services.scheduler_service import scheduled_check_in_task
|
||||
|
||||
scheduler.add_job(
|
||||
func=scheduled_check_in_task,
|
||||
trigger=CronTrigger.from_crontab(cron_str),
|
||||
id=job_id,
|
||||
name=f"CheckIn-Task-{task.id}",
|
||||
args=[task.id],
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(f"✅ 任务 {task.id} 已重新加载到调度器: {cron_str}")
|
||||
else:
|
||||
logger.warning(f"任务 {task.id} 的 cron 表达式无效: {cron_str}")
|
||||
else:
|
||||
logger.info(f"任务 {task.id} 未启用或无 cron 表达式,已从调度器移除")
|
||||
|
||||
TaskService._sync_scheduler_for_task(task)
|
||||
except Exception as e:
|
||||
logger.error(f"重新加载任务 {task.id} 到调度器失败: {str(e)}")
|
||||
|
||||
@@ -370,15 +417,6 @@ class TaskService:
|
||||
task_id: 任务 ID
|
||||
"""
|
||||
try:
|
||||
from backend.services.scheduler_service import scheduler
|
||||
|
||||
if not scheduler:
|
||||
return
|
||||
|
||||
job_id = f"task_{task_id}"
|
||||
if scheduler.get_job(job_id):
|
||||
scheduler.remove_job(job_id)
|
||||
logger.info(f"✅ 任务 {task_id} 已从调度器移除")
|
||||
|
||||
TaskService._remove_scheduler_for_task(task_id)
|
||||
except Exception as e:
|
||||
logger.error(f"从调度器移除任务 {task_id} 失败: {str(e)}")
|
||||
|
||||
@@ -4,7 +4,9 @@ from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from backend.models import TaskTemplate, CheckInTask
|
||||
from backend.exceptions import BaseAPIException
|
||||
from backend.models import CheckInTask, TaskTemplate
|
||||
from backend.schemas.task import TaskCreate
|
||||
from backend.schemas.template import TemplateCreate, TemplateUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -134,6 +136,9 @@ class TemplateService:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
|
||||
)
|
||||
except (BaseAPIException, HTTPException):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建模板失败: {str(e)}")
|
||||
db.rollback()
|
||||
@@ -219,6 +224,9 @@ class TemplateService:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
|
||||
)
|
||||
except (BaseAPIException, HTTPException):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新模板失败: {str(e)}")
|
||||
db.rollback()
|
||||
@@ -247,6 +255,9 @@ class TemplateService:
|
||||
db.commit()
|
||||
logger.info(f"删除模板成功: {template.name} (ID: {template_id})")
|
||||
return True
|
||||
except (BaseAPIException, HTTPException):
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除模板失败: {str(e)}")
|
||||
db.rollback()
|
||||
@@ -424,6 +435,10 @@ class TemplateService:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"解析模板配置失败"
|
||||
)
|
||||
except BaseAPIException:
|
||||
raise
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"组装 payload 失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
@@ -518,31 +533,32 @@ class TemplateService:
|
||||
signature = payload.get("Signature", "Unknown")
|
||||
task_name = f"{template.name} - {signature}"
|
||||
|
||||
# 创建任务(包含 cron_expression)
|
||||
try:
|
||||
task = CheckInTask(
|
||||
from backend.services.task_service import TaskService
|
||||
|
||||
task = TaskService.create_task(
|
||||
user_id=user_id,
|
||||
payload_config=json.dumps(payload, ensure_ascii=False),
|
||||
name=task_name,
|
||||
is_active=True,
|
||||
cron_expression=cron_expression or "0 20 * * *",
|
||||
task_data=TaskCreate(
|
||||
payload_config=json.dumps(payload, ensure_ascii=False),
|
||||
name=task_name,
|
||||
is_active=True,
|
||||
cron_expression=cron_expression or "0 20 * * *",
|
||||
),
|
||||
db=db,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
|
||||
logger.info(
|
||||
f"从模板创建任务成功: {task.name} (ID: {task.id}, 模板: {template.name}, ThreadId: {thread_id})"
|
||||
)
|
||||
|
||||
# 如果任务启用且包含 cron_expression,立即添加到调度器
|
||||
if task.is_scheduled_enabled:
|
||||
from backend.services.task_service import TaskService
|
||||
|
||||
TaskService._reload_scheduler_for_task(task, db)
|
||||
|
||||
return task
|
||||
|
||||
except BaseAPIException:
|
||||
db.rollback()
|
||||
raise
|
||||
except HTTPException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"从模板创建任务失败: {str(e)}")
|
||||
db.rollback()
|
||||
|
||||
Reference in New Issue
Block a user