mirror of
https://github.com/Cccc-owo/CheckInApp.git
synced 2026-06-17 05:56:29 +00:00
feat(backend): harden task boundaries
This commit is contained in:
@@ -11,9 +11,11 @@ from backend.services.check_in_service import CheckInService
|
|||||||
from backend.services.admin_service import AdminService
|
from backend.services.admin_service import AdminService
|
||||||
from backend.dependencies import get_current_admin_user
|
from backend.dependencies import get_current_admin_user
|
||||||
from backend.config import settings
|
from backend.config import settings
|
||||||
|
from backend.exceptions import BaseAPIException
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
|
||||||
|
|
||||||
|
|
||||||
class BatchToggleTasksRequest(BaseModel):
|
class BatchToggleTasksRequest(BaseModel):
|
||||||
@@ -43,13 +45,21 @@ async def batch_toggle_tasks(
|
|||||||
task.is_active = request.is_active
|
task.is_active = request.is_active
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
from backend.services.scheduler_service import sync_scheduled_task
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
for task_id in request.task_ids:
|
||||||
|
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
|
||||||
|
if task:
|
||||||
|
sync_scheduled_task(task)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": f"已{'启用' if request.is_active else '禁用'} {count} 个任务",
|
"message": f"已{'启用' if request.is_active else '禁用'} {count} 个任务",
|
||||||
"count": count,
|
"count": count,
|
||||||
}
|
}
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量操作失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量操作失败: {str(e)}"
|
||||||
@@ -72,6 +82,8 @@ async def batch_check_in(
|
|||||||
try:
|
try:
|
||||||
result = CheckInService.batch_check_in_tasks(request.task_ids, db)
|
result = CheckInService.batch_check_in_tasks(request.task_ids, db)
|
||||||
return result
|
return result
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量打卡失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"批量打卡失败: {str(e)}"
|
||||||
@@ -235,6 +247,8 @@ async def get_system_stats(
|
|||||||
"tokens": {"expiring_soon": expiring_users},
|
"tokens": {"expiring_soon": expiring_users},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}"
|
||||||
@@ -251,6 +265,8 @@ async def get_pending_users(
|
|||||||
try:
|
try:
|
||||||
users = AdminService.get_pending_users(db)
|
users = AdminService.get_pending_users(db)
|
||||||
return users
|
return users
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
@@ -274,7 +290,7 @@ async def approve_user(
|
|||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"])
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"])
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except HTTPException:
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -298,7 +314,7 @@ async def reject_user(
|
|||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"])
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=result["message"])
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except HTTPException:
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|||||||
@@ -12,10 +12,11 @@ from backend.schemas.auth import (
|
|||||||
AliasLoginResponse,
|
AliasLoginResponse,
|
||||||
)
|
)
|
||||||
from backend.services.auth_service import AuthService
|
from backend.services.auth_service import AuthService
|
||||||
from backend.exceptions import BusinessLogicError
|
from backend.exceptions import BaseAPIException, BusinessLogicError
|
||||||
from backend.limiter import limiter
|
from backend.limiter import limiter
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/request_qrcode", response_model=dict, summary="请求 QQ 扫码二维码")
|
@router.post("/request_qrcode", response_model=dict, summary="请求 QQ 扫码二维码")
|
||||||
@@ -68,6 +69,8 @@ async def request_qrcode(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建扫码会话失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建扫码会话失败: {str(e)}"
|
||||||
@@ -95,6 +98,8 @@ async def get_qrcode_status(session_id: str, db: Session = Depends(get_db)):
|
|||||||
try:
|
try:
|
||||||
result = AuthService.get_qrcode_status(session_id, db)
|
result = AuthService.get_qrcode_status(session_id, db)
|
||||||
return result
|
return result
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"查询扫码状态失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"查询扫码状态失败: {str(e)}"
|
||||||
@@ -113,6 +118,8 @@ async def cancel_qrcode_session(session_id: str):
|
|||||||
try:
|
try:
|
||||||
result = AuthService.cancel_qrcode_session(session_id)
|
result = AuthService.cancel_qrcode_session(session_id)
|
||||||
return result
|
return result
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"取消会话失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"取消会话失败: {str(e)}"
|
||||||
@@ -136,6 +143,8 @@ async def verify_token(request: TokenVerifyRequest, db: Session = Depends(get_db
|
|||||||
try:
|
try:
|
||||||
result = AuthService.verify_token(request.authorization, db)
|
result = AuthService.verify_token(request.authorization, db)
|
||||||
return result
|
return result
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"验证 Token 失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"验证 Token 失败: {str(e)}"
|
||||||
@@ -170,6 +179,8 @@ async def alias_login(
|
|||||||
try:
|
try:
|
||||||
result = AuthService.alias_login(login_data.alias, login_data.password, db)
|
result = AuthService.alias_login(login_data.alias, login_data.password, db)
|
||||||
return result
|
return result
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"别名登录失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"别名登录失败: {str(e)}"
|
||||||
|
|||||||
@@ -12,8 +12,10 @@ from backend.schemas.check_in import (
|
|||||||
from backend.services.check_in_service import CheckInService
|
from backend.services.check_in_service import CheckInService
|
||||||
from backend.services.task_service import TaskService
|
from backend.services.task_service import TaskService
|
||||||
from backend.dependencies import get_current_user, get_current_admin_user
|
from backend.dependencies import get_current_user, get_current_admin_user
|
||||||
|
from backend.exceptions import BaseAPIException
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/manual/{task_id}", summary="手动触发打卡(异步)")
|
@router.post("/manual/{task_id}", summary="手动触发打卡(异步)")
|
||||||
@@ -38,6 +40,8 @@ async def manual_check_in(
|
|||||||
try:
|
try:
|
||||||
result = CheckInService.start_async_check_in(task, "manual", db)
|
result = CheckInService.start_async_check_in(task, "manual", db)
|
||||||
return result
|
return result
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"启动打卡任务失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"启动打卡任务失败: {str(e)}"
|
||||||
@@ -111,6 +115,8 @@ async def get_task_check_in_records(
|
|||||||
task_id, db, skip, limit, status_filter, trigger_type
|
task_id, db, skip, limit, status_filter, trigger_type
|
||||||
)
|
)
|
||||||
return PaginatedResponse(records=records, total=total, skip=skip, limit=limit)
|
return PaginatedResponse(records=records, total=total, skip=skip, limit=limit)
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
|
||||||
@@ -145,6 +151,8 @@ async def get_my_check_in_records(
|
|||||||
current_user.id, db, skip, limit, status_filter, trigger_type
|
current_user.id, db, skip, limit, status_filter, trigger_type
|
||||||
)
|
)
|
||||||
return PaginatedResponse(records=records, total=total, skip=skip, limit=limit)
|
return PaginatedResponse(records=records, total=total, skip=skip, limit=limit)
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
|
||||||
@@ -181,6 +189,8 @@ async def get_all_check_in_records(
|
|||||||
CheckInService.enrich_record_with_user_task_info(record, db) for record in records
|
CheckInService.enrich_record_with_user_task_info(record, db) for record in records
|
||||||
]
|
]
|
||||||
return PaginatedResponse(records=enriched_records, total=total, skip=skip, limit=limit)
|
return PaginatedResponse(records=enriched_records, total=total, skip=skip, limit=limit)
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取打卡记录失败: {str(e)}"
|
||||||
@@ -213,6 +223,8 @@ async def get_check_in_records_count(
|
|||||||
total = query.count()
|
total = query.count()
|
||||||
|
|
||||||
return {"total": total}
|
return {"total": total}
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取统计失败: {str(e)}"
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import List
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.exceptions import BaseAPIException
|
||||||
from backend.models import get_db, User
|
from backend.models import get_db, User
|
||||||
from backend.schemas.task import TaskUpdate, TaskResponse
|
from backend.schemas.task import TaskUpdate, TaskResponse
|
||||||
from backend.services.task_service import TaskService
|
from backend.services.task_service import TaskService
|
||||||
@@ -37,6 +38,8 @@ async def get_tasks(
|
|||||||
# 为每个任务添加额外信息
|
# 为每个任务添加额外信息
|
||||||
enriched_tasks = [TaskService.enrich_task_with_check_in_info(task, db) for task in tasks]
|
enriched_tasks = [TaskService.enrich_task_with_check_in_info(task, db) for task in tasks]
|
||||||
return enriched_tasks
|
return enriched_tasks
|
||||||
|
except (BaseAPIException, HTTPException):
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}"
|
||||||
@@ -173,6 +176,8 @@ async def validate_cron_expression(request: CronValidateRequest):
|
|||||||
"next_times": next_times,
|
"next_times": next_times,
|
||||||
"description": generate_cron_description(cron_expr),
|
"description": generate_cron_description(cron_expr),
|
||||||
}
|
}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的 Crontab 表达式: {str(e)}"
|
status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的 Crontab 表达式: {str(e)}"
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from backend.models import User
|
from backend.models import User
|
||||||
from backend.dependencies import get_db, get_current_user, get_current_admin_user
|
from backend.dependencies import get_db, get_current_user, get_current_admin_user
|
||||||
|
from backend.exceptions import BaseAPIException
|
||||||
from backend.schemas.template import (
|
from backend.schemas.template import (
|
||||||
TemplateCreate,
|
TemplateCreate,
|
||||||
TemplateUpdate,
|
TemplateUpdate,
|
||||||
@@ -17,6 +18,9 @@ from backend.services.template_service import TemplateService
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=List[TemplateResponse], summary="获取所有模板列表")
|
@router.get("/", response_model=List[TemplateResponse], summary="获取所有模板列表")
|
||||||
async def get_all_templates(
|
async def get_all_templates(
|
||||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||||
@@ -35,6 +39,8 @@ async def get_all_templates(
|
|||||||
try:
|
try:
|
||||||
templates = TemplateService.get_all_templates(db, skip, limit, is_active)
|
templates = TemplateService.get_all_templates(db, skip, limit, is_active)
|
||||||
return templates
|
return templates
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}"
|
||||||
@@ -57,6 +63,8 @@ async def get_active_templates(
|
|||||||
try:
|
try:
|
||||||
templates = TemplateService.get_all_templates(db, skip, limit, is_active=True)
|
templates = TemplateService.get_all_templates(db, skip, limit, is_active=True)
|
||||||
return templates
|
return templates
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取模板列表失败: {str(e)}"
|
||||||
@@ -115,6 +123,8 @@ async def preview_template(
|
|||||||
"preview_payload": preview_payload,
|
"preview_payload": preview_payload,
|
||||||
"field_config": merged_config,
|
"field_config": merged_config,
|
||||||
}
|
}
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"生成预览失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"生成预览失败: {str(e)}"
|
||||||
|
|||||||
@@ -14,9 +14,15 @@ from backend.schemas.task import TaskResponse
|
|||||||
from backend.services.user_service import UserService
|
from backend.services.user_service import UserService
|
||||||
from backend.services.task_service import TaskService
|
from backend.services.task_service import TaskService
|
||||||
from backend.dependencies import get_current_user, get_current_admin_user
|
from backend.dependencies import get_current_user, get_current_admin_user
|
||||||
from backend.exceptions import ValidationError, AuthorizationError, ResourceNotFoundError
|
from backend.exceptions import (
|
||||||
|
AuthorizationError,
|
||||||
|
BaseAPIException,
|
||||||
|
ResourceNotFoundError,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
EXPECTED_API_EXCEPTIONS = (BaseAPIException, HTTPException)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -42,6 +48,8 @@ async def create_user(
|
|||||||
return user
|
return user
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValidationError(str(e))
|
raise ValidationError(str(e))
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建用户失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"创建用户失败: {str(e)}"
|
||||||
@@ -103,6 +111,8 @@ async def update_current_user_profile(
|
|||||||
return user
|
return user
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValidationError(str(e))
|
raise ValidationError(str(e))
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新个人信息失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新个人信息失败: {str(e)}"
|
||||||
@@ -144,6 +154,8 @@ async def get_current_user_tasks(
|
|||||||
try:
|
try:
|
||||||
tasks = TaskService.get_user_tasks(current_user.id, db, include_inactive)
|
tasks = TaskService.get_user_tasks(current_user.id, db, include_inactive)
|
||||||
return tasks
|
return tasks
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取任务列表失败: {str(e)}"
|
||||||
@@ -170,6 +182,8 @@ async def get_all_users(
|
|||||||
try:
|
try:
|
||||||
users = UserService.get_all_users(db, skip, limit, search, role)
|
users = UserService.get_all_users(db, skip, limit, search, role)
|
||||||
return users
|
return users
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取用户列表失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取用户列表失败: {str(e)}"
|
||||||
@@ -252,6 +266,8 @@ async def update_user(
|
|||||||
return user
|
return user
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新用户失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"更新用户失败: {str(e)}"
|
||||||
@@ -272,6 +288,8 @@ async def delete_user(
|
|||||||
return None
|
return None
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ResourceNotFoundError(str(e))
|
raise ResourceNotFoundError(str(e))
|
||||||
|
except EXPECTED_API_EXCEPTIONS:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"删除用户失败: {str(e)}"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"删除用户失败: {str(e)}"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text
|
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text, UniqueConstraint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
@@ -23,6 +23,12 @@ class CheckInTask(Base):
|
|||||||
index=True,
|
index=True,
|
||||||
comment="用户 ID",
|
comment="用户 ID",
|
||||||
)
|
)
|
||||||
|
thread_id: Mapped[str | None] = mapped_column(
|
||||||
|
String(100),
|
||||||
|
index=True,
|
||||||
|
nullable=True,
|
||||||
|
comment="接龙项目 ID",
|
||||||
|
)
|
||||||
payload_config: Mapped[str] = mapped_column(
|
payload_config: Mapped[str] = mapped_column(
|
||||||
Text,
|
Text,
|
||||||
default="{}",
|
default="{}",
|
||||||
@@ -57,6 +63,7 @@ class CheckInTask(Base):
|
|||||||
# 添加索引:加速查询
|
# 添加索引:加速查询
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_task_user_active", "user_id", "is_active"),
|
Index("ix_task_user_active", "user_id", "is_active"),
|
||||||
|
UniqueConstraint("user_id", "thread_id", name="uq_task_user_thread_id"),
|
||||||
Index("ix_task_cron", "cron_expression"), # 加速查询启用了定时打卡的任务
|
Index("ix_task_cron", "cron_expression"), # 加速查询启用了定时打卡的任务
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from sqlalchemy import create_engine, event
|
from sqlalchemy import DateTime, create_engine, event, inspect
|
||||||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||||
|
|
||||||
from backend.config import settings
|
from backend.config import settings
|
||||||
@@ -24,21 +24,23 @@ class Base(DeclarativeBase):
|
|||||||
@event.listens_for(Base, "load", propagate=True)
|
@event.listens_for(Base, "load", propagate=True)
|
||||||
def receive_load(target, context):
|
def receive_load(target, context):
|
||||||
"""在从数据库加载对象后,将所有 datetime 字段转换为 timezone-aware (UTC)"""
|
"""在从数据库加载对象后,将所有 datetime 字段转换为 timezone-aware (UTC)"""
|
||||||
for attr_name in dir(target):
|
mapper = inspect(target).mapper
|
||||||
# 跳过私有属性和方法
|
for attr in mapper.column_attrs:
|
||||||
if attr_name.startswith("_"):
|
column = attr.columns[0]
|
||||||
|
if not isinstance(column.type, DateTime):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
attr_value = getattr(target, attr_name)
|
attr_value = getattr(target, attr.key)
|
||||||
|
|
||||||
# 如果是 naive datetime,添加 UTC timezone
|
|
||||||
if isinstance(attr_value, datetime) and attr_value.tzinfo is None:
|
|
||||||
setattr(target, attr_name, attr_value.replace(tzinfo=timezone.utc))
|
|
||||||
except (AttributeError, TypeError):
|
except (AttributeError, TypeError):
|
||||||
# 某些属性可能无法访问或设置,跳过
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if isinstance(attr_value, datetime) and attr_value.tzinfo is None:
|
||||||
|
try:
|
||||||
|
setattr(target, attr.key, attr_value.replace(tzinfo=timezone.utc))
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
def get_db():
|
def get_db():
|
||||||
"""依赖注入:获取数据库会话"""
|
"""依赖注入:获取数据库会话"""
|
||||||
|
|||||||
@@ -0,0 +1,113 @@
|
|||||||
|
"""
|
||||||
|
数据库迁移脚本:添加打卡任务 thread_id 字段并回填。
|
||||||
|
|
||||||
|
运行方式:
|
||||||
|
uv run python -m backend.scripts.migrate_add_task_thread_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
APPS_DIR = Path(__file__).resolve().parents[2]
|
||||||
|
sys.path.insert(0, str(APPS_DIR))
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from backend.models.database import engine
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_thread_id(payload_config: str | None) -> str | None:
|
||||||
|
if not payload_config:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
payload = json.loads(payload_config)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return None
|
||||||
|
thread_id = payload.get("ThreadId")
|
||||||
|
value = str(thread_id).strip() if thread_id is not None else ""
|
||||||
|
return value or None
|
||||||
|
|
||||||
|
|
||||||
|
def migrate() -> None:
|
||||||
|
"""执行迁移。"""
|
||||||
|
logger.info("开始迁移:添加 check_in_tasks.thread_id 字段...")
|
||||||
|
|
||||||
|
with engine.connect() as conn:
|
||||||
|
result = conn.execute(text("PRAGMA table_info(check_in_tasks)"))
|
||||||
|
columns = [row[1] for row in result]
|
||||||
|
|
||||||
|
if "thread_id" not in columns:
|
||||||
|
logger.info("添加 thread_id 字段...")
|
||||||
|
conn.execute(text("ALTER TABLE check_in_tasks ADD COLUMN thread_id VARCHAR(100)"))
|
||||||
|
conn.commit()
|
||||||
|
logger.info("✓ thread_id 字段添加成功")
|
||||||
|
else:
|
||||||
|
logger.info("✓ thread_id 字段已存在,跳过")
|
||||||
|
|
||||||
|
rows = conn.execute(text("SELECT id, payload_config FROM check_in_tasks")).fetchall()
|
||||||
|
invalid_ids: list[int] = []
|
||||||
|
seen: dict[tuple[int, str], int] = {}
|
||||||
|
duplicate_ids: list[int] = []
|
||||||
|
|
||||||
|
full_rows = conn.execute(
|
||||||
|
text("SELECT id, user_id, payload_config FROM check_in_tasks")
|
||||||
|
).fetchall()
|
||||||
|
for row in full_rows:
|
||||||
|
thread_id = _extract_thread_id(row.payload_config)
|
||||||
|
if not thread_id:
|
||||||
|
invalid_ids.append(row.id)
|
||||||
|
continue
|
||||||
|
key = (row.user_id, thread_id)
|
||||||
|
if key in seen:
|
||||||
|
duplicate_ids.append(row.id)
|
||||||
|
else:
|
||||||
|
seen[key] = row.id
|
||||||
|
|
||||||
|
if invalid_ids or duplicate_ids:
|
||||||
|
messages = []
|
||||||
|
if invalid_ids:
|
||||||
|
messages.append(f"payload_config 缺少有效 ThreadId 的任务: {invalid_ids}")
|
||||||
|
if duplicate_ids:
|
||||||
|
messages.append(f"同用户 ThreadId 重复的任务: {duplicate_ids}")
|
||||||
|
raise RuntimeError(";".join(messages))
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
thread_id = _extract_thread_id(row.payload_config)
|
||||||
|
if thread_id:
|
||||||
|
conn.execute(
|
||||||
|
text("UPDATE check_in_tasks SET thread_id = :thread_id WHERE id = :id"),
|
||||||
|
{"thread_id": thread_id, "id": row.id},
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
indexes = conn.execute(text("PRAGMA index_list(check_in_tasks)")).fetchall()
|
||||||
|
index_names = [row[1] for row in indexes]
|
||||||
|
if "ix_task_user_thread_id_unique" not in index_names:
|
||||||
|
logger.info("添加用户级 thread_id 唯一索引...")
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"CREATE UNIQUE INDEX ix_task_user_thread_id_unique "
|
||||||
|
"ON check_in_tasks (user_id, thread_id)"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
logger.info("✓ 用户级 thread_id 唯一索引添加成功")
|
||||||
|
else:
|
||||||
|
logger.info("✓ 用户级 thread_id 唯一索引已存在,跳过")
|
||||||
|
|
||||||
|
logger.info("✅ 迁移完成!任务 thread_id 身份字段已启用")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
migrate()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 迁移失败: {e}")
|
||||||
|
sys.exit(1)
|
||||||
@@ -541,10 +541,10 @@ class CheckInService:
|
|||||||
)
|
)
|
||||||
task_name = task.name
|
task_name = task.name
|
||||||
|
|
||||||
# 从 payload_config 提取 ThreadId
|
# 优先从显式任务身份读取 ThreadId,兼容旧数据时回退到 payload_config
|
||||||
from backend.utils.json_helpers import extract_thread_id
|
from backend.utils.json_helpers import resolve_task_thread_id
|
||||||
|
|
||||||
thread_id = extract_thread_id(task.payload_config) # type: ignore
|
thread_id = resolve_task_thread_id(task)
|
||||||
|
|
||||||
# 转换为字典并添加额外字段
|
# 转换为字典并添加额外字段
|
||||||
record_dict = {
|
record_dict = {
|
||||||
|
|||||||
@@ -20,6 +20,70 @@ scheduler = None
|
|||||||
scheduler_lock = None
|
scheduler_lock = None
|
||||||
|
|
||||||
|
|
||||||
|
def _task_job_id(task_id: int) -> str:
|
||||||
|
return f"task_{task_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_task_job(job_id: str, scheduler_instance=None) -> bool:
|
||||||
|
active_scheduler = scheduler_instance or scheduler
|
||||||
|
if not active_scheduler:
|
||||||
|
return False
|
||||||
|
|
||||||
|
existing_job = active_scheduler.get_job(job_id)
|
||||||
|
if not existing_job:
|
||||||
|
return False
|
||||||
|
|
||||||
|
active_scheduler.remove_job(job_id)
|
||||||
|
logger.info(f"✅ 已移除调度任务: {job_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def remove_scheduled_task(task_id: int, scheduler_instance=None) -> bool:
|
||||||
|
"""从调度器移除指定任务。"""
|
||||||
|
active_scheduler = scheduler_instance or scheduler
|
||||||
|
if not active_scheduler:
|
||||||
|
logger.warning(f"调度器未启动,无法移除任务 {task_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return _remove_task_job(_task_job_id(task_id), active_scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
def sync_scheduled_task(task: CheckInTask, scheduler_instance=None) -> bool:
|
||||||
|
"""
|
||||||
|
根据任务当前状态同步调度器中的对应 job。
|
||||||
|
|
||||||
|
返回 True 表示任务已成功安排到调度器,False 表示未安排或调度器不可用。
|
||||||
|
"""
|
||||||
|
active_scheduler = scheduler_instance or scheduler
|
||||||
|
if not active_scheduler:
|
||||||
|
logger.warning(f"调度器未启动,无法同步任务 {task.id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
job_id = _task_job_id(task.id)
|
||||||
|
_remove_task_job(job_id, active_scheduler)
|
||||||
|
|
||||||
|
if not task.is_scheduled_enabled:
|
||||||
|
logger.info(f"任务 {task.id} 未启用或无 cron 表达式,已从调度器移除")
|
||||||
|
return False
|
||||||
|
|
||||||
|
cron_str = str(task.cron_expression or "").strip()
|
||||||
|
if not cron_str or not croniter.is_valid(cron_str):
|
||||||
|
logger.warning(f"任务 {task.id} 的 cron 表达式无效: {cron_str}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
active_scheduler.add_job(
|
||||||
|
func=scheduled_check_in_task,
|
||||||
|
trigger=CronTrigger.from_crontab(cron_str),
|
||||||
|
id=job_id,
|
||||||
|
name=f"CheckIn-Task-{task.id}",
|
||||||
|
args=[task.id],
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"✅ 任务 {task.id} 已同步到调度器: {cron_str}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def load_scheduled_tasks(db: Session, scheduler_instance):
|
def load_scheduled_tasks(db: Session, scheduler_instance):
|
||||||
"""
|
"""
|
||||||
从数据库加载所有启用的定时任务并添加到 APScheduler
|
从数据库加载所有启用的定时任务并添加到 APScheduler
|
||||||
@@ -55,33 +119,10 @@ def load_scheduled_tasks(db: Session, scheduler_instance):
|
|||||||
|
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
try:
|
try:
|
||||||
# 验证 cron 表达式
|
if sync_scheduled_task(task, scheduler_instance):
|
||||||
cron_str = str(task.cron_expression) if task.cron_expression else None
|
loaded_count += 1
|
||||||
if not cron_str or not croniter.is_valid(cron_str):
|
else:
|
||||||
logger.warning(f"跳过任务 {task.id}: 无效的 cron 表达式 '{task.cron_expression}'")
|
|
||||||
skipped_count += 1
|
skipped_count += 1
|
||||||
continue
|
|
||||||
|
|
||||||
# 创建任务 ID
|
|
||||||
job_id = f"task_{task.id}"
|
|
||||||
|
|
||||||
# 检查任务是否已存在
|
|
||||||
if scheduler_instance.get_job(job_id):
|
|
||||||
logger.debug(f"任务 {task.id} 已存在,跳过")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 添加任务到调度器
|
|
||||||
scheduler_instance.add_job(
|
|
||||||
func=scheduled_check_in_task,
|
|
||||||
trigger=CronTrigger.from_crontab(cron_str),
|
|
||||||
id=job_id,
|
|
||||||
name=f"CheckIn-Task-{task.id}",
|
|
||||||
args=[task.id],
|
|
||||||
replace_existing=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"✅ 加载任务 {task.id}: {task.name} (Cron: {task.cron_expression})")
|
|
||||||
loaded_count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ 加载任务 {task.id} 时出错: {str(e)}")
|
logger.error(f"❌ 加载任务 {task.id} 时出错: {str(e)}")
|
||||||
|
|||||||
@@ -3,6 +3,12 @@ from typing import List, Optional, Dict, Any
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import desc
|
from sqlalchemy import desc
|
||||||
|
|
||||||
|
from backend.exceptions import (
|
||||||
|
InternalServerError,
|
||||||
|
ResourceConflictError,
|
||||||
|
ResourceNotFoundError,
|
||||||
|
ValidationError as APIValidationError,
|
||||||
|
)
|
||||||
from backend.models import User, CheckInTask, CheckInRecord
|
from backend.models import User, CheckInTask, CheckInRecord
|
||||||
from backend.schemas.task import TaskCreate, TaskUpdate
|
from backend.schemas.task import TaskCreate, TaskUpdate
|
||||||
|
|
||||||
@@ -12,6 +18,68 @@ logger = logging.getLogger(__name__)
|
|||||||
class TaskService:
|
class TaskService:
|
||||||
"""打卡任务服务"""
|
"""打卡任务服务"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_thread_id(thread_id: Any) -> str:
|
||||||
|
value = str(thread_id).strip() if thread_id is not None else ""
|
||||||
|
if not value:
|
||||||
|
raise APIValidationError(
|
||||||
|
"payload_config 必须包含有效的 ThreadId 字段",
|
||||||
|
error_code="TASK_IDENTITY_INVALID",
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_thread_id_from_payload(payload_config: str) -> str:
|
||||||
|
from backend.utils.json_helpers import safe_parse_payload
|
||||||
|
|
||||||
|
payload = safe_parse_payload(payload_config)
|
||||||
|
return TaskService._normalize_thread_id(payload.get("ThreadId"))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_thread_id_for_task(task: CheckInTask) -> Optional[str]:
|
||||||
|
thread_id = getattr(task, "thread_id", None)
|
||||||
|
if thread_id is not None:
|
||||||
|
value = str(thread_id).strip()
|
||||||
|
if value:
|
||||||
|
return value
|
||||||
|
|
||||||
|
from backend.utils.json_helpers import extract_thread_id
|
||||||
|
|
||||||
|
legacy_thread_id = extract_thread_id(task.payload_config)
|
||||||
|
if legacy_thread_id is None:
|
||||||
|
return None
|
||||||
|
value = str(legacy_thread_id).strip()
|
||||||
|
return value or None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_unique_thread_id(
|
||||||
|
db: Session, user_id: int, thread_id: str, exclude_task_id: int | None = None
|
||||||
|
) -> None:
|
||||||
|
query = db.query(CheckInTask.id).filter(
|
||||||
|
CheckInTask.user_id == user_id, CheckInTask.thread_id == thread_id
|
||||||
|
)
|
||||||
|
if exclude_task_id is not None:
|
||||||
|
query = query.filter(CheckInTask.id != exclude_task_id)
|
||||||
|
|
||||||
|
conflict = query.first()
|
||||||
|
if conflict:
|
||||||
|
raise ResourceConflictError(
|
||||||
|
f"该接龙中已存在任务。ThreadId: {thread_id}",
|
||||||
|
error_code="TASK_IDENTITY_CONFLICT",
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sync_scheduler_for_task(task: CheckInTask) -> None:
|
||||||
|
from backend.services.scheduler_service import sync_scheduled_task
|
||||||
|
|
||||||
|
sync_scheduled_task(task)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _remove_scheduler_for_task(task_id: int) -> None:
|
||||||
|
from backend.services.scheduler_service import remove_scheduled_task
|
||||||
|
|
||||||
|
remove_scheduled_task(task_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_task(user_id: int, task_data: TaskCreate, db: Session) -> CheckInTask:
|
def create_task(user_id: int, task_data: TaskCreate, db: Session) -> CheckInTask:
|
||||||
"""
|
"""
|
||||||
@@ -25,34 +93,16 @@ class TaskService:
|
|||||||
Returns:
|
Returns:
|
||||||
创建的任务对象
|
创建的任务对象
|
||||||
"""
|
"""
|
||||||
import json
|
|
||||||
|
|
||||||
# 1. 检查用户是否存在
|
# 1. 检查用户是否存在
|
||||||
user = db.query(User).filter(User.id == user_id).first()
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
if not user:
|
if not user:
|
||||||
raise ValueError(f"用户 ID {user_id} 不存在")
|
raise ResourceNotFoundError(f"用户 ID {user_id} 不存在", error_code="USER_NOT_FOUND")
|
||||||
|
|
||||||
# 2. 从 payload_config 中提取 ThreadId 用于唯一性校验
|
# 2. 从 payload_config 中提取 ThreadId 用于唯一性校验
|
||||||
from backend.utils.json_helpers import safe_parse_payload, extract_thread_id
|
thread_id = TaskService._extract_thread_id_from_payload(task_data.payload_config)
|
||||||
|
|
||||||
payload = safe_parse_payload(task_data.payload_config)
|
|
||||||
thread_id = payload.get("ThreadId")
|
|
||||||
if not thread_id:
|
|
||||||
raise ValueError("payload_config 中缺少 ThreadId")
|
|
||||||
|
|
||||||
# 3. 验证唯一性:同一用户在同一个接龙中不能有重复的任务
|
# 3. 验证唯一性:同一用户在同一个接龙中不能有重复的任务
|
||||||
existing_tasks = (
|
TaskService._ensure_unique_thread_id(db, user_id, thread_id)
|
||||||
db.query(CheckInTask.payload_config).filter(CheckInTask.user_id == user_id).all()
|
|
||||||
)
|
|
||||||
|
|
||||||
for (payload_config,) in existing_tasks:
|
|
||||||
existing_thread_id = extract_thread_id(payload_config)
|
|
||||||
# extract_thread_id 已处理异常,失败时返回 None
|
|
||||||
if existing_thread_id and existing_thread_id == thread_id:
|
|
||||||
logger.warning(
|
|
||||||
f"⚠️ 任务创建冲突 - User: {user.alias}({user_id}), ThreadId: {thread_id}"
|
|
||||||
)
|
|
||||||
raise ValueError(f"该接龙中已存在任务。ThreadId: {thread_id}")
|
|
||||||
|
|
||||||
# 4. 记录日志
|
# 4. 记录日志
|
||||||
task_name = task_data.name or f"接龙任务 {thread_id}"
|
task_name = task_data.name or f"接龙任务 {thread_id}"
|
||||||
@@ -61,9 +111,11 @@ class TaskService:
|
|||||||
# 5. 创建任务
|
# 5. 创建任务
|
||||||
task = CheckInTask(
|
task = CheckInTask(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
thread_id=thread_id,
|
||||||
payload_config=task_data.payload_config,
|
payload_config=task_data.payload_config,
|
||||||
name=task_data.name or task_name,
|
name=task_data.name or task_name,
|
||||||
is_active=task_data.is_active if task_data.is_active is not None else True,
|
is_active=task_data.is_active if task_data.is_active is not None else True,
|
||||||
|
cron_expression=task_data.cron_expression or "0 20 * * *",
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -75,14 +127,19 @@ class TaskService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 如果任务启用且包含 cron_expression,立即添加到调度器
|
# 如果任务启用且包含 cron_expression,立即添加到调度器
|
||||||
if task.is_scheduled_enabled:
|
TaskService._sync_scheduler_for_task(task)
|
||||||
TaskService._reload_scheduler_for_task(task, db)
|
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
except APIValidationError:
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
except ResourceConflictError:
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.error(f"❌ 任务创建失败: {str(e)}")
|
logger.error(f"❌ 任务创建失败: {str(e)}")
|
||||||
raise ValueError(f"任务创建失败: {str(e)}")
|
raise InternalServerError(f"任务创建失败: {str(e)}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_task(task_id: int, db: Session) -> Optional[CheckInTask]:
|
def get_task(task_id: int, db: Session) -> Optional[CheckInTask]:
|
||||||
@@ -110,8 +167,6 @@ class TaskService:
|
|||||||
Returns:
|
Returns:
|
||||||
包含额外信息的任务字典
|
包含额外信息的任务字典
|
||||||
"""
|
"""
|
||||||
from backend.utils.json_helpers import extract_thread_id
|
|
||||||
|
|
||||||
# 获取最后一次打卡记录
|
# 获取最后一次打卡记录
|
||||||
last_record = (
|
last_record = (
|
||||||
db.query(CheckInRecord)
|
db.query(CheckInRecord)
|
||||||
@@ -120,8 +175,8 @@ class TaskService:
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 从 payload_config 提取 ThreadId
|
# 优先使用持久化的 ThreadId,兼容旧数据时回退到 payload_config
|
||||||
thread_id = extract_thread_id(task.payload_config) # type: ignore
|
thread_id = TaskService._resolve_thread_id_for_task(task)
|
||||||
|
|
||||||
# 转换为字典并添加额外字段
|
# 转换为字典并添加额外字段
|
||||||
task_dict = {
|
task_dict = {
|
||||||
@@ -192,7 +247,7 @@ class TaskService:
|
|||||||
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
|
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
|
||||||
|
|
||||||
if not task:
|
if not task:
|
||||||
return None
|
raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
|
||||||
|
|
||||||
# 更新字段
|
# 更新字段
|
||||||
update_data = task_data.model_dump(exclude_unset=True)
|
update_data = task_data.model_dump(exclude_unset=True)
|
||||||
@@ -201,17 +256,34 @@ class TaskService:
|
|||||||
cron_changed = "cron_expression" in update_data
|
cron_changed = "cron_expression" in update_data
|
||||||
active_changed = "is_active" in update_data
|
active_changed = "is_active" in update_data
|
||||||
|
|
||||||
|
if "payload_config" in update_data:
|
||||||
|
new_thread_id = TaskService._extract_thread_id_from_payload(
|
||||||
|
update_data["payload_config"]
|
||||||
|
)
|
||||||
|
TaskService._ensure_unique_thread_id(
|
||||||
|
db, task.user_id, new_thread_id, exclude_task_id=task.id
|
||||||
|
)
|
||||||
|
task.thread_id = new_thread_id
|
||||||
|
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(task, field, value)
|
setattr(task, field, value)
|
||||||
|
|
||||||
db.commit()
|
try:
|
||||||
db.refresh(task)
|
db.commit()
|
||||||
|
db.refresh(task)
|
||||||
|
except ResourceConflictError:
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"任务 {task_id} 更新失败: {str(e)}")
|
||||||
|
raise InternalServerError(f"任务更新失败: {str(e)}")
|
||||||
|
|
||||||
logger.info(f"任务 {task_id} 已更新")
|
logger.info(f"任务 {task_id} 已更新")
|
||||||
|
|
||||||
# 如果 cron_expression 或 is_active 发生变化,重新加载调度器
|
# 如果 cron_expression 或 is_active 发生变化,重新加载调度器
|
||||||
if cron_changed or active_changed:
|
if cron_changed or active_changed:
|
||||||
TaskService._reload_scheduler_for_task(task, db)
|
TaskService._sync_scheduler_for_task(task)
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@@ -230,15 +302,20 @@ class TaskService:
|
|||||||
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
|
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
|
||||||
|
|
||||||
if not task:
|
if not task:
|
||||||
return False
|
raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
|
||||||
|
|
||||||
db.delete(task)
|
try:
|
||||||
db.commit()
|
db.delete(task)
|
||||||
|
db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"任务 {task_id} 删除失败: {str(e)}")
|
||||||
|
raise InternalServerError(f"任务删除失败: {str(e)}")
|
||||||
|
|
||||||
logger.info(f"任务 {task_id} 已删除")
|
logger.info(f"任务 {task_id} 已删除")
|
||||||
|
|
||||||
# 从调度器中移除该任务
|
# 从调度器中移除该任务
|
||||||
TaskService._remove_task_from_scheduler(task_id)
|
TaskService._remove_scheduler_for_task(task_id)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -257,16 +334,21 @@ class TaskService:
|
|||||||
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
|
task = db.query(CheckInTask).filter(CheckInTask.id == task_id).first()
|
||||||
|
|
||||||
if not task:
|
if not task:
|
||||||
return None
|
raise ResourceNotFoundError("任务不存在", error_code="TASK_NOT_FOUND")
|
||||||
|
|
||||||
task.is_active = not task.is_active
|
task.is_active = not task.is_active
|
||||||
db.commit()
|
try:
|
||||||
db.refresh(task)
|
db.commit()
|
||||||
|
db.refresh(task)
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
logger.error(f"任务 {task_id} 状态切换失败: {str(e)}")
|
||||||
|
raise InternalServerError(f"任务状态切换失败: {str(e)}")
|
||||||
|
|
||||||
logger.info(f"任务 {task_id} 状态已切换为: {'启用' if task.is_active else '禁用'}")
|
logger.info(f"任务 {task_id} 状态已切换为: {'启用' if task.is_active else '禁用'}")
|
||||||
|
|
||||||
# 重新加载调度器
|
# 重新加载调度器
|
||||||
TaskService._reload_scheduler_for_task(task, db)
|
TaskService._sync_scheduler_for_task(task)
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
@@ -322,42 +404,7 @@ class TaskService:
|
|||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from backend.services.scheduler_service import scheduler
|
TaskService._sync_scheduler_for_task(task)
|
||||||
from apscheduler.triggers.cron import CronTrigger
|
|
||||||
from croniter import croniter
|
|
||||||
|
|
||||||
if not scheduler:
|
|
||||||
logger.warning(f"调度器未启动,无法加载任务 {task.id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
job_id = f"task_{task.id}"
|
|
||||||
|
|
||||||
# 先移除旧的任务(如果存在)
|
|
||||||
existing_job = scheduler.get_job(job_id)
|
|
||||||
if existing_job:
|
|
||||||
scheduler.remove_job(job_id)
|
|
||||||
logger.info(f"从调度器移除旧任务: {job_id}")
|
|
||||||
|
|
||||||
# 如果任务启用且有有效的 cron 表达式,添加新任务
|
|
||||||
if task.is_scheduled_enabled:
|
|
||||||
cron_str = str(task.cron_expression)
|
|
||||||
if croniter.is_valid(cron_str):
|
|
||||||
from backend.services.scheduler_service import scheduled_check_in_task
|
|
||||||
|
|
||||||
scheduler.add_job(
|
|
||||||
func=scheduled_check_in_task,
|
|
||||||
trigger=CronTrigger.from_crontab(cron_str),
|
|
||||||
id=job_id,
|
|
||||||
name=f"CheckIn-Task-{task.id}",
|
|
||||||
args=[task.id],
|
|
||||||
replace_existing=True,
|
|
||||||
)
|
|
||||||
logger.info(f"✅ 任务 {task.id} 已重新加载到调度器: {cron_str}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"任务 {task.id} 的 cron 表达式无效: {cron_str}")
|
|
||||||
else:
|
|
||||||
logger.info(f"任务 {task.id} 未启用或无 cron 表达式,已从调度器移除")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"重新加载任务 {task.id} 到调度器失败: {str(e)}")
|
logger.error(f"重新加载任务 {task.id} 到调度器失败: {str(e)}")
|
||||||
|
|
||||||
@@ -370,15 +417,6 @@ class TaskService:
|
|||||||
task_id: 任务 ID
|
task_id: 任务 ID
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from backend.services.scheduler_service import scheduler
|
TaskService._remove_scheduler_for_task(task_id)
|
||||||
|
|
||||||
if not scheduler:
|
|
||||||
return
|
|
||||||
|
|
||||||
job_id = f"task_{task_id}"
|
|
||||||
if scheduler.get_job(job_id):
|
|
||||||
scheduler.remove_job(job_id)
|
|
||||||
logger.info(f"✅ 任务 {task_id} 已从调度器移除")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从调度器移除任务 {task_id} 失败: {str(e)}")
|
logger.error(f"从调度器移除任务 {task_id} 失败: {str(e)}")
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ from typing import List, Dict, Any, Optional
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
from backend.models import TaskTemplate, CheckInTask
|
from backend.exceptions import BaseAPIException
|
||||||
|
from backend.models import CheckInTask, TaskTemplate
|
||||||
|
from backend.schemas.task import TaskCreate
|
||||||
from backend.schemas.template import TemplateCreate, TemplateUpdate
|
from backend.schemas.template import TemplateCreate, TemplateUpdate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -134,6 +136,9 @@ class TemplateService:
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
|
status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
|
||||||
)
|
)
|
||||||
|
except (BaseAPIException, HTTPException):
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建模板失败: {str(e)}")
|
logger.error(f"创建模板失败: {str(e)}")
|
||||||
db.rollback()
|
db.rollback()
|
||||||
@@ -219,6 +224,9 @@ class TemplateService:
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
|
status_code=status.HTTP_400_BAD_REQUEST, detail=f"字段配置 JSON 格式错误: {str(e)}"
|
||||||
)
|
)
|
||||||
|
except (BaseAPIException, HTTPException):
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新模板失败: {str(e)}")
|
logger.error(f"更新模板失败: {str(e)}")
|
||||||
db.rollback()
|
db.rollback()
|
||||||
@@ -247,6 +255,9 @@ class TemplateService:
|
|||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"删除模板成功: {template.name} (ID: {template_id})")
|
logger.info(f"删除模板成功: {template.name} (ID: {template_id})")
|
||||||
return True
|
return True
|
||||||
|
except (BaseAPIException, HTTPException):
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除模板失败: {str(e)}")
|
logger.error(f"删除模板失败: {str(e)}")
|
||||||
db.rollback()
|
db.rollback()
|
||||||
@@ -424,6 +435,10 @@ class TemplateService:
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"解析模板配置失败"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"解析模板配置失败"
|
||||||
)
|
)
|
||||||
|
except BaseAPIException:
|
||||||
|
raise
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"组装 payload 失败: {str(e)}")
|
logger.error(f"组装 payload 失败: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -518,31 +533,32 @@ class TemplateService:
|
|||||||
signature = payload.get("Signature", "Unknown")
|
signature = payload.get("Signature", "Unknown")
|
||||||
task_name = f"{template.name} - {signature}"
|
task_name = f"{template.name} - {signature}"
|
||||||
|
|
||||||
# 创建任务(包含 cron_expression)
|
|
||||||
try:
|
try:
|
||||||
task = CheckInTask(
|
from backend.services.task_service import TaskService
|
||||||
|
|
||||||
|
task = TaskService.create_task(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
payload_config=json.dumps(payload, ensure_ascii=False),
|
task_data=TaskCreate(
|
||||||
name=task_name,
|
payload_config=json.dumps(payload, ensure_ascii=False),
|
||||||
is_active=True,
|
name=task_name,
|
||||||
cron_expression=cron_expression or "0 20 * * *",
|
is_active=True,
|
||||||
|
cron_expression=cron_expression or "0 20 * * *",
|
||||||
|
),
|
||||||
|
db=db,
|
||||||
)
|
)
|
||||||
db.add(task)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(task)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"从模板创建任务成功: {task.name} (ID: {task.id}, 模板: {template.name}, ThreadId: {thread_id})"
|
f"从模板创建任务成功: {task.name} (ID: {task.id}, 模板: {template.name}, ThreadId: {thread_id})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 如果任务启用且包含 cron_expression,立即添加到调度器
|
|
||||||
if task.is_scheduled_enabled:
|
|
||||||
from backend.services.task_service import TaskService
|
|
||||||
|
|
||||||
TaskService._reload_scheduler_for_task(task, db)
|
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
except BaseAPIException:
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
except HTTPException:
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从模板创建任务失败: {str(e)}")
|
logger.error(f"从模板创建任务失败: {str(e)}")
|
||||||
db.rollback()
|
db.rollback()
|
||||||
|
|||||||
@@ -92,6 +92,29 @@ def build_task_info(task) -> Dict[str, str]:
|
|||||||
包含 thread_id 和 name 的字典
|
包含 thread_id 和 name 的字典
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"thread_id": extract_thread_id(getattr(task, "payload_config", None)) or "未知",
|
"thread_id": resolve_task_thread_id(task) or "未知",
|
||||||
"name": getattr(task, "name", None) or f"Task-{getattr(task, 'id', 'Unknown')}",
|
"name": getattr(task, "name", None) or f"Task-{getattr(task, 'id', 'Unknown')}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_task_thread_id(task) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
优先从显式字段读取任务 ThreadId,兼容旧数据时回退到 payload_config。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: CheckInTask 或相似对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ThreadId 或 None
|
||||||
|
"""
|
||||||
|
thread_id = getattr(task, "thread_id", None)
|
||||||
|
if thread_id is not None:
|
||||||
|
value = str(thread_id).strip()
|
||||||
|
if value:
|
||||||
|
return value
|
||||||
|
|
||||||
|
legacy_thread_id = extract_thread_id(getattr(task, "payload_config", None))
|
||||||
|
if legacy_thread_id is None:
|
||||||
|
return None
|
||||||
|
value = str(legacy_thread_id).strip()
|
||||||
|
return value or None
|
||||||
|
|||||||
@@ -0,0 +1,235 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from backend.exceptions import ResourceConflictError
|
||||||
|
from backend.models import Base, CheckInTask, User
|
||||||
|
from backend.schemas.task import TaskCreate, TaskUpdate
|
||||||
|
from backend.services import scheduler_service
|
||||||
|
from backend.services.task_service import TaskService
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def db_session():
|
||||||
|
engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False})
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
session = Session()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def add_user(db_session, alias: str) -> User:
|
||||||
|
user = User(alias=alias)
|
||||||
|
db_session.add(user)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def payload(thread_id: str, **extra: Any) -> str:
|
||||||
|
return json.dumps({"ThreadId": thread_id, **extra}, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_task_persists_thread_id_identity(db_session) -> None:
|
||||||
|
user = add_user(db_session, "alice")
|
||||||
|
|
||||||
|
task = TaskService.create_task(
|
||||||
|
user.id,
|
||||||
|
TaskCreate(
|
||||||
|
payload_config=payload("thread-1", Signature="sig"),
|
||||||
|
name="Morning",
|
||||||
|
cron_expression="0 8 * * *",
|
||||||
|
),
|
||||||
|
db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.thread_id == "thread-1"
|
||||||
|
assert task.payload_config == payload("thread-1", Signature="sig")
|
||||||
|
assert task.cron_expression == "0 8 * * *"
|
||||||
|
|
||||||
|
|
||||||
|
def test_same_user_duplicate_thread_id_is_structured_conflict(db_session) -> None:
|
||||||
|
user = add_user(db_session, "alice")
|
||||||
|
first = TaskCreate(payload_config=payload("thread-1"), name="First")
|
||||||
|
duplicate = TaskCreate(payload_config=payload("thread-1"), name="Duplicate")
|
||||||
|
|
||||||
|
TaskService.create_task(user.id, first, db_session)
|
||||||
|
|
||||||
|
with pytest.raises(ResourceConflictError) as exc_info:
|
||||||
|
TaskService.create_task(user.id, duplicate, db_session)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 409
|
||||||
|
assert exc_info.value.error_code == "TASK_IDENTITY_CONFLICT"
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_users_can_share_thread_id(db_session) -> None:
|
||||||
|
alice = add_user(db_session, "alice")
|
||||||
|
bob = add_user(db_session, "bob")
|
||||||
|
|
||||||
|
alice_task = TaskService.create_task(
|
||||||
|
alice.id, TaskCreate(payload_config=payload("shared-thread")), db_session
|
||||||
|
)
|
||||||
|
bob_task = TaskService.create_task(
|
||||||
|
bob.id, TaskCreate(payload_config=payload("shared-thread")), db_session
|
||||||
|
)
|
||||||
|
|
||||||
|
assert alice_task.thread_id == "shared-thread"
|
||||||
|
assert bob_task.thread_id == "shared-thread"
|
||||||
|
assert alice_task.user_id != bob_task.user_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_task_rejects_duplicate_thread_id(db_session) -> None:
|
||||||
|
user = add_user(db_session, "alice")
|
||||||
|
first = TaskService.create_task(
|
||||||
|
user.id, TaskCreate(payload_config=payload("thread-1")), db_session
|
||||||
|
)
|
||||||
|
second = TaskService.create_task(
|
||||||
|
user.id, TaskCreate(payload_config=payload("thread-2")), db_session
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ResourceConflictError):
|
||||||
|
TaskService.update_task(
|
||||||
|
second.id,
|
||||||
|
TaskUpdate(payload_config=payload("thread-1")),
|
||||||
|
db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.refresh(first)
|
||||||
|
db_session.refresh(second)
|
||||||
|
assert first.thread_id == "thread-1"
|
||||||
|
assert second.thread_id == "thread-2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_enrichment_uses_stored_thread_id(db_session) -> None:
|
||||||
|
user = add_user(db_session, "alice")
|
||||||
|
task = CheckInTask(
|
||||||
|
user_id=user.id,
|
||||||
|
thread_id="stored-thread",
|
||||||
|
payload_config=payload("payload-thread"),
|
||||||
|
name="Task",
|
||||||
|
)
|
||||||
|
db_session.add(task)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(task)
|
||||||
|
|
||||||
|
enriched = TaskService.enrich_task_with_check_in_info(task, db_session)
|
||||||
|
|
||||||
|
assert enriched["thread_id"] == "stored-thread"
|
||||||
|
|
||||||
|
|
||||||
|
def test_template_created_duplicate_task_preserves_structured_conflict(db_session) -> None:
|
||||||
|
from backend.models import TaskTemplate
|
||||||
|
from backend.services.template_service import TemplateService
|
||||||
|
|
||||||
|
user = add_user(db_session, "alice")
|
||||||
|
template = TaskTemplate(name="Daily", field_config="{}", is_active=True)
|
||||||
|
db_session.add(template)
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(template)
|
||||||
|
|
||||||
|
TemplateService.create_task_from_template(
|
||||||
|
template_id=template.id,
|
||||||
|
thread_id="thread-1",
|
||||||
|
field_values={},
|
||||||
|
user_id=user.id,
|
||||||
|
task_name="First",
|
||||||
|
db=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ResourceConflictError) as exc_info:
|
||||||
|
TemplateService.create_task_from_template(
|
||||||
|
template_id=template.id,
|
||||||
|
thread_id="thread-1",
|
||||||
|
field_values={},
|
||||||
|
user_id=user.id,
|
||||||
|
task_name="Duplicate",
|
||||||
|
db=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 409
|
||||||
|
assert exc_info.value.error_code == "TASK_IDENTITY_CONFLICT"
|
||||||
|
|
||||||
|
|
||||||
|
class FakeScheduler:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.jobs: dict[str, object] = {}
|
||||||
|
self.added: list[str] = []
|
||||||
|
self.removed: list[str] = []
|
||||||
|
|
||||||
|
def get_job(self, job_id: str):
|
||||||
|
return self.jobs.get(job_id)
|
||||||
|
|
||||||
|
def remove_job(self, job_id: str) -> None:
|
||||||
|
self.jobs.pop(job_id, None)
|
||||||
|
self.removed.append(job_id)
|
||||||
|
|
||||||
|
def add_job(self, **kwargs) -> None:
|
||||||
|
job_id = kwargs["id"]
|
||||||
|
self.jobs[job_id] = kwargs
|
||||||
|
self.added.append(job_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_sync_skips_invalid_cron_and_removes_existing_job(monkeypatch) -> None:
|
||||||
|
fake_scheduler = FakeScheduler()
|
||||||
|
fake_scheduler.jobs["task_12"] = object()
|
||||||
|
monkeypatch.setattr(scheduler_service, "scheduler", fake_scheduler)
|
||||||
|
task = CheckInTask(id=12, thread_id="thread-1", payload_config=payload("thread-1"))
|
||||||
|
task.is_active = True
|
||||||
|
task.cron_expression = "invalid cron"
|
||||||
|
|
||||||
|
scheduled = scheduler_service.sync_scheduled_task(task)
|
||||||
|
|
||||||
|
assert scheduled is False
|
||||||
|
assert fake_scheduler.added == []
|
||||||
|
assert fake_scheduler.removed == ["task_12"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_sync_is_noop_when_scheduler_unavailable(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(scheduler_service, "scheduler", None)
|
||||||
|
task = CheckInTask(id=12, thread_id="thread-1", payload_config=payload("thread-1"))
|
||||||
|
task.is_active = True
|
||||||
|
task.cron_expression = "0 8 * * *"
|
||||||
|
|
||||||
|
assert scheduler_service.sync_scheduled_task(task) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_datetime_normalization_does_not_access_unmapped_properties(
|
||||||
|
db_session, monkeypatch
|
||||||
|
) -> None:
|
||||||
|
user = add_user(db_session, "alice")
|
||||||
|
user_id = user.id
|
||||||
|
db_session.expunge_all()
|
||||||
|
|
||||||
|
def explode(self):
|
||||||
|
raise RuntimeError("unmapped property accessed")
|
||||||
|
|
||||||
|
monkeypatch.setattr(User, "explosive_property", property(explode), raising=False)
|
||||||
|
|
||||||
|
loaded = db_session.query(User).filter(User.id == user_id).one()
|
||||||
|
|
||||||
|
assert loaded.created_at.tzinfo is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_task_route_preserves_structured_service_errors(monkeypatch, db_session) -> None:
|
||||||
|
from backend.api.tasks import get_tasks
|
||||||
|
|
||||||
|
user = User(id=1, alias="alice")
|
||||||
|
|
||||||
|
def raise_conflict(*args, **kwargs):
|
||||||
|
raise ResourceConflictError("duplicate", error_code="TASK_IDENTITY_CONFLICT")
|
||||||
|
|
||||||
|
monkeypatch.setattr(TaskService, "get_user_tasks", raise_conflict)
|
||||||
|
|
||||||
|
with pytest.raises(ResourceConflictError):
|
||||||
|
await get_tasks(current_user=user, db=db_session)
|
||||||
Reference in New Issue
Block a user