mirror of
https://github.com/Cccc-owo/CheckInApp.git
synced 2026-06-17 05:56:29 +00:00
feat(backend): add automatic DB migrations
Add a lightweight migration runner with schema_migrations tracking, run pending migrations during backend startup before the scheduler, and keep a manual backend-migrate entrypoint. The change also moves the existing lockout and task-thread-ID schema steps into shared migration modules, updates docs, and archives the OpenSpec change.
This commit is contained in:
@@ -8,6 +8,7 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
from backend.config import settings
|
||||
from backend.migrations import run_pending_migrations
|
||||
from backend.models import init_db
|
||||
from backend.exceptions import BaseAPIException
|
||||
from backend.schemas.response import ErrorResponse, ErrorDetail
|
||||
@@ -37,6 +38,14 @@ async def lifespan(app: FastAPI):
|
||||
init_db()
|
||||
logger.info("数据库初始化完成")
|
||||
|
||||
logger.info("正在执行数据库迁移...")
|
||||
migration_result = run_pending_migrations()
|
||||
logger.info(
|
||||
"数据库迁移完成:applied=%s skipped=%s",
|
||||
len(migration_result.applied),
|
||||
len(migration_result.skipped),
|
||||
)
|
||||
|
||||
# 确保必要的目录存在
|
||||
settings.SESSION_DIR.mkdir(parents=True, exist_ok=True)
|
||||
(settings.BASE_DIR / "data").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Database migration step implementations."""
|
||||
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
def _table_columns(conn: Connection, table_name: str) -> set[str]:
|
||||
rows = conn.execute(text(f"PRAGMA table_info({table_name})")).fetchall()
|
||||
return {str(row[1]) for row in rows}
|
||||
|
||||
|
||||
def apply(conn: Connection) -> None:
|
||||
columns = _table_columns(conn, "users")
|
||||
|
||||
if "failed_login_attempts" not in columns:
|
||||
conn.execute(
|
||||
text("ALTER TABLE users ADD COLUMN failed_login_attempts INTEGER DEFAULT 0 NOT NULL")
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
columns = _table_columns(conn, "users")
|
||||
if "locked_until" not in columns:
|
||||
conn.execute(text("ALTER TABLE users ADD COLUMN locked_until DATETIME"))
|
||||
conn.commit()
|
||||
|
||||
columns = _table_columns(conn, "users")
|
||||
if "last_failed_login" not in columns:
|
||||
conn.execute(text("ALTER TABLE users ADD COLUMN last_failed_login DATETIME"))
|
||||
conn.commit()
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
|
||||
def _table_columns(conn: Connection, table_name: str) -> set[str]:
|
||||
rows = conn.execute(text(f"PRAGMA table_info({table_name})")).fetchall()
|
||||
return {str(row[1]) for row in rows}
|
||||
|
||||
|
||||
def _table_indexes(conn: Connection, table_name: str) -> set[str]:
|
||||
rows = conn.execute(text(f"PRAGMA index_list({table_name})")).fetchall()
|
||||
return {str(row[1]) for row in rows}
|
||||
|
||||
|
||||
def _has_thread_id_uniqueness(conn: Connection) -> bool:
|
||||
indexes = conn.execute(text("PRAGMA index_list(check_in_tasks)")).fetchall()
|
||||
for row in indexes:
|
||||
is_unique = bool(row[2])
|
||||
if not is_unique:
|
||||
continue
|
||||
index_name = str(row[1])
|
||||
columns = conn.execute(text(f"PRAGMA index_info({index_name})")).fetchall()
|
||||
column_names = [str(column[2]) for column in columns]
|
||||
if column_names == ["user_id", "thread_id"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _extract_thread_id(payload_config: str | None) -> str | None:
|
||||
if not payload_config:
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(payload_config)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
thread_id = payload.get("ThreadId")
|
||||
value = str(thread_id).strip() if thread_id is not None else ""
|
||||
return value or None
|
||||
|
||||
|
||||
def apply(conn: Connection) -> None:
|
||||
columns = _table_columns(conn, "check_in_tasks")
|
||||
|
||||
if "thread_id" not in columns:
|
||||
conn.execute(text("ALTER TABLE check_in_tasks ADD COLUMN thread_id VARCHAR(100)"))
|
||||
conn.commit()
|
||||
|
||||
full_rows = conn.execute(
|
||||
text("SELECT id, user_id, payload_config FROM check_in_tasks")
|
||||
).fetchall()
|
||||
invalid_ids: list[int] = []
|
||||
seen: dict[tuple[int, str], int] = {}
|
||||
duplicate_ids: list[int] = []
|
||||
|
||||
for row in full_rows:
|
||||
thread_id = _extract_thread_id(row.payload_config)
|
||||
if not thread_id:
|
||||
invalid_ids.append(row.id)
|
||||
continue
|
||||
key = (row.user_id, thread_id)
|
||||
if key in seen:
|
||||
duplicate_ids.append(row.id)
|
||||
else:
|
||||
seen[key] = row.id
|
||||
|
||||
if invalid_ids or duplicate_ids:
|
||||
messages = []
|
||||
if invalid_ids:
|
||||
messages.append(f"payload_config 缺少有效 ThreadId 的任务: {invalid_ids}")
|
||||
if duplicate_ids:
|
||||
messages.append(f"同用户 ThreadId 重复的任务: {duplicate_ids}")
|
||||
raise RuntimeError(";".join(messages))
|
||||
|
||||
rows = conn.execute(text("SELECT id, payload_config FROM check_in_tasks")).fetchall()
|
||||
for row in rows:
|
||||
thread_id = _extract_thread_id(row.payload_config)
|
||||
if thread_id:
|
||||
conn.execute(
|
||||
text("UPDATE check_in_tasks SET thread_id = :thread_id WHERE id = :id"),
|
||||
{"thread_id": thread_id, "id": row.id},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
indexes = _table_indexes(conn, "check_in_tasks")
|
||||
if "ix_task_user_thread_id_unique" not in indexes and not _has_thread_id_uniqueness(conn):
|
||||
conn.execute(
|
||||
text(
|
||||
"CREATE UNIQUE INDEX ix_task_user_thread_id_unique "
|
||||
"ON check_in_tasks (user_id, thread_id)"
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Engine, text
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
from backend.migration_steps.account_lockout import apply as apply_account_lockout
|
||||
from backend.migration_steps.task_thread_id import apply as apply_task_thread_id
|
||||
from backend.models.database import engine as default_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIGRATION_TABLE_NAME = "schema_migrations"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Migration:
|
||||
id: str
|
||||
description: str
|
||||
apply: Callable[[Connection], None]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MigrationRunResult:
|
||||
applied: tuple[str, ...]
|
||||
skipped: tuple[str, ...]
|
||||
|
||||
|
||||
class MigrationExecutionError(RuntimeError):
|
||||
def __init__(self, migration_id: str, original: Exception) -> None:
|
||||
self.migration_id = migration_id
|
||||
self.original = original
|
||||
super().__init__(f"Migration {migration_id} failed: {original}")
|
||||
|
||||
|
||||
def ensure_migration_table(conn: Connection) -> None:
|
||||
conn.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {MIGRATION_TABLE_NAME} (
|
||||
id VARCHAR(200) PRIMARY KEY,
|
||||
description VARCHAR(500) NOT NULL,
|
||||
applied_at DATETIME NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_applied_migration_ids(conn: Connection) -> set[str]:
|
||||
ensure_migration_table(conn)
|
||||
rows = conn.execute(text(f"SELECT id FROM {MIGRATION_TABLE_NAME}")).fetchall()
|
||||
return {str(row.id) for row in rows}
|
||||
|
||||
|
||||
def mark_migration_applied(conn: Connection, migration: Migration) -> None:
|
||||
conn.execute(
|
||||
text(
|
||||
f"""
|
||||
INSERT INTO {MIGRATION_TABLE_NAME} (id, description, applied_at)
|
||||
VALUES (:id, :description, :applied_at)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": migration.id,
|
||||
"description": migration.description,
|
||||
"applied_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
MIGRATIONS: tuple[Migration, ...] = (
|
||||
Migration(
|
||||
id="2026050401_add_account_lockout",
|
||||
description="Add account lockout columns to users.",
|
||||
apply=apply_account_lockout,
|
||||
),
|
||||
Migration(
|
||||
id="2026050402_add_task_thread_id",
|
||||
description="Add and backfill check-in task thread identity.",
|
||||
apply=apply_task_thread_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def run_pending_migrations(
|
||||
*,
|
||||
engine: Engine = default_engine,
|
||||
migrations: Sequence[Migration] = MIGRATIONS,
|
||||
) -> MigrationRunResult:
|
||||
applied: list[str] = []
|
||||
skipped: list[str] = []
|
||||
|
||||
with engine.connect() as conn:
|
||||
applied_ids = get_applied_migration_ids(conn)
|
||||
|
||||
for migration in migrations:
|
||||
if migration.id in applied_ids:
|
||||
logger.info("Skipping applied migration %s", migration.id)
|
||||
skipped.append(migration.id)
|
||||
continue
|
||||
|
||||
logger.info("Applying migration %s: %s", migration.id, migration.description)
|
||||
try:
|
||||
migration.apply(conn)
|
||||
mark_migration_applied(conn, migration)
|
||||
except Exception as exc:
|
||||
conn.rollback()
|
||||
logger.exception("Migration %s failed", migration.id)
|
||||
raise MigrationExecutionError(migration.id, exc) from exc
|
||||
|
||||
logger.info("Applied migration %s", migration.id)
|
||||
applied.append(migration.id)
|
||||
applied_ids.add(migration.id)
|
||||
|
||||
return MigrationRunResult(applied=tuple(applied), skipped=tuple(skipped))
|
||||
@@ -1,69 +1,26 @@
|
||||
"""
|
||||
数据库迁移脚本:添加账户锁定相关字段
|
||||
数据库迁移脚本:添加账户锁定相关字段。
|
||||
|
||||
添加字段:
|
||||
- failed_login_attempts: 连续登录失败次数
|
||||
- locked_until: 账户锁定到期时间
|
||||
- last_failed_login: 最后一次登录失败时间
|
||||
|
||||
运行方式:
|
||||
通常无需手动运行,后端启动时会自动执行待迁移项。需要单独执行时:
|
||||
uv run python -m backend.scripts.migrate_add_account_lockout
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from __future__ import annotations
|
||||
|
||||
APPS_DIR = Path(__file__).resolve().parents[2]
|
||||
sys.path.insert(0, str(APPS_DIR))
|
||||
|
||||
from sqlalchemy import text
|
||||
from backend.models.database import engine
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from backend.migration_steps.account_lockout import apply as apply_account_lockout
|
||||
from backend.models.database import engine
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate():
|
||||
"""执行迁移"""
|
||||
def migrate() -> None:
|
||||
logger.info("开始迁移:添加账户锁定相关字段...")
|
||||
|
||||
with engine.connect() as conn:
|
||||
# 检查字段是否已存在
|
||||
result = conn.execute(text("PRAGMA table_info(users)"))
|
||||
columns = [row[1] for row in result]
|
||||
|
||||
# 添加 failed_login_attempts 字段
|
||||
if "failed_login_attempts" not in columns:
|
||||
logger.info("添加 failed_login_attempts 字段...")
|
||||
conn.execute(
|
||||
text(
|
||||
"ALTER TABLE users ADD COLUMN failed_login_attempts INTEGER DEFAULT 0 NOT NULL"
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
logger.info("✓ failed_login_attempts 字段添加成功")
|
||||
else:
|
||||
logger.info("✓ failed_login_attempts 字段已存在,跳过")
|
||||
|
||||
# 添加 locked_until 字段
|
||||
if "locked_until" not in columns:
|
||||
logger.info("添加 locked_until 字段...")
|
||||
conn.execute(text("ALTER TABLE users ADD COLUMN locked_until DATETIME"))
|
||||
conn.commit()
|
||||
logger.info("✓ locked_until 字段添加成功")
|
||||
else:
|
||||
logger.info("✓ locked_until 字段已存在,跳过")
|
||||
|
||||
# 添加 last_failed_login 字段
|
||||
if "last_failed_login" not in columns:
|
||||
logger.info("添加 last_failed_login 字段...")
|
||||
conn.execute(text("ALTER TABLE users ADD COLUMN last_failed_login DATETIME"))
|
||||
conn.commit()
|
||||
logger.info("✓ last_failed_login 字段添加成功")
|
||||
else:
|
||||
logger.info("✓ last_failed_login 字段已存在,跳过")
|
||||
|
||||
apply_account_lockout(conn)
|
||||
logger.info("✅ 迁移完成!账户锁定功能已启用")
|
||||
|
||||
|
||||
@@ -71,5 +28,5 @@ if __name__ == "__main__":
|
||||
try:
|
||||
migrate()
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 迁移失败: {e}")
|
||||
logger.error("❌ 迁移失败: %s", e)
|
||||
sys.exit(1)
|
||||
|
||||
@@ -1,107 +1,26 @@
|
||||
"""
|
||||
数据库迁移脚本:添加打卡任务 thread_id 字段并回填。
|
||||
|
||||
运行方式:
|
||||
通常无需手动运行,后端启动时会自动执行待迁移项。需要单独执行时:
|
||||
uv run python -m backend.scripts.migrate_add_task_thread_id
|
||||
"""
|
||||
|
||||
import json
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
APPS_DIR = Path(__file__).resolve().parents[2]
|
||||
sys.path.insert(0, str(APPS_DIR))
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from backend.migration_steps.task_thread_id import apply as apply_task_thread_id
|
||||
from backend.models.database import engine
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_thread_id(payload_config: str | None) -> str | None:
|
||||
if not payload_config:
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(payload_config)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
thread_id = payload.get("ThreadId")
|
||||
value = str(thread_id).strip() if thread_id is not None else ""
|
||||
return value or None
|
||||
|
||||
|
||||
def migrate() -> None:
|
||||
"""执行迁移。"""
|
||||
logger.info("开始迁移:添加 check_in_tasks.thread_id 字段...")
|
||||
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text("PRAGMA table_info(check_in_tasks)"))
|
||||
columns = [row[1] for row in result]
|
||||
|
||||
if "thread_id" not in columns:
|
||||
logger.info("添加 thread_id 字段...")
|
||||
conn.execute(text("ALTER TABLE check_in_tasks ADD COLUMN thread_id VARCHAR(100)"))
|
||||
conn.commit()
|
||||
logger.info("✓ thread_id 字段添加成功")
|
||||
else:
|
||||
logger.info("✓ thread_id 字段已存在,跳过")
|
||||
|
||||
rows = conn.execute(text("SELECT id, payload_config FROM check_in_tasks")).fetchall()
|
||||
invalid_ids: list[int] = []
|
||||
seen: dict[tuple[int, str], int] = {}
|
||||
duplicate_ids: list[int] = []
|
||||
|
||||
full_rows = conn.execute(
|
||||
text("SELECT id, user_id, payload_config FROM check_in_tasks")
|
||||
).fetchall()
|
||||
for row in full_rows:
|
||||
thread_id = _extract_thread_id(row.payload_config)
|
||||
if not thread_id:
|
||||
invalid_ids.append(row.id)
|
||||
continue
|
||||
key = (row.user_id, thread_id)
|
||||
if key in seen:
|
||||
duplicate_ids.append(row.id)
|
||||
else:
|
||||
seen[key] = row.id
|
||||
|
||||
if invalid_ids or duplicate_ids:
|
||||
messages = []
|
||||
if invalid_ids:
|
||||
messages.append(f"payload_config 缺少有效 ThreadId 的任务: {invalid_ids}")
|
||||
if duplicate_ids:
|
||||
messages.append(f"同用户 ThreadId 重复的任务: {duplicate_ids}")
|
||||
raise RuntimeError(";".join(messages))
|
||||
|
||||
for row in rows:
|
||||
thread_id = _extract_thread_id(row.payload_config)
|
||||
if thread_id:
|
||||
conn.execute(
|
||||
text("UPDATE check_in_tasks SET thread_id = :thread_id WHERE id = :id"),
|
||||
{"thread_id": thread_id, "id": row.id},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
indexes = conn.execute(text("PRAGMA index_list(check_in_tasks)")).fetchall()
|
||||
index_names = [row[1] for row in indexes]
|
||||
if "ix_task_user_thread_id_unique" not in index_names:
|
||||
logger.info("添加用户级 thread_id 唯一索引...")
|
||||
conn.execute(
|
||||
text(
|
||||
"CREATE UNIQUE INDEX ix_task_user_thread_id_unique "
|
||||
"ON check_in_tasks (user_id, thread_id)"
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
logger.info("✓ 用户级 thread_id 唯一索引添加成功")
|
||||
else:
|
||||
logger.info("✓ 用户级 thread_id 唯一索引已存在,跳过")
|
||||
|
||||
apply_task_thread_id(conn)
|
||||
logger.info("✅ 迁移完成!任务 thread_id 身份字段已启用")
|
||||
|
||||
|
||||
@@ -109,5 +28,5 @@ if __name__ == "__main__":
|
||||
try:
|
||||
migrate()
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 迁移失败: {e}")
|
||||
logger.error("❌ 迁移失败: %s", e)
|
||||
sys.exit(1)
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
运行数据库迁移的脚本。
|
||||
|
||||
使用方法:
|
||||
uv run python -m backend.scripts.run_migrations
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from backend.migrations import MigrationExecutionError, run_pending_migrations
|
||||
from backend.models import init_db
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
try:
|
||||
init_db()
|
||||
result = run_pending_migrations()
|
||||
except MigrationExecutionError as exc:
|
||||
logger.error("❌ 迁移失败: %s", exc)
|
||||
return 1
|
||||
|
||||
logger.info("✅ 迁移完成:applied=%s skipped=%s", len(result.applied), len(result.skipped))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user