feat(backend): harden task boundaries

This commit is contained in:
2026-05-05 00:55:29 +08:00
parent 817540f8a0
commit e243dccfd7
15 changed files with 694 additions and 147 deletions
+18 -2
View File
@@ -11,9 +11,11 @@ from backend.services.check_in_service import CheckInService
from backend.services.admin_service import AdminService from backend.services.admin_service import AdminService
from backend.dependencies import get_current_admin_user from backend.dependencies import get_current_admin_user
from backend.config import settings from backend.config import settings
from backend.exceptions import BaseAPIException
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
class BatchToggleTasksRequest(BaseModel): class BatchToggleTasksRequest(BaseModel):
@@ -43,13 +45,21 @@ async def batch_toggle_tasks(
task.is_active = request.is_active task.is_active = request.is_active
count += 1 count += 1
from backend.services.scheduler_service import sync_scheduled_task
db.commit() db.commit()
for task_id in request.task_ids:
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
if task:
sync_scheduled_task(task)
return { return {
"success": True, "success": True,
"message": f"{'启用' if request.is_active else '禁用'} {count} 个任务", "message": f"{'启用' if request.is_active else '禁用'} {count} 个任务",
"count": count, "count": count,
} }
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量操作失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量操作失败: {str(e)}"
@@ -72,6 +82,8 @@ async def batch_check_in(
try: try:
result = CheckInService.batch_check_in_tasks(request.task_ids, db) result = CheckInService.batch_check_in_tasks(request.task_ids, db)
return result return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量打卡失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量打卡失败: {str(e)}"
@@ -235,6 +247,8 @@ async def get_system_stats(
"tokens": {"expiring_soon": expiring_users}, "tokens": {"expiring_soon": expiring_users},
} }
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}"
@@ -251,6 +265,8 @@ async def get_pending_users(
try: try:
users = AdminService.get_pending_users(db) users = AdminService.get_pending_users(db)
return users return users
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -274,7 +290,7 @@ async def approve_user(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"]) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"])
return result return result
except HTTPException: except EXPECTED_API_EXCEPTIONS:
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
@@ -298,7 +314,7 @@ async def reject_user(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"]) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"])
return result return result
except HTTPException: except EXPECTED_API_EXCEPTIONS:
raise raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
+12 -1
View File
@@ -12,10 +12,11 @@ from backend.schemas.auth import (
AliasLoginResponse, AliasLoginResponse,
) )
from backend.services.auth_service import AuthService from backend.services.auth_service import AuthService
from backend.exceptions import BusinessLogicError from backend.exceptions import BaseAPIException, BusinessLogicError
from backend.limiter import limiter from backend.limiter import limiter
router = APIRouter() router = APIRouter()
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
@router.post("/request_qrcode", response_model=dict, summary="请求 QQ 扫码二维码") @router.post("/request_qrcode", response_model=dict, summary="请求 QQ 扫码二维码")
@@ -68,6 +69,8 @@ async def request_qrcode(
) )
return result return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建扫码会话失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建扫码会话失败: {str(e)}"
@@ -95,6 +98,8 @@ async def get_qrcode_status(session_id: str, db: Session = Depends(get_db)):
try: try:
result = AuthService.get_qrcode_status(session_id, db) result = AuthService.get_qrcode_status(session_id, db)
return result return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"查询扫码状态失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"查询扫码状态失败: {str(e)}"
@@ -113,6 +118,8 @@ async def cancel_qrcode_session(session_id: str):
try: try:
result = AuthService.cancel_qrcode_session(session_id) result = AuthService.cancel_qrcode_session(session_id)
return result return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"取消会话失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"取消会话失败: {str(e)}"
@@ -136,6 +143,8 @@ async def verify_token(request: TokenVerifyRequest, db: Session = Depends(get_db
try: try:
result = AuthService.verify_token(request.authorization, db) result = AuthService.verify_token(request.authorization, db)
return result return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"验证 Token 失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"验证 Token 失败: {str(e)}"
@@ -170,6 +179,8 @@ async def alias_login(
try: try:
result = AuthService.alias_login(login_data.alias, login_data.password, db) result = AuthService.alias_login(login_data.alias, login_data.password, db)
return result return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"别名登录失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"别名登录失败: {str(e)}"
+12
View File
@@ -12,8 +12,10 @@ from backend.schemas.check_in import (
from backend.services.check_in_service import CheckInService from backend.services.check_in_service import CheckInService
from backend.services.task_service import TaskService from backend.services.task_service import TaskService
from backend.dependencies import get_current_user, get_current_admin_user from backend.dependencies import get_current_user, get_current_admin_user
from backend.exceptions import BaseAPIException
router = APIRouter() router = APIRouter()
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
@router.post("/manual/{task_id}", summary="手动触发打卡(异步)") @router.post("/manual/{task_id}", summary="手动触发打卡(异步)")
@@ -38,6 +40,8 @@ async def manual_check_in(
try: try:
result = CheckInService.start_async_check_in(task, "manual", db) result = CheckInService.start_async_check_in(task, "manual", db)
return result return result
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"启动打卡任务失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"启动打卡任务失败: {str(e)}"
@@ -111,6 +115,8 @@ async def get_task_check_in_records(
task_id, db, skip, limit, status_filter, trigger_type task_id, db, skip, limit, status_filter, trigger_type
) )
return PaginatedResponse(records=records, total=total, skip=skip, limit=limit) return PaginatedResponse(records=records, total=total, skip=skip, limit=limit)
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
@@ -145,6 +151,8 @@ async def get_my_check_in_records(
current_user.id, db, skip, limit, status_filter, trigger_type current_user.id, db, skip, limit, status_filter, trigger_type
) )
return PaginatedResponse(records=records, total=total, skip=skip, limit=limit) return PaginatedResponse(records=records, total=total, skip=skip, limit=limit)
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
@@ -181,6 +189,8 @@ async def get_all_check_in_records(
CheckInService.enrich_record_with_user_task_info(record, db) for record in records CheckInService.enrich_record_with_user_task_info(record, db) for record in records
] ]
return PaginatedResponse(records=enriched_records, total=total, skip=skip, limit=limit) return PaginatedResponse(records=enriched_records, total=total, skip=skip, limit=limit)
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
@@ -213,6 +223,8 @@ async def get_check_in_records_count(
total = query.count() total = query.count()
return {"total": total} return {"total": total}
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}"
+5
View File
@@ -4,6 +4,7 @@ from typing import List
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from backend.exceptions import BaseAPIException
from backend.models import get_db, User from backend.models import get_db, User
from backend.schemas.task import TaskUpdate, TaskResponse from backend.schemas.task import TaskUpdate, TaskResponse
from backend.services.task_service import TaskService from backend.services.task_service import TaskService
@@ -37,6 +38,8 @@ async def get_tasks(
# 为每个任务添加额外信息 # 为每个任务添加额外信息
enriched_tasks = [TaskService.enrich_task_with_check_in_info(task, db) for task in tasks] enriched_tasks = [TaskService.enrich_task_with_check_in_info(task, db) for task in tasks]
return enriched_tasks return enriched_tasks
except (BaseAPIException, HTTPException):
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}"
@@ -173,6 +176,8 @@ async def validate_cron_expression(request: CronValidateRequest):
"next_times": next_times, "next_times": next_times,
"description": generate_cron_description(cron_expr), "description": generate_cron_description(cron_expr),
} }
except HTTPException:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的 Crontab 表达式: {str(e)}" status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的 Crontab 表达式: {str(e)}"
+10
View File
@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
from backend.models import User from backend.models import User
from backend.dependencies import get_db, get_current_user, get_current_admin_user from backend.dependencies import get_db, get_current_user, get_current_admin_user
from backend.exceptions import BaseAPIException
from backend.schemas.template import ( from backend.schemas.template import (
TemplateCreate, TemplateCreate,
TemplateUpdate, TemplateUpdate,
@@ -17,6 +18,9 @@ from backend.services.template_service import TemplateService
router = APIRouter() router = APIRouter()
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
@router.get("/", response_model=List[TemplateResponse], summary="获取所有模板列表") @router.get("/", response_model=List[TemplateResponse], summary="获取所有模板列表")
async def get_all_templates( async def get_all_templates(
skip: int = Query(0, ge=0, description="跳过记录数"), skip: int = Query(0, ge=0, description="跳过记录数"),
@@ -35,6 +39,8 @@ async def get_all_templates(
try: try:
templates = TemplateService.get_all_templates(db, skip, limit, is_active) templates = TemplateService.get_all_templates(db, skip, limit, is_active)
return templates return templates
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}"
@@ -57,6 +63,8 @@ async def get_active_templates(
try: try:
templates = TemplateService.get_all_templates(db, skip, limit, is_active=True) templates = TemplateService.get_all_templates(db, skip, limit, is_active=True)
return templates return templates
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}"
@@ -115,6 +123,8 @@ async def preview_template(
"preview_payload": preview_payload, "preview_payload": preview_payload,
"field_config": merged_config, "field_config": merged_config,
} }
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"生成预览失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"生成预览失败: {str(e)}"
+19 -1
View File
@@ -14,9 +14,15 @@ from backend.schemas.task import TaskResponse
from backend.services.user_service import UserService from backend.services.user_service import UserService
from backend.services.task_service import TaskService from backend.services.task_service import TaskService
from backend.dependencies import get_current_user, get_current_admin_user from backend.dependencies import get_current_user, get_current_admin_user
from backend.exceptions import ValidationError, AuthorizationError, ResourceNotFoundError from backend.exceptions import (
AuthorizationError,
BaseAPIException,
ResourceNotFoundError,
ValidationError,
)
router = APIRouter() router = APIRouter()
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
@router.post( @router.post(
@@ -42,6 +48,8 @@ async def create_user(
return user return user
except ValueError as e: except ValueError as e:
raise ValidationError(str(e)) raise ValidationError(str(e))
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建用户失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建用户失败: {str(e)}"
@@ -103,6 +111,8 @@ async def update_current_user_profile(
return user return user
except ValueError as e: except ValueError as e:
raise ValidationError(str(e)) raise ValidationError(str(e))
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新个人信息失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新个人信息失败: {str(e)}"
@@ -144,6 +154,8 @@ async def get_current_user_tasks(
try: try:
tasks = TaskService.get_user_tasks(current_user.id, db, include_inactive) tasks = TaskService.get_user_tasks(current_user.id, db, include_inactive)
return tasks return tasks
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}"
@@ -170,6 +182,8 @@ async def get_all_users(
try: try:
users = UserService.get_all_users(db, skip, limit, search, role) users = UserService.get_all_users(db, skip, limit, search, role)
return users return users
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取用户列表失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取用户列表失败: {str(e)}"
@@ -252,6 +266,8 @@ async def update_user(
return user return user
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新用户失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新用户失败: {str(e)}"
@@ -272,6 +288,8 @@ async def delete_user(
return None return None
except ValueError as e: except ValueError as e:
raise ResourceNotFoundError(str(e)) raise ResourceNotFoundError(str(e))
except EXPECTED_API_EXCEPTIONS:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"删除用户失败: {str(e)}" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"删除用户失败: {str(e)}"
+8 -1
View File
@@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.sql import func from sqlalchemy.sql import func
@@ -23,6 +23,12 @@ class CheckInTask(Base):
index=True, index=True,
comment="用户 ID", comment="用户 ID",
) )
thread_id: Mapped[str | None] = mapped_column(
String(100),
index=True,
nullable=True,
comment="接龙项目 ID",
)
payload_config: Mapped[str] = mapped_column( payload_config: Mapped[str] = mapped_column(
Text, Text,
default="{}", default="{}",
@@ -57,6 +63,7 @@ class CheckInTask(Base):
# 添加索引:加速查询 # 添加索引:加速查询
__table_args__ = ( __table_args__ = (
Index("ix_task_user_active", "user_id", "is_active"), Index("ix_task_user_active", "user_id", "is_active"),
UniqueConstraint("user_id", "thread_id", name="uq_task_user_thread_id"),
Index("ix_task_cron", "cron_expression"), # 加速查询启用了定时打卡的任务 Index("ix_task_cron", "cron_expression"), # 加速查询启用了定时打卡的任务
) )
+12 -10
View File
@@ -1,6 +1,6 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from sqlalchemy import create_engine, event from sqlalchemy import DateTime, create_engine, event, inspect
from sqlalchemy.orm import DeclarativeBase, sessionmaker from sqlalchemy.orm import DeclarativeBase, sessionmaker
from backend.config import settings from backend.config import settings
@@ -24,21 +24,23 @@ class Base(DeclarativeBase):
@event.listens_for(Base, "load", propagate=True) @event.listens_for(Base, "load", propagate=True)
def receive_load(target, context): def receive_load(target, context):
"""在从数据库加载对象后,将所有 datetime 字段转换为 timezone-aware (UTC)""" """在从数据库加载对象后,将所有 datetime 字段转换为 timezone-aware (UTC)"""
for attr_name in dir(target): mapper = inspect(target).mapper
# 跳过私有属性和方法 for attr in mapper.column_attrs:
if attr_name.startswith("_"): column = attr.columns[0]
if not isinstance(column.type, DateTime):
continue continue
try: try:
attr_value = getattr(target, attr_name) attr_value = getattr(target, attr.key)
# 如果是 naive datetime,添加 UTC timezone
if isinstance(attr_value, datetime) and attr_value.tzinfo is None:
setattr(target, attr_name, attr_value.replace(tzinfo=timezone.utc))
except (AttributeError, TypeError): except (AttributeError, TypeError):
# 某些属性可能无法访问或设置,跳过
continue continue
if isinstance(attr_value, datetime) and attr_value.tzinfo is None:
try:
setattr(target, attr.key, attr_value.replace(tzinfo=timezone.utc))
except (AttributeError, TypeError):
continue
def get_db(): def get_db():
"""依赖注入:获取数据库会话""" """依赖注入:获取数据库会话"""
@@ -0,0 +1,113 @@
"""
数据库迁移脚本:添加打卡任务 thread_id 字段并回填。
运行方式:
uv run python -m backend.scripts.migrate_add_task_thread_id
"""
import json
import logging
import sys
from pathlib import Path
APPS_DIR = Path(__file__).resolve().parents[2]
sys.path.insert(0, str(APPS_DIR))
from sqlalchemy import text
from backend.models.database import engine
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def _extract_thread_id(payload_config: str | None) -> str | None:
if not payload_config:
return None
try:
payload = json.loads(payload_config)
except json.JSONDecodeError:
return None
if not isinstance(payload, dict):
return None
thread_id = payload.get("ThreadId")
value = str(thread_id).strip() if thread_id is not None else ""
return value or None
def migrate() -> None:
"""执行迁移。"""
logger.info("开始迁移:添加 check_in_tasks.thread_id 字段...")
with engine.connect() as conn:
result = conn.execute(text("PRAGMA table_info(check_in_tasks)"))
columns = [row[1] for row in result]
if "thread_id" not in columns:
logger.info("添加 thread_id 字段...")
conn.execute(text("ALTER TABLE check_in_tasks ADD COLUMN thread_id VARCHAR(100)"))
conn.commit()
logger.info("✓ thread_id 字段添加成功")
else:
logger.info("✓ thread_id 字段已存在,跳过")
rows = conn.execute(text("SELECT id, payload_config FROM check_in_tasks")).fetchall()
invalid_ids: list[int] = []
seen: dict[tuple[int, str], int] = {}
duplicate_ids: list[int] = []
full_rows = conn.execute(
text("SELECT id, user_id, payload_config FROM check_in_tasks")
).fetchall()
for row in full_rows:
thread_id = _extract_thread_id(row.payload_config)
if not thread_id:
invalid_ids.append(row.id)
continue
key = (row.user_id, thread_id)
if key in seen:
duplicate_ids.append(row.id)
else:
seen[key] = row.id
if invalid_ids or duplicate_ids:
messages = []
if invalid_ids:
messages.append(f"payload_config 缺少有效 ThreadId 的任务: {invalid_ids}")
if duplicate_ids:
messages.append(f"同用户 ThreadId 重复的任务: {duplicate_ids}")
raise RuntimeError("".join(messages))
for row in rows:
thread_id = _extract_thread_id(row.payload_config)
if thread_id:
conn.execute(
text("UPDATE check_in_tasks SET thread_id = :thread_id WHERE id = :id"),
{"thread_id": thread_id, "id": row.id},
)
conn.commit()
indexes = conn.execute(text("PRAGMA index_list(check_in_tasks)")).fetchall()
index_names = [row[1] for row in indexes]
if "ix_task_user_thread_id_unique" not in index_names:
logger.info("添加用户级 thread_id 唯一索引...")
conn.execute(
text(
"CREATE UNIQUE INDEX ix_task_user_thread_id_unique "
"ON check_in_tasks (user_id, thread_id)"
)
)
conn.commit()
logger.info("✓ 用户级 thread_id 唯一索引添加成功")
else:
logger.info("✓ 用户级 thread_id 唯一索引已存在,跳过")
logger.info("✅ 迁移完成!任务 thread_id 身份字段已启用")
if __name__ == "__main__":
try:
migrate()
except Exception as e:
logger.error(f"❌ 迁移失败: {e}")
sys.exit(1)
+3 -3
View File
@@ -541,10 +541,10 @@ class CheckInService:
) )
task_name = task.name task_name = task.name
# payload_config 提取 ThreadId # 优先从显式任务身份读取 ThreadId,兼容旧数据时回退到 payload_config
from backend.utils.json_helpers import extract_thread_id from backend.utils.json_helpers import resolve_task_thread_id
thread_id = extract_thread_id(task.payload_config) # type: ignore thread_id = resolve_task_thread_id(task)
# 转换为字典并添加额外字段 # 转换为字典并添加额外字段
record_dict = { record_dict = {
+67 -26
View File
@@ -20,6 +20,70 @@ scheduler = None
scheduler_lock = None scheduler_lock = None
def _task_job_id(task_id: int) -> str:
return f"task_{task_id}"
def _remove_task_job(job_id: str, scheduler_instance=None) -> bool:
active_scheduler = scheduler_instance or scheduler
if not active_scheduler:
return False
existing_job = active_scheduler.get_job(job_id)
if not existing_job:
return False
active_scheduler.remove_job(job_id)
logger.info(f"✅ 已移除调度任务: {job_id}")
return True
def remove_scheduled_task(task_id: int, scheduler_instance=None) -> bool:
"""从调度器移除指定任务。"""
active_scheduler = scheduler_instance or scheduler
if not active_scheduler:
logger.warning(f"调度器未启动,无法移除任务 {task_id}")
return False
return _remove_task_job(_task_job_id(task_id), active_scheduler)
def sync_scheduled_task(task: CheckInTask, scheduler_instance=None) -> bool:
"""
根据任务当前状态同步调度器中的对应 job。
返回 True 表示任务已成功安排到调度器,False 表示未安排或调度器不可用。
"""
active_scheduler = scheduler_instance or scheduler
if not active_scheduler:
logger.warning(f"调度器未启动,无法同步任务 {task.id}")
return False
job_id = _task_job_id(task.id)
_remove_task_job(job_id, active_scheduler)
if not task.is_scheduled_enabled:
logger.info(f"任务 {task.id} 未启用或无 cron 表达式,已从调度器移除")
return False
cron_str = str(task.cron_expression or "").strip()
if not cron_str or not croniter.is_valid(cron_str):
logger.warning(f"任务 {task.id} 的 cron 表达式无效: {cron_str}")
return False
active_scheduler.add_job(
func=scheduled_check_in_task,
trigger=CronTrigger.from_crontab(cron_str),
id=job_id,
name=f"CheckIn-Task-{task.id}",
args=[task.id],
replace_existing=True,
)
logger.info(f"✅ 任务 {task.id} 已同步到调度器: {cron_str}")
return True
def load_scheduled_tasks(db: Session, scheduler_instance): def load_scheduled_tasks(db: Session, scheduler_instance):
""" """
从数据库加载所有启用的定时任务并添加到 APScheduler 从数据库加载所有启用的定时任务并添加到 APScheduler
@@ -55,33 +119,10 @@ def load_scheduled_tasks(db: Session, scheduler_instance):
for task in tasks: for task in tasks:
try: try:
# 验证 cron 表达式 if sync_scheduled_task(task, scheduler_instance):
cron_str = str(task.cron_expression) if task.cron_expression else None loaded_count += 1
if not cron_str or not croniter.is_valid(cron_str): else:
logger.warning(f"跳过任务 {task.id}: 无效的 cron 表达式 '{task.cron_expression}'")
skipped_count += 1 skipped_count += 1
continue
# 创建任务 ID
job_id = f"task_{task.id}"
# 检查任务是否已存在
if scheduler_instance.get_job(job_id):
logger.debug(f"任务 {task.id} 已存在,跳过")
continue
# 添加任务到调度器
scheduler_instance.add_job(
func=scheduled_check_in_task,
trigger=CronTrigger.from_crontab(cron_str),
id=job_id,
name=f"CheckIn-Task-{task.id}",
args=[task.id],
replace_existing=True,
)
logger.info(f"✅ 加载任务 {task.id}: {task.name} (Cron: {task.cron_expression})")
loaded_count += 1
except Exception as e: except Exception as e:
logger.error(f"❌ 加载任务 {task.id} 时出错: {str(e)}") logger.error(f"❌ 加载任务 {task.id} 时出错: {str(e)}")
+124 -86
View File
@@ -3,6 +3,12 @@ from typing import List, Optional, Dict, Any
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import desc from sqlalchemy import desc
from backend.exceptions import (
InternalServerError,
ResourceConflictError,
ResourceNotFoundError,
ValidationError as APIValidationError,
)
from backend.models import User, CheckInTask, CheckInRecord from backend.models import User, CheckInTask, CheckInRecord
from backend.schemas.task import TaskCreate, TaskUpdate from backend.schemas.task import TaskCreate, TaskUpdate
@@ -12,6 +18,68 @@ logger = logging.getLogger(__name__)
class TaskService: class TaskService:
"""打卡任务服务""" """打卡任务服务"""
@staticmethod
def _normalize_thread_id(thread_id: Any) -> str:
value = str(thread_id).strip() if thread_id is not None else ""
if not value:
raise APIValidationError(
"payload_config 必须包含有效的 ThreadId 字段",
error_code="TASK_IDENTITY_INVALID",
)
return value
@staticmethod
def _extract_thread_id_from_payload(payload_config: str) -> str:
from backend.utils.json_helpers import safe_parse_payload
payload = safe_parse_payload(payload_config)
return TaskService._normalize_thread_id(payload.get("ThreadId"))
@staticmethod
def _resolve_thread_id_for_task(task: CheckInTask) -> Optional[str]:
thread_id = getattr(task, "thread_id", None)
if thread_id is not None:
value = str(thread_id).strip()
if value:
return value
from backend.utils.json_helpers import extract_thread_id
legacy_thread_id = extract_thread_id(task.payload_config)
if legacy_thread_id is None:
return None
value = str(legacy_thread_id).strip()
return value or None
@staticmethod
def _ensure_unique_thread_id(
db: Session, user_id: int, thread_id: str, exclude_task_id: int | None = None
) -> None:
query = db.query(CheckInTask.id).filter(
CheckInTask.user_id == user_id, CheckInTask.thread_id == thread_id
)
if exclude_task_id is not None:
query = query.filter(CheckInTask.id != exclude_task_id)
conflict = query.first()
if conflict:
raise ResourceConflictError(
f"该接龙中已存在任务。ThreadId: {thread_id}",
error_code="TASK_IDENTITY_CONFLICT",
)
@staticmethod
def _sync_scheduler_for_task(task: CheckInTask) -> None:
from backend.services.scheduler_service import sync_scheduled_task
sync_scheduled_task(task)
@staticmethod
def _remove_scheduler_for_task(task_id: int) -> None:
from backend.services.scheduler_service import remove_scheduled_task
remove_scheduled_task(task_id)
@staticmethod @staticmethod
def create_task(user_id: int, task_data: TaskCreate, db: Session) -> CheckInTask: def create_task(user_id: int, task_data: TaskCreate, db: Session) -> CheckInTask:
""" """
@@ -25,34 +93,16 @@ class TaskService:
Returns: Returns:
创建的任务对象 创建的任务对象
""" """
import json
# 1. 检查用户是否存在 # 1. 检查用户是否存在
user = db.query(User).filter(User.id == user_id).first() user = db.query(User).filter(User.id == user_id).first()
if not user: if not user:
raise ValueError(f"用户 ID {user_id} 不存在") raise ResourceNotFoundError(f"用户 ID {user_id} 不存在", error_code="USER_NOT_FOUND")
# 2. 从 payload_config 中提取 ThreadId 用于唯一性校验 # 2. 从 payload_config 中提取 ThreadId 用于唯一性校验
from backend.utils.json_helpers import safe_parse_payload, extract_thread_id thread_id = TaskService._extract_thread_id_from_payload(task_data.payload_config)
payload = safe_parse_payload(task_data.payload_config)
thread_id = payload.get("ThreadId")
if not thread_id:
raise ValueError("payload_config 中缺少 ThreadId")
# 3. 验证唯一性:同一用户在同一个接龙中不能有重复的任务 # 3. 验证唯一性:同一用户在同一个接龙中不能有重复的任务
existing_tasks = ( TaskService._ensure_unique_thread_id(db, user_id, thread_id)
db.query(CheckInTask.payload_config).filter(CheckInTask.user_id == user_id).all()
)
for (payload_config,) in existing_tasks:
existing_thread_id = extract_thread_id(payload_config)
# extract_thread_id 已处理异常,失败时返回 None
if existing_thread_id and existing_thread_id == thread_id:
logger.warning(
f"⚠️ 任务创建冲突 - User: {user.alias}({user_id}), ThreadId: {thread_id}"
)
raise ValueError(f"该接龙中已存在任务。ThreadId: {thread_id}")
# 4. 记录日志 # 4. 记录日志
task_name = task_data.name or f"接龙任务 {thread_id}" task_name = task_data.name or f"接龙任务 {thread_id}"
@@ -61,9 +111,11 @@ class TaskService:
# 5. 创建任务 # 5. 创建任务
task = CheckInTask( task = CheckInTask(
user_id=user_id, user_id=user_id,
thread_id=thread_id,
payload_config=task_data.payload_config, payload_config=task_data.payload_config,
name=task_data.name or task_name, name=task_data.name or task_name,
is_active=task_data.is_active if task_data.is_active is not None else True, is_active=task_data.is_active if task_data.is_active is not None else True,
cron_expression=task_data.cron_expression or "0 20 * * *",
) )
try: try:
@@ -75,14 +127,19 @@ class TaskService:
) )
# 如果任务启用且包含 cron_expression,立即添加到调度器 # 如果任务启用且包含 cron_expression,立即添加到调度器
if task.is_scheduled_enabled: TaskService._sync_scheduler_for_task(task)
TaskService._reload_scheduler_for_task(task, db)
return task return task
except APIValidationError:
db.rollback()
raise
except ResourceConflictError:
db.rollback()
raise
except Exception as e: except Exception as e:
db.rollback() db.rollback()
logger.error(f"❌ 任务创建失败: {str(e)}") logger.error(f"❌ 任务创建失败: {str(e)}")
raise ValueError(f"任务创建失败: {str(e)}") raise InternalServerError(f"任务创建失败: {str(e)}")
@staticmethod @staticmethod
def get_task(task_id: int, db: Session) -> Optional[CheckInTask]: def get_task(task_id: int, db: Session) -> Optional[CheckInTask]:
@@ -110,8 +167,6 @@ class TaskService:
Returns: Returns:
包含额外信息的任务字典 包含额外信息的任务字典
""" """
from backend.utils.json_helpers import extract_thread_id
# 获取最后一次打卡记录 # 获取最后一次打卡记录
last_record = ( last_record = (
db.query(CheckInRecord) db.query(CheckInRecord)
@@ -120,8 +175,8 @@ class TaskService:
.first() .first()
) )
# payload_config 提取 ThreadId # 优先使用持久化的 ThreadId,兼容旧数据时回退到 payload_config
thread_id = extract_thread_id(task.payload_config) # type: ignore thread_id = TaskService._resolve_thread_id_for_task(task)
# 转换为字典并添加额外字段 # 转换为字典并添加额外字段
task_dict = { task_dict = {
@@ -192,7 +247,7 @@ class TaskService:
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first() task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
if not task: if not task:
return None raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
# 更新字段 # 更新字段
update_data = task_data.model_dump(exclude_unset=True) update_data = task_data.model_dump(exclude_unset=True)
@@ -201,17 +256,34 @@ class TaskService:
cron_changed = "cron_expression" in update_data cron_changed = "cron_expression" in update_data
active_changed = "is_active" in update_data active_changed = "is_active" in update_data
if "payload_config" in update_data:
new_thread_id = TaskService._extract_thread_id_from_payload(
update_data["payload_config"]
)
TaskService._ensure_unique_thread_id(
db, task.user_id, new_thread_id, exclude_task_id=task.id
)
task.thread_id = new_thread_id
for field, value in update_data.items(): for field, value in update_data.items():
setattr(task, field, value) setattr(task, field, value)
db.commit() try:
db.refresh(task) db.commit()
db.refresh(task)
except ResourceConflictError:
db.rollback()
raise
except Exception as e:
db.rollback()
logger.error(f"任务 {task_id} 更新失败: {str(e)}")
raise InternalServerError(f"任务更新失败: {str(e)}")
logger.info(f"任务 {task_id} 已更新") logger.info(f"任务 {task_id} 已更新")
# 如果 cron_expression 或 is_active 发生变化,重新加载调度器 # 如果 cron_expression 或 is_active 发生变化,重新加载调度器
if cron_changed or active_changed: if cron_changed or active_changed:
TaskService._reload_scheduler_for_task(task, db) TaskService._sync_scheduler_for_task(task)
return task return task
@@ -230,15 +302,20 @@ class TaskService:
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first() task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
if not task: if not task:
return False raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
db.delete(task) try:
db.commit() db.delete(task)
db.commit()
except Exception as e:
db.rollback()
logger.error(f"任务 {task_id} 删除失败: {str(e)}")
raise InternalServerError(f"任务删除失败: {str(e)}")
logger.info(f"任务 {task_id} 已删除") logger.info(f"任务 {task_id} 已删除")
# 从调度器中移除该任务 # 从调度器中移除该任务
TaskService._remove_task_from_scheduler(task_id) TaskService._remove_scheduler_for_task(task_id)
return True return True
@@ -257,16 +334,21 @@ class TaskService:
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first() task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
if not task: if not task:
return None raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
task.is_active = not task.is_active task.is_active = not task.is_active
db.commit() try:
db.refresh(task) db.commit()
db.refresh(task)
except Exception as e:
db.rollback()
logger.error(f"任务 {task_id} 状态切换失败: {str(e)}")
raise InternalServerError(f"任务状态切换失败: {str(e)}")
logger.info(f"任务 {task_id} 状态已切换为: {'启用' if task.is_active else '禁用'}") logger.info(f"任务 {task_id} 状态已切换为: {'启用' if task.is_active else '禁用'}")
# 重新加载调度器 # 重新加载调度器
TaskService._reload_scheduler_for_task(task, db) TaskService._sync_scheduler_for_task(task)
return task return task
@@ -322,42 +404,7 @@ class TaskService:
db: 数据库会话 db: 数据库会话
""" """
try: try:
from backend.services.scheduler_service import scheduler TaskService._sync_scheduler_for_task(task)
from apscheduler.triggers.cron import CronTrigger
from croniter import croniter
if not scheduler:
logger.warning(f"调度器未启动,无法加载任务 {task.id}")
return
job_id = f"task_{task.id}"
# 先移除旧的任务(如果存在)
existing_job = scheduler.get_job(job_id)
if existing_job:
scheduler.remove_job(job_id)
logger.info(f"从调度器移除旧任务: {job_id}")
# 如果任务启用且有有效的 cron 表达式,添加新任务
if task.is_scheduled_enabled:
cron_str = str(task.cron_expression)
if croniter.is_valid(cron_str):
from backend.services.scheduler_service import scheduled_check_in_task
scheduler.add_job(
func=scheduled_check_in_task,
trigger=CronTrigger.from_crontab(cron_str),
id=job_id,
name=f"CheckIn-Task-{task.id}",
args=[task.id],
replace_existing=True,
)
logger.info(f"✅ 任务 {task.id} 已重新加载到调度器: {cron_str}")
else:
logger.warning(f"任务 {task.id} 的 cron 表达式无效: {cron_str}")
else:
logger.info(f"任务 {task.id} 未启用或无 cron 表达式,已从调度器移除")
except Exception as e: except Exception as e:
logger.error(f"重新加载任务 {task.id} 到调度器失败: {str(e)}") logger.error(f"重新加载任务 {task.id} 到调度器失败: {str(e)}")
@@ -370,15 +417,6 @@ class TaskService:
task_id: 任务 ID task_id: 任务 ID
""" """
try: try:
from backend.services.scheduler_service import scheduler TaskService._remove_scheduler_for_task(task_id)
if not scheduler:
return
job_id = f"task_{task_id}"
if scheduler.get_job(job_id):
scheduler.remove_job(job_id)
logger.info(f"✅ 任务 {task_id} 已从调度器移除")
except Exception as e: except Exception as e:
logger.error(f"从调度器移除任务 {task_id} 失败: {str(e)}") logger.error(f"从调度器移除任务 {task_id} 失败: {str(e)}")
+32 -16
View File
@@ -4,7 +4,9 @@ from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi import HTTPException, status from fastapi import HTTPException, status
from backend.models import TaskTemplate, CheckInTask from backend.exceptions import BaseAPIException
from backend.models import CheckInTask, TaskTemplate
from backend.schemas.task import TaskCreate
from backend.schemas.template import TemplateCreate, TemplateUpdate from backend.schemas.template import TemplateCreate, TemplateUpdate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -134,6 +136,9 @@ class TemplateService:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}" status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
) )
except (BaseAPIException, HTTPException):
db.rollback()
raise
except Exception as e: except Exception as e:
logger.error(f"创建模板失败: {str(e)}") logger.error(f"创建模板失败: {str(e)}")
db.rollback() db.rollback()
@@ -219,6 +224,9 @@ class TemplateService:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}" status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
) )
except (BaseAPIException, HTTPException):
db.rollback()
raise
except Exception as e: except Exception as e:
logger.error(f"更新模板失败: {str(e)}") logger.error(f"更新模板失败: {str(e)}")
db.rollback() db.rollback()
@@ -247,6 +255,9 @@ class TemplateService:
db.commit() db.commit()
logger.info(f"删除模板成功: {template.name} (ID: {template_id})") logger.info(f"删除模板成功: {template.name} (ID: {template_id})")
return True return True
except (BaseAPIException, HTTPException):
db.rollback()
raise
except Exception as e: except Exception as e:
logger.error(f"删除模板失败: {str(e)}") logger.error(f"删除模板失败: {str(e)}")
db.rollback() db.rollback()
@@ -424,6 +435,10 @@ class TemplateService:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"解析模板配置失败" status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"解析模板配置失败"
) )
except BaseAPIException:
raise
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"组装 payload 失败: {str(e)}") logger.error(f"组装 payload 失败: {str(e)}")
raise HTTPException( raise HTTPException(
@@ -518,31 +533,32 @@ class TemplateService:
signature = payload.get("Signature", "Unknown") signature = payload.get("Signature", "Unknown")
task_name = f"{template.name} - {signature}" task_name = f"{template.name} - {signature}"
# 创建任务(包含 cron_expression
try: try:
task = CheckInTask( from backend.services.task_service import TaskService
task = TaskService.create_task(
user_id=user_id, user_id=user_id,
payload_config=json.dumps(payload, ensure_ascii=False), task_data=TaskCreate(
name=task_name, payload_config=json.dumps(payload, ensure_ascii=False),
is_active=True, name=task_name,
cron_expression=cron_expression or "0 20 * * *", is_active=True,
cron_expression=cron_expression or "0 20 * * *",
),
db=db,
) )
db.add(task)
db.commit()
db.refresh(task)
logger.info( logger.info(
f"从模板创建任务成功: {task.name} (ID: {task.id}, 模板: {template.name}, ThreadId: {thread_id})" f"从模板创建任务成功: {task.name} (ID: {task.id}, 模板: {template.name}, ThreadId: {thread_id})"
) )
# 如果任务启用且包含 cron_expression,立即添加到调度器
if task.is_scheduled_enabled:
from backend.services.task_service import TaskService
TaskService._reload_scheduler_for_task(task, db)
return task return task
except BaseAPIException:
db.rollback()
raise
except HTTPException:
db.rollback()
raise
except Exception as e: except Exception as e:
logger.error(f"从模板创建任务失败: {str(e)}") logger.error(f"从模板创建任务失败: {str(e)}")
db.rollback() db.rollback()
+24 -1
View File
@@ -92,6 +92,29 @@ def build_task_info(task) -> Dict[str, str]:
包含 thread_id 和 name 的字典 包含 thread_id 和 name 的字典
""" """
return { return {
"thread_id": extract_thread_id(getattr(task, "payload_config", None)) or "未知", "thread_id": resolve_task_thread_id(task) or "未知",
"name": getattr(task, "name", None) or f"Task-{getattr(task, 'id', 'Unknown')}", "name": getattr(task, "name", None) or f"Task-{getattr(task, 'id', 'Unknown')}",
} }
def resolve_task_thread_id(task) -> Optional[str]:
"""
优先从显式字段读取任务 ThreadId,兼容旧数据时回退到 payload_config。
Args:
task: CheckInTask 或相似对象
Returns:
ThreadId 或 None
"""
thread_id = getattr(task, "thread_id", None)
if thread_id is not None:
value = str(thread_id).strip()
if value:
return value
legacy_thread_id = extract_thread_id(getattr(task, "payload_config", None))
if legacy_thread_id is None:
return None
value = str(legacy_thread_id).strip()
return value or None
+235
View File
@@ -0,0 +1,235 @@
from __future__ import annotations
import json
from datetime import datetime
from typing import Any
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from backend.exceptions import ResourceConflictError
from backend.models import Base, CheckInTask, User
from backend.schemas.task import TaskCreate, TaskUpdate
from backend.services import scheduler_service
from backend.services.task_service import TaskService
@pytest.fixture()
def db_session():
engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False})
Base.metadata.create_all(bind=engine)
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
try:
yield session
finally:
session.close()
engine.dispose()
def add_user(db_session, alias: str) -> User:
user = User(alias=alias)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user
def payload(thread_id: str, **extra: Any) -> str:
return json.dumps({"ThreadId": thread_id, **extra}, ensure_ascii=False)
def test_create_task_persists_thread_id_identity(db_session) -> None:
user = add_user(db_session, "alice")
task = TaskService.create_task(
user.id,
TaskCreate(
payload_config=payload("thread-1", Signature="sig"),
name="Morning",
cron_expression="0 8 * * *",
),
db_session,
)
assert task.thread_id == "thread-1"
assert task.payload_config == payload("thread-1", Signature="sig")
assert task.cron_expression == "0 8 * * *"
def test_same_user_duplicate_thread_id_is_structured_conflict(db_session) -> None:
user = add_user(db_session, "alice")
first = TaskCreate(payload_config=payload("thread-1"), name="First")
duplicate = TaskCreate(payload_config=payload("thread-1"), name="Duplicate")
TaskService.create_task(user.id, first, db_session)
with pytest.raises(ResourceConflictError) as exc_info:
TaskService.create_task(user.id, duplicate, db_session)
assert exc_info.value.status_code == 409
assert exc_info.value.error_code == "TASK_IDENTITY_CONFLICT"
def test_different_users_can_share_thread_id(db_session) -> None:
alice = add_user(db_session, "alice")
bob = add_user(db_session, "bob")
alice_task = TaskService.create_task(
alice.id, TaskCreate(payload_config=payload("shared-thread")), db_session
)
bob_task = TaskService.create_task(
bob.id, TaskCreate(payload_config=payload("shared-thread")), db_session
)
assert alice_task.thread_id == "shared-thread"
assert bob_task.thread_id == "shared-thread"
assert alice_task.user_id != bob_task.user_id
def test_update_task_rejects_duplicate_thread_id(db_session) -> None:
user = add_user(db_session, "alice")
first = TaskService.create_task(
user.id, TaskCreate(payload_config=payload("thread-1")), db_session
)
second = TaskService.create_task(
user.id, TaskCreate(payload_config=payload("thread-2")), db_session
)
with pytest.raises(ResourceConflictError):
TaskService.update_task(
second.id,
TaskUpdate(payload_config=payload("thread-1")),
db_session,
)
db_session.refresh(first)
db_session.refresh(second)
assert first.thread_id == "thread-1"
assert second.thread_id == "thread-2"
def test_task_enrichment_uses_stored_thread_id(db_session) -> None:
user = add_user(db_session, "alice")
task = CheckInTask(
user_id=user.id,
thread_id="stored-thread",
payload_config=payload("payload-thread"),
name="Task",
)
db_session.add(task)
db_session.commit()
db_session.refresh(task)
enriched = TaskService.enrich_task_with_check_in_info(task, db_session)
assert enriched["thread_id"] == "stored-thread"
def test_template_created_duplicate_task_preserves_structured_conflict(db_session) -> None:
from backend.models import TaskTemplate
from backend.services.template_service import TemplateService
user = add_user(db_session, "alice")
template = TaskTemplate(name="Daily", field_config="{}", is_active=True)
db_session.add(template)
db_session.commit()
db_session.refresh(template)
TemplateService.create_task_from_template(
template_id=template.id,
thread_id="thread-1",
field_values={},
user_id=user.id,
task_name="First",
db=db_session,
)
with pytest.raises(ResourceConflictError) as exc_info:
TemplateService.create_task_from_template(
template_id=template.id,
thread_id="thread-1",
field_values={},
user_id=user.id,
task_name="Duplicate",
db=db_session,
)
assert exc_info.value.status_code == 409
assert exc_info.value.error_code == "TASK_IDENTITY_CONFLICT"
class FakeScheduler:
def __init__(self) -> None:
self.jobs: dict[str, object] = {}
self.added: list[str] = []
self.removed: list[str] = []
def get_job(self, job_id: str):
return self.jobs.get(job_id)
def remove_job(self, job_id: str) -> None:
self.jobs.pop(job_id, None)
self.removed.append(job_id)
def add_job(self, **kwargs) -> None:
job_id = kwargs["id"]
self.jobs[job_id] = kwargs
self.added.append(job_id)
def test_scheduler_sync_skips_invalid_cron_and_removes_existing_job(monkeypatch) -> None:
fake_scheduler = FakeScheduler()
fake_scheduler.jobs["task_12"] = object()
monkeypatch.setattr(scheduler_service, "scheduler", fake_scheduler)
task = CheckInTask(id=12, thread_id="thread-1", payload_config=payload("thread-1"))
task.is_active = True
task.cron_expression = "invalid cron"
scheduled = scheduler_service.sync_scheduled_task(task)
assert scheduled is False
assert fake_scheduler.added == []
assert fake_scheduler.removed == ["task_12"]
def test_scheduler_sync_is_noop_when_scheduler_unavailable(monkeypatch) -> None:
monkeypatch.setattr(scheduler_service, "scheduler", None)
task = CheckInTask(id=12, thread_id="thread-1", payload_config=payload("thread-1"))
task.is_active = True
task.cron_expression = "0 8 * * *"
assert scheduler_service.sync_scheduled_task(task) is False
def test_datetime_normalization_does_not_access_unmapped_properties(
db_session, monkeypatch
) -> None:
user = add_user(db_session, "alice")
user_id = user.id
db_session.expunge_all()
def explode(self):
raise RuntimeError("unmapped property accessed")
monkeypatch.setattr(User, "explosive_property", property(explode), raising=False)
loaded = db_session.query(User).filter(User.id == user_id).one()
assert loaded.created_at.tzinfo is not None
@pytest.mark.asyncio
async def test_task_route_preserves_structured_service_errors(monkeypatch, db_session) -> None:
from backend.api.tasks import get_tasks
user = User(id=1, alias="alice")
def raise_conflict(*args, **kwargs):
raise ResourceConflictError("duplicate", error_code="TASK_IDENTITY_CONFLICT")
monkeypatch.setattr(TaskService, "get_user_tasks", raise_conflict)
with pytest.raises(ResourceConflictError):
await get_tasks(current_user=user, db=db_session)