diff --git a/.github/workflows/api-docker-image.yml b/.github/workflows/api-docker-image.yml new file mode 100644 index 0000000..685084a --- /dev/null +++ b/.github/workflows/api-docker-image.yml @@ -0,0 +1,66 @@ +name: API Docker Image + +on: + pull_request: + branches: [main, develop] + paths: + - "apps/api/**" + - ".github/workflows/api-docker-image.yml" + push: + branches: [main] + paths: + - "apps/api/**" + - ".github/workflows/api-docker-image.yml" + workflow_dispatch: + +concurrency: + group: api-docker-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-and-publish: + name: Build API Docker Image + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Check Dockerfile + id: dockerfile + run: | + if [ -f apps/api/Dockerfile ]; then + echo "exists=true" >> "$GITHUB_OUTPUT" + else + echo "exists=false" >> "$GITHUB_OUTPUT" + fi + + - name: Setup Docker Buildx + if: steps.dockerfile.outputs.exists == 'true' + uses: docker/setup-buildx-action@v3 + + - name: Login to GitHub Container Registry + if: steps.dockerfile.outputs.exists == 'true' && github.event_name == 'push' && github.ref == 'refs/heads/main' + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build (PR/manual) or Build and Push (main) + if: steps.dockerfile.outputs.exists == 'true' + uses: docker/build-push-action@v6 + with: + context: ./apps/api + file: ./apps/api/Dockerfile + push: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} + tags: | + ghcr.io/${{ github.repository }}/api:${{ github.sha }} + ghcr.io/${{ github.repository }}/api:latest + + - name: Skip notice + if: steps.dockerfile.outputs.exists != 'true' + run: echo "apps/api/Dockerfile not found, skip docker build." diff --git a/.github/workflows/deploy-admin.yml b/.github/workflows/deploy-admin.yml new file mode 100644 index 0000000..4c1f691 --- /dev/null +++ b/.github/workflows/deploy-admin.yml @@ -0,0 +1,59 @@ +name: Deploy Admin + +on: + push: + branches: [main] + paths: + - "apps/admin/**" + - "packages/shared-types/**" + - "packages/ui/**" + - ".github/workflows/deploy-admin.yml" + workflow_dispatch: + +concurrency: + group: deploy-admin-${{ github.ref }} + cancel-in-progress: true + +jobs: + build: + name: Build Admin + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 9.15.2 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: 20 + cache: pnpm + + - name: Install dependencies + run: pnpm install --frozen-lockfile + + - name: Build workspace + run: pnpm run build + + deploy: + name: Deploy Admin (Template) + runs-on: ubuntu-latest + needs: build + + steps: + - name: Trigger deployment webhook + env: + ADMIN_DEPLOY_WEBHOOK_URL: ${{ secrets.ADMIN_DEPLOY_WEBHOOK_URL }} + run: | + if [ -z "$ADMIN_DEPLOY_WEBHOOK_URL" ]; then + echo "ADMIN_DEPLOY_WEBHOOK_URL is not configured. Skipping deploy." + exit 0 + fi + + curl -X POST "$ADMIN_DEPLOY_WEBHOOK_URL" + echo "Admin deployment webhook triggered." diff --git a/.github/workflows/deploy-web.yml b/.github/workflows/deploy-web.yml new file mode 100644 index 0000000..a55a8be --- /dev/null +++ b/.github/workflows/deploy-web.yml @@ -0,0 +1,59 @@ +name: Deploy Web + +on: + push: + branches: [main] + paths: + - "apps/web/**" + - "packages/shared-types/**" + - "packages/ui/**" + - ".github/workflows/deploy-web.yml" + workflow_dispatch: + +concurrency: + group: deploy-web-${{ github.ref }} + cancel-in-progress: true + +jobs: + build: + name: Build Web + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 9.15.2 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: 20 + cache: pnpm + + - name: Install dependencies + run: pnpm install --frozen-lockfile + + - name: Build workspace + run: pnpm run build + + deploy: + name: Deploy Web (Template) + runs-on: ubuntu-latest + needs: build + + steps: + - name: Trigger deployment webhook + env: + WEB_DEPLOY_WEBHOOK_URL: ${{ secrets.WEB_DEPLOY_WEBHOOK_URL }} + run: | + if [ -z "$WEB_DEPLOY_WEBHOOK_URL" ]; then + echo "WEB_DEPLOY_WEBHOOK_URL is not configured. Skipping deploy." + exit 0 + fi + + curl -X POST "$WEB_DEPLOY_WEBHOOK_URL" + echo "Web deployment webhook triggered." diff --git a/.github/workflows/pr-quality.yml b/.github/workflows/pr-quality.yml new file mode 100644 index 0000000..2dbdb43 --- /dev/null +++ b/.github/workflows/pr-quality.yml @@ -0,0 +1,46 @@ +name: PR Quality + +on: + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + branches: [main, develop] + +concurrency: + group: pr-quality-${{ github.ref }} + cancel-in-progress: true + +jobs: + quality: + name: Lint, Typecheck, Test, Build + runs-on: ubuntu-latest + if: github.event.pull_request.draft == false + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup pnpm + uses: pnpm/action-setup@v4 + with: + version: 9.15.2 + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: 20 + cache: pnpm + + - name: Install dependencies + run: pnpm install --frozen-lockfile + + - name: Lint + run: pnpm run lint + + - name: Typecheck + run: pnpm run typecheck + + - name: Test + run: pnpm run test + + - name: Build + run: pnpm run build diff --git a/.gitignore b/.gitignore index 5fcdc0a..ce34796 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,6 @@ develop.md -.idea/ \ No newline at end of file +node_modules/ +.turbo/ +.idea/ +.eslintcache +/.husky/_ diff --git a/.husky/pre-commit b/.husky/pre-commit new file mode 100755 index 0000000..f9241ee --- /dev/null +++ b/.husky/pre-commit @@ -0,0 +1 @@ +pnpm lint:staged diff --git a/.husky/pre-push b/.husky/pre-push new file mode 100755 index 0000000..4226722 --- /dev/null +++ b/.husky/pre-push @@ -0,0 +1,2 @@ +pnpm typecheck +pnpm test diff --git a/.lintstagedrc.cjs b/.lintstagedrc.cjs new file mode 100644 index 0000000..b7ddb53 --- /dev/null +++ b/.lintstagedrc.cjs @@ -0,0 +1,4 @@ +module.exports = { + "*.{js,mjs,cjs,ts,tsx}": ["eslint --fix", "prettier --write"], + "*.{json,md,yml,yaml}": ["prettier --write"] +}; diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 0000000..1613f18 --- /dev/null +++ b/.prettierignore @@ -0,0 +1,7 @@ +node_modules +.turbo +.idea +dist +build +coverage +*.png diff --git a/.prettierrc.json b/.prettierrc.json new file mode 100644 index 0000000..8a0f27e --- /dev/null +++ b/.prettierrc.json @@ -0,0 +1,6 @@ +{ + "semi": true, + "singleQuote": false, + "trailingComma": "none", + "printWidth": 100 +} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..728d44f --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,69 @@ +# 贡献指南(Contributing) + +本文档定义 TodoList 仓库的协作规范,所有贡献者提交代码前请先阅读。 + +## 1. 分支模型 + +- 长期分支: + - `main`:生产稳定分支 + - `develop`:开发集成分支 +- 功能分支: + - 命名:`feature/-` + - 示例:`feature/p1-code-quality-hooks` +- 其他分支: + - `release/` + - `hotfix/-` + +## 2. 提交流程 + +1. 从目标基线分支切出功能分支。 +2. 每完成一个小功能,提交一个最小 commit。 +3. 完成后推送分支并创建 PR。 +4. 通过 Code Review 后再合并到目标分支。 + +## 3. Commit 规范 + +- 使用 Conventional Commits: + - `feat(scope): ...` + - `fix(scope): ...` + - `chore(scope): ...` + - `docs(scope): ...` + - `test(scope): ...` + - `ci(scope): ...` +- 要求: + - commit 粒度最小化,不要把多个不相关改动塞进一个提交。 + - commit 必须可回滚、可解释。 + - 默认使用 GPG 签名提交:`git commit -S`。 + +## 4. PR 规范 + +- PR 标题简明描述变更目标。 +- PR 描述至少包含: + - 变更概述 + - 具体改动 + - 测试结果 + - 风险评估 + - 回滚方案 +- 一个 PR 只解决一类问题,避免“超大 PR”。 + +## 5. 代码质量检查 + +提交前建议至少执行: + +```bash +pnpm install +pnpm run lint +pnpm run typecheck +pnpm run test +``` + +说明: + +- `pre-commit` 会自动执行 `lint-staged`。 +- `pre-push` 会自动执行 `typecheck + test`。 + +## 6. 变更边界要求 + +- 不要提交无关文件(例如本地 IDE 缓存、临时导出文件)。 +- 不要随意修改与当前任务无关的历史代码。 +- 如发现仓库出现非本人预期改动,先暂停并和维护者确认。 diff --git a/README.md b/README.md index 29367df..98c8577 100644 --- a/README.md +++ b/README.md @@ -71,23 +71,23 @@ > 状态说明:`[x]` 已完成,`[ ]` 进行中/未开始(请随开发进度更新) -| 顺序 | 功能实现项(用户视角) | 你会看到的效果 | 状态 | -|---|---|---|---| -| 1 | 明确产品能力与交互流程 | 确认 TodoList 的核心使用方式与页面路径 | [x] | -| 2 | 实现基础登录(邮箱验证码) | 可以注册/登录并进入主页面 | [ ] | -| 3 | 实现任务基础能力(增删改查) | 可以创建、编辑、删除、完成任务 | [ ] | -| 4 | 实现富文本与媒体内容 | 任务详情可插入图片、视频、链接等内容 | [ ] | -| 5 | 实现本地离线存储(Dexie) | 无网时仍可打开并编辑任务 | [ ] | -| 6 | 实现云端同步与冲突处理 | 恢复网络后自动同步,冲突按规则合并 | [ ] | -| 7 | 实现提醒系统(邮件) | DDL 临近时收到邮件提醒 | [ ] | -| 8 | 实现 AI 问答(用户自带 Key) | 可直接用自己的 AI API Key 获取建议 | [ ] | -| 9 | 实现 Astrbot Provider 接入 | 可复用 Astrbot 内配置的 AI 提供商 | [ ] | -| 10 | 实现公共 AI 通道(可开关) | 管理员开启后,用户可直接使用站点公共 AI | [ ] | -| 11 | 实现 Astrbot Skill 对接 | 可通过 QQ 机器人添加/修改任务与获取建议 | [ ] | -| 12 | 实现完整账号安全(2FA + OAuth) | 支持 2FA、QQ/微信/GitHub 登录 | [ ] | -| 13 | 实现 PWA 安装与离线体验优化 | 支持“添加到桌面”,像本地 App 一样使用 | [ ] | -| 14 | 实现管理后台(配额/日志/系统配置) | 管理员可管理用户配额、站点信息、日志 | [ ] | -| 15 | 上线前安全与性能收尾 | 使用更稳定、更安全,核心链路可观测 | [ ] | +| 顺序 | 功能实现项(用户视角) | 你会看到的效果 | 状态 | +| ---- | ---------------------------------- | --------------------------------------- | ---- | +| 1 | 明确产品能力与交互流程 | 确认 TodoList 的核心使用方式与页面路径 | [x] | +| 2 | 实现基础登录(邮箱验证码) | 可以注册/登录并进入主页面 | [x] | +| 3 | 实现任务基础能力(增删改查) | 可以创建、编辑、删除、完成任务 | [x] | +| 4 | 实现富文本与媒体内容 | 任务详情可插入图片、视频、链接等内容 | [x] | +| 5 | 实现本地离线存储(Dexie) | 无网时仍可打开并编辑任务 | [ ] | +| 6 | 实现云端同步与冲突处理 | 恢复网络后自动同步,冲突按规则合并 | [ ] | +| 7 | 实现提醒系统(邮件) | DDL 临近时收到邮件提醒 | [ ] | +| 8 | 实现 AI 问答(用户自带 Key) | 可直接用自己的 AI API Key 获取建议 | [ ] | +| 9 | 实现 Astrbot Provider 接入 | 可复用 Astrbot 内配置的 AI 提供商 | [ ] | +| 10 | 实现公共 AI 通道(可开关) | 管理员开启后,用户可直接使用站点公共 AI | [ ] | +| 11 | 实现 Astrbot Skill 对接 | 可通过 QQ 机器人添加/修改任务与获取建议 | [ ] | +| 12 | 实现完整账号安全(2FA + OAuth) | 支持 2FA、QQ/微信/GitHub 登录 | [ ] | +| 13 | 实现 PWA 安装与离线体验优化 | 支持“添加到桌面”,像本地 App 一样使用 | [ ] | +| 14 | 实现管理后台(配额/日志/系统配置) | 管理员可管理用户配额、站点信息、日志 | [ ] | +| 15 | 上线前安全与性能收尾 | 使用更稳定、更安全,核心链路可观测 | [ ] | --- @@ -151,6 +151,97 @@ TodoList/ --- +## 部署与使用 + +### 1. 环境要求 + +- Node.js `20.x` +- pnpm `9.15.2` +- PostgreSQL `14+`(本地或远程都可) +- 可选:MinIO / S3(附件上传功能使用) + +### 2. 安装依赖 + +```bash +pnpm install +``` + +### 3. 后端环境变量配置 + +1. 复制环境变量示例文件: + +```bash +cp apps/api/.env.example apps/api/.env +# PowerShell: +# Copy-Item apps/api/.env.example apps/api/.env +``` + +2. 至少修改以下配置: + +- `DATABASE_URL`:你的 PostgreSQL 连接串 +- `AUTH_ACCESS_SECRET`:生产环境请改为高强度随机值 +- `MAIL_SMTP_*`:邮件服务器配置(验证码/提醒邮件) +- `OAUTH_*`:第三方登录配置(未接入可先保留示例值) +- `S3_*`:对象存储配置(未启用附件可后续再配) + +### 4. 初始化数据库 + +```bash +pnpm --filter @todolist/api exec prisma db push +``` + +### 5. 本地开发启动 + +1. 启动后端(默认端口 `3000`): + +```bash +pnpm --filter @todolist/api start:dev +``` + +2. 启动前端(默认端口 `5173`): + +```bash +pnpm --filter web dev +``` + +3. 若前端需连接非默认后端地址,可设置: + +```bash +VITE_API_BASE_URL=http://localhost:3000 +``` + +### 6. 生产构建与运行 + +1. 构建: + +```bash +pnpm run build +``` + +2. 运行 API(需先构建): + +```bash +pnpm --filter @todolist/api start +``` + +3. 发布 Web: + +- `apps/web/dist` 为静态资源产物,建议使用 Nginx/静态托管服务发布。 + +### 7. CI/CD 说明(当前仓库) + +- PR 质量检查:`.github/workflows/pr-quality.yml` +- Web 部署模板:`.github/workflows/deploy-web.yml` +- Admin 部署模板:`.github/workflows/deploy-admin.yml` +- API 镜像构建:`.github/workflows/api-docker-image.yml` + +说明: + +- Web/Admin 工作流通过 Webhook 触发真实部署,需在仓库 Secrets 配置: + - `WEB_DEPLOY_WEBHOOK_URL` + - `ADMIN_DEPLOY_WEBHOOK_URL` +- API 镜像工作流仅在存在 `apps/api/Dockerfile` 时执行镜像构建与推送。 + ## License 本项目遵循 [GNUv3](./LICENSE)。 diff --git a/apps/admin/.gitkeep b/apps/admin/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/apps/api/.env.example b/apps/api/.env.example new file mode 100644 index 0000000..c8045bf --- /dev/null +++ b/apps/api/.env.example @@ -0,0 +1,73 @@ +# ----------------------------------------------------------------------------- +# TodoList API 环境变量示例 +# 用法: +# 1) 复制为 apps/api/.env +# 2) 按实际环境替换值(尤其是密钥、密码、令牌) +# ----------------------------------------------------------------------------- + +# [数据库] PostgreSQL 连接串 +# 格式:postgresql://:@:/?schema=public +DATABASE_URL="postgresql://postgres:postgres@localhost:5432/todolist?schema=public" + +# [鉴权] Access Token 签名密钥(生产环境必须使用高强度随机值) +AUTH_ACCESS_SECRET="dev-access-secret" +# [鉴权] Access Token 有效期(秒),默认 15 分钟 +AUTH_ACCESS_EXPIRES_IN_SECONDS="900" +# [鉴权] Refresh Token 有效期(秒),默认 30 天 +AUTH_REFRESH_EXPIRES_IN_SECONDS="2592000" +# [鉴权] 邮箱验证码有效期(秒),默认 5 分钟 +AUTH_EMAIL_CODE_TTL_SECONDS="300" +# [2FA] TOTP 签发方名称(会显示在验证器 App 中) +AUTH_TOTP_ISSUER="TodoList" + +# [OAuth - GitHub] 第三方登录配置 +OAUTH_GITHUB_CLIENT_ID="github-client-id" +OAUTH_GITHUB_CLIENT_SECRET="github-client-secret" +OAUTH_GITHUB_CALLBACK_URL="http://localhost:3000/auth/oauth/github/callback" + +# [OAuth - QQ] 第三方登录配置 +OAUTH_QQ_CLIENT_ID="qq-client-id" +OAUTH_QQ_CLIENT_SECRET="qq-client-secret" +OAUTH_QQ_CALLBACK_URL="http://localhost:3000/auth/oauth/qq/callback" +OAUTH_QQ_AUTH_URL="https://graph.qq.com/oauth2.0/authorize" +OAUTH_QQ_TOKEN_URL="https://graph.qq.com/oauth2.0/token" + +# [OAuth - 微信] 第三方登录配置 +OAUTH_WECHAT_CLIENT_ID="wechat-client-id" +OAUTH_WECHAT_CLIENT_SECRET="wechat-client-secret" +OAUTH_WECHAT_CALLBACK_URL="http://localhost:3000/auth/oauth/wechat/callback" +OAUTH_WECHAT_AUTH_URL="https://open.weixin.qq.com/connect/qrconnect" +OAUTH_WECHAT_TOKEN_URL="https://api.weixin.qq.com/sns/oauth2/access_token" + +# [对象存储] S3/MinIO 配置(附件上传) +# 本地开发可使用 MinIO,生产可切换到云厂商 S3 兼容服务 +S3_ENDPOINT="http://127.0.0.1:9000" +S3_REGION="us-east-1" +S3_BUCKET="todolist" +S3_ACCESS_KEY_ID="minioadmin" +S3_SECRET_ACCESS_KEY="minioadmin" +# MinIO 常用 true;AWS S3 常用 false +S3_FORCE_PATH_STYLE="true" +# 预签名上传 URL 的有效期(秒) +S3_PRESIGN_EXPIRES_SECONDS="900" +# 对外访问附件的基础地址(用于拼接公开 URL) +S3_PUBLIC_BASE_URL="http://127.0.0.1:9000" + +# [邮件] SMTP 配置(验证码/DDL 提醒邮件) +MAIL_SMTP_HOST="smtp.example.com" +MAIL_SMTP_PORT="465" +# 465 通常为 true(SSL),587 通常为 false(STARTTLS) +MAIL_SMTP_SECURE="true" +MAIL_SMTP_USER="no-reply@example.com" +MAIL_SMTP_PASS="replace-with-smtp-password" +# 发件人显示名称与地址 +MAIL_FROM_NAME="TodoList" +MAIL_FROM_ADDRESS="no-reply@example.com" + +# [数据加密] 服务端敏感数据加密主密钥 +# 用于加密 AI 配置、任务内容、同步 payload、附件元数据等数据库字段 +# 请使用高强度随机字符串,生产环境务必单独保管 +DATA_ENCRYPTION_SECRET="replace-with-a-long-random-secret" + +# [对象存储加密] 服务端对象加密策略,默认使用 AES256;如需关闭可填写 NONE +S3_SERVER_SIDE_ENCRYPTION="AES256" diff --git a/apps/api/.gitignore b/apps/api/.gitignore new file mode 100644 index 0000000..13b7b20 --- /dev/null +++ b/apps/api/.gitignore @@ -0,0 +1,8 @@ +node_modules +# 环境变量文件不纳入版本控制 +.env + +/generated/prisma +dist +prisma.config.js +prisma.config.js.map diff --git a/apps/api/jest.config.cjs b/apps/api/jest.config.cjs new file mode 100644 index 0000000..8194a11 --- /dev/null +++ b/apps/api/jest.config.cjs @@ -0,0 +1,11 @@ +/** @type {import('jest').Config} */ +module.exports = { + rootDir: ".", + testEnvironment: "node", + clearMocks: true, + testMatch: ["/test/**/*.spec.ts"], + moduleFileExtensions: ["ts", "js", "json"], + transform: { + "^.+\\.(t|j)s$": ["ts-jest", { tsconfig: "/tsconfig.spec.json" }] + } +}; diff --git a/apps/api/package.json b/apps/api/package.json new file mode 100644 index 0000000..cac4457 --- /dev/null +++ b/apps/api/package.json @@ -0,0 +1,61 @@ +{ + "name": "@todolist/api", + "version": "0.1.0", + "description": "TodoList API service", + "scripts": { + "prisma:generate": "node -e \"require('node:fs').rmSync('generated/prisma', { recursive: true, force: true })\" && prisma generate", + "prisma:format": "prisma format", + "prisma:validate": "prisma validate", + "prebuild": "pnpm run prisma:generate", + "pretypecheck": "pnpm run prisma:generate", + "pretest": "pnpm run prisma:generate", + "data:reencrypt": "node -e \"require('node:fs').rmSync('.tmp-compile', { recursive: true, force: true })\" && tsc -p tsconfig.json --outDir .tmp-compile --noEmit false && node .tmp-compile/scripts/reencrypt-sensitive-data.js && node -e \"require('node:fs').rmSync('.tmp-compile', { recursive: true, force: true })\"", + "start": "node dist/main.js", + "start:dev": "ts-node-dev --respawn --transpile-only src/main.ts", + "build": "tsc -p tsconfig.build.json", + "typecheck": "tsc --noEmit -p tsconfig.json", + "test": "jest --config jest.config.cjs --runInBand" + }, + "license": "GPL-3.0-or-later", + "devDependencies": { + "@nestjs/testing": "^11.1.18", + "@types/jest": "^30.0.0", + "@types/node": "^25.5.2", + "@types/nodemailer": "^8.0.0", + "@types/passport-github2": "^1.2.9", + "@types/passport-oauth2": "^1.8.0", + "@types/supertest": "^7.2.0", + "dotenv": "^16.6.1", + "jest": "^30.3.0", + "prisma": "^7.6.0", + "supertest": "^7.2.2", + "ts-jest": "^29.4.9", + "ts-node": "^10.9.2", + "ts-node-dev": "^2.0.0", + "typescript": "^5.9.3" + }, + "private": true, + "dependencies": { + "@aws-sdk/client-s3": "^3.1024.0", + "@aws-sdk/s3-request-presigner": "^3.1024.0", + "@nestjs/common": "^11.1.18", + "@nestjs/config": "^4.0.3", + "@nestjs/core": "^11.1.18", + "@nestjs/jwt": "^11.0.2", + "@nestjs/passport": "^11.0.5", + "@nestjs/platform-express": "^11.1.18", + "@otplib/preset-default": "^12.0.1", + "@prisma/adapter-pg": "^7.6.0", + "@prisma/client": "^7.6.0", + "class-transformer": "^0.5.1", + "class-validator": "^0.15.1", + "nodemailer": "^8.0.4", + "otplib": "^13.4.0", + "passport": "^0.7.0", + "passport-github2": "^0.1.12", + "passport-oauth2": "^1.8.0", + "pg": "^8.20.0", + "reflect-metadata": "^0.2.2", + "rxjs": "^7.8.2" + } +} diff --git a/apps/api/prisma.config.ts b/apps/api/prisma.config.ts new file mode 100644 index 0000000..d2c3b82 --- /dev/null +++ b/apps/api/prisma.config.ts @@ -0,0 +1,13 @@ +// Prisma CLI 配置(TodoList) +import "dotenv/config"; +import { defineConfig } from "prisma/config"; + +export default defineConfig({ + schema: "prisma/schema.prisma", + migrations: { + path: "prisma/migrations" + }, + datasource: { + url: process.env["DATABASE_URL"] + } +}); diff --git a/apps/api/prisma/schema.prisma b/apps/api/prisma/schema.prisma new file mode 100644 index 0000000..b6cfa31 --- /dev/null +++ b/apps/api/prisma/schema.prisma @@ -0,0 +1,408 @@ +// Prisma 数据模型定义(TodoList) + +generator client { + provider = "prisma-client" + output = "../generated/prisma" +} + +datasource db { + provider = "postgresql" +} + +enum UserStatus { + ACTIVE + DISABLED + BANNED +} + +enum AuthProvider { + EMAIL + GITHUB + QQ + WECHAT +} + +enum TaskPriority { + LOW + MEDIUM + HIGH + URGENT +} + +enum TaskStatus { + TODO + IN_PROGRESS + DONE + ARCHIVED +} + +enum AttachmentType { + IMAGE + VIDEO + FILE + LINK +} + +enum AiChannel { + USER_KEY + ASTRBOT + PUBLIC_POOL +} + +enum NotificationChannel { + EMAIL + WEB_PUSH +} + +enum NotificationStatus { + PENDING + SENT + FAILED + CANCELED +} + +model User { + id String @id @default(cuid()) + email String + emailHash String? @unique + nickname String? + avatarUrl String? + status UserStatus @default(ACTIVE) + defaultStorageQuotaMb Int @default(100) + usedStorageBytes BigInt @default(0) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + identities AuthIdentity[] + refreshTokens RefreshToken[] + security UserSecurity? + tasks Task[] + tags Tag[] + attachments Attachment[] + taskActivityLogs TaskActivityLog[] + syncOperations SyncOperation[] + syncCursors SyncCursor[] + taskTombstones TaskTombstone[] + aiProviderBindings AiProviderBinding[] + aiUsageLogs AiUsageLog[] + notificationRules NotificationRule[] + notificationJobs NotificationJob[] + createdAdminTokens AdminToken[] + auditLogs AuditLog[] + + @@map("users") +} + +model AuthIdentity { + id String @id @default(cuid()) + userId String + provider AuthProvider + providerUserId String + email String? + emailHash String? + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@unique([provider, providerUserId]) + @@index([emailHash]) + @@index([userId]) + @@map("auth_identities") +} + +model UserSecurity { + id String @id @default(cuid()) + userId String @unique + twoFactorEnabled Boolean @default(false) + twoFactorSecret String? + recoveryCodes String[] @default([]) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@map("user_security") +} + +model RefreshToken { + id String @id @default(cuid()) + userId String + tokenHash String @unique + deviceId String? + expiresAt DateTime + revokedAt DateTime? + createdAt DateTime @default(now()) + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@index([userId, expiresAt]) + @@map("refresh_tokens") +} + +model Task { + id String @id @default(cuid()) + userId String + title String + contentJson Json? + contentText String? + priority TaskPriority @default(MEDIUM) + status TaskStatus @default(TODO) + ddl DateTime? + completedAt DateTime? + version Int @default(1) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + taskTags TaskTag[] + attachments Attachment[] + activityLogs TaskActivityLog[] + notificationJobs NotificationJob[] + notificationRules NotificationRule[] + + @@index([userId, status]) + @@index([userId, ddl]) + @@map("tasks") +} + +model Tag { + id String @id @default(cuid()) + userId String + name String + color String? + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + taskTags TaskTag[] + + @@unique([userId, name]) + @@index([userId]) + @@map("tags") +} + +model TaskTag { + taskId String + tagId String + createdAt DateTime @default(now()) + task Task @relation(fields: [taskId], references: [id], onDelete: Cascade) + tag Tag @relation(fields: [tagId], references: [id], onDelete: Cascade) + + @@id([taskId, tagId]) + @@index([tagId]) + @@map("task_tags") +} + +model Attachment { + id String @id @default(cuid()) + userId String + taskId String? + type AttachmentType + url String + mimeType String? + fileName String? + fileSize Int + width Int? + height Int? + durationMs Int? + checksum String? + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + task Task? @relation(fields: [taskId], references: [id], onDelete: SetNull) + + @@index([userId]) + @@index([taskId]) + @@map("attachments") +} + +model TaskActivityLog { + id String @id @default(cuid()) + userId String + taskId String + action String + payload Json? + createdAt DateTime @default(now()) + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + task Task @relation(fields: [taskId], references: [id], onDelete: Cascade) + + @@index([taskId, createdAt]) + @@index([userId, createdAt]) + @@map("task_activity_logs") +} + +model SyncOperation { + id String @id @default(cuid()) + opId String @unique + userId String + deviceId String + entityType String + entityId String + action String + payload Json? + clientTs DateTime + serverTs DateTime @default(now()) + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@index([userId, deviceId, serverTs]) + @@index([userId, entityType, entityId]) + @@map("sync_operations") +} + +model SyncCursor { + id String @id @default(cuid()) + userId String + deviceId String + lastPulledAt DateTime? + lastOperationServerTs DateTime? + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@unique([userId, deviceId]) + @@map("sync_cursors") +} + +model TaskTombstone { + id String @id @default(cuid()) + taskId String @unique + userId String + deletedAt DateTime @default(now()) + deleteOpId String? + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@index([userId, deletedAt]) + @@map("task_tombstones") +} + +model AiProviderBinding { + id String @id @default(cuid()) + userId String + channel AiChannel + providerName String + model String? + configId String? + configName String? + encryptedApiKey String? + endpoint String? + isDefault Boolean @default(false) + isEnabled Boolean @default(true) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@index([userId, isEnabled]) + @@map("ai_provider_bindings") +} + +model AiPublicPoolConfig { + id String @id @default(cuid()) + enabled Boolean @default(false) + providerName String? + model String? + encryptedApiKey String? + endpoint String? + rpmLimit Int @default(60) + dailyTokenLimit Int @default(0) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@map("ai_public_pool_config") +} + +model AiUsageLog { + id String @id @default(cuid()) + userId String? + channel AiChannel + providerName String? + model String? + promptTokens Int @default(0) + completionTokens Int @default(0) + totalTokens Int @default(0) + latencyMs Int? + success Boolean @default(true) + errorCode String? + createdAt DateTime @default(now()) + user User? @relation(fields: [userId], references: [id], onDelete: SetNull) + + @@index([userId, createdAt]) + @@index([channel, createdAt]) + @@map("ai_usage_logs") +} + +model NotificationRule { + id String @id @default(cuid()) + userId String + taskId String? + channel NotificationChannel @default(EMAIL) + advanceMinutes Int @default(60) + enabled Boolean @default(true) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + task Task? @relation(fields: [taskId], references: [id], onDelete: SetNull) + jobs NotificationJob[] + + @@index([userId, enabled]) + @@index([taskId]) + @@map("notification_rules") +} + +model NotificationJob { + id String @id @default(cuid()) + userId String + taskId String? + ruleId String? + channel NotificationChannel + scheduledAt DateTime + sentAt DateTime? + status NotificationStatus @default(PENDING) + retryCount Int @default(0) + errorMessage String? + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + task Task? @relation(fields: [taskId], references: [id], onDelete: SetNull) + rule NotificationRule? @relation(fields: [ruleId], references: [id], onDelete: SetNull) + + @@index([status, scheduledAt]) + @@index([userId, createdAt]) + @@map("notification_jobs") +} + +model SystemSetting { + id String @id @default(cuid()) + key String @unique + value Json + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@map("system_settings") +} + +model AdminToken { + id String @id @default(cuid()) + tokenHash String @unique + name String + expiresAt DateTime + lastUsedAt DateTime? + revokedAt DateTime? + createdAt DateTime @default(now()) + createdByUserId String? + createdByUser User? @relation(fields: [createdByUserId], references: [id], onDelete: SetNull) + + @@index([expiresAt]) + @@map("admin_tokens") +} + +model AuditLog { + id String @id @default(cuid()) + actorUserId String? + action String + targetType String + targetId String? + meta Json? + ip String? + userAgent String? + createdAt DateTime @default(now()) + actorUser User? @relation(fields: [actorUserId], references: [id], onDelete: SetNull) + + @@index([action, createdAt]) + @@index([actorUserId, createdAt]) + @@map("audit_logs") +} diff --git a/apps/api/scripts/reencrypt-sensitive-data.ts b/apps/api/scripts/reencrypt-sensitive-data.ts new file mode 100644 index 0000000..af39197 --- /dev/null +++ b/apps/api/scripts/reencrypt-sensitive-data.ts @@ -0,0 +1,418 @@ +import "dotenv/config"; +import { PrismaPg } from "@prisma/adapter-pg"; +import { ConfigService } from "@nestjs/config"; +import { Prisma, PrismaClient } from "../generated/prisma/client"; +import { DataEncryptionService } from "../src/security/data-encryption.service"; + +type MigrationCounter = Record< + | "users" + | "authIdentities" + | "aiBindings" + | "publicPools" + | "aiUsageLogs" + | "tasks" + | "attachments" + | "syncOperations", + number +>; + +function createEncryptionService(): DataEncryptionService { + const configService = { + get: (key: string) => process.env[key] + } as ConfigService; + + return new DataEncryptionService(configService); +} + +function encryptStringIfNeeded( + value: string | null, + dataEncryptionService: DataEncryptionService +): string | null | undefined { + if (value === null || dataEncryptionService.isEncryptedString(value)) { + return undefined; + } + + return dataEncryptionService.encryptString(value) ?? null; +} + +function assignRequiredEncryptedString, K extends keyof T>( + target: T, + key: K, + value: string | null | undefined +): void { + if (typeof value === "string") { + target[key] = value as T[K]; + } +} + +function assignOptionalEncryptedString, K extends keyof T>( + target: T, + key: K, + value: string | null | undefined +): void { + if (value !== undefined) { + target[key] = value as T[K]; + } +} + +function encryptJsonIfNeeded( + value: Prisma.JsonValue | null, + dataEncryptionService: DataEncryptionService +): Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput | undefined { + if (value === null) { + return undefined; + } + + if (typeof value === "string" && dataEncryptionService.isEncryptedString(value)) { + return undefined; + } + + return (dataEncryptionService.encryptJson(value as Prisma.InputJsonValue) ?? Prisma.JsonNull) as + | Prisma.InputJsonValue + | Prisma.NullableJsonNullValueInput; +} + +function resolvePlainString( + value: string | null, + dataEncryptionService: DataEncryptionService +): string | null { + if (value === null) { + return null; + } + + return dataEncryptionService.isEncryptedString(value) + ? (dataEncryptionService.decryptString(value) ?? null) + : value; +} + +async function main(): Promise { + if (!process.env["DATABASE_URL"]) { + throw new Error("缺少 DATABASE_URL,无法执行敏感数据迁移"); + } + + if (!process.env["DATA_ENCRYPTION_SECRET"]) { + throw new Error("缺少 DATA_ENCRYPTION_SECRET,无法执行敏感数据迁移"); + } + + const prisma = new PrismaClient({ + adapter: new PrismaPg({ + connectionString: process.env["DATABASE_URL"] + }) + }); + const dataEncryptionService = createEncryptionService(); + const counter: MigrationCounter = { + users: 0, + authIdentities: 0, + aiBindings: 0, + publicPools: 0, + aiUsageLogs: 0, + tasks: 0, + attachments: 0, + syncOperations: 0 + }; + + try { + const users = await prisma.user.findMany({ + select: { + id: true, + email: true, + emailHash: true, + nickname: true, + avatarUrl: true + } + }); + + for (const user of users) { + const normalizedEmail = resolvePlainString(user.email, dataEncryptionService)?.toLowerCase(); + if (!normalizedEmail) { + continue; + } + const nextEmailHash = dataEncryptionService.createLookupHash("user.email", normalizedEmail); + const data: Prisma.UserUpdateInput = {}; + const email = encryptStringIfNeeded(user.email, dataEncryptionService); + const nickname = encryptStringIfNeeded(user.nickname, dataEncryptionService); + const avatarUrl = encryptStringIfNeeded(user.avatarUrl, dataEncryptionService); + + assignRequiredEncryptedString(data, "email", email); + if (user.emailHash !== nextEmailHash) { + data.emailHash = nextEmailHash; + } + assignOptionalEncryptedString(data, "nickname", nickname); + assignOptionalEncryptedString(data, "avatarUrl", avatarUrl); + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.user.update({ + where: { + id: user.id + }, + data + }); + counter.users += 1; + } + + const authIdentities = await prisma.authIdentity.findMany({ + select: { + id: true, + email: true, + emailHash: true + } + }); + + for (const authIdentity of authIdentities) { + const data: Prisma.AuthIdentityUpdateInput = {}; + const email = encryptStringIfNeeded(authIdentity.email, dataEncryptionService); + const normalizedIdentityEmail = resolvePlainString(authIdentity.email, dataEncryptionService); + const nextEmailHash = + normalizedIdentityEmail === null + ? null + : dataEncryptionService.createLookupHash( + "auth_identity.email", + normalizedIdentityEmail.toLowerCase() + ); + + assignOptionalEncryptedString(data, "email", email); + if (authIdentity.emailHash !== nextEmailHash) { + data.emailHash = nextEmailHash; + } + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.authIdentity.update({ + where: { + id: authIdentity.id + }, + data + }); + counter.authIdentities += 1; + } + + const aiBindings = await prisma.aiProviderBinding.findMany({ + select: { + id: true, + providerName: true, + model: true, + configId: true, + configName: true, + endpoint: true, + encryptedApiKey: true + } + }); + + for (const binding of aiBindings) { + const data: Prisma.AiProviderBindingUpdateInput = {}; + const providerName = encryptStringIfNeeded(binding.providerName, dataEncryptionService); + const model = encryptStringIfNeeded(binding.model, dataEncryptionService); + const configId = encryptStringIfNeeded(binding.configId, dataEncryptionService); + const configName = encryptStringIfNeeded(binding.configName, dataEncryptionService); + const endpoint = encryptStringIfNeeded(binding.endpoint, dataEncryptionService); + const encryptedApiKey = encryptStringIfNeeded(binding.encryptedApiKey, dataEncryptionService); + + assignRequiredEncryptedString(data, "providerName", providerName); + assignOptionalEncryptedString(data, "model", model); + assignOptionalEncryptedString(data, "configId", configId); + assignOptionalEncryptedString(data, "configName", configName); + assignOptionalEncryptedString(data, "endpoint", endpoint); + assignOptionalEncryptedString(data, "encryptedApiKey", encryptedApiKey); + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.aiProviderBinding.update({ + where: { + id: binding.id + }, + data + }); + counter.aiBindings += 1; + } + + const publicPools = await prisma.aiPublicPoolConfig.findMany({ + select: { + id: true, + providerName: true, + model: true, + endpoint: true, + encryptedApiKey: true + } + }); + + for (const publicPool of publicPools) { + const data: Prisma.AiPublicPoolConfigUpdateInput = {}; + const providerName = encryptStringIfNeeded(publicPool.providerName, dataEncryptionService); + const model = encryptStringIfNeeded(publicPool.model, dataEncryptionService); + const endpoint = encryptStringIfNeeded(publicPool.endpoint, dataEncryptionService); + const encryptedApiKey = encryptStringIfNeeded( + publicPool.encryptedApiKey, + dataEncryptionService + ); + + assignOptionalEncryptedString(data, "providerName", providerName); + assignOptionalEncryptedString(data, "model", model); + assignOptionalEncryptedString(data, "endpoint", endpoint); + assignOptionalEncryptedString(data, "encryptedApiKey", encryptedApiKey); + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.aiPublicPoolConfig.update({ + where: { + id: publicPool.id + }, + data + }); + counter.publicPools += 1; + } + + const aiUsageLogs = await prisma.aiUsageLog.findMany({ + select: { + id: true, + providerName: true, + model: true + } + }); + + for (const aiUsageLog of aiUsageLogs) { + const data: Prisma.AiUsageLogUpdateInput = {}; + const providerName = encryptStringIfNeeded(aiUsageLog.providerName, dataEncryptionService); + const model = encryptStringIfNeeded(aiUsageLog.model, dataEncryptionService); + + assignOptionalEncryptedString(data, "providerName", providerName); + assignOptionalEncryptedString(data, "model", model); + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.aiUsageLog.update({ + where: { + id: aiUsageLog.id + }, + data + }); + counter.aiUsageLogs += 1; + } + + const tasks = await prisma.task.findMany({ + select: { + id: true, + title: true, + contentJson: true, + contentText: true + } + }); + + for (const task of tasks) { + const data: Prisma.TaskUpdateInput = {}; + const title = encryptStringIfNeeded(task.title, dataEncryptionService); + const contentJson = encryptJsonIfNeeded(task.contentJson, dataEncryptionService); + const contentText = encryptStringIfNeeded(task.contentText, dataEncryptionService); + + assignRequiredEncryptedString(data, "title", title); + if (contentJson !== undefined) { + data.contentJson = contentJson; + } + assignOptionalEncryptedString(data, "contentText", contentText); + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.task.update({ + where: { + id: task.id + }, + data + }); + counter.tasks += 1; + } + + const attachments = await prisma.attachment.findMany({ + select: { + id: true, + url: true, + fileName: true, + checksum: true + } + }); + + for (const attachment of attachments) { + const data: Prisma.AttachmentUpdateInput = {}; + const url = encryptStringIfNeeded(attachment.url, dataEncryptionService); + const fileName = encryptStringIfNeeded(attachment.fileName, dataEncryptionService); + const checksum = encryptStringIfNeeded(attachment.checksum, dataEncryptionService); + + assignRequiredEncryptedString(data, "url", url); + assignOptionalEncryptedString(data, "fileName", fileName); + assignOptionalEncryptedString(data, "checksum", checksum); + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.attachment.update({ + where: { + id: attachment.id + }, + data + }); + counter.attachments += 1; + } + + const syncOperations = await prisma.syncOperation.findMany({ + select: { + id: true, + payload: true + } + }); + + for (const operation of syncOperations) { + if (operation.payload === null) { + continue; + } + + let nextPayload: string | null = null; + if (typeof operation.payload === "string") { + if (dataEncryptionService.isEncryptedString(operation.payload)) { + continue; + } + + nextPayload = dataEncryptionService.encryptString(operation.payload) ?? null; + } else { + nextPayload = + dataEncryptionService.encryptString(JSON.stringify(operation.payload)) ?? null; + } + + if (nextPayload === null) { + continue; + } + + await prisma.syncOperation.update({ + where: { + id: operation.id + }, + data: { + payload: nextPayload + } + }); + counter.syncOperations += 1; + } + + console.log("敏感数据迁移完成"); + console.log(JSON.stringify(counter, null, 2)); + } finally { + await prisma.$disconnect(); + } +} + +void main().catch((error: unknown) => { + const message = error instanceof Error ? error.message : "未知错误"; + console.error(`敏感数据迁移失败:${message}`); + process.exitCode = 1; +}); diff --git a/apps/api/src/ai/ai-provider-registry.service.ts b/apps/api/src/ai/ai-provider-registry.service.ts new file mode 100644 index 0000000..54387a9 --- /dev/null +++ b/apps/api/src/ai/ai-provider-registry.service.ts @@ -0,0 +1,28 @@ +import { Injectable } from "@nestjs/common"; +import { AiChannel } from "../../generated/prisma/client"; +import { AstrbotProvider } from "./providers/astrbot.provider"; +import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider"; +import { AiChannelExecutor } from "./ai.types"; + +@Injectable() +export class AiProviderRegistryService { + private readonly executors = new Map(); + + constructor( + openAiCompatibleProvider: OpenAiCompatibleProvider, + astrbotProvider: AstrbotProvider + ) { + this.executors.set(AiChannel.USER_KEY, openAiCompatibleProvider); + this.executors.set(AiChannel.PUBLIC_POOL, openAiCompatibleProvider); + this.executors.set(AiChannel.ASTRBOT, astrbotProvider); + } + + getExecutor(channel: AiChannel): AiChannelExecutor { + const executor = this.executors.get(channel); + if (!executor) { + throw new Error(`未找到 ${channel} 对应的 AI 通道执行器`); + } + + return executor; + } +} diff --git a/apps/api/src/ai/ai-rate-limit.service.ts b/apps/api/src/ai/ai-rate-limit.service.ts new file mode 100644 index 0000000..e84b2de --- /dev/null +++ b/apps/api/src/ai/ai-rate-limit.service.ts @@ -0,0 +1,123 @@ +import { Injectable } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; + +type AiRateLimitBucket = { + count: number; + resetAt: number; +}; + +export type AiRateLimitResult = + | { + allowed: true; + } + | { + allowed: false; + reason: "USER" | "IP"; + retryAfterMs: number; + limit: number; + windowMs: number; + }; + +@Injectable() +export class AiRateLimitService { + private readonly userBuckets = new Map(); + private readonly ipBuckets = new Map(); + private readonly windowMs: number; + private readonly userLimit: number; + private readonly ipLimit: number; + + constructor(private readonly configService: ConfigService) { + this.windowMs = this.readPositiveInt("AI_RATE_LIMIT_WINDOW_MS", 60_000); + this.userLimit = this.readPositiveInt("AI_RATE_LIMIT_USER_MAX", 20); + this.ipLimit = this.readPositiveInt("AI_RATE_LIMIT_IP_MAX", 60); + } + + consume(userId: string, clientIp: string | null): AiRateLimitResult { + const now = Date.now(); + const userBucket = this.getBucket(this.userBuckets, userId, now); + if (userBucket.count >= this.userLimit) { + return { + allowed: false, + reason: "USER", + retryAfterMs: Math.max(0, userBucket.resetAt - now), + limit: this.userLimit, + windowMs: this.windowMs + }; + } + + const normalizedIp = this.normalizeIp(clientIp); + const ipBucket = normalizedIp ? this.getBucket(this.ipBuckets, normalizedIp, now) : null; + if (ipBucket && ipBucket.count >= this.ipLimit) { + return { + allowed: false, + reason: "IP", + retryAfterMs: Math.max(0, ipBucket.resetAt - now), + limit: this.ipLimit, + windowMs: this.windowMs + }; + } + + userBucket.count += 1; + if (ipBucket) { + ipBucket.count += 1; + } + + this.cleanupExpiredBuckets(this.userBuckets, now); + this.cleanupExpiredBuckets(this.ipBuckets, now); + + return { + allowed: true + }; + } + + private getBucket( + buckets: Map, + key: string, + now: number + ): AiRateLimitBucket { + const currentBucket = buckets.get(key); + if (!currentBucket || now >= currentBucket.resetAt) { + const nextBucket: AiRateLimitBucket = { + count: 0, + resetAt: now + this.windowMs + }; + buckets.set(key, nextBucket); + return nextBucket; + } + + return currentBucket; + } + + private cleanupExpiredBuckets(buckets: Map, now: number): void { + if (buckets.size <= 256) { + return; + } + + for (const [key, bucket] of buckets.entries()) { + if (now >= bucket.resetAt) { + buckets.delete(key); + } + } + } + + private normalizeIp(clientIp: string | null): string | null { + if (!clientIp) { + return null; + } + + const normalizedIp = clientIp.trim(); + return normalizedIp.length > 0 ? normalizedIp : null; + } + + private readPositiveInt(key: string, fallbackValue: number): number { + const rawValue = this.configService.get(key); + const parsedValue = + typeof rawValue === "number" ? rawValue : Number.parseInt(String(rawValue ?? ""), 10); + + if (!Number.isFinite(parsedValue) || parsedValue <= 0) { + return fallbackValue; + } + + return parsedValue; + } +} diff --git a/apps/api/src/ai/ai.controller.ts b/apps/api/src/ai/ai.controller.ts new file mode 100644 index 0000000..8756e08 --- /dev/null +++ b/apps/api/src/ai/ai.controller.ts @@ -0,0 +1,74 @@ +import { + Body, + Controller, + Get, + Headers, + Ip, + Post, + Query, + UnauthorizedException +} from "@nestjs/common"; +import { AiChatDto } from "./dto/ai-chat.dto"; +import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto"; +import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; +import { + AiChatResponse, + AiService, + ListAiBindingsResponse, + ListAiUsageLogsResponse, + TestAiBindingResponse +} from "./ai.service"; + +@Controller("ai") +export class AiController { + constructor(private readonly aiService: AiService) {} + + @Get("bindings") + async listBindings( + @Headers("x-user-id") userIdHeader: string | string[] | undefined + ): Promise { + return this.aiService.listBindings(this.resolveUserId(userIdHeader)); + } + + @Get("usage-logs") + async listUsageLogs( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Query() query: ListAiUsageLogsQueryDto + ): Promise { + return this.aiService.listUsageLogs(this.resolveUserId(userIdHeader), query); + } + + @Post("bindings") + async upsertBinding( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: UpsertAiProviderBindingDto + ) { + return this.aiService.upsertBinding(this.resolveUserId(userIdHeader), body); + } + + @Post("bindings/test") + async testBinding( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: UpsertAiProviderBindingDto + ): Promise { + return this.aiService.testBinding(this.resolveUserId(userIdHeader), body); + } + + @Post("chat") + async chat( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Ip() clientIp: string, + @Body() body: AiChatDto + ): Promise { + return this.aiService.chat(this.resolveUserId(userIdHeader), body, clientIp); + } + + private resolveUserId(userIdHeader: string | string[] | undefined): string { + const userId = Array.isArray(userIdHeader) ? userIdHeader[0] : userIdHeader; + if (!userId) { + throw new UnauthorizedException("缺少用户上下文"); + } + + return userId; + } +} diff --git a/apps/api/src/ai/ai.module.ts b/apps/api/src/ai/ai.module.ts new file mode 100644 index 0000000..a17544a --- /dev/null +++ b/apps/api/src/ai/ai.module.ts @@ -0,0 +1,21 @@ +import { Module } from "@nestjs/common"; +import { PrismaModule } from "../prisma/prisma.module"; +import { AiRateLimitService } from "./ai-rate-limit.service"; +import { AiController } from "./ai.controller"; +import { AiProviderRegistryService } from "./ai-provider-registry.service"; +import { AiService } from "./ai.service"; +import { AstrbotProvider } from "./providers/astrbot.provider"; +import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider"; + +@Module({ + imports: [PrismaModule], + controllers: [AiController], + providers: [ + AiService, + AiRateLimitService, + AiProviderRegistryService, + OpenAiCompatibleProvider, + AstrbotProvider + ] +}) +export class AiModule {} diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts new file mode 100644 index 0000000..25cbf61 --- /dev/null +++ b/apps/api/src/ai/ai.service.ts @@ -0,0 +1,988 @@ +import { + BadGatewayException, + BadRequestException, + HttpException, + HttpStatus, + Injectable, + Logger +} from "@nestjs/common"; +import { + AiChannel, + AiUsageLog, + AiProviderBinding, + AiPublicPoolConfig, + Prisma, + TaskPriority, + TaskStatus +} from "../../generated/prisma/client"; +import { PrismaService } from "../prisma/prisma.service"; +import { DataEncryptionService } from "../security/data-encryption.service"; +import { AiRateLimitService } from "./ai-rate-limit.service"; +import { AiProviderRegistryService } from "./ai-provider-registry.service"; +import { AiChatDto } from "./dto/ai-chat.dto"; +import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto"; +import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; +import { + AiResolvedRouteCandidate, + AiRouteAttempt, + AiRouteFailureError, + AiUsageMetrics +} from "./ai.types"; + +type AiBindingSummary = { + id: string; + channel: AiChannel; + providerName: string; + model: string | null; + configId: string | null; + configName: string | null; + endpoint: string | null; + isEnabled: boolean; + hasApiKey: boolean; + maskedApiKey: string | null; + updatedAt: string; +}; + +type AiRoutePlanEntry = + | { + kind: "candidate"; + candidate: AiResolvedRouteCandidate; + } + | { + kind: "skip"; + attempt: AiRouteAttempt; + }; + +export type ListAiBindingsResponse = { + routeOrder: AiChannel[]; + bindings: AiBindingSummary[]; + publicPool: { + enabled: boolean; + providerName: string | null; + model: string | null; + hasApiKey: boolean; + } | null; +}; + +type AiUsageLogSummary = { + id: string; + channel: AiChannel; + providerName: string | null; + model: string | null; + promptTokens: number; + completionTokens: number; + totalTokens: number; + latencyMs: number | null; + success: boolean; + errorCode: string | null; + createdAt: string; +}; + +type AiContextTaskItem = { + id: string; + title: string; + priority: TaskPriority; + status: TaskStatus; + ddl: Date | null; + contentText: string | null; + updatedAt: Date; +}; + +export type ListAiUsageLogsResponse = { + items: AiUsageLogSummary[]; + page: number; + pageSize: number; + total: number; +}; + +export type AiChatResponse = { + channel: AiChannel; + providerName: string; + model: string | null; + content: string; + sessionId: string | null; + attempts: AiRouteAttempt[]; +}; + +export type TestAiBindingResponse = + | { + success: true; + channel: AiChannel; + providerName: string; + model: string | null; + contentPreview: string; + } + | { + success: false; + channel: AiChannel; + providerName: string; + model: string | null; + code: string; + message: string; + }; + +@Injectable() +export class AiService { + private readonly logger = new Logger(AiService.name); + private readonly maxContextTasks = 6; + private readonly maxContextContentLength = 80; + + constructor( + private readonly prismaService: PrismaService, + private readonly aiProviderRegistryService: AiProviderRegistryService, + private readonly dataEncryptionService: DataEncryptionService, + private readonly aiRateLimitService: AiRateLimitService + ) {} + + async listBindings(userId: string): Promise { + const [bindings, publicPool] = await Promise.all([ + this.prismaService.aiProviderBinding.findMany({ + where: { + userId + }, + orderBy: [{ updatedAt: "desc" }] + }), + this.prismaService.aiPublicPoolConfig.findFirst({ + orderBy: { + updatedAt: "desc" + } + }) + ]); + + const latestBindings = this.pickLatestBindingsByChannel(bindings); + + return { + routeOrder: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL], + bindings: latestBindings.map((binding) => this.serializeBinding(binding)), + publicPool: publicPool + ? { + enabled: publicPool.enabled, + providerName: this.readDecryptedString(publicPool.providerName), + model: this.readDecryptedString(publicPool.model), + hasApiKey: Boolean(publicPool.encryptedApiKey) + } + : null + }; + } + + async listUsageLogs( + userId: string, + query: ListAiUsageLogsQueryDto + ): Promise { + const page = query.page ?? 1; + const pageSize = query.pageSize ?? 20; + const skip = (page - 1) * pageSize; + const where: Prisma.AiUsageLogWhereInput = { + userId + }; + + if (query.channel) { + where.channel = query.channel; + } + + if (query.success !== undefined) { + where.success = query.success; + } + + const [items, total] = await Promise.all([ + this.prismaService.aiUsageLog.findMany({ + where, + orderBy: { + createdAt: "desc" + }, + skip, + take: pageSize + }), + this.prismaService.aiUsageLog.count({ + where + }) + ]); + + return { + items: items.map((item) => this.serializeUsageLog(item)), + page, + pageSize, + total + }; + } + + async upsertBinding(userId: string, dto: UpsertAiProviderBindingDto): Promise { + if (dto.channel === AiChannel.PUBLIC_POOL) { + throw new BadRequestException("公共 AI 通道只能由管理员配置"); + } + + this.validateBindingInput(dto); + + const result = await this.prismaService.$transaction(async (tx) => { + const existingBinding = await tx.aiProviderBinding.findFirst({ + where: { + userId, + channel: dto.channel + }, + orderBy: { + updatedAt: "desc" + } + }); + + if (!existingBinding) { + return tx.aiProviderBinding.create({ + data: { + userId, + channel: dto.channel, + providerName: this.encryptRequiredString(this.normalizeProviderName(dto.providerName)), + model: this.encryptOptionalString(dto.model), + configId: this.encryptOptionalString(dto.configId), + configName: this.encryptOptionalString(dto.configName), + endpoint: this.encryptOptionalString(dto.endpoint), + encryptedApiKey: this.encryptOptionalString(dto.apiKey), + isEnabled: dto.isEnabled ?? true + } + }); + } + + const updateData: Prisma.AiProviderBindingUpdateInput = { + channel: dto.channel, + providerName: this.encryptRequiredString(this.normalizeProviderName(dto.providerName)), + model: this.encryptOptionalString(dto.model), + configId: this.encryptOptionalString(dto.configId), + configName: this.encryptOptionalString(dto.configName), + isEnabled: dto.isEnabled ?? existingBinding.isEnabled + }; + + if (dto.endpoint !== undefined) { + updateData.endpoint = this.encryptOptionalString(dto.endpoint); + } + + if (dto.apiKey !== undefined) { + updateData.encryptedApiKey = this.encryptOptionalString(dto.apiKey); + } + + return tx.aiProviderBinding.update({ + where: { + id: existingBinding.id + }, + data: updateData + }); + }); + + return this.serializeBinding(result); + } + + async testBinding( + userId: string, + dto: UpsertAiProviderBindingDto + ): Promise { + if (dto.channel === AiChannel.PUBLIC_POOL) { + throw new BadRequestException("公共 AI 通道不能由用户自行测试"); + } + + const candidate = await this.buildTestCandidate(userId, dto); + const executor = this.aiProviderRegistryService.getExecutor(candidate.channel); + + try { + const result = await executor.execute(candidate, { + userId, + message: "请只回复“连接成功”,不要添加其他内容。", + sessionId: null + }); + + return { + success: true, + channel: result.channel, + providerName: result.providerName, + model: result.model, + contentPreview: this.limitPreviewText(result.content) + }; + } catch (error) { + if (error instanceof AiRouteFailureError) { + return { + success: false, + channel: error.channel, + providerName: error.providerName, + model: candidate.model, + code: error.code, + message: error.message + }; + } + + if (error instanceof Error) { + return { + success: false, + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + code: "UNKNOWN_ERROR", + message: error.message + }; + } + + return { + success: false, + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + code: "UNKNOWN_ERROR", + message: "未知错误" + }; + } + } + + async chat( + userId: string, + dto: AiChatDto, + clientIp: string | null = null + ): Promise { + const rateLimitResult = this.aiRateLimitService.consume(userId, clientIp); + if (!rateLimitResult.allowed) { + throw new HttpException( + { + message: "AI 请求过于频繁,请稍后再试", + code: "AI_RATE_LIMITED", + dimension: rateLimitResult.reason === "USER" ? "user" : "ip", + retryAfterMs: rateLimitResult.retryAfterMs, + limit: rateLimitResult.limit, + windowMs: rateLimitResult.windowMs + }, + HttpStatus.TOO_MANY_REQUESTS + ); + } + + const attempts: AiRouteAttempt[] = []; + const plan = await this.buildRoutePlan(userId, dto.channel ?? null); + const promptMessage = await this.buildPromptMessage(userId, dto.message, dto.localTasks ?? []); + + for (const entry of plan) { + if (entry.kind === "skip") { + attempts.push(entry.attempt); + continue; + } + + const executor = this.aiProviderRegistryService.getExecutor(entry.candidate.channel); + const startedAt = Date.now(); + + try { + const result = await executor.execute(entry.candidate, { + userId, + message: promptMessage, + sessionId: dto.sessionId ?? null + }); + const latencyMs = Date.now() - startedAt; + + attempts.push({ + channel: result.channel, + providerName: result.providerName, + model: result.model, + status: "success", + reasonCode: null, + reasonMessage: null + }); + await this.recordUsageLog({ + userId, + channel: result.channel, + providerName: result.providerName, + model: result.model, + usage: result.usage, + latencyMs, + success: true, + errorCode: null + }); + + return { + channel: result.channel, + providerName: result.providerName, + model: result.model, + content: result.content, + sessionId: result.sessionId, + attempts + }; + } catch (error) { + const latencyMs = Date.now() - startedAt; + const failureAttempt = this.toFailureAttempt(entry.candidate, error); + attempts.push(failureAttempt); + await this.recordUsageLog({ + userId, + channel: failureAttempt.channel, + providerName: failureAttempt.providerName, + model: failureAttempt.model, + usage: null, + latencyMs, + success: false, + errorCode: failureAttempt.reasonCode + }); + this.logger.warn( + `AI 通道降级:channel=${failureAttempt.channel} provider=${failureAttempt.providerName ?? "unknown"} code=${failureAttempt.reasonCode ?? "UNKNOWN"} message=${failureAttempt.reasonMessage ?? "unknown"}` + ); + } + } + + throw new BadGatewayException({ + message: "当前没有可用的 AI 通道,请稍后重试", + attempts + }); + } + + private async buildRoutePlan( + userId: string, + selectedChannel: AiChannel | null + ): Promise { + const plan: AiRoutePlanEntry[] = []; + const targetChannels = selectedChannel + ? [selectedChannel] + : [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL]; + + for (const channel of targetChannels) { + if (channel === AiChannel.PUBLIC_POOL) { + const publicPool = await this.findEnabledPublicPool(); + if (publicPool) { + plan.push({ + kind: "candidate", + candidate: this.toPublicPoolCandidate(publicPool) + }); + } else { + plan.push({ + kind: "skip", + attempt: { + channel: AiChannel.PUBLIC_POOL, + providerName: null, + model: null, + status: "skipped", + reasonCode: "PUBLIC_POOL_DISABLED", + reasonMessage: "公共 AI 通道未开启" + } + }); + } + continue; + } + + const binding = await this.findPreferredBinding(userId, channel); + if (binding) { + plan.push({ + kind: "candidate", + candidate: this.toBindingCandidate(binding) + }); + continue; + } + + plan.push({ + kind: "skip", + attempt: { + channel, + providerName: null, + model: null, + status: "skipped", + reasonCode: "CHANNEL_NOT_CONFIGURED", + reasonMessage: + channel === AiChannel.USER_KEY + ? "当前用户未配置可用的自备 Key 通道" + : "当前用户未配置可用的 AstrBot 通道" + } + }); + } + + return plan; + } + + private async findPreferredBinding( + userId: string, + channel: AiChannel + ): Promise { + return this.prismaService.aiProviderBinding.findFirst({ + where: { + userId, + channel, + isEnabled: true + }, + orderBy: { + updatedAt: "desc" + } + }); + } + + private async findEnabledPublicPool(): Promise { + return this.prismaService.aiPublicPoolConfig.findFirst({ + where: { + enabled: true + }, + orderBy: { + updatedAt: "desc" + } + }); + } + + private async buildTestCandidate( + userId: string, + dto: UpsertAiProviderBindingDto + ): Promise { + const existingBinding = await this.prismaService.aiProviderBinding.findFirst({ + where: { + userId, + channel: dto.channel + }, + orderBy: { + updatedAt: "desc" + } + }); + + const mergedDto: UpsertAiProviderBindingDto = { + channel: dto.channel, + providerName: + dto.providerName ?? this.readDecryptedString(existingBinding?.providerName ?? null) ?? "", + model: dto.model ?? this.readDecryptedString(existingBinding?.model ?? null) ?? undefined, + configId: + dto.configId ?? this.readDecryptedString(existingBinding?.configId ?? null) ?? undefined, + configName: + dto.configName ?? + this.readDecryptedString(existingBinding?.configName ?? null) ?? + undefined, + endpoint: + dto.endpoint ?? this.readDecryptedString(existingBinding?.endpoint ?? null) ?? undefined, + apiKey: + dto.apiKey ?? + this.readDecryptedString(existingBinding?.encryptedApiKey ?? null) ?? + undefined, + isEnabled: dto.isEnabled ?? existingBinding?.isEnabled ?? true + }; + + this.validateBindingInput(mergedDto); + + return { + channel: mergedDto.channel, + source: existingBinding ? "binding" : "binding", + sourceId: existingBinding?.id ?? null, + providerName: this.normalizeProviderName(mergedDto.providerName), + model: this.normalizeOptionalString(mergedDto.model), + configId: this.normalizeOptionalString(mergedDto.configId), + configName: this.normalizeOptionalString(mergedDto.configName), + endpoint: this.normalizeOptionalString(mergedDto.endpoint), + apiKey: this.normalizeOptionalString(mergedDto.apiKey) + }; + } + + private toBindingCandidate(binding: AiProviderBinding): AiResolvedRouteCandidate { + return { + channel: binding.channel, + source: "binding", + sourceId: binding.id, + providerName: this.readDecryptedString(binding.providerName) ?? "", + model: this.readDecryptedString(binding.model), + configId: this.readDecryptedString(binding.configId), + configName: this.readDecryptedString(binding.configName), + endpoint: this.readDecryptedString(binding.endpoint), + apiKey: this.readDecryptedString(binding.encryptedApiKey) + }; + } + + private toPublicPoolCandidate(publicPool: AiPublicPoolConfig): AiResolvedRouteCandidate { + return { + channel: AiChannel.PUBLIC_POOL, + source: "public_pool", + sourceId: publicPool.id, + providerName: this.readDecryptedString(publicPool.providerName) ?? "public-pool", + model: this.readDecryptedString(publicPool.model), + configId: null, + configName: null, + endpoint: this.readDecryptedString(publicPool.endpoint), + apiKey: this.readDecryptedString(publicPool.encryptedApiKey) + }; + } + + private serializeBinding(binding: AiProviderBinding): AiBindingSummary { + const decryptedProviderName = this.readDecryptedString(binding.providerName) ?? ""; + const decryptedModel = this.readDecryptedString(binding.model); + const decryptedConfigId = this.readDecryptedString(binding.configId); + const decryptedConfigName = this.readDecryptedString(binding.configName); + const decryptedEndpoint = this.readDecryptedString(binding.endpoint); + const decryptedApiKey = this.readDecryptedString(binding.encryptedApiKey); + + return { + id: binding.id, + channel: binding.channel, + providerName: decryptedProviderName, + model: decryptedModel, + configId: decryptedConfigId, + configName: decryptedConfigName, + endpoint: decryptedEndpoint, + isEnabled: binding.isEnabled, + hasApiKey: Boolean(binding.encryptedApiKey), + maskedApiKey: this.maskSecret(decryptedApiKey), + updatedAt: binding.updatedAt.toISOString() + }; + } + + private pickLatestBindingsByChannel(bindings: AiProviderBinding[]): AiProviderBinding[] { + const bindingMap = new Map(); + + for (const binding of bindings) { + if (!bindingMap.has(binding.channel)) { + bindingMap.set(binding.channel, binding); + } + } + + return [AiChannel.USER_KEY, AiChannel.ASTRBOT] + .map((channel) => bindingMap.get(channel) ?? null) + .filter((binding): binding is AiProviderBinding => binding !== null); + } + + private serializeUsageLog(log: AiUsageLog): AiUsageLogSummary { + return { + id: log.id, + channel: log.channel, + providerName: this.readDecryptedString(log.providerName), + model: this.readDecryptedString(log.model), + promptTokens: log.promptTokens, + completionTokens: log.completionTokens, + totalTokens: log.totalTokens, + latencyMs: log.latencyMs, + success: log.success, + errorCode: log.errorCode, + createdAt: log.createdAt.toISOString() + }; + } + + private async buildPromptMessage( + userId: string, + userMessage: string, + localTasks: NonNullable + ): Promise { + const taskSummary = await this.buildTaskContextSummary(userId, localTasks); + if (!taskSummary) { + return userMessage; + } + + return [ + "你是 TodoList 的 AI 助手,需要结合用户当前待办提供任务统筹建议。", + "以下是系统整理的未完成任务摘要:", + taskSummary, + "请优先根据这些任务的紧急度、截止时间和执行顺序回答,并给出明确可执行的建议。", + `用户当前问题:${userMessage}` + ].join("\n\n"); + } + + private async buildTaskContextSummary( + userId: string, + localTasks: NonNullable + ): Promise { + const tasks = await this.prismaService.task.findMany({ + where: { + userId, + status: { + in: [TaskStatus.TODO, TaskStatus.IN_PROGRESS] + } + }, + select: { + id: true, + title: true, + priority: true, + status: true, + ddl: true, + contentText: true, + updatedAt: true + }, + take: 20 + }); + + const sortedTasks = this.sortContextTasks(this.mergeContextTasks(tasks, localTasks)); + if (sortedTasks.length === 0) { + return null; + } + + const visibleTasks = sortedTasks.slice(0, this.maxContextTasks); + const lines = visibleTasks.map((task, index) => { + const parts = [ + `${index + 1}. ${task.title}`, + `优先级:${this.getPriorityLabel(task.priority)}`, + `状态:${this.getStatusLabel(task.status)}`, + `DDL:${task.ddl ? task.ddl.toISOString() : "未设置"}` + ]; + + const contentSnippet = this.getContentSnippet(task.contentText); + if (contentSnippet) { + parts.push(`内容摘要:${contentSnippet}`); + } + + return parts.join(" | "); + }); + + const omittedCount = sortedTasks.length - visibleTasks.length; + if (omittedCount > 0) { + lines.push(`另有 ${omittedCount} 条任务已省略。`); + } + + return [`共 ${sortedTasks.length} 条未完成任务。`, ...lines].join("\n"); + } + + private mergeContextTasks( + databaseTasks: Array<{ + id: string; + title: string; + priority: TaskPriority; + status: TaskStatus; + ddl: Date | null; + contentText: string | null; + updatedAt: Date; + }>, + localTasks: NonNullable + ): AiContextTaskItem[] { + const taskMap = new Map(); + + for (const task of databaseTasks) { + taskMap.set(task.id, { + id: task.id, + title: this.readDecryptedString(task.title) ?? "未命名任务", + priority: task.priority, + status: task.status, + ddl: task.ddl, + contentText: this.readDecryptedString(task.contentText), + updatedAt: task.updatedAt + }); + } + + for (const task of localTasks) { + if (task.status !== TaskStatus.TODO && task.status !== TaskStatus.IN_PROGRESS) { + continue; + } + + const currentTask = taskMap.get(task.id); + const nextTask: AiContextTaskItem = { + id: task.id, + title: task.title.trim().length > 0 ? task.title.trim() : "未命名任务", + priority: task.priority, + status: task.status, + ddl: typeof task.ddlAt === "number" ? new Date(task.ddlAt) : null, + contentText: + typeof task.contentText === "string" && task.contentText.trim().length > 0 + ? task.contentText + : null, + updatedAt: new Date(task.updatedAt) + }; + + if (!currentTask || nextTask.updatedAt.getTime() >= currentTask.updatedAt.getTime()) { + taskMap.set(task.id, nextTask); + } + } + + return [...taskMap.values()].filter( + (task) => task.status === TaskStatus.TODO || task.status === TaskStatus.IN_PROGRESS + ); + } + + private sortContextTasks(tasks: AiContextTaskItem[]): AiContextTaskItem[] { + return [...tasks].sort((left, right) => { + const priorityDiff = + this.getPriorityWeight(right.priority) - this.getPriorityWeight(left.priority); + if (priorityDiff !== 0) { + return priorityDiff; + } + + const leftDdl = left.ddl?.getTime() ?? Number.POSITIVE_INFINITY; + const rightDdl = right.ddl?.getTime() ?? Number.POSITIVE_INFINITY; + if (leftDdl !== rightDdl) { + return leftDdl - rightDdl; + } + + return right.updatedAt.getTime() - left.updatedAt.getTime(); + }); + } + + private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt { + if (error instanceof AiRouteFailureError) { + return { + channel: error.channel, + providerName: error.providerName, + model: candidate.model, + status: "failed", + reasonCode: error.code, + reasonMessage: error.message + }; + } + + if (error instanceof Error) { + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + status: "failed", + reasonCode: "UNKNOWN_ERROR", + reasonMessage: error.message + }; + } + + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + status: "failed", + reasonCode: "UNKNOWN_ERROR", + reasonMessage: "未知错误" + }; + } + + private normalizeOptionalString(value: string | undefined): string | null { + if (value === undefined) { + return null; + } + + const normalizedValue = value.trim(); + return normalizedValue.length > 0 ? normalizedValue : null; + } + + private normalizeProviderName(value: string | undefined): string { + return this.normalizeOptionalString(value) ?? ""; + } + + private encryptOptionalString(value: string | undefined): string | null | undefined { + const normalizedValue = this.normalizeOptionalString(value); + return this.dataEncryptionService.encryptString(normalizedValue); + } + + private encryptRequiredString(value: string): string { + const encryptedValue = this.dataEncryptionService.encryptString(value); + if (!encryptedValue) { + throw new BadRequestException("敏感配置加密失败"); + } + + return encryptedValue; + } + + private readDecryptedString(value: string | null): string | null { + const decryptedValue = this.dataEncryptionService.decryptString(value); + return typeof decryptedValue === "string" ? decryptedValue : null; + } + + private validateBindingInput(dto: UpsertAiProviderBindingDto): void { + const providerName = this.normalizeOptionalString(dto.providerName); + const configId = this.normalizeOptionalString(dto.configId); + const configName = this.normalizeOptionalString(dto.configName); + + if (dto.channel === AiChannel.ASTRBOT) { + if (!providerName && !configId && !configName) { + throw new BadRequestException( + "AstrBot 通道至少需要 providerName、configId、configName 三者之一" + ); + } + return; + } + + if (!providerName) { + throw new BadRequestException("当前通道必须提供 providerName"); + } + } + + private maskSecret(secret: string | null): string | null { + if (!secret) { + return null; + } + + if (secret.length <= 6) { + return "*".repeat(secret.length); + } + + return `${secret.slice(0, 4)}***${secret.slice(-2)}`; + } + + private limitPreviewText(content: string): string { + const normalizedContent = content.replace(/\s+/g, " ").trim(); + if (normalizedContent.length <= 60) { + return normalizedContent; + } + + return `${normalizedContent.slice(0, 60)}...`; + } + + private getPriorityWeight(priority: TaskPriority): number { + switch (priority) { + case TaskPriority.URGENT: + return 4; + case TaskPriority.HIGH: + return 3; + case TaskPriority.MEDIUM: + return 2; + case TaskPriority.LOW: + return 1; + default: + return 0; + } + } + + private getPriorityLabel(priority: TaskPriority): string { + switch (priority) { + case TaskPriority.URGENT: + return "紧急"; + case TaskPriority.HIGH: + return "高"; + case TaskPriority.MEDIUM: + return "中"; + case TaskPriority.LOW: + return "低"; + default: + return String(priority); + } + } + + private getStatusLabel(status: TaskStatus): string { + switch (status) { + case TaskStatus.TODO: + return "待开始"; + case TaskStatus.IN_PROGRESS: + return "进行中"; + case TaskStatus.DONE: + return "已完成"; + case TaskStatus.ARCHIVED: + return "已归档"; + default: + return String(status); + } + } + + private getContentSnippet(contentText: string | null): string | null { + if (!contentText) { + return null; + } + + const normalizedContent = contentText.replace(/\s+/g, " ").trim(); + if (normalizedContent.length === 0) { + return null; + } + + if (normalizedContent.length <= this.maxContextContentLength) { + return normalizedContent; + } + + return `${normalizedContent.slice(0, this.maxContextContentLength)}...`; + } + + private async recordUsageLog(input: { + userId: string; + channel: AiChannel; + providerName: string | null; + model: string | null; + usage: AiUsageMetrics | null; + latencyMs: number; + success: boolean; + errorCode: string | null; + }): Promise { + try { + await this.prismaService.aiUsageLog.create({ + data: { + userId: input.userId, + channel: input.channel, + providerName: + input.providerName === null + ? null + : this.dataEncryptionService.encryptString(input.providerName), + model: + input.model === null ? null : this.dataEncryptionService.encryptString(input.model), + promptTokens: input.usage?.promptTokens ?? 0, + completionTokens: input.usage?.completionTokens ?? 0, + totalTokens: input.usage?.totalTokens ?? 0, + latencyMs: input.latencyMs, + success: input.success, + errorCode: input.errorCode + } + }); + } catch (error) { + const message = error instanceof Error ? error.message : "未知错误"; + this.logger.warn(`写入 AI 使用日志失败:${message}`); + } + } +} diff --git a/apps/api/src/ai/ai.types.ts b/apps/api/src/ai/ai.types.ts new file mode 100644 index 0000000..ccb8088 --- /dev/null +++ b/apps/api/src/ai/ai.types.ts @@ -0,0 +1,61 @@ +import { AiChannel } from "../../generated/prisma/client"; + +export type AiResolvedRouteCandidate = { + channel: AiChannel; + source: "binding" | "public_pool"; + sourceId: string | null; + providerName: string; + model: string | null; + configId: string | null; + configName: string | null; + endpoint: string | null; + apiKey: string | null; +}; + +export type AiChatInput = { + userId: string; + message: string; + sessionId: string | null; +}; + +export type AiChatResult = { + channel: AiChannel; + providerName: string; + model: string | null; + content: string; + sessionId: string | null; + usage: AiUsageMetrics | null; + raw: unknown; +}; + +export type AiUsageMetrics = { + promptTokens: number; + completionTokens: number; + totalTokens: number; +}; + +export type AiRouteAttempt = { + channel: AiChannel; + providerName: string | null; + model: string | null; + status: "skipped" | "failed" | "success"; + reasonCode: string | null; + reasonMessage: string | null; +}; + +export class AiRouteFailureError extends Error { + constructor( + public readonly channel: AiChannel, + public readonly providerName: string, + public readonly code: string, + message: string + ) { + super(message); + this.name = "AiRouteFailureError"; + Object.setPrototypeOf(this, new.target.prototype); + } +} + +export interface AiChannelExecutor { + execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise; +} diff --git a/apps/api/src/ai/dto/ai-chat.dto.ts b/apps/api/src/ai/dto/ai-chat.dto.ts new file mode 100644 index 0000000..e2613ae --- /dev/null +++ b/apps/api/src/ai/dto/ai-chat.dto.ts @@ -0,0 +1,60 @@ +import { Type } from "class-transformer"; +import { + IsArray, + IsEnum, + IsInt, + IsOptional, + IsString, + MinLength, + ValidateNested +} from "class-validator"; +import { AiChannel } from "../../../generated/prisma/client"; +import { TaskPriority, TaskStatus } from "../../../generated/prisma/client"; + +export class LocalTaskContextItemDto { + @IsString() + @MinLength(1) + id!: string; + + @IsString() + @MinLength(1) + title!: string; + + @IsEnum(TaskPriority) + priority!: TaskPriority; + + @IsEnum(TaskStatus) + status!: TaskStatus; + + @IsOptional() + @IsInt() + ddlAt?: number | null; + + @IsOptional() + @IsString() + contentText?: string | null; + + @IsInt() + updatedAt!: number; +} + +export class AiChatDto { + @IsString() + @MinLength(1) + message!: string; + + @IsOptional() + @IsString() + @MinLength(1) + sessionId?: string; + + @IsOptional() + @IsEnum(AiChannel) + channel?: AiChannel; + + @IsOptional() + @IsArray() + @ValidateNested({ each: true }) + @Type(() => LocalTaskContextItemDto) + localTasks?: LocalTaskContextItemDto[]; +} diff --git a/apps/api/src/ai/dto/list-ai-usage-logs-query.dto.ts b/apps/api/src/ai/dto/list-ai-usage-logs-query.dto.ts new file mode 100644 index 0000000..49aa2e0 --- /dev/null +++ b/apps/api/src/ai/dto/list-ai-usage-logs-query.dto.ts @@ -0,0 +1,48 @@ +import { Transform, Type } from "class-transformer"; +import { IsBoolean, IsEnum, IsInt, IsOptional, Max, Min } from "class-validator"; +import { AiChannel } from "../../../generated/prisma/client"; + +function normalizeBoolean(value: unknown): boolean | undefined { + if (typeof value === "boolean") { + return value; + } + + if (typeof value !== "string") { + return undefined; + } + + const normalized = value.trim().toLowerCase(); + if (normalized === "true" || normalized === "1") { + return true; + } + + if (normalized === "false" || normalized === "0") { + return false; + } + + return undefined; +} + +export class ListAiUsageLogsQueryDto { + @Type(() => Number) + @IsOptional() + @IsInt() + @Min(1) + page?: number; + + @Type(() => Number) + @IsOptional() + @IsInt() + @Min(1) + @Max(100) + pageSize?: number; + + @IsOptional() + @IsEnum(AiChannel) + channel?: AiChannel; + + @Transform(({ value }) => normalizeBoolean(value)) + @IsOptional() + @IsBoolean() + success?: boolean; +} diff --git a/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts new file mode 100644 index 0000000..b821bcc --- /dev/null +++ b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts @@ -0,0 +1,47 @@ +import { AiChannel } from "../../../generated/prisma/client"; +import { IsBoolean, IsEnum, IsOptional, IsString, IsUrl, MinLength } from "class-validator"; + +export class UpsertAiProviderBindingDto { + @IsEnum(AiChannel) + channel!: AiChannel; + + @IsOptional() + @IsString() + @MinLength(1) + providerName?: string; + + @IsOptional() + @IsString() + @MinLength(1) + model?: string; + + @IsOptional() + @IsString() + @MinLength(1) + configId?: string; + + @IsOptional() + @IsString() + @MinLength(1) + configName?: string; + + @IsOptional() + @IsUrl( + { + require_tld: false + }, + { + message: "endpoint \u5fc5\u987b\u662f\u5408\u6cd5\u7684 URL" + } + ) + endpoint?: string; + + @IsOptional() + @IsString() + @MinLength(1) + apiKey?: string; + + @IsOptional() + @IsBoolean() + isEnabled?: boolean; +} diff --git a/apps/api/src/ai/providers/astrbot.provider.ts b/apps/api/src/ai/providers/astrbot.provider.ts new file mode 100644 index 0000000..3d28cbb --- /dev/null +++ b/apps/api/src/ai/providers/astrbot.provider.ts @@ -0,0 +1,284 @@ +import { Injectable } from "@nestjs/common"; +import { + AiChannelExecutor, + AiChatInput, + AiChatResult, + AiResolvedRouteCandidate, + AiRouteFailureError +} from "../ai.types"; + +@Injectable() +export class AstrbotProvider implements AiChannelExecutor { + async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise { + const routeLabel = + candidate.providerName || candidate.configName || candidate.configId || "astrbot"; + + if (!candidate.endpoint) { + throw new AiRouteFailureError( + candidate.channel, + routeLabel, + "MISSING_ENDPOINT", + "缺少 AstrBot 服务地址配置" + ); + } + + if (!candidate.apiKey) { + throw new AiRouteFailureError( + candidate.channel, + routeLabel, + "MISSING_API_KEY", + "缺少 AstrBot API Key 配置" + ); + } + + const requestUrl = this.buildRequestUrl(candidate.endpoint); + + let response: Response; + try { + response = await fetch(requestUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${candidate.apiKey}` + }, + body: JSON.stringify({ + username: input.userId, + session_id: input.sessionId ?? undefined, + message: input.message, + enable_streaming: false, + selected_model: candidate.model ?? undefined + }), + signal: AbortSignal.timeout(30000) + }); + } catch (error) { + throw new AiRouteFailureError( + candidate.channel, + routeLabel, + "UPSTREAM_UNREACHABLE", + this.toErrorMessage(error, "AstrBot 服务请求失败") + ); + } + + if (!response.ok) { + const rawText = await response.text(); + throw new AiRouteFailureError( + candidate.channel, + routeLabel, + `UPSTREAM_HTTP_${response.status}`, + this.extractHttpErrorMessage(rawText, response.status) + ); + } + + const events = await this.readSseEvents(response); + let content = ""; + let sessionId = input.sessionId; + + for (const event of events) { + const type = this.readString(event["type"]); + if (type === "session_id") { + sessionId = this.readString(event["session_id"]) ?? sessionId; + continue; + } + + if (type === "error") { + throw new AiRouteFailureError( + candidate.channel, + routeLabel, + this.readString(event["code"]) ?? "ASTRBOT_ERROR", + this.readString(event["data"]) ?? "AstrBot 返回错误" + ); + } + + if (type !== "plain") { + continue; + } + + const chainType = this.readString(event["chain_type"]); + if ( + chainType === "reasoning" || + chainType === "tool_call" || + chainType === "tool_call_result" + ) { + continue; + } + + const data = this.readString(event["data"]); + if (!data) { + continue; + } + + if (event["streaming"] === true) { + content += data; + continue; + } + + content = data; + } + + if (!content.trim()) { + throw new AiRouteFailureError( + candidate.channel, + routeLabel, + "EMPTY_RESPONSE", + "AstrBot 没有返回有效内容" + ); + } + + return { + channel: candidate.channel, + providerName: routeLabel, + model: candidate.model, + content, + sessionId, + usage: this.extractUsage(events), + raw: events + }; + } + + private buildRequestUrl(endpoint: string): string { + const normalizedEndpoint = endpoint.replace(/\/+$/, ""); + if (normalizedEndpoint.endsWith("/api/v1/chat")) { + return normalizedEndpoint; + } + if (normalizedEndpoint.endsWith("/api/v1")) { + return `${normalizedEndpoint}/chat`; + } + if (normalizedEndpoint.endsWith("/api")) { + return `${normalizedEndpoint}/v1/chat`; + } + return `${normalizedEndpoint}/api/v1/chat`; + } + + private parseSseEvents(rawText: string): Array> { + return rawText + .split(/\r?\n\r?\n/) + .map((block) => + block + .split(/\r?\n/) + .filter((line) => line.startsWith("data:")) + .map((line) => line.slice(5).trim()) + .join("\n") + ) + .filter((payload) => payload.length > 0) + .map((payload) => { + try { + return JSON.parse(payload) as Record; + } catch { + return null; + } + }) + .filter((item): item is Record => item !== null); + } + + private async readSseEvents(response: Response): Promise>> { + if (!response.body) { + return this.parseSseEvents(await response.text()); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + const events: Array> = []; + let buffer = ""; + let reachedEndEvent = false; + + try { + while (!reachedEndEvent) { + const { done, value } = await reader.read(); + if (done) { + break; + } + + buffer += decoder.decode(value, { stream: true }); + const segments = buffer.split(/\r?\n\r?\n/); + buffer = segments.pop() ?? ""; + + for (const segment of segments) { + const parsedEvents = this.parseSseEvents(segment); + for (const event of parsedEvents) { + events.push(event); + if (this.readString(event["type"]) === "end") { + reachedEndEvent = true; + break; + } + } + + if (reachedEndEvent) { + break; + } + } + } + + const tail = `${buffer}${decoder.decode()}`; + if (tail.trim().length > 0) { + events.push(...this.parseSseEvents(tail)); + } + } finally { + await reader.cancel(); + } + + return events; + } + + private extractHttpErrorMessage(rawText: string, statusCode: number): string { + try { + const payload = JSON.parse(rawText) as Record; + if (typeof payload["message"] === "string") { + return payload["message"]; + } + if (typeof payload["data"] === "string") { + return payload["data"]; + } + } catch { + return `AstrBot 服务调用失败,状态码 ${statusCode}`; + } + + return `AstrBot 服务调用失败,状态码 ${statusCode}`; + } + + private readString(value: unknown): string | null { + return typeof value === "string" ? value : null; + } + + private toErrorMessage(error: unknown, fallback: string): string { + if (error instanceof Error && error.message) { + return error.message; + } + + return fallback; + } + + private extractUsage(events: Array>): AiChatResult["usage"] { + for (const event of events) { + if (this.readString(event["type"]) !== "agent_stats") { + continue; + } + + const data = this.asRecord(event["data"]); + const tokenUsage = this.asRecord(data?.["token_usage"]); + if (!tokenUsage) { + continue; + } + + const promptTokens = + (this.readNumber(tokenUsage["input_other"]) ?? 0) + + (this.readNumber(tokenUsage["input_cached"]) ?? 0); + const completionTokens = this.readNumber(tokenUsage["output"]) ?? 0; + + return { + promptTokens, + completionTokens, + totalTokens: promptTokens + completionTokens + }; + } + + return null; + } + + private asRecord(value: unknown): Record | null { + return typeof value === "object" && value !== null ? (value as Record) : null; + } + + private readNumber(value: unknown): number | null { + return typeof value === "number" && Number.isFinite(value) ? value : null; + } +} diff --git a/apps/api/src/ai/providers/openai-compatible.provider.ts b/apps/api/src/ai/providers/openai-compatible.provider.ts new file mode 100644 index 0000000..1ba4eff --- /dev/null +++ b/apps/api/src/ai/providers/openai-compatible.provider.ts @@ -0,0 +1,300 @@ +import { Injectable } from "@nestjs/common"; +import { + AiChannelExecutor, + AiChatInput, + AiChatResult, + AiResolvedRouteCandidate, + AiRouteFailureError +} from "../ai.types"; + +@Injectable() +export class OpenAiCompatibleProvider implements AiChannelExecutor { + async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise { + if (!candidate.endpoint) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_ENDPOINT", + "缺少 AI 服务地址配置" + ); + } + + if (!candidate.apiKey) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_API_KEY", + "缺少 AI 服务密钥配置" + ); + } + + if (!candidate.model) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_MODEL", + "缺少 AI 模型配置" + ); + } + + const requestUrl = this.buildRequestUrl(candidate.endpoint); + + let response: Response; + try { + response = await fetch(requestUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${candidate.apiKey}` + }, + body: JSON.stringify({ + model: candidate.model, + messages: [ + { + role: "user", + content: input.message + } + ], + stream: false + }), + signal: AbortSignal.timeout(30000) + }); + } catch (error) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "UPSTREAM_UNREACHABLE", + this.toErrorMessage(error, "AI 服务请求失败") + ); + } + + let payload: unknown; + try { + payload = await response.json(); + } catch (error) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "INVALID_RESPONSE", + this.toErrorMessage(error, "AI 服务返回了无法解析的数据") + ); + } + + if (!response.ok) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + `UPSTREAM_HTTP_${response.status}`, + this.extractErrorMessage(payload, `AI 服务调用失败,状态码 ${response.status}`) + ); + } + + const content = this.extractAssistantText(payload); + if (!content.trim()) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "EMPTY_RESPONSE", + "AI 服务没有返回有效内容" + ); + } + + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: this.extractModel(payload) ?? candidate.model, + content, + sessionId: input.sessionId, + usage: this.extractUsage(payload), + raw: payload + }; + } + + private buildRequestUrl(endpoint: string): string { + const normalizedEndpoint = endpoint.replace(/\/+$/, ""); + if (normalizedEndpoint.endsWith("/chat/completions")) { + return normalizedEndpoint; + } + if (normalizedEndpoint.endsWith("/v1")) { + return `${normalizedEndpoint}/chat/completions`; + } + return `${normalizedEndpoint}/v1/chat/completions`; + } + + private extractAssistantText(payload: unknown): string { + const chatCompletionText = this.extractChatCompletionText(payload); + if (chatCompletionText) { + return chatCompletionText; + } + + const responsesText = this.extractResponsesApiText(payload); + if (responsesText) { + return responsesText; + } + + return ""; + } + + private extractChatCompletionText(payload: unknown): string { + if (!this.isRecord(payload)) { + return ""; + } + + const choices = payload["choices"]; + if (!Array.isArray(choices) || choices.length === 0) { + return ""; + } + + const firstChoice = choices[0]; + if (!this.isRecord(firstChoice)) { + return ""; + } + + const message = firstChoice["message"]; + if (this.isRecord(message)) { + const messageContent = this.extractMessageContent(message["content"]); + if (messageContent) { + return messageContent; + } + } + + if (typeof firstChoice["text"] === "string") { + return firstChoice["text"]; + } + + return ""; + } + + private extractResponsesApiText(payload: unknown): string { + if (!this.isRecord(payload)) { + return ""; + } + + if (typeof payload["output_text"] === "string") { + return payload["output_text"]; + } + + const output = payload["output"]; + if (!Array.isArray(output)) { + return ""; + } + + return output + .map((item) => { + if (!this.isRecord(item)) { + return ""; + } + + if (typeof item["text"] === "string") { + return item["text"]; + } + + return this.extractMessageContent(item["content"]); + }) + .filter((item) => item.length > 0) + .join("\n") + .trim(); + } + + private extractMessageContent(content: unknown): string { + if (typeof content === "string") { + return content; + } + + if (!Array.isArray(content)) { + return ""; + } + + return content + .map((item) => this.extractContentPartText(item)) + .filter((item) => item.length > 0) + .join("\n") + .trim(); + } + + private extractContentPartText(item: unknown): string { + if (!this.isRecord(item)) { + return ""; + } + + if (typeof item["text"] === "string") { + return item["text"]; + } + + if (this.isRecord(item["text"]) && typeof item["text"]["value"] === "string") { + return item["text"]["value"]; + } + + if (typeof item["content"] === "string") { + return item["content"]; + } + + if (this.isRecord(item["content"]) && typeof item["content"]["text"] === "string") { + return item["content"]["text"]; + } + + return ""; + } + + private extractModel(payload: unknown): string | null { + if (!this.isRecord(payload) || typeof payload["model"] !== "string") { + return null; + } + + return payload["model"]; + } + + private extractUsage(payload: unknown): AiChatResult["usage"] { + if (!this.isRecord(payload)) { + return null; + } + + const usage = payload["usage"]; + if (!this.isRecord(usage)) { + return null; + } + + const promptTokens = this.readNumber(usage["prompt_tokens"]); + const completionTokens = this.readNumber(usage["completion_tokens"]); + const totalTokens = this.readNumber(usage["total_tokens"]); + + if (promptTokens === null && completionTokens === null && totalTokens === null) { + return null; + } + + return { + promptTokens: promptTokens ?? 0, + completionTokens: completionTokens ?? 0, + totalTokens: totalTokens ?? (promptTokens ?? 0) + (completionTokens ?? 0) + }; + } + + private extractErrorMessage(payload: unknown, fallback: string): string { + if (!this.isRecord(payload)) { + return fallback; + } + + const error = payload["error"]; + if (!this.isRecord(error) || typeof error["message"] !== "string") { + return fallback; + } + + return error["message"]; + } + + private isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; + } + + private toErrorMessage(error: unknown, fallback: string): string { + if (error instanceof Error && error.message) { + return error.message; + } + + return fallback; + } + + private readNumber(value: unknown): number | null { + return typeof value === "number" && Number.isFinite(value) ? value : null; + } +} diff --git a/apps/api/src/app.module.ts b/apps/api/src/app.module.ts new file mode 100644 index 0000000..aae9297 --- /dev/null +++ b/apps/api/src/app.module.ts @@ -0,0 +1,27 @@ +import { Module } from "@nestjs/common"; +import { ConfigModule } from "@nestjs/config"; +import { resolve } from "node:path"; +import { AiModule } from "./ai/ai.module"; +import { AttachmentModule } from "./attachment/attachment.module"; +import { AuthModule } from "./auth/auth.module"; +import { PrismaModule } from "./prisma/prisma.module"; +import { SecurityModule } from "./security/security.module"; +import { SyncModule } from "./sync/sync.module"; +import { TaskModule } from "./task/task.module"; + +@Module({ + imports: [ + ConfigModule.forRoot({ + isGlobal: true, + envFilePath: [resolve(__dirname, "../.env"), ".env"] + }), + PrismaModule, + SecurityModule, + AuthModule, + TaskModule, + AttachmentModule, + SyncModule, + AiModule + ] +}) +export class AppModule {} diff --git a/apps/api/src/attachment/attachment.controller.ts b/apps/api/src/attachment/attachment.controller.ts new file mode 100644 index 0000000..6baec6e --- /dev/null +++ b/apps/api/src/attachment/attachment.controller.ts @@ -0,0 +1,38 @@ +import { Body, Controller, Headers, Post, UnauthorizedException } from "@nestjs/common"; +import { + AttachmentResponse, + AttachmentService, + PresignAttachmentResponse +} from "./attachment.service"; +import { CompleteAttachmentDto } from "./dto/complete-attachment.dto"; +import { PresignAttachmentDto } from "./dto/presign-attachment.dto"; + +@Controller("attachments") +export class AttachmentController { + constructor(private readonly attachmentService: AttachmentService) {} + + @Post("presign") + async presignAttachment( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: PresignAttachmentDto + ): Promise { + return this.attachmentService.presignAttachment(this.resolveUserId(userIdHeader), body); + } + + @Post("complete") + async completeAttachment( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: CompleteAttachmentDto + ): Promise { + return this.attachmentService.completeAttachment(this.resolveUserId(userIdHeader), body); + } + + private resolveUserId(userIdHeader: string | string[] | undefined): string { + const userId = Array.isArray(userIdHeader) ? userIdHeader[0] : userIdHeader; + if (!userId) { + throw new UnauthorizedException("缺少用户上下文"); + } + + return userId; + } +} diff --git a/apps/api/src/attachment/attachment.module.ts b/apps/api/src/attachment/attachment.module.ts new file mode 100644 index 0000000..cd8dfe3 --- /dev/null +++ b/apps/api/src/attachment/attachment.module.ts @@ -0,0 +1,11 @@ +import { Module } from "@nestjs/common"; +import { PrismaModule } from "../prisma/prisma.module"; +import { AttachmentController } from "./attachment.controller"; +import { AttachmentService } from "./attachment.service"; + +@Module({ + imports: [PrismaModule], + controllers: [AttachmentController], + providers: [AttachmentService] +}) +export class AttachmentModule {} diff --git a/apps/api/src/attachment/attachment.service.ts b/apps/api/src/attachment/attachment.service.ts new file mode 100644 index 0000000..cb15f7e --- /dev/null +++ b/apps/api/src/attachment/attachment.service.ts @@ -0,0 +1,335 @@ +import { randomUUID } from "node:crypto"; +import { + Injectable, + InternalServerErrorException, + NotFoundException, + PayloadTooLargeException +} from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { PutObjectCommand, S3Client } from "@aws-sdk/client-s3"; +import { getSignedUrl } from "@aws-sdk/s3-request-presigner"; +import { AttachmentType } from "../../generated/prisma/client"; +import { PrismaService } from "../prisma/prisma.service"; +import { DataEncryptionService } from "../security/data-encryption.service"; +import { CompleteAttachmentDto } from "./dto/complete-attachment.dto"; +import { PresignAttachmentDto } from "./dto/presign-attachment.dto"; + +type QuotaInfo = { + totalBytes: bigint; + usedBytes: bigint; +}; + +export type PresignAttachmentResponse = { + method: "PUT"; + uploadUrl: string; + bucket: string; + objectKey: string; + objectUrl: string; + expiresInSeconds: number; + quota: { + totalBytes: string; + usedBytes: string; + remainingBytes: string; + }; + headers: Record; +}; + +export type AttachmentResponse = { + id: string; + taskId: string | null; + type: AttachmentType; + url: string; + mimeType: string | null; + fileName: string | null; + fileSize: number; + width: number | null; + height: number | null; + durationMs: number | null; + checksum: string | null; + createdAt: string; + updatedAt: string; +}; + +@Injectable() +export class AttachmentService { + private s3Client: S3Client | null = null; + + constructor( + private readonly configService: ConfigService, + private readonly prismaService: PrismaService, + private readonly dataEncryptionService: DataEncryptionService + ) {} + + async presignAttachment( + userId: string, + body: PresignAttachmentDto + ): Promise { + const quotaInfo = await this.getQuotaSnapshot(userId); + this.assertQuotaAvailable(quotaInfo.totalBytes, quotaInfo.usedBytes, body.fileSize); + + if (body.taskId) { + await this.ensureTaskOwnership(userId, body.taskId); + } + + const bucket = this.getDefaultBucket(); + const objectKey = this.generateObjectKey(body.fileName); + const objectUrl = this.resolveObjectUrl(bucket, objectKey); + const expiresInSeconds = this.getPresignExpiresInSeconds(); + const serverSideEncryption = this.getServerSideEncryptionMode(); + + const command = new PutObjectCommand({ + Bucket: bucket, + Key: objectKey, + ContentType: body.mimeType, + ContentLength: body.fileSize, + ServerSideEncryption: serverSideEncryption + }); + + const uploadUrl = await getSignedUrl(this.getS3Client(), command, { + expiresIn: expiresInSeconds + }); + + return { + method: "PUT", + uploadUrl, + bucket, + objectKey, + objectUrl, + expiresInSeconds, + quota: { + totalBytes: quotaInfo.totalBytes.toString(), + usedBytes: quotaInfo.usedBytes.toString(), + remainingBytes: (quotaInfo.totalBytes - quotaInfo.usedBytes).toString() + }, + headers: this.buildUploadHeaders(body.mimeType, serverSideEncryption) + }; + } + + async completeAttachment( + userId: string, + body: CompleteAttachmentDto + ): Promise { + if (body.taskId) { + await this.ensureTaskOwnership(userId, body.taskId); + } + + const bucket = body.bucket ?? this.getDefaultBucket(); + const objectUrl = this.resolveObjectUrl(bucket, body.objectKey); + + const attachment = await this.prismaService.$transaction(async (tx) => { + const quotaInfo = await this.getQuotaSnapshot(userId, tx); + this.assertQuotaAvailable(quotaInfo.totalBytes, quotaInfo.usedBytes, body.fileSize); + + const uploadBytes = BigInt(body.fileSize); + const maxUsedBeforeUpload = quotaInfo.totalBytes - uploadBytes; + const updatedUser = await tx.user.updateMany({ + where: { + id: userId, + usedStorageBytes: { + lte: maxUsedBeforeUpload + } + }, + data: { + usedStorageBytes: { + increment: uploadBytes + } + } + }); + if (updatedUser.count === 0) { + throw new PayloadTooLargeException("存储配额不足"); + } + + return tx.attachment.create({ + data: { + userId, + taskId: body.taskId ?? null, + type: body.type ?? this.resolveAttachmentType(body.mimeType), + url: this.encryptRequiredString(objectUrl), + mimeType: body.mimeType, + fileName: this.encryptNullableString(body.fileName), + fileSize: body.fileSize, + width: body.width ?? null, + height: body.height ?? null, + durationMs: body.durationMs ?? null, + checksum: this.encryptNullableString(body.checksum) + } + }); + }); + + return { + id: attachment.id, + taskId: attachment.taskId, + type: attachment.type, + url: this.readDecryptedString(attachment.url) ?? objectUrl, + mimeType: attachment.mimeType, + fileName: this.readDecryptedString(attachment.fileName), + fileSize: attachment.fileSize, + width: attachment.width, + height: attachment.height, + durationMs: attachment.durationMs, + checksum: this.readDecryptedString(attachment.checksum), + createdAt: attachment.createdAt.toISOString(), + updatedAt: attachment.updatedAt.toISOString() + }; + } + + private getS3Client(): S3Client { + if (this.s3Client) { + return this.s3Client; + } + + const endpoint = this.configService.get("S3_ENDPOINT") ?? "http://127.0.0.1:9000"; + const region = this.configService.get("S3_REGION") ?? "us-east-1"; + const forcePathStyle = + this.configService.get("S3_FORCE_PATH_STYLE")?.toLowerCase() !== "false"; + + this.s3Client = new S3Client({ + endpoint, + region, + forcePathStyle, + credentials: { + accessKeyId: this.configService.get("S3_ACCESS_KEY_ID") ?? "minioadmin", + secretAccessKey: this.configService.get("S3_SECRET_ACCESS_KEY") ?? "minioadmin" + } + }); + + return this.s3Client; + } + + private getDefaultBucket(): string { + return this.configService.get("S3_BUCKET") ?? "todolist"; + } + + private getPresignExpiresInSeconds(): number { + const configValue = Number(this.configService.get("S3_PRESIGN_EXPIRES_SECONDS") ?? 900); + if (!Number.isFinite(configValue) || configValue <= 0) { + return 900; + } + + return Math.min(configValue, 604800); + } + + private generateObjectKey(fileName: string): string { + const datePrefix = new Date().toISOString().slice(0, 10); + return `attachments/${datePrefix}/${randomUUID()}${this.extractFileExtension(fileName)}`; + } + + private resolveObjectUrl(bucket: string, objectKey: string): string { + const publicBaseUrl = this.configService.get("S3_PUBLIC_BASE_URL"); + if (publicBaseUrl) { + return `${publicBaseUrl.replace(/\/+$/, "")}/${bucket}/${objectKey}`; + } + + const endpoint = this.configService.get("S3_ENDPOINT") ?? "http://127.0.0.1:9000"; + return `${endpoint.replace(/\/+$/, "")}/${bucket}/${objectKey}`; + } + + private resolveAttachmentType(mimeType: string): AttachmentType { + if (mimeType.startsWith("image/")) { + return AttachmentType.IMAGE; + } + + if (mimeType.startsWith("video/")) { + return AttachmentType.VIDEO; + } + + return AttachmentType.FILE; + } + + private buildUploadHeaders( + mimeType: string, + serverSideEncryption: "AES256" | undefined + ): Record { + const headers: Record = { + "Content-Type": mimeType + }; + + if (serverSideEncryption) { + headers["x-amz-server-side-encryption"] = serverSideEncryption; + } + + return headers; + } + + private getServerSideEncryptionMode(): "AES256" | undefined { + const configValue = + this.configService.get("S3_SERVER_SIDE_ENCRYPTION")?.trim().toUpperCase() ?? "AES256"; + + if (configValue === "NONE" || configValue === "DISABLED") { + return undefined; + } + + return "AES256"; + } + + private extractFileExtension(fileName: string): string { + const match = /\.[a-zA-Z0-9]{1,16}$/.exec(fileName); + return match?.[0]?.toLowerCase() ?? ""; + } + + private async ensureTaskOwnership(userId: string, taskId: string): Promise { + const task = await this.prismaService.task.findFirst({ + where: { + id: taskId, + userId + }, + select: { + id: true + } + }); + + if (!task) { + throw new NotFoundException("任务不存在"); + } + } + + private async getQuotaSnapshot( + userId: string, + tx: Pick = this.prismaService + ): Promise { + const user = await tx.user.findUnique({ + where: { + id: userId + }, + select: { + id: true, + defaultStorageQuotaMb: true, + usedStorageBytes: true + } + }); + + if (!user) { + throw new NotFoundException("用户不存在"); + } + + return { + totalBytes: BigInt(user.defaultStorageQuotaMb) * 1024n * 1024n, + usedBytes: user.usedStorageBytes + }; + } + + private assertQuotaAvailable(totalBytes: bigint, usedBytes: bigint, fileSize: number): void { + const uploadBytes = BigInt(fileSize); + if (uploadBytes > totalBytes || usedBytes + uploadBytes > totalBytes) { + throw new PayloadTooLargeException("存储配额不足"); + } + } + + private encryptRequiredString(value: string): string { + const encryptedValue = this.dataEncryptionService.encryptString(value); + if (!encryptedValue) { + throw new InternalServerErrorException("附件元数据加密失败"); + } + + return encryptedValue; + } + + private encryptNullableString(value: string | null | undefined): string | null | undefined { + return this.dataEncryptionService.encryptString(value); + } + + private readDecryptedString(value: string | null): string | null { + const decryptedValue = this.dataEncryptionService.decryptString(value); + return typeof decryptedValue === "string" ? decryptedValue : null; + } +} diff --git a/apps/api/src/attachment/dto/complete-attachment.dto.ts b/apps/api/src/attachment/dto/complete-attachment.dto.ts new file mode 100644 index 0000000..2318e3b --- /dev/null +++ b/apps/api/src/attachment/dto/complete-attachment.dto.ts @@ -0,0 +1,89 @@ +import { Transform, Type } from "class-transformer"; +import { + IsEnum, + IsInt, + IsOptional, + IsString, + Max, + MaxLength, + Min, + MinLength +} from "class-validator"; +import { AttachmentType } from "../../../generated/prisma/client"; + +function normalizeString(value: unknown): unknown { + if (typeof value !== "string") { + return value; + } + + return value.trim(); +} + +export class CompleteAttachmentDto { + @Transform(({ value }) => normalizeString(value)) + @IsString() + @MinLength(1) + @MaxLength(255) + objectKey!: string; + + @Transform(({ value }) => normalizeString(value)) + @IsOptional() + @IsString() + @MaxLength(100) + bucket?: string; + + @Transform(({ value }) => normalizeString(value)) + @IsString() + @MinLength(1) + @MaxLength(255) + fileName!: string; + + @Transform(({ value }) => normalizeString(value)) + @IsString() + @MinLength(1) + @MaxLength(255) + mimeType!: string; + + @Type(() => Number) + @IsInt() + @Min(1) + @Max(1073741824) + fileSize!: number; + + @IsOptional() + @IsEnum(AttachmentType) + type?: AttachmentType; + + @Transform(({ value }) => normalizeString(value)) + @IsOptional() + @IsString() + @MaxLength(255) + taskId?: string; + + @Transform(({ value }) => normalizeString(value)) + @IsOptional() + @IsString() + @MaxLength(128) + checksum?: string; + + @Type(() => Number) + @IsOptional() + @IsInt() + @Min(1) + @Max(100000) + width?: number; + + @Type(() => Number) + @IsOptional() + @IsInt() + @Min(1) + @Max(100000) + height?: number; + + @Type(() => Number) + @IsOptional() + @IsInt() + @Min(1) + @Max(86400000) + durationMs?: number; +} diff --git a/apps/api/src/attachment/dto/presign-attachment.dto.ts b/apps/api/src/attachment/dto/presign-attachment.dto.ts new file mode 100644 index 0000000..9eda1a9 --- /dev/null +++ b/apps/api/src/attachment/dto/presign-attachment.dto.ts @@ -0,0 +1,35 @@ +import { Transform } from "class-transformer"; +import { IsInt, IsOptional, IsString, Max, MaxLength, Min, MinLength } from "class-validator"; + +function normalizeString(value: unknown): unknown { + if (typeof value !== "string") { + return value; + } + + return value.trim(); +} + +export class PresignAttachmentDto { + @Transform(({ value }) => normalizeString(value)) + @IsString() + @MinLength(1) + @MaxLength(255) + fileName!: string; + + @Transform(({ value }) => normalizeString(value)) + @IsString() + @MinLength(1) + @MaxLength(255) + mimeType!: string; + + @IsInt() + @Min(1) + @Max(1073741824) + fileSize!: number; + + @Transform(({ value }) => normalizeString(value)) + @IsOptional() + @IsString() + @MaxLength(255) + taskId?: string; +} diff --git a/apps/api/src/auth/auth-mail.service.ts b/apps/api/src/auth/auth-mail.service.ts new file mode 100644 index 0000000..b0944bb --- /dev/null +++ b/apps/api/src/auth/auth-mail.service.ts @@ -0,0 +1,131 @@ +import { + Injectable, + InternalServerErrorException, + Logger, + ServiceUnavailableException +} from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { createTransport, type Transporter } from "nodemailer"; + +type MailRuntimeConfig = { + host: string; + port: number; + secure: boolean; + user: string; + pass: string; + fromName: string; + fromAddress: string; +}; + +@Injectable() +export class AuthMailService { + private readonly logger = new Logger(AuthMailService.name); + private cachedConfig: MailRuntimeConfig | null = null; + private transporter: Transporter | null = null; + + constructor(private readonly configService: ConfigService) {} + + async sendLoginCode(email: string, code: string, ttlSeconds: number): Promise { + const config = this.getRuntimeConfig(); + const transporter = this.getTransporter(config); + + try { + await transporter.sendMail({ + from: this.resolveFromField(config), + to: email, + subject: "TodoList 登录验证码", + text: `你的验证码是 ${code},${ttlSeconds} 秒内有效。`, + html: `

你的验证码是 ${code},${ttlSeconds} 秒内有效。

` + }); + } catch (error) { + this.logger.error( + `验证码邮件发送失败: ${email}`, + error instanceof Error ? error.stack : undefined + ); + throw new ServiceUnavailableException("验证码邮件发送失败,请稍后重试"); + } + } + + private getTransporter(config: MailRuntimeConfig): Transporter { + if (this.transporter) { + return this.transporter; + } + + this.transporter = createTransport({ + host: config.host, + port: config.port, + secure: config.secure, + auth: { + user: config.user, + pass: config.pass + } + }); + + return this.transporter; + } + + private getRuntimeConfig(): MailRuntimeConfig { + if (this.cachedConfig) { + return this.cachedConfig; + } + + const host = this.getRequiredString("MAIL_SMTP_HOST"); + const port = this.getRequiredNumber("MAIL_SMTP_PORT"); + const secure = this.getBoolean("MAIL_SMTP_SECURE", port === 465); + const user = this.getRequiredString("MAIL_SMTP_USER"); + const pass = this.getRequiredString("MAIL_SMTP_PASS"); + const fromName = this.configService.get("MAIL_FROM_NAME")?.trim() || "TodoList"; + const fromAddress = this.configService.get("MAIL_FROM_ADDRESS")?.trim() || user; + + const config: MailRuntimeConfig = { + host, + port, + secure, + user, + pass, + fromName, + fromAddress + }; + + this.cachedConfig = config; + return config; + } + + private getRequiredString(key: string): string { + const value = this.configService.get(key)?.trim(); + if (!value) { + throw new InternalServerErrorException(`邮件配置缺失: ${key}`); + } + + return value; + } + + private getRequiredNumber(key: string): number { + const rawValue = this.configService.get(key)?.trim(); + if (!rawValue) { + throw new InternalServerErrorException(`邮件配置缺失: ${key}`); + } + + const parsedValue = Number(rawValue); + if (!Number.isFinite(parsedValue)) { + throw new InternalServerErrorException(`邮件配置格式错误: ${key}`); + } + + return parsedValue; + } + + private getBoolean(key: string, fallback: boolean): boolean { + const rawValue = this.configService.get(key); + if (!rawValue) { + return fallback; + } + + const normalizedValue = rawValue.trim().toLowerCase(); + return normalizedValue === "true" || normalizedValue === "1"; + } + + private resolveFromField(config: MailRuntimeConfig): string { + const sanitizedName = config.fromName.replace(/"/g, ""); + return `"${sanitizedName}" <${config.fromAddress}>`; + } +} diff --git a/apps/api/src/auth/auth.controller.ts b/apps/api/src/auth/auth.controller.ts new file mode 100644 index 0000000..707399a --- /dev/null +++ b/apps/api/src/auth/auth.controller.ts @@ -0,0 +1,120 @@ +import { Body, Controller, Get, Post, Req, UseGuards } from "@nestjs/common"; +import { AuthGuard } from "@nestjs/passport"; +import { AuthService } from "./auth.service"; +import { EmailLoginDto } from "./dto/email-login.dto"; +import { RefreshTokenDto } from "./dto/refresh-token.dto"; +import { SendEmailCodeDto } from "./dto/send-email-code.dto"; +import { TwoFactorEnrollDto } from "./dto/two-factor-enroll.dto"; +import { TwoFactorVerifyDto } from "./dto/two-factor-verify.dto"; + +@Controller("auth") +export class AuthController { + constructor(private readonly authService: AuthService) {} + + @Post("email/send-code") + async sendEmailCode( + @Body() body: SendEmailCodeDto + ): Promise<{ success: boolean; expiresInSeconds: number }> { + return this.authService.sendEmailCode(body.email); + } + + @Post("email/login") + async loginWithEmailCode(@Body() body: EmailLoginDto): Promise<{ + accessToken: string; + tokenType: "Bearer"; + expiresInSeconds: number; + refreshToken: string; + refreshExpiresInSeconds: number; + user: { id: string; email: string }; + }> { + return this.authService.loginWithEmailCode(body.email, body.code); + } + + @Post("token/refresh") + async refreshTokens(@Body() body: RefreshTokenDto): Promise<{ + accessToken: string; + tokenType: "Bearer"; + expiresInSeconds: number; + refreshToken: string; + refreshExpiresInSeconds: number; + user: { id: string; email: string }; + }> { + return this.authService.refreshTokens(body.refreshToken); + } + + @Post("token/revoke") + async revokeRefreshToken(@Body() body: RefreshTokenDto): Promise<{ success: boolean }> { + return this.authService.revokeRefreshToken(body.refreshToken); + } + + @Post("2fa/enroll") + async enrollTwoFactor(@Body() body: TwoFactorEnrollDto): Promise<{ + userId: string; + secret: string; + otpauthUrl: string; + enabled: boolean; + }> { + return this.authService.enrollTwoFactor(body.email); + } + + @Post("2fa/verify") + async verifyTwoFactor( + @Body() body: TwoFactorVerifyDto + ): Promise<{ success: boolean; enabled: boolean }> { + return this.authService.verifyTwoFactor(body.email, body.token); + } + + @Get("oauth/github") + @UseGuards(AuthGuard("github")) + githubLogin(): void {} + + @Get("oauth/github/callback") + @UseGuards(AuthGuard("github")) + githubCallback(@Req() req: { user: unknown }): { + success: boolean; + provider: "github"; + profile: unknown; + } { + return { + success: true, + provider: "github", + profile: req.user + }; + } + + @Get("oauth/qq") + @UseGuards(AuthGuard("qq")) + qqLogin(): void {} + + @Get("oauth/qq/callback") + @UseGuards(AuthGuard("qq")) + qqCallback(@Req() req: { user: unknown }): { + success: boolean; + provider: "qq"; + profile: unknown; + } { + return { + success: true, + provider: "qq", + profile: req.user + }; + } + + @Get("oauth/wechat") + @UseGuards(AuthGuard("wechat")) + wechatLogin(): void {} + + @Get("oauth/wechat/callback") + @UseGuards(AuthGuard("wechat")) + wechatCallback(@Req() req: { user: unknown }): { + success: boolean; + provider: "wechat"; + profile: unknown; + } { + return { + success: true, + provider: "wechat", + profile: req.user + }; + } +} diff --git a/apps/api/src/auth/auth.module.ts b/apps/api/src/auth/auth.module.ts new file mode 100644 index 0000000..ede59b7 --- /dev/null +++ b/apps/api/src/auth/auth.module.ts @@ -0,0 +1,33 @@ +import { Module } from "@nestjs/common"; +import { ConfigModule, ConfigService } from "@nestjs/config"; +import { JwtModule } from "@nestjs/jwt"; +import { PassportModule } from "@nestjs/passport"; +import { AuthController } from "./auth.controller"; +import { AuthMailService } from "./auth-mail.service"; +import { AuthService } from "./auth.service"; +import { GithubStrategy } from "./strategies/github.strategy"; +import { QqStrategy } from "./strategies/qq.strategy"; +import { WechatStrategy } from "./strategies/wechat.strategy"; + +@Module({ + imports: [ + ConfigModule, + PassportModule.register({ session: false }), + JwtModule.registerAsync({ + inject: [ConfigService], + useFactory: (configService: ConfigService) => { + const expiresInSeconds = Number(configService.get("AUTH_ACCESS_EXPIRES_IN_SECONDS") ?? 900); + + return { + secret: configService.get("AUTH_ACCESS_SECRET") ?? "dev-access-secret", + signOptions: { + expiresIn: expiresInSeconds + } + }; + } + }) + ], + controllers: [AuthController], + providers: [AuthService, AuthMailService, GithubStrategy, QqStrategy, WechatStrategy] +}) +export class AuthModule {} diff --git a/apps/api/src/auth/auth.service.ts b/apps/api/src/auth/auth.service.ts new file mode 100644 index 0000000..8554a5c --- /dev/null +++ b/apps/api/src/auth/auth.service.ts @@ -0,0 +1,288 @@ +import { Injectable, UnauthorizedException } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { JwtService } from "@nestjs/jwt"; +import { randomUUID } from "node:crypto"; +import { authenticator } from "@otplib/preset-default"; +import { AuthMailService } from "./auth-mail.service"; +import { PrismaService } from "../prisma/prisma.service"; +import { DataEncryptionService } from "../security/data-encryption.service"; + +type EmailCodeEntry = { + code: string; + expiresAt: number; +}; + +type AuthUser = { + id: string; + email: string; +}; + +type AuthTokenResult = { + accessToken: string; + tokenType: "Bearer"; + expiresInSeconds: number; + refreshToken: string; + refreshExpiresInSeconds: number; + user: AuthUser; +}; + +@Injectable() +export class AuthService { + private readonly emailCodeStore = new Map(); + + constructor( + private readonly configService: ConfigService, + private readonly jwtService: JwtService, + private readonly authMailService: AuthMailService, + private readonly prismaService: PrismaService, + private readonly dataEncryptionService: DataEncryptionService + ) {} + + async sendEmailCode(email: string): Promise<{ success: boolean; expiresInSeconds: number }> { + const ttlSeconds = Number(this.configService.get("AUTH_EMAIL_CODE_TTL_SECONDS") ?? 300); + const code = this.generateCode(); + const expiresAt = Date.now() + ttlSeconds * 1000; + const normalizedEmail = email.toLowerCase(); + + await this.authMailService.sendLoginCode(normalizedEmail, code, ttlSeconds); + this.emailCodeStore.set(normalizedEmail, { code, expiresAt }); + + return { + success: true, + expiresInSeconds: ttlSeconds + }; + } + + async loginWithEmailCode(email: string, code: string): Promise { + const lowerEmail = email.toLowerCase(); + const codeEntry = this.emailCodeStore.get(lowerEmail); + + if (!codeEntry) { + throw new UnauthorizedException("验证码不存在或已失效"); + } + + if (codeEntry.expiresAt < Date.now()) { + this.emailCodeStore.delete(lowerEmail); + throw new UnauthorizedException("验证码已过期"); + } + + if (codeEntry.code !== code) { + throw new UnauthorizedException("验证码错误"); + } + + this.emailCodeStore.delete(lowerEmail); + + const user = await this.getOrCreateUser(lowerEmail); + return this.issueTokens(user); + } + + async refreshTokens(refreshToken: string): Promise { + const entry = await this.prismaService.refreshToken.findUnique({ + where: { + tokenHash: refreshToken + }, + include: { + user: { + select: { + id: true, + email: true + } + } + } + }); + + if (!entry) { + throw new UnauthorizedException("刷新令牌不存在"); + } + + if (entry.revokedAt) { + throw new UnauthorizedException("刷新令牌已注销"); + } + + if (entry.expiresAt.getTime() < Date.now()) { + await this.prismaService.refreshToken.update({ + where: { + id: entry.id + }, + data: { + revokedAt: new Date() + } + }); + throw new UnauthorizedException("刷新令牌已过期"); + } + + await this.prismaService.refreshToken.update({ + where: { + id: entry.id + }, + data: { + revokedAt: new Date() + } + }); + + return this.issueTokens({ + id: entry.user.id, + email: this.readRequiredEmail(entry.user.email) + }); + } + + async revokeRefreshToken(refreshToken: string): Promise<{ success: boolean }> { + await this.prismaService.refreshToken.updateMany({ + where: { + tokenHash: refreshToken, + revokedAt: null + }, + data: { + revokedAt: new Date() + } + }); + + return { success: true }; + } + + async enrollTwoFactor( + email: string + ): Promise<{ userId: string; secret: string; otpauthUrl: string; enabled: boolean }> { + const user = await this.getOrCreateUser(email.toLowerCase()); + const secret = authenticator.generateSecret(); + const issuer = this.configService.get("AUTH_TOTP_ISSUER") ?? "TodoList"; + const otpauthUrl = authenticator.keyuri(user.email, issuer, secret); + + await this.prismaService.userSecurity.upsert({ + where: { + userId: user.id + }, + update: { + twoFactorSecret: secret, + twoFactorEnabled: false + }, + create: { + userId: user.id, + twoFactorSecret: secret, + twoFactorEnabled: false + } + }); + + return { + userId: user.id, + secret, + otpauthUrl, + enabled: false + }; + } + + async verifyTwoFactor( + email: string, + token: string + ): Promise<{ success: boolean; enabled: boolean }> { + const user = await this.getOrCreateUser(email.toLowerCase()); + const security = await this.prismaService.userSecurity.findUnique({ + where: { + userId: user.id + }, + select: { + twoFactorSecret: true + } + }); + + if (!security?.twoFactorSecret) { + throw new UnauthorizedException("尚未启用两步验证"); + } + + const valid = authenticator.check(token, security.twoFactorSecret); + if (!valid) { + throw new UnauthorizedException("两步验证码错误"); + } + + await this.prismaService.userSecurity.update({ + where: { + userId: user.id + }, + data: { + twoFactorEnabled: true + } + }); + + return { + success: true, + enabled: true + }; + } + + private async getOrCreateUser(email: string): Promise { + const normalizedEmail = email.toLowerCase(); + const emailHash = this.dataEncryptionService.createLookupHash("user.email", normalizedEmail); + const user = await this.prismaService.user.upsert({ + where: { + emailHash + }, + update: {}, + create: { + email: this.encryptRequiredString(normalizedEmail), + emailHash + }, + select: { + id: true, + email: true + } + }); + + return { + id: user.id, + email: this.readRequiredEmail(user.email) + }; + } + + private generateCode(): string { + return String(Math.floor(100000 + Math.random() * 900000)); + } + + private async issueTokens(user: AuthUser): Promise { + const accessExpiresInSeconds = Number( + this.configService.get("AUTH_ACCESS_EXPIRES_IN_SECONDS") ?? 900 + ); + const refreshExpiresInSeconds = Number( + this.configService.get("AUTH_REFRESH_EXPIRES_IN_SECONDS") ?? 2592000 + ); + const accessToken = await this.jwtService.signAsync({ + sub: user.id, + email: user.email + }); + const refreshToken = `${randomUUID()}${randomUUID()}`; + + await this.prismaService.refreshToken.create({ + data: { + userId: user.id, + tokenHash: refreshToken, + expiresAt: new Date(Date.now() + refreshExpiresInSeconds * 1000) + } + }); + + return { + accessToken, + tokenType: "Bearer", + expiresInSeconds: accessExpiresInSeconds, + refreshToken, + refreshExpiresInSeconds, + user + }; + } + + private encryptRequiredString(value: string): string { + const encryptedValue = this.dataEncryptionService.encryptString(value); + if (!encryptedValue) { + throw new UnauthorizedException("用户敏感字段加密失败"); + } + + return encryptedValue; + } + + private readRequiredEmail(value: string): string { + const decryptedValue = this.dataEncryptionService.decryptString(value); + if (typeof decryptedValue !== "string" || decryptedValue.length === 0) { + throw new UnauthorizedException("用户邮箱解密失败"); + } + + return decryptedValue; + } +} diff --git a/apps/api/src/auth/dto/email-login.dto.ts b/apps/api/src/auth/dto/email-login.dto.ts new file mode 100644 index 0000000..32b1511 --- /dev/null +++ b/apps/api/src/auth/dto/email-login.dto.ts @@ -0,0 +1,11 @@ +import { IsEmail, IsString, Length, Matches } from "class-validator"; + +export class EmailLoginDto { + @IsEmail() + email!: string; + + @IsString() + @Length(6, 6) + @Matches(/^\d{6}$/) + code!: string; +} diff --git a/apps/api/src/auth/dto/refresh-token.dto.ts b/apps/api/src/auth/dto/refresh-token.dto.ts new file mode 100644 index 0000000..2d91b3c --- /dev/null +++ b/apps/api/src/auth/dto/refresh-token.dto.ts @@ -0,0 +1,7 @@ +import { IsString, MinLength } from "class-validator"; + +export class RefreshTokenDto { + @IsString() + @MinLength(20) + refreshToken!: string; +} diff --git a/apps/api/src/auth/dto/send-email-code.dto.ts b/apps/api/src/auth/dto/send-email-code.dto.ts new file mode 100644 index 0000000..998b618 --- /dev/null +++ b/apps/api/src/auth/dto/send-email-code.dto.ts @@ -0,0 +1,6 @@ +import { IsEmail } from "class-validator"; + +export class SendEmailCodeDto { + @IsEmail() + email!: string; +} diff --git a/apps/api/src/auth/dto/two-factor-enroll.dto.ts b/apps/api/src/auth/dto/two-factor-enroll.dto.ts new file mode 100644 index 0000000..2dea4ef --- /dev/null +++ b/apps/api/src/auth/dto/two-factor-enroll.dto.ts @@ -0,0 +1,6 @@ +import { IsEmail } from "class-validator"; + +export class TwoFactorEnrollDto { + @IsEmail() + email!: string; +} diff --git a/apps/api/src/auth/dto/two-factor-verify.dto.ts b/apps/api/src/auth/dto/two-factor-verify.dto.ts new file mode 100644 index 0000000..f9317bd --- /dev/null +++ b/apps/api/src/auth/dto/two-factor-verify.dto.ts @@ -0,0 +1,11 @@ +import { IsEmail, IsString, Length, Matches } from "class-validator"; + +export class TwoFactorVerifyDto { + @IsEmail() + email!: string; + + @IsString() + @Length(6, 6) + @Matches(/^\d{6}$/) + token!: string; +} diff --git a/apps/api/src/auth/strategies/github.strategy.ts b/apps/api/src/auth/strategies/github.strategy.ts new file mode 100644 index 0000000..3b3133d --- /dev/null +++ b/apps/api/src/auth/strategies/github.strategy.ts @@ -0,0 +1,32 @@ +import { Injectable } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { PassportStrategy } from "@nestjs/passport"; +import { Profile, Strategy } from "passport-github2"; + +@Injectable() +export class GithubStrategy extends PassportStrategy(Strategy, "github") { + constructor(configService: ConfigService) { + super({ + clientID: configService.get("OAUTH_GITHUB_CLIENT_ID") ?? "github-client-id", + clientSecret: + configService.get("OAUTH_GITHUB_CLIENT_SECRET") ?? "github-client-secret", + callbackURL: + configService.get("OAUTH_GITHUB_CALLBACK_URL") ?? + "http://localhost:3000/auth/oauth/github/callback", + scope: ["user:email"] + }); + } + + async validate( + accessToken: string, + refreshToken: string, + profile: Profile + ): Promise<{ provider: "github"; accessToken: string; refreshToken: string; profile: Profile }> { + return { + provider: "github", + accessToken, + refreshToken, + profile + }; + } +} diff --git a/apps/api/src/auth/strategies/qq.strategy.ts b/apps/api/src/auth/strategies/qq.strategy.ts new file mode 100644 index 0000000..5191151 --- /dev/null +++ b/apps/api/src/auth/strategies/qq.strategy.ts @@ -0,0 +1,33 @@ +import { Injectable } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { PassportStrategy } from "@nestjs/passport"; +import { Strategy } from "passport-oauth2"; + +@Injectable() +export class QqStrategy extends PassportStrategy(Strategy, "qq") { + constructor(configService: ConfigService) { + super({ + authorizationURL: + configService.get("OAUTH_QQ_AUTH_URL") ?? "https://graph.qq.com/oauth2.0/authorize", + tokenURL: + configService.get("OAUTH_QQ_TOKEN_URL") ?? "https://graph.qq.com/oauth2.0/token", + clientID: configService.get("OAUTH_QQ_CLIENT_ID") ?? "qq-client-id", + clientSecret: configService.get("OAUTH_QQ_CLIENT_SECRET") ?? "qq-client-secret", + callbackURL: + configService.get("OAUTH_QQ_CALLBACK_URL") ?? + "http://localhost:3000/auth/oauth/qq/callback", + scope: ["get_user_info"] + }); + } + + async validate( + accessToken: string, + refreshToken: string + ): Promise<{ provider: "qq"; accessToken: string; refreshToken: string }> { + return { + provider: "qq", + accessToken, + refreshToken + }; + } +} diff --git a/apps/api/src/auth/strategies/wechat.strategy.ts b/apps/api/src/auth/strategies/wechat.strategy.ts new file mode 100644 index 0000000..1e4343b --- /dev/null +++ b/apps/api/src/auth/strategies/wechat.strategy.ts @@ -0,0 +1,36 @@ +import { Injectable } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { PassportStrategy } from "@nestjs/passport"; +import { Strategy } from "passport-oauth2"; + +@Injectable() +export class WechatStrategy extends PassportStrategy(Strategy, "wechat") { + constructor(configService: ConfigService) { + super({ + authorizationURL: + configService.get("OAUTH_WECHAT_AUTH_URL") ?? + "https://open.weixin.qq.com/connect/qrconnect", + tokenURL: + configService.get("OAUTH_WECHAT_TOKEN_URL") ?? + "https://api.weixin.qq.com/sns/oauth2/access_token", + clientID: configService.get("OAUTH_WECHAT_CLIENT_ID") ?? "wechat-client-id", + clientSecret: + configService.get("OAUTH_WECHAT_CLIENT_SECRET") ?? "wechat-client-secret", + callbackURL: + configService.get("OAUTH_WECHAT_CALLBACK_URL") ?? + "http://localhost:3000/auth/oauth/wechat/callback", + scope: ["snsapi_login"] + }); + } + + async validate( + accessToken: string, + refreshToken: string + ): Promise<{ provider: "wechat"; accessToken: string; refreshToken: string }> { + return { + provider: "wechat", + accessToken, + refreshToken + }; + } +} diff --git a/apps/api/src/main.ts b/apps/api/src/main.ts new file mode 100644 index 0000000..6a4f0ca --- /dev/null +++ b/apps/api/src/main.ts @@ -0,0 +1,31 @@ +import "reflect-metadata"; +import { ValidationPipe } from "@nestjs/common"; +import { NestFactory } from "@nestjs/core"; +import type { NestExpressApplication } from "@nestjs/platform-express"; +import { AppModule } from "./app.module"; + +async function bootstrap(): Promise { + const app = await NestFactory.create(AppModule); + const bodyLimit = process.env.API_BODY_LIMIT ?? "8mb"; + + app.useBodyParser("json", { limit: bodyLimit }); + app.useBodyParser("urlencoded", { + extended: true, + limit: bodyLimit + }); + app.enableCors({ + origin: true, + credentials: true + }); + app.useGlobalPipes( + new ValidationPipe({ + transform: true, + whitelist: true, + forbidNonWhitelisted: true + }) + ); + + await app.listen(3000); +} + +void bootstrap(); diff --git a/apps/api/src/prisma/prisma.module.ts b/apps/api/src/prisma/prisma.module.ts new file mode 100644 index 0000000..7a94e73 --- /dev/null +++ b/apps/api/src/prisma/prisma.module.ts @@ -0,0 +1,9 @@ +import { Global, Module } from "@nestjs/common"; +import { PrismaService } from "./prisma.service"; + +@Global() +@Module({ + providers: [PrismaService], + exports: [PrismaService] +}) +export class PrismaModule {} diff --git a/apps/api/src/prisma/prisma.service.ts b/apps/api/src/prisma/prisma.service.ts new file mode 100644 index 0000000..27e6013 --- /dev/null +++ b/apps/api/src/prisma/prisma.service.ts @@ -0,0 +1,28 @@ +import { Injectable, OnModuleDestroy, OnModuleInit } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { PrismaPg } from "@prisma/adapter-pg"; +import { PrismaClient } from "../../generated/prisma/client"; + +@Injectable() +export class PrismaService extends PrismaClient implements OnModuleInit, OnModuleDestroy { + constructor(configService: ConfigService) { + const connectionString = configService.get("DATABASE_URL"); + if (!connectionString) { + throw new Error("缺少数据库连接配置 DATABASE_URL"); + } + + super({ + adapter: new PrismaPg({ + connectionString + }) + }); + } + + async onModuleInit(): Promise { + await this.$connect(); + } + + async onModuleDestroy(): Promise { + await this.$disconnect(); + } +} diff --git a/apps/api/src/security/data-encryption.service.ts b/apps/api/src/security/data-encryption.service.ts new file mode 100644 index 0000000..ece7ceb --- /dev/null +++ b/apps/api/src/security/data-encryption.service.ts @@ -0,0 +1,155 @@ +import { Injectable, InternalServerErrorException } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { Prisma } from "../../generated/prisma/client"; +import { createCipheriv, createDecipheriv, createHash, createHmac, randomBytes } from "node:crypto"; + +const ENCRYPTION_PREFIX = "encv1"; +const ENCRYPTION_ALGORITHM = "aes-256-gcm"; +const ENCRYPTION_IV_LENGTH = 12; + +@Injectable() +export class DataEncryptionService { + constructor(private readonly configService: ConfigService) {} + + isConfigured(): boolean { + return Boolean(this.configService.get("DATA_ENCRYPTION_SECRET")); + } + + isEncryptedString(value: string): boolean { + return value.startsWith(`${ENCRYPTION_PREFIX}:`); + } + + encryptString(value: string | null | undefined): string | null | undefined { + if (value === undefined) { + return undefined; + } + + if (value === null) { + return null; + } + + const key = this.resolveKey(); + const iv = randomBytes(ENCRYPTION_IV_LENGTH); + const cipher = createCipheriv(ENCRYPTION_ALGORITHM, key, iv); + const encrypted = Buffer.concat([cipher.update(value, "utf8"), cipher.final()]); + const authTag = cipher.getAuthTag(); + + return [ + ENCRYPTION_PREFIX, + iv.toString("base64url"), + authTag.toString("base64url"), + encrypted.toString("base64url") + ].join(":"); + } + + decryptString(value: string | null | undefined): string | null | undefined { + if (value === undefined) { + return undefined; + } + + if (value === null || !this.isEncryptedPayload(value)) { + return value; + } + + const [prefix, ivText, authTagText, encryptedText] = value.split(":"); + if (prefix !== ENCRYPTION_PREFIX || !ivText || !authTagText || encryptedText === undefined) { + throw new InternalServerErrorException("加密数据格式无效"); + } + + try { + const key = this.resolveKey(); + const decipher = createDecipheriv( + ENCRYPTION_ALGORITHM, + key, + Buffer.from(ivText, "base64url") + ); + decipher.setAuthTag(Buffer.from(authTagText, "base64url")); + const decrypted = Buffer.concat([ + decipher.update(Buffer.from(encryptedText, "base64url")), + decipher.final() + ]); + + return decrypted.toString("utf8"); + } catch { + throw new InternalServerErrorException("加密数据解密失败"); + } + } + + encryptJson( + value: Prisma.InputJsonValue | null | undefined + ): Prisma.InputJsonValue | null | undefined { + if (value === undefined) { + return undefined; + } + + if (value === null) { + return null; + } + + return this.encryptString(JSON.stringify(value)); + } + + decryptJson(value: Prisma.JsonValue | null): Prisma.JsonValue | null { + if (value === null) { + return null; + } + + if (typeof value !== "string" || !this.isEncryptedPayload(value)) { + return value; + } + + const decrypted = this.decryptString(value); + if (typeof decrypted !== "string") { + throw new InternalServerErrorException("加密数据解密失败"); + } + + try { + return JSON.parse(decrypted) as Prisma.JsonValue; + } catch { + throw new InternalServerErrorException("加密 JSON 数据损坏"); + } + } + + decryptPayload(value: Prisma.JsonValue | null): string | null { + if (value === null) { + return null; + } + + if (typeof value === "string") { + return this.decryptString(value) ?? null; + } + + return JSON.stringify(value); + } + + createLookupHash(scope: string, value: string): string { + const normalizedScope = scope.trim().toLowerCase(); + if (!normalizedScope) { + throw new InternalServerErrorException("缺少盲索引作用域"); + } + + const secret = this.configService.get("DATA_ENCRYPTION_SECRET"); + if (!secret) { + throw new InternalServerErrorException("服务端未配置 DATA_ENCRYPTION_SECRET,无法生成盲索引"); + } + + return createHmac("sha256", `lookup:${normalizedScope}:${secret}`) + .update(value, "utf8") + .digest("hex"); + } + + private isEncryptedPayload(value: string): boolean { + return this.isEncryptedString(value); + } + + private resolveKey(): Buffer { + const secret = this.configService.get("DATA_ENCRYPTION_SECRET"); + if (!secret) { + throw new InternalServerErrorException( + "服务端未配置 DATA_ENCRYPTION_SECRET,无法写入加密数据" + ); + } + + return createHash("sha256").update(secret, "utf8").digest(); + } +} diff --git a/apps/api/src/security/security.module.ts b/apps/api/src/security/security.module.ts new file mode 100644 index 0000000..8373141 --- /dev/null +++ b/apps/api/src/security/security.module.ts @@ -0,0 +1,9 @@ +import { Global, Module } from "@nestjs/common"; +import { DataEncryptionService } from "./data-encryption.service"; + +@Global() +@Module({ + providers: [DataEncryptionService], + exports: [DataEncryptionService] +}) +export class SecurityModule {} diff --git a/apps/api/src/sync/dto/sync-pull.dto.ts b/apps/api/src/sync/dto/sync-pull.dto.ts new file mode 100644 index 0000000..c5c99fb --- /dev/null +++ b/apps/api/src/sync/dto/sync-pull.dto.ts @@ -0,0 +1,16 @@ +import { Type } from "class-transformer"; +import { IsInt, IsOptional, IsString, Max, MaxLength, Min } from "class-validator"; + +export class SyncPullQueryDto { + @IsOptional() + @IsString() + @MaxLength(512) + cursor?: string; + + @Type(() => Number) + @IsOptional() + @IsInt() + @Min(1) + @Max(200) + limit?: number; +} diff --git a/apps/api/src/sync/dto/sync-push.dto.ts b/apps/api/src/sync/dto/sync-push.dto.ts new file mode 100644 index 0000000..2e43da4 --- /dev/null +++ b/apps/api/src/sync/dto/sync-push.dto.ts @@ -0,0 +1,62 @@ +import { Type } from "class-transformer"; +import { + ArrayMaxSize, + ArrayMinSize, + IsArray, + IsEnum, + IsInt, + IsOptional, + IsString, + MaxLength, + Min, + ValidateNested +} from "class-validator"; + +export enum SyncEntityTypeDto { + TASK = "TASK" +} + +export enum SyncActionTypeDto { + CREATE = "CREATE", + UPDATE = "UPDATE", + DELETE = "DELETE" +} + +export class SyncPushOperationDto { + @IsString() + @MaxLength(64) + opId!: string; + + @IsString() + @MaxLength(64) + entityId!: string; + + @IsEnum(SyncEntityTypeDto) + entityType!: SyncEntityTypeDto; + + @IsEnum(SyncActionTypeDto) + action!: SyncActionTypeDto; + + @IsOptional() + @IsString() + @MaxLength(5000000) + payload?: string; + + @Type(() => Number) + @IsInt() + @Min(0) + clientTs!: number; + + @IsString() + @MaxLength(128) + deviceId!: string; +} + +export class SyncPushDto { + @IsArray() + @ArrayMinSize(1) + @ArrayMaxSize(200) + @ValidateNested({ each: true }) + @Type(() => SyncPushOperationDto) + operations!: SyncPushOperationDto[]; +} diff --git a/apps/api/src/sync/sync.controller.ts b/apps/api/src/sync/sync.controller.ts new file mode 100644 index 0000000..72ccfac --- /dev/null +++ b/apps/api/src/sync/sync.controller.ts @@ -0,0 +1,34 @@ +import { Body, Controller, Get, Headers, Post, Query, UnauthorizedException } from "@nestjs/common"; +import { SyncPullQueryDto } from "./dto/sync-pull.dto"; +import { SyncPushDto } from "./dto/sync-push.dto"; +import { SyncPullResponse, SyncPushResponse, SyncService } from "./sync.service"; + +@Controller("sync") +export class SyncController { + constructor(private readonly syncService: SyncService) {} + + @Get("pull") + async pullOperations( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Query() query: SyncPullQueryDto + ): Promise { + return this.syncService.pullOperations(this.resolveUserId(userIdHeader), query); + } + + @Post("push") + async pushOperations( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: SyncPushDto + ): Promise { + return this.syncService.pushOperations(this.resolveUserId(userIdHeader), body); + } + + private resolveUserId(userIdHeader: string | string[] | undefined): string { + const userId = Array.isArray(userIdHeader) ? userIdHeader[0] : userIdHeader; + if (!userId) { + throw new UnauthorizedException("缺少用户上下文"); + } + + return userId; + } +} diff --git a/apps/api/src/sync/sync.module.ts b/apps/api/src/sync/sync.module.ts new file mode 100644 index 0000000..65f1492 --- /dev/null +++ b/apps/api/src/sync/sync.module.ts @@ -0,0 +1,11 @@ +import { Module } from "@nestjs/common"; +import { PrismaModule } from "../prisma/prisma.module"; +import { SyncController } from "./sync.controller"; +import { SyncService } from "./sync.service"; + +@Module({ + imports: [PrismaModule], + controllers: [SyncController], + providers: [SyncService] +}) +export class SyncModule {} diff --git a/apps/api/src/sync/sync.service.ts b/apps/api/src/sync/sync.service.ts new file mode 100644 index 0000000..cfd0f49 --- /dev/null +++ b/apps/api/src/sync/sync.service.ts @@ -0,0 +1,309 @@ +import { BadRequestException, Injectable } from "@nestjs/common"; +import { Prisma } from "../../generated/prisma/client"; +import { PrismaService } from "../prisma/prisma.service"; +import { DataEncryptionService } from "../security/data-encryption.service"; +import { SyncPullQueryDto } from "./dto/sync-pull.dto"; +import { SyncPushDto, SyncPushOperationDto } from "./dto/sync-push.dto"; + +export type SyncPushItemStatus = "accepted" | "duplicate" | "failed"; + +export type SyncPushItemResult = { + opId: string; + status: SyncPushItemStatus; + serverTs: string | null; + reason: string | null; +}; + +export type SyncPushResponse = { + acceptedCount: number; + duplicateCount: number; + failedCount: number; + results: SyncPushItemResult[]; +}; + +type ExistingOperationRecord = { + opId: string; + serverTs: Date; +}; + +type SyncPullCursorState = { + serverTs: string; + opId: string; +}; + +type SyncPullOperationRecord = { + opId: string; + entityId: string; + entityType: string; + action: string; + payload: Prisma.JsonValue | null; + clientTs: Date; + deviceId: string; + serverTs: Date; +}; + +export type SyncPullItem = { + opId: string; + entityId: string; + entityType: string; + action: string; + payload: string | null; + clientTs: number; + deviceId: string; + serverTs: string; +}; + +export type SyncPullResponse = { + items: SyncPullItem[]; + nextCursor: string | null; + hasMore: boolean; +}; + +@Injectable() +export class SyncService { + constructor( + private readonly prismaService: PrismaService, + private readonly dataEncryptionService: DataEncryptionService + ) {} + + async pullOperations(userId: string, query: SyncPullQueryDto): Promise { + const limit = query.limit ?? 100; + const cursor = this.parseCursor(query.cursor); + + const operations = (await this.prismaService.syncOperation.findMany({ + where: this.buildPullWhereInput(userId, cursor), + orderBy: [{ serverTs: "asc" }, { opId: "asc" }], + take: limit + 1, + select: { + opId: true, + entityId: true, + entityType: true, + action: true, + payload: true, + clientTs: true, + deviceId: true, + serverTs: true + } + })) as SyncPullOperationRecord[]; + + const hasMore = operations.length > limit; + const pageItems = hasMore ? operations.slice(0, limit) : operations; + const lastOperation = pageItems.at(-1); + + return { + items: pageItems.map((operation) => this.serializePullItem(operation)), + nextCursor: lastOperation + ? this.encodeCursor({ + serverTs: lastOperation.serverTs.toISOString(), + opId: lastOperation.opId + }) + : (query.cursor ?? null), + hasMore + }; + } + + async pushOperations(userId: string, body: SyncPushDto): Promise { + const existingOperations = await this.loadExistingOperations(userId, body.operations); + const results: SyncPushItemResult[] = []; + const seenOperationIds = new Set(); + const acceptedOperationServerTs = new Map(); + + for (const operation of body.operations) { + if (seenOperationIds.has(operation.opId)) { + results.push({ + opId: operation.opId, + status: "duplicate", + serverTs: acceptedOperationServerTs.get(operation.opId) ?? null, + reason: "same_batch_duplicate" + }); + continue; + } + + seenOperationIds.add(operation.opId); + + const existingOperation = existingOperations.get(operation.opId); + if (existingOperation) { + results.push({ + opId: operation.opId, + status: "duplicate", + serverTs: existingOperation.serverTs.toISOString(), + reason: "already_synced" + }); + continue; + } + + try { + const createdOperation = await this.prismaService.syncOperation.create({ + data: { + opId: operation.opId, + userId, + deviceId: operation.deviceId, + entityType: operation.entityType, + entityId: operation.entityId, + action: operation.action, + payload: this.dataEncryptionService.encryptString(operation.payload) ?? undefined, + clientTs: new Date(operation.clientTs) + }, + select: { + opId: true, + serverTs: true + } + }); + + const serverTs = createdOperation.serverTs.toISOString(); + acceptedOperationServerTs.set(createdOperation.opId, serverTs); + results.push({ + opId: createdOperation.opId, + status: "accepted", + serverTs, + reason: null + }); + } catch (error) { + if (this.isDuplicateOpIdError(error)) { + results.push({ + opId: operation.opId, + status: "duplicate", + serverTs: null, + reason: "already_synced" + }); + continue; + } + + results.push({ + opId: operation.opId, + status: "failed", + serverTs: null, + reason: "persist_failed" + }); + } + } + + return { + acceptedCount: results.filter((item) => item.status === "accepted").length, + duplicateCount: results.filter((item) => item.status === "duplicate").length, + failedCount: results.filter((item) => item.status === "failed").length, + results + }; + } + + private async loadExistingOperations( + userId: string, + operations: SyncPushOperationDto[] + ): Promise> { + const opIds = Array.from(new Set(operations.map((operation) => operation.opId))); + + const existingOperations = (await this.prismaService.syncOperation.findMany({ + where: { + userId, + opId: { + in: opIds + } + }, + select: { + opId: true, + serverTs: true + } + })) as ExistingOperationRecord[]; + + return new Map( + existingOperations.map((operation): [string, ExistingOperationRecord] => [ + operation.opId, + operation + ]) + ); + } + + private buildPullWhereInput( + userId: string, + cursor: SyncPullCursorState | null + ): Prisma.SyncOperationWhereInput { + if (!cursor) { + return { userId }; + } + + const cursorDate = new Date(cursor.serverTs); + + return { + userId, + // 同一毫秒内可能有多条操作,必须使用 opId 作为二级游标来保证稳定分页。 + OR: [ + { + serverTs: { + gt: cursorDate + } + }, + { + serverTs: cursorDate, + opId: { + gt: cursor.opId + } + } + ] + }; + } + + private serializePullItem(operation: SyncPullOperationRecord): SyncPullItem { + return { + opId: operation.opId, + entityId: operation.entityId, + entityType: operation.entityType, + action: operation.action, + payload: this.serializePayload(operation.payload), + clientTs: operation.clientTs.getTime(), + deviceId: operation.deviceId, + serverTs: operation.serverTs.toISOString() + }; + } + + private serializePayload(payload: Prisma.JsonValue | null): string | null { + return this.dataEncryptionService.decryptPayload(payload); + } + + private parseCursor(cursor: string | undefined): SyncPullCursorState | null { + if (!cursor) { + return null; + } + + let decodedCursor: unknown; + try { + decodedCursor = JSON.parse(Buffer.from(cursor, "base64url").toString("utf8")); + } catch { + throw new BadRequestException("Invalid sync cursor"); + } + + if (typeof decodedCursor !== "object" || decodedCursor === null) { + throw new BadRequestException("Invalid sync cursor"); + } + + const cursorRecord = decodedCursor as { + serverTs?: unknown; + opId?: unknown; + }; + + if ( + typeof cursorRecord.serverTs !== "string" || + typeof cursorRecord.opId !== "string" || + Number.isNaN(Date.parse(cursorRecord.serverTs)) || + cursorRecord.opId.trim().length === 0 + ) { + throw new BadRequestException("Invalid sync cursor"); + } + + return { + serverTs: cursorRecord.serverTs, + opId: cursorRecord.opId + }; + } + + private encodeCursor(cursor: SyncPullCursorState): string { + return Buffer.from(JSON.stringify(cursor), "utf8").toString("base64url"); + } + + private isDuplicateOpIdError(error: unknown): boolean { + if (!(error instanceof Prisma.PrismaClientKnownRequestError)) { + return false; + } + + return error.code === "P2002"; + } +} diff --git a/apps/api/src/task/dto/create-task.dto.ts b/apps/api/src/task/dto/create-task.dto.ts new file mode 100644 index 0000000..69c2000 --- /dev/null +++ b/apps/api/src/task/dto/create-task.dto.ts @@ -0,0 +1,64 @@ +import { Transform } from "class-transformer"; +import { + IsArray, + IsDateString, + IsEnum, + IsObject, + IsOptional, + IsString, + MaxLength, + MinLength +} from "class-validator"; +import { TaskPriority, TaskStatus } from "../../../generated/prisma/client"; + +function normalizeString(value: unknown): unknown { + if (typeof value !== "string") { + return value; + } + + return value.trim(); +} + +export class CreateTaskDto { + @Transform(({ value }) => normalizeString(value)) + @IsString() + @MinLength(1) + @MaxLength(120) + title!: string; + + @IsOptional() + @IsObject() + contentJson?: Record; + + @Transform(({ value }) => normalizeString(value)) + @IsOptional() + @IsString() + @MaxLength(20000) + contentText?: string; + + @IsOptional() + @IsEnum(TaskPriority) + priority?: TaskPriority; + + @IsOptional() + @IsEnum(TaskStatus) + status?: TaskStatus; + + @IsOptional() + @IsDateString() + ddl?: string; + + @Transform(({ value }) => { + if (!Array.isArray(value)) { + return value; + } + + return value.map((item) => normalizeString(item)); + }) + @IsOptional() + @IsArray() + @IsString({ each: true }) + @MinLength(1, { each: true }) + @MaxLength(30, { each: true }) + tagNames?: string[]; +} diff --git a/apps/api/src/task/dto/list-tasks-query.dto.ts b/apps/api/src/task/dto/list-tasks-query.dto.ts new file mode 100644 index 0000000..baa5afb --- /dev/null +++ b/apps/api/src/task/dto/list-tasks-query.dto.ts @@ -0,0 +1,92 @@ +import { Transform, Type } from "class-transformer"; +import { IsArray, IsEnum, IsInt, IsOptional, IsString, Max, MaxLength, Min } from "class-validator"; +import { TaskPriority, TaskStatus } from "../../../generated/prisma/client"; + +export enum TaskSortBy { + CREATED_AT = "createdAt", + UPDATED_AT = "updatedAt", + DDL = "ddl" +} + +export enum TaskSortOrder { + ASC = "asc", + DESC = "desc" +} + +function normalizeString(value: unknown): string | undefined { + if (typeof value !== "string") { + return undefined; + } + + const normalized = value.trim(); + if (!normalized) { + return undefined; + } + + return normalized; +} + +export class ListTasksQueryDto { + @IsOptional() + @IsEnum(TaskStatus) + status?: TaskStatus; + + @IsOptional() + @IsEnum(TaskPriority) + priority?: TaskPriority; + + @Transform(({ value }) => { + if (value === undefined || value === null || value === "") { + return undefined; + } + + if (Array.isArray(value)) { + const normalized = value + .map((item) => normalizeString(item)) + .filter((item): item is string => item !== undefined); + return normalized.length > 0 ? normalized : undefined; + } + + if (typeof value === "string") { + const normalized = value + .split(",") + .map((item) => normalizeString(item)) + .filter((item): item is string => item !== undefined); + return normalized.length > 0 ? normalized : undefined; + } + + return undefined; + }) + @IsOptional() + @IsArray() + @IsString({ each: true }) + @MaxLength(30, { each: true }) + tags?: string[]; + + @Transform(({ value }) => normalizeString(value)) + @IsOptional() + @IsString() + @MaxLength(120) + keyword?: string; + + @Type(() => Number) + @IsOptional() + @IsInt() + @Min(1) + page?: number; + + @Type(() => Number) + @IsOptional() + @IsInt() + @Min(1) + @Max(100) + pageSize?: number; + + @IsOptional() + @IsEnum(TaskSortBy) + sortBy?: TaskSortBy; + + @IsOptional() + @IsEnum(TaskSortOrder) + sortOrder?: TaskSortOrder; +} diff --git a/apps/api/src/task/dto/update-task.dto.ts b/apps/api/src/task/dto/update-task.dto.ts new file mode 100644 index 0000000..23b8894 --- /dev/null +++ b/apps/api/src/task/dto/update-task.dto.ts @@ -0,0 +1,65 @@ +import { Transform } from "class-transformer"; +import { + IsArray, + IsDateString, + IsEnum, + IsObject, + IsOptional, + IsString, + MaxLength, + MinLength +} from "class-validator"; +import { TaskPriority, TaskStatus } from "../../../generated/prisma/client"; + +function normalizeString(value: unknown): unknown { + if (typeof value !== "string") { + return value; + } + + return value.trim(); +} + +export class UpdateTaskDto { + @Transform(({ value }) => normalizeString(value)) + @IsOptional() + @IsString() + @MinLength(1) + @MaxLength(120) + title?: string; + + @IsOptional() + @IsObject() + contentJson?: Record; + + @Transform(({ value }) => normalizeString(value)) + @IsOptional() + @IsString() + @MaxLength(20000) + contentText?: string; + + @IsOptional() + @IsEnum(TaskPriority) + priority?: TaskPriority; + + @IsOptional() + @IsEnum(TaskStatus) + status?: TaskStatus; + + @IsOptional() + @IsDateString() + ddl?: string; + + @Transform(({ value }) => { + if (!Array.isArray(value)) { + return value; + } + + return value.map((item) => normalizeString(item)); + }) + @IsOptional() + @IsArray() + @IsString({ each: true }) + @MinLength(1, { each: true }) + @MaxLength(30, { each: true }) + tagNames?: string[]; +} diff --git a/apps/api/src/task/task.controller.ts b/apps/api/src/task/task.controller.ts new file mode 100644 index 0000000..33b9710 --- /dev/null +++ b/apps/api/src/task/task.controller.ts @@ -0,0 +1,71 @@ +import { + Body, + Controller, + Delete, + Get, + Headers, + Param, + Patch, + Post, + Query, + UnauthorizedException +} from "@nestjs/common"; +import { CreateTaskDto } from "./dto/create-task.dto"; +import { ListTasksQueryDto } from "./dto/list-tasks-query.dto"; +import { UpdateTaskDto } from "./dto/update-task.dto"; +import { ListTasksResponse, TaskResponse, TaskService } from "./task.service"; + +@Controller("tasks") +export class TaskController { + constructor(private readonly taskService: TaskService) {} + + @Get() + async listTasks( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Query() query: ListTasksQueryDto + ): Promise { + return this.taskService.listTasks(this.resolveUserId(userIdHeader), query); + } + + @Get(":taskId") + async getTaskById( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Param("taskId") taskId: string + ): Promise { + return this.taskService.getTaskById(this.resolveUserId(userIdHeader), taskId); + } + + @Post() + async createTask( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: CreateTaskDto + ): Promise { + return this.taskService.createTask(this.resolveUserId(userIdHeader), body); + } + + @Patch(":taskId") + async updateTask( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Param("taskId") taskId: string, + @Body() body: UpdateTaskDto + ): Promise { + return this.taskService.updateTask(this.resolveUserId(userIdHeader), taskId, body); + } + + @Delete(":taskId") + async deleteTask( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Param("taskId") taskId: string + ): Promise<{ success: boolean }> { + return this.taskService.deleteTask(this.resolveUserId(userIdHeader), taskId); + } + + private resolveUserId(userIdHeader: string | string[] | undefined): string { + const userId = Array.isArray(userIdHeader) ? userIdHeader[0] : userIdHeader; + if (!userId) { + throw new UnauthorizedException("缺少用户上下文"); + } + + return userId; + } +} diff --git a/apps/api/src/task/task.module.ts b/apps/api/src/task/task.module.ts new file mode 100644 index 0000000..8226093 --- /dev/null +++ b/apps/api/src/task/task.module.ts @@ -0,0 +1,11 @@ +import { Module } from "@nestjs/common"; +import { PrismaModule } from "../prisma/prisma.module"; +import { TaskController } from "./task.controller"; +import { TaskService } from "./task.service"; + +@Module({ + imports: [PrismaModule], + controllers: [TaskController], + providers: [TaskService] +}) +export class TaskModule {} diff --git a/apps/api/src/task/task.service.ts b/apps/api/src/task/task.service.ts new file mode 100644 index 0000000..deb1f76 --- /dev/null +++ b/apps/api/src/task/task.service.ts @@ -0,0 +1,458 @@ +import { Injectable, InternalServerErrorException, NotFoundException } from "@nestjs/common"; +import { Prisma, TaskPriority, TaskStatus } from "../../generated/prisma/client"; +import { PrismaService } from "../prisma/prisma.service"; +import { DataEncryptionService } from "../security/data-encryption.service"; +import { CreateTaskDto } from "./dto/create-task.dto"; +import { ListTasksQueryDto, TaskSortBy, TaskSortOrder } from "./dto/list-tasks-query.dto"; +import { UpdateTaskDto } from "./dto/update-task.dto"; + +type TaskEntity = Prisma.TaskGetPayload<{ + include: { + taskTags: { + include: { + tag: { + select: { + name: true; + }; + }; + }; + }; + }; +}>; + +export type TaskResponse = { + id: string; + title: string; + contentJson: unknown | null; + contentText: string | null; + priority: TaskPriority; + status: TaskStatus; + ddl: string | null; + completedAt: string | null; + version: number; + tags: string[]; + createdAt: string; + updatedAt: string; +}; + +export type ListTasksResponse = { + items: TaskResponse[]; + page: number; + pageSize: number; + total: number; +}; + +@Injectable() +export class TaskService { + constructor( + private readonly prismaService: PrismaService, + private readonly dataEncryptionService: DataEncryptionService + ) {} + + async listTasks(userId: string, query: ListTasksQueryDto): Promise { + const page = query.page ?? 1; + const pageSize = query.pageSize ?? 20; + const skip = (page - 1) * pageSize; + const keyword = query.keyword?.trim() ?? ""; + + const where = this.buildWhereInput(userId, query, keyword.length === 0); + const orderBy = this.buildOrderByInput(query); + + if (keyword.length > 0) { + const items = await this.prismaService.task.findMany({ + where, + orderBy, + include: { + taskTags: { + include: { + tag: { + select: { + name: true + } + } + } + } + } + }); + + const serializedItems = items.map((item: TaskEntity) => this.serializeTask(item)); + const filteredItems = serializedItems.filter((item) => this.matchesKeyword(item, keyword)); + + return { + items: filteredItems.slice(skip, skip + pageSize), + page, + pageSize, + total: filteredItems.length + }; + } + + const [items, total] = await Promise.all([ + this.prismaService.task.findMany({ + where, + orderBy, + skip, + take: pageSize, + include: { + taskTags: { + include: { + tag: { + select: { + name: true + } + } + } + } + } + }), + this.prismaService.task.count({ where }) + ]); + + return { + items: items.map((item: TaskEntity) => this.serializeTask(item)), + page, + pageSize, + total + }; + } + + async getTaskById(userId: string, taskId: string): Promise { + const task = await this.prismaService.task.findFirst({ + where: { + id: taskId, + userId + }, + include: { + taskTags: { + include: { + tag: { + select: { + name: true + } + } + } + } + } + }); + + if (!task) { + throw new NotFoundException("任务不存在"); + } + + return this.serializeTask(task); + } + + async createTask(userId: string, body: CreateTaskDto): Promise { + const tagNames = this.normalizeTagNames(body.tagNames); + const nextStatus = body.status ?? TaskStatus.TODO; + const contentJson = + body.contentJson !== undefined + ? ((this.dataEncryptionService.encryptJson(body.contentJson as Prisma.InputJsonValue) ?? + Prisma.JsonNull) as Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput) + : undefined; + + const task = await this.prismaService.$transaction(async (tx) => { + const createdTask = await tx.task.create({ + data: { + userId, + title: this.encryptRequiredString(body.title), + contentJson, + contentText: this.encryptNullableString(body.contentText), + priority: body.priority ?? TaskPriority.MEDIUM, + status: nextStatus, + ddl: body.ddl ? new Date(body.ddl) : null, + completedAt: nextStatus === TaskStatus.DONE ? new Date() : null + } + }); + + await this.replaceTaskTags(tx, userId, createdTask.id, tagNames); + + return tx.task.findUniqueOrThrow({ + where: { id: createdTask.id }, + include: { + taskTags: { + include: { + tag: { + select: { + name: true + } + } + } + } + } + }); + }); + + return this.serializeTask(task); + } + + async updateTask(userId: string, taskId: string, body: UpdateTaskDto): Promise { + const currentTask = await this.prismaService.task.findFirst({ + where: { + id: taskId, + userId + }, + select: { + id: true, + status: true + } + }); + + if (!currentTask) { + throw new NotFoundException("任务不存在"); + } + + const data: Prisma.TaskUpdateInput = { + version: { + increment: 1 + } + }; + + if (body.title !== undefined) { + data.title = this.encryptRequiredString(body.title); + } + if (body.contentJson !== undefined) { + data.contentJson = (this.dataEncryptionService.encryptJson( + body.contentJson as Prisma.InputJsonValue + ) ?? Prisma.JsonNull) as Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput; + } + if (body.contentText !== undefined) { + data.contentText = this.encryptNullableString(body.contentText); + } + if (body.priority !== undefined) { + data.priority = body.priority; + } + if (body.status !== undefined) { + data.status = body.status; + if (body.status === TaskStatus.DONE && currentTask.status !== TaskStatus.DONE) { + data.completedAt = new Date(); + } else if (body.status !== TaskStatus.DONE) { + data.completedAt = null; + } + } + if (body.ddl !== undefined) { + data.ddl = body.ddl ? new Date(body.ddl) : null; + } + + const shouldReplaceTags = body.tagNames !== undefined; + const nextTagNames = this.normalizeTagNames(body.tagNames); + + const task = await this.prismaService.$transaction(async (tx) => { + await tx.task.update({ + where: { id: taskId }, + data + }); + + if (shouldReplaceTags) { + await this.replaceTaskTags(tx, userId, taskId, nextTagNames); + } + + return tx.task.findUniqueOrThrow({ + where: { id: taskId }, + include: { + taskTags: { + include: { + tag: { + select: { + name: true + } + } + } + } + } + }); + }); + + return this.serializeTask(task); + } + + async deleteTask(userId: string, taskId: string): Promise<{ success: boolean }> { + const deleted = await this.prismaService.task.deleteMany({ + where: { + id: taskId, + userId + } + }); + + if (deleted.count === 0) { + throw new NotFoundException("任务不存在"); + } + + return { success: true }; + } + + private buildWhereInput( + userId: string, + query: ListTasksQueryDto, + includeKeyword: boolean + ): Prisma.TaskWhereInput { + const where: Prisma.TaskWhereInput = { + userId + }; + + if (query.status !== undefined) { + where.status = query.status; + } + + if (query.priority !== undefined) { + where.priority = query.priority; + } + + if (query.tags !== undefined && query.tags.length > 0) { + where.taskTags = { + some: { + tag: { + name: { + in: query.tags + } + } + } + }; + } + + if (includeKeyword && query.keyword !== undefined && query.keyword.length > 0) { + where.OR = [ + { + title: { + contains: query.keyword, + mode: "insensitive" + } + }, + { + contentText: { + contains: query.keyword, + mode: "insensitive" + } + } + ]; + } + + return where; + } + + private buildOrderByInput(query: ListTasksQueryDto): Prisma.TaskOrderByWithRelationInput { + const order: Prisma.SortOrder = + query.sortOrder === TaskSortOrder.ASC ? Prisma.SortOrder.asc : Prisma.SortOrder.desc; + + if (query.sortBy === TaskSortBy.CREATED_AT) { + return { createdAt: order }; + } + + if (query.sortBy === TaskSortBy.DDL) { + return { ddl: order }; + } + + return { updatedAt: order }; + } + + private normalizeTagNames(tagNames: string[] | undefined): string[] { + if (!tagNames) { + return []; + } + + const result: string[] = []; + const uniqueNames = new Set(); + + for (const rawTagName of tagNames) { + const normalized = rawTagName.trim(); + if (!normalized) { + continue; + } + + const uniqueKey = normalized.toLocaleLowerCase(); + if (uniqueNames.has(uniqueKey)) { + continue; + } + + uniqueNames.add(uniqueKey); + result.push(normalized); + } + + return result; + } + + private async replaceTaskTags( + tx: Prisma.TransactionClient, + userId: string, + taskId: string, + tagNames: string[] + ): Promise { + await tx.taskTag.deleteMany({ + where: { + taskId + } + }); + + if (tagNames.length === 0) { + return; + } + + const tags = await Promise.all( + tagNames.map((name) => + tx.tag.upsert({ + where: { + userId_name: { + userId, + name + } + }, + update: {}, + create: { + userId, + name + } + }) + ) + ); + + await tx.taskTag.createMany({ + data: tags.map((tag: { id: string }) => ({ + taskId, + tagId: tag.id + })), + skipDuplicates: true + }); + } + + private serializeTask(task: TaskEntity): TaskResponse { + return { + id: task.id, + title: this.readDecryptedString(task.title) ?? "未命名任务", + contentJson: this.dataEncryptionService.decryptJson(task.contentJson), + contentText: this.readDecryptedString(task.contentText), + priority: task.priority, + status: task.status, + ddl: task.ddl?.toISOString() ?? null, + completedAt: task.completedAt?.toISOString() ?? null, + version: task.version, + tags: task.taskTags.map((taskTag: { tag: { name: string } }) => taskTag.tag.name), + createdAt: task.createdAt.toISOString(), + updatedAt: task.updatedAt.toISOString() + }; + } + + private encryptRequiredString(value: string): string { + const encryptedValue = this.dataEncryptionService.encryptString(value); + if (!encryptedValue) { + throw new InternalServerErrorException("任务字段加密失败"); + } + + return encryptedValue; + } + + private encryptNullableString(value: string | null | undefined): string | null | undefined { + return this.dataEncryptionService.encryptString(value); + } + + private readDecryptedString(value: string | null): string | null { + const decryptedValue = this.dataEncryptionService.decryptString(value); + return typeof decryptedValue === "string" ? decryptedValue : null; + } + + private matchesKeyword(task: TaskResponse, keyword: string): boolean { + const lowerKeyword = keyword.toLocaleLowerCase(); + return ( + task.title.toLocaleLowerCase().includes(lowerKeyword) || + task.contentText?.toLocaleLowerCase().includes(lowerKeyword) === true + ); + } +} diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts new file mode 100644 index 0000000..7dc70e1 --- /dev/null +++ b/apps/api/test/ai.spec.ts @@ -0,0 +1,1250 @@ +import request from "supertest"; +import { INestApplication, ValidationPipe } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { Test, TestingModule } from "@nestjs/testing"; +import { + AiChannel, + AiUsageLog, + AiProviderBinding, + AiPublicPoolConfig, + TaskPriority, + TaskStatus +} from "../generated/prisma/client"; +import { AiController } from "../src/ai/ai.controller"; +import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service"; +import { AiRateLimitService } from "../src/ai/ai-rate-limit.service"; +import { AiService } from "../src/ai/ai.service"; +import { + AiChatInput, + AiChannelExecutor, + AiResolvedRouteCandidate, + AiRouteFailureError +} from "../src/ai/ai.types"; +import { PrismaService } from "../src/prisma/prisma.service"; +import { DataEncryptionService } from "../src/security/data-encryption.service"; + +type AiUsageLogRecord = { + id: string; + userId: string | null; + channel: AiChannel; + providerName: string | null; + model: string | null; + promptTokens: number; + completionTokens: number; + totalTokens: number; + latencyMs: number | null; + success: boolean; + errorCode: string | null; + createdAt: Date; +}; + +type AiTaskRecord = { + id: string; + userId: string; + title: string; + priority: TaskPriority; + status: TaskStatus; + ddl: Date | null; + contentText: string | null; + updatedAt: Date; +}; + +class InMemoryAiPrismaService { + private bindingIdSequence = 1; + private publicPoolIdSequence = 1; + private usageLogIdSequence = 1; + private bindings: AiProviderBinding[] = []; + private publicPools: AiPublicPoolConfig[] = []; + private usageLogs: AiUsageLogRecord[] = []; + private tasks: AiTaskRecord[] = []; + + readonly aiProviderBinding = { + findMany: async (args: { + where: { + userId: string; + }; + }) => { + return this.bindings + .filter((binding) => binding.userId === args.where.userId) + .sort((left, right) => right.updatedAt.getTime() - left.updatedAt.getTime()); + }, + + findFirst: async (args: { + where: { + id?: string; + userId?: string; + channel?: AiChannel; + isEnabled?: boolean; + }; + }) => { + return ( + this.bindings + .filter((binding) => { + if (args.where.id !== undefined && binding.id !== args.where.id) { + return false; + } + if (args.where.userId !== undefined && binding.userId !== args.where.userId) { + return false; + } + if (args.where.channel !== undefined && binding.channel !== args.where.channel) { + return false; + } + if (args.where.isEnabled !== undefined && binding.isEnabled !== args.where.isEnabled) { + return false; + } + return true; + }) + .sort((left, right) => { + if (left.isDefault !== right.isDefault) { + return Number(right.isDefault) - Number(left.isDefault); + } + return right.updatedAt.getTime() - left.updatedAt.getTime(); + })[0] ?? null + ); + }, + + create: async (args: { + data: { + userId: string; + channel: AiChannel; + providerName: string; + model: string | null; + configId: string | null; + configName: string | null; + endpoint: string | null; + encryptedApiKey: string | null; + isDefault: boolean; + isEnabled: boolean; + }; + }) => { + const now = new Date(); + const binding: AiProviderBinding = { + id: `binding_${this.bindingIdSequence++}`, + userId: args.data.userId, + channel: args.data.channel, + providerName: args.data.providerName, + model: args.data.model, + configId: args.data.configId, + configName: args.data.configName, + encryptedApiKey: args.data.encryptedApiKey, + endpoint: args.data.endpoint, + isDefault: args.data.isDefault, + isEnabled: args.data.isEnabled, + createdAt: now, + updatedAt: now + }; + + this.bindings.push(binding); + return binding; + }, + + update: async (args: { + where: { + id: string; + }; + data: Partial; + }) => { + const binding = this.bindings.find((item) => item.id === args.where.id); + if (!binding) { + throw new Error("binding not found"); + } + + Object.assign(binding, args.data, { updatedAt: new Date() }); + return binding; + }, + + updateMany: async (args: { + where: { + userId?: string; + channel?: AiChannel; + id?: { + not: string; + }; + }; + data: { + isDefault?: boolean; + }; + }) => { + let count = 0; + for (const binding of this.bindings) { + if (args.where.userId !== undefined && binding.userId !== args.where.userId) { + continue; + } + if (args.where.channel !== undefined && binding.channel !== args.where.channel) { + continue; + } + if (args.where.id?.not !== undefined && binding.id === args.where.id.not) { + continue; + } + + if (args.data.isDefault !== undefined) { + binding.isDefault = args.data.isDefault; + binding.updatedAt = new Date(); + } + count += 1; + } + + return { count }; + } + }; + + readonly aiPublicPoolConfig = { + findFirst: async (args?: { + where?: { + enabled?: boolean; + }; + }) => { + const items = this.publicPools + .filter((item) => + args?.where?.enabled === undefined ? true : item.enabled === args.where.enabled + ) + .sort((left, right) => right.updatedAt.getTime() - left.updatedAt.getTime()); + + return items[0] ?? null; + } + }; + + readonly aiUsageLog = { + create: async (args: { data: Omit }) => { + const usageLog: AiUsageLogRecord = { + id: `usage_log_${this.usageLogIdSequence++}`, + createdAt: new Date(), + ...args.data + }; + + this.usageLogs.push(usageLog); + return usageLog; + }, + + findMany: async (args: { + where?: { + userId?: string; + channel?: AiChannel; + success?: boolean; + }; + orderBy?: { + createdAt: "asc" | "desc"; + }; + skip?: number; + take?: number; + }) => { + const filteredLogs = this.filterUsageLogs(args.where); + const sortedLogs = [...filteredLogs].sort((left, right) => { + const direction = args.orderBy?.createdAt === "asc" ? 1 : -1; + return (left.createdAt.getTime() - right.createdAt.getTime()) * direction; + }); + const start = args.skip ?? 0; + const end = args.take === undefined ? undefined : start + args.take; + return sortedLogs.slice(start, end); + }, + + count: async (args?: { + where?: { + userId?: string; + channel?: AiChannel; + success?: boolean; + }; + }) => { + return this.filterUsageLogs(args?.where).length; + } + }; + + readonly task = { + findMany: async (args: { + where: { + userId: string; + status: { + in: TaskStatus[]; + }; + }; + take?: number; + }) => { + const filteredTasks = this.tasks.filter( + (task) => task.userId === args.where.userId && args.where.status.in.includes(task.status) + ); + + return filteredTasks.slice(0, args.take ?? filteredTasks.length).map((task) => ({ + id: task.id, + title: task.title, + priority: task.priority, + status: task.status, + ddl: task.ddl, + contentText: task.contentText, + updatedAt: task.updatedAt + })); + } + }; + + async $transaction(callback: (tx: InMemoryAiPrismaService) => Promise): Promise { + return callback(this); + } + + seedBinding(binding: Omit): void { + const now = new Date(); + this.bindings.push({ + ...binding, + createdAt: now, + updatedAt: now + }); + } + + seedPublicPool(publicPool: Omit): void { + const now = new Date(); + this.publicPools.push({ + id: `pool_${this.publicPoolIdSequence++}`, + createdAt: now, + updatedAt: now, + ...publicPool + }); + } + + getUsageLogs(): AiUsageLogRecord[] { + return [...this.usageLogs]; + } + + getBindings(): AiProviderBinding[] { + return [...this.bindings]; + } + + seedTask(task: AiTaskRecord): void { + this.tasks.push(task); + } + + seedUsageLog(log: Omit & { id?: string }): void { + this.usageLogs.push({ + id: log.id ?? `usage_log_${this.usageLogIdSequence++}`, + ...log + }); + } + + private filterUsageLogs(where?: { + userId?: string; + channel?: AiChannel; + success?: boolean; + }): AiUsageLogRecord[] { + return this.usageLogs.filter((log) => { + if (where?.userId !== undefined && log.userId !== where.userId) { + return false; + } + if (where?.channel !== undefined && log.channel !== where.channel) { + return false; + } + if (where?.success !== undefined && log.success !== where.success) { + return false; + } + + return true; + }); + } +} + +class StaticExecutor implements AiChannelExecutor { + readonly inputs: Array<{ + candidate: AiResolvedRouteCandidate; + message: string; + }> = []; + + constructor( + private readonly resolver: (channel: AiChannel) => { + content?: string; + code?: string; + message?: string; + } + ) {} + + async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput) { + this.inputs.push({ + candidate, + message: input.message + }); + + const result = this.resolver(candidate.channel); + if (result.code) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName || candidate.configName || candidate.configId || "unknown", + result.code, + result.message ?? "执行失败" + ); + } + + return { + channel: candidate.channel, + providerName: candidate.providerName || candidate.configName || candidate.configId || "", + model: candidate.model, + content: result.content ?? "", + sessionId: "session_ai", + usage: { + promptTokens: 12, + completionTokens: 8, + totalTokens: 20 + }, + raw: null + }; + } +} + +describe("AiController (integration)", () => { + let app: INestApplication; + let prismaService: InMemoryAiPrismaService; + let astrbotExecutor: StaticExecutor; + let openAiExecutor: StaticExecutor; + + beforeEach(async () => { + prismaService = new InMemoryAiPrismaService(); + + openAiExecutor = new StaticExecutor((channel) => + channel === AiChannel.USER_KEY + ? { + code: "UPSTREAM_UNREACHABLE", + message: "用户自备 Key 渠道暂时不可用" + } + : { + content: "公共 AI 已接管" + } + ); + astrbotExecutor = new StaticExecutor(() => ({ + content: "AstrBot 已接管" + })); + + const moduleRef: TestingModule = await Test.createTestingModule({ + controllers: [AiController], + providers: [ + AiService, + AiRateLimitService, + DataEncryptionService, + { + provide: PrismaService, + useValue: prismaService + }, + { + provide: ConfigService, + useValue: { + get: (key: string) => { + if (key === "DATA_ENCRYPTION_SECRET") { + return "test-data-encryption-secret"; + } + if (key === "AI_RATE_LIMIT_WINDOW_MS") { + return 60_000; + } + if (key === "AI_RATE_LIMIT_USER_MAX") { + return 2; + } + if (key === "AI_RATE_LIMIT_IP_MAX") { + return 3; + } + + return undefined; + } + } + }, + { + provide: AiProviderRegistryService, + useValue: { + getExecutor: (channel: AiChannel) => + channel === AiChannel.ASTRBOT ? astrbotExecutor : openAiExecutor + } + } + ] + }).compile(); + + app = moduleRef.createNestApplication(); + app.useGlobalPipes( + new ValidationPipe({ + transform: true, + whitelist: true, + forbidNonWhitelisted: true + }) + ); + await app.init(); + }); + + afterEach(async () => { + await app.close(); + }); + + it("should create and list ai bindings", async () => { + await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.ASTRBOT, + providerName: "astrbot-main", + model: "deepseek-chat", + configId: "default", + endpoint: "http://127.0.0.1:6185", + apiKey: "abk_secret_1234", + isEnabled: true + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .get("/ai/bindings") + .set("x-user-id", "user_1") + .expect(200); + + expect(response.body.routeOrder).toEqual([ + AiChannel.USER_KEY, + AiChannel.ASTRBOT, + AiChannel.PUBLIC_POOL + ]); + expect(response.body.bindings).toHaveLength(1); + expect(response.body.bindings[0]).toMatchObject({ + channel: AiChannel.ASTRBOT, + providerName: "astrbot-main", + model: "deepseek-chat", + configId: "default", + configName: null, + hasApiKey: true, + maskedApiKey: "abk_***34", + isEnabled: true + }); + + const storedBinding = prismaService.getBindings()[0]; + expect(storedBinding?.providerName).not.toBe("astrbot-main"); + expect(storedBinding?.endpoint).not.toBe("http://127.0.0.1:6185"); + expect(storedBinding?.encryptedApiKey).not.toBe("abk_secret_1234"); + }); + + it("should hide public pool endpoint from user bindings response", async () => { + prismaService.seedPublicPool({ + enabled: true, + providerName: "public-openai", + model: "gpt-4o-mini", + encryptedApiKey: "sk-public", + endpoint: "https://internal.example.com/v1", + rpmLimit: 60, + dailyTokenLimit: 100000 + }); + + const response = await request(app.getHttpServer()) + .get("/ai/bindings") + .set("x-user-id", "user_1") + .expect(200); + + expect(response.body.publicPool).toEqual({ + enabled: true, + providerName: "public-openai", + model: "gpt-4o-mini", + hasApiKey: true + }); + }); + + it("should upsert one binding per user channel", async () => { + await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + endpoint: "https://api.example.com", + apiKey: "sk-first", + isEnabled: true + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "google", + model: "gemini-2.5-flash", + endpoint: "https://generativelanguage.googleapis.com", + apiKey: "sk-second", + isEnabled: false + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .get("/ai/bindings") + .set("x-user-id", "user_1") + .expect(200); + + expect(response.body.bindings).toEqual([ + expect.objectContaining({ + channel: AiChannel.USER_KEY, + providerName: "google", + model: "gemini-2.5-flash", + endpoint: "https://generativelanguage.googleapis.com", + isEnabled: false, + maskedApiKey: "sk-s***nd" + }) + ]); + }); + + it("should fallback from user key to astrbot", async () => { + prismaService.seedBinding({ + id: "binding_user_key", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + configId: null, + configName: null, + encryptedApiKey: "sk-user", + endpoint: "https://api.example.com", + isDefault: true, + isEnabled: true + }); + prismaService.seedBinding({ + id: "binding_astrbot", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "帮我安排今天的任务" + }) + .expect(201); + + expect(response.body.channel).toBe(AiChannel.ASTRBOT); + expect(response.body.content).toBe("AstrBot 已接管"); + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + status: "failed", + reasonCode: "UPSTREAM_UNREACHABLE", + reasonMessage: "用户自备 Key 渠道暂时不可用" + }, + { + channel: AiChannel.ASTRBOT, + providerName: "default", + model: null, + status: "success", + reasonCode: null, + reasonMessage: null + } + ]); + expect(prismaService.getUsageLogs()).toEqual([ + expect.objectContaining({ + id: expect.any(String), + userId: "user_1", + channel: AiChannel.USER_KEY, + promptTokens: 0, + completionTokens: 0, + totalTokens: 0, + latencyMs: expect.any(Number), + success: false, + errorCode: "UPSTREAM_UNREACHABLE", + createdAt: expect.any(Date) + }), + expect.objectContaining({ + id: expect.any(String), + userId: "user_1", + channel: AiChannel.ASTRBOT, + promptTokens: 12, + completionTokens: 8, + totalTokens: 20, + latencyMs: expect.any(Number), + success: true, + errorCode: null, + createdAt: expect.any(Date) + }) + ]); + expect(prismaService.getUsageLogs()[0]?.providerName).not.toBe("openai"); + expect(prismaService.getUsageLogs()[0]?.model).not.toBe("gpt-4o-mini"); + }); + + it("should allow astrbot binding with config id only", async () => { + const response = await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.ASTRBOT, + configId: "default", + endpoint: "http://127.0.0.1:6185", + apiKey: "abk_secret_1234", + isEnabled: true + }) + .expect(201); + + expect(response.body).toMatchObject({ + channel: AiChannel.ASTRBOT, + providerName: "", + configId: "default", + configName: null, + isEnabled: true + }); + }); + + it("should test binding with stored secret when api key is omitted", async () => { + prismaService.seedBinding({ + id: "binding_user_key_test_existing_secret", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + configId: null, + configName: null, + encryptedApiKey: "sk-existing", + endpoint: "https://api.example.com", + isDefault: false, + isEnabled: true + }); + + const executeSpy = jest.spyOn(openAiExecutor, "execute").mockResolvedValue({ + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + content: "连接成功", + sessionId: "session_binding_test", + usage: { + promptTokens: 1, + completionTokens: 1, + totalTokens: 2 + }, + raw: null + }); + + const response = await request(app.getHttpServer()) + .post("/ai/bindings/test") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + endpoint: "https://api.example.com" + }) + .expect(201); + + expect(response.body).toEqual({ + success: true, + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + contentPreview: "连接成功" + }); + expect(executeSpy).toHaveBeenCalledWith( + expect.objectContaining({ + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + endpoint: "https://api.example.com", + apiKey: "sk-existing" + }), + expect.objectContaining({ + userId: "user_1" + }) + ); + }); + + it("should return structured failure result when binding test fails", async () => { + prismaService.seedBinding({ + id: "binding_user_key_test_failure", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-5.4", + configId: null, + configName: null, + encryptedApiKey: "sk-existing", + endpoint: "https://api.example.com", + isDefault: false, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/bindings/test") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-5.4", + endpoint: "https://api.example.com" + }) + .expect(201); + + expect(response.body).toEqual({ + success: false, + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-5.4", + code: "UPSTREAM_UNREACHABLE", + message: "用户自备 Key 渠道暂时不可用" + }); + }); + + it("should use selected channel without automatic fallback", async () => { + prismaService.seedBinding({ + id: "binding_user_key_selected", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + configId: null, + configName: null, + encryptedApiKey: "sk-user", + endpoint: "https://api.example.com", + isDefault: false, + isEnabled: true + }); + prismaService.seedBinding({ + id: "binding_astrbot_selected", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: false, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "只使用自备渠道", + channel: AiChannel.USER_KEY + }) + .expect(502); + + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + status: "failed", + reasonCode: "UPSTREAM_UNREACHABLE", + reasonMessage: "用户自备 Key 渠道暂时不可用" + } + ]); + }); + + it("should inject unfinished task summary into ai prompt", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_context", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + prismaService.seedTask({ + id: "task_weekly_report", + userId: "user_1", + title: "今晚提交周报", + priority: TaskPriority.URGENT, + status: TaskStatus.IN_PROGRESS, + ddl: new Date("2026-04-06T12:00:00.000Z"), + contentText: "需要汇总 AI 路由、AstrBot 接入和同步模块进度", + updatedAt: new Date("2026-04-06T08:00:00.000Z") + }); + prismaService.seedTask({ + id: "task_done_item", + userId: "user_1", + title: "整理已完成事项", + priority: TaskPriority.LOW, + status: TaskStatus.DONE, + ddl: null, + contentText: "这条任务不应该出现在上下文里", + updatedAt: new Date("2026-04-06T07:00:00.000Z") + }); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "帮我安排今天剩余任务" + }) + .expect(201); + + expect(astrbotExecutor.inputs).toHaveLength(1); + expect(astrbotExecutor.inputs[0]?.message).toContain("以下是系统整理的未完成任务摘要"); + expect(astrbotExecutor.inputs[0]?.message).toContain("今晚提交周报"); + expect(astrbotExecutor.inputs[0]?.message).toContain("优先级:紧急"); + expect(astrbotExecutor.inputs[0]?.message).not.toContain("整理已完成事项"); + expect(astrbotExecutor.inputs[0]?.message).toContain("用户当前问题:帮我安排今天剩余任务"); + }); + + it("should inject local unfinished tasks into ai prompt when database is empty", async () => { + prismaService.seedBinding({ + id: "binding_user_key_local_context", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + configId: null, + configName: null, + encryptedApiKey: "sk-user", + endpoint: "https://api.example.com", + isDefault: true, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "结合我的 TodoList 帮我排优先级", + channel: AiChannel.USER_KEY, + localTasks: [ + { + id: "local_task_1", + title: "准备明天答辩材料", + priority: TaskPriority.URGENT, + status: TaskStatus.IN_PROGRESS, + ddlAt: new Date("2026-04-07T13:00:00.000Z").getTime(), + contentText: "需要补齐演示文稿和总结页", + updatedAt: new Date("2026-04-07T09:00:00.000Z").getTime() + } + ] + }) + .expect(502); + + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + status: "failed", + reasonCode: "UPSTREAM_UNREACHABLE", + reasonMessage: "用户自备 Key 渠道暂时不可用" + } + ]); + expect(openAiExecutor.inputs).toHaveLength(1); + expect(openAiExecutor.inputs[0]?.message).toContain("准备明天答辩材料"); + expect(openAiExecutor.inputs[0]?.message).toContain("优先级:紧急"); + expect(openAiExecutor.inputs[0]?.message).toContain("内容摘要:需要补齐演示文稿和总结页"); + expect(openAiExecutor.inputs[0]?.message).toContain( + "用户当前问题:结合我的 TodoList 帮我排优先级" + ); + expect(astrbotExecutor.inputs).toHaveLength(0); + }); + + it("should prefer newer local task snapshot over older database task", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_local_override", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + prismaService.seedTask({ + id: "task_same_id", + userId: "user_1", + title: "旧标题", + priority: TaskPriority.LOW, + status: TaskStatus.TODO, + ddl: new Date("2026-04-08T10:00:00.000Z"), + contentText: "旧内容", + updatedAt: new Date("2026-04-07T08:00:00.000Z") + }); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "看看我最新要做什么", + channel: AiChannel.ASTRBOT, + localTasks: [ + { + id: "task_same_id", + title: "新标题", + priority: TaskPriority.HIGH, + status: TaskStatus.IN_PROGRESS, + ddlAt: new Date("2026-04-07T15:00:00.000Z").getTime(), + contentText: "新内容", + updatedAt: new Date("2026-04-07T12:00:00.000Z").getTime() + } + ] + }) + .expect(201); + + expect(astrbotExecutor.inputs.at(-1)?.message).toContain("新标题"); + expect(astrbotExecutor.inputs.at(-1)?.message).toContain("优先级:高"); + expect(astrbotExecutor.inputs.at(-1)?.message).toContain("内容摘要:新内容"); + expect(astrbotExecutor.inputs.at(-1)?.message).not.toContain("旧标题"); + expect(astrbotExecutor.inputs.at(-1)?.message).not.toContain("旧内容"); + }); + + it("should return skipped attempts when no channel is available", async () => { + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "帮我总结今天的安排" + }) + .expect(502); + + expect(response.body.message).toBe("当前没有可用的 AI 通道,请稍后重试"); + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: null, + model: null, + status: "skipped", + reasonCode: "CHANNEL_NOT_CONFIGURED", + reasonMessage: "当前用户未配置可用的自备 Key 通道" + }, + { + channel: AiChannel.ASTRBOT, + providerName: null, + model: null, + status: "skipped", + reasonCode: "CHANNEL_NOT_CONFIGURED", + reasonMessage: "当前用户未配置可用的 AstrBot 通道" + }, + { + channel: AiChannel.PUBLIC_POOL, + providerName: null, + model: null, + status: "skipped", + reasonCode: "PUBLIC_POOL_DISABLED", + reasonMessage: "公共 AI 通道未开启" + } + ]); + expect(prismaService.getUsageLogs()).toEqual([]); + }); + + it("should rate limit ai chat by user in the same window", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_rate_limit_user", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", "203.0.113.10") + .send({ + message: "第一条" + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", "203.0.113.10") + .send({ + message: "第二条" + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", "203.0.113.10") + .send({ + message: "第三条" + }) + .expect(429); + + expect(response.body).toMatchObject({ + message: "AI 请求过于频繁,请稍后再试", + code: "AI_RATE_LIMITED", + dimension: "user", + limit: 2, + windowMs: 60000 + }); + expect(response.body.retryAfterMs).toEqual(expect.any(Number)); + expect(astrbotExecutor.inputs).toHaveLength(2); + expect(prismaService.getUsageLogs()).toHaveLength(2); + }); + + it("should rate limit ai chat by ip across different users", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_rate_limit_ip_user_1", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + prismaService.seedBinding({ + id: "binding_astrbot_rate_limit_ip_user_2", + userId: "user_2", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + + const sharedIp = "198.51.100.7"; + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", sharedIp) + .send({ + message: "用户一第一条" + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_2") + .set("x-forwarded-for", sharedIp) + .send({ + message: "用户二第一条" + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", sharedIp) + .send({ + message: "用户一第二条" + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_2") + .set("x-forwarded-for", sharedIp) + .send({ + message: "用户二第二条" + }) + .expect(429); + + expect(response.body).toMatchObject({ + message: "AI 请求过于频繁,请稍后再试", + code: "AI_RATE_LIMITED", + dimension: "ip", + limit: 3, + windowMs: 60000 + }); + expect(response.body.retryAfterMs).toEqual(expect.any(Number)); + expect(astrbotExecutor.inputs).toHaveLength(3); + expect(prismaService.getUsageLogs()).toHaveLength(3); + }); + + it("should list usage logs with pagination and filters", async () => { + prismaService.seedUsageLog({ + id: "usage_log_1", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: "deepseek-chat", + promptTokens: 10, + completionTokens: 6, + totalTokens: 16, + latencyMs: 120, + success: true, + errorCode: null, + createdAt: new Date("2026-04-06T08:00:00.000Z") + }); + prismaService.seedUsageLog({ + id: "usage_log_2", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: "deepseek-chat", + promptTokens: 14, + completionTokens: 9, + totalTokens: 23, + latencyMs: 100, + success: true, + errorCode: null, + createdAt: new Date("2026-04-06T09:00:00.000Z") + }); + prismaService.seedUsageLog({ + id: "usage_log_3", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + promptTokens: 20, + completionTokens: 12, + totalTokens: 32, + latencyMs: 210, + success: false, + errorCode: "UPSTREAM_UNREACHABLE", + createdAt: new Date("2026-04-06T10:00:00.000Z") + }); + prismaService.seedUsageLog({ + id: "usage_log_4", + userId: "user_2", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: "deepseek-chat", + promptTokens: 18, + completionTokens: 11, + totalTokens: 29, + latencyMs: 90, + success: true, + errorCode: null, + createdAt: new Date("2026-04-06T11:00:00.000Z") + }); + + const response = await request(app.getHttpServer()) + .get("/ai/usage-logs") + .set("x-user-id", "user_1") + .query({ + page: 2, + pageSize: 1, + channel: AiChannel.ASTRBOT, + success: true + }) + .expect(200); + + expect(response.body).toEqual({ + items: [ + { + id: "usage_log_1", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: "deepseek-chat", + promptTokens: 10, + completionTokens: 6, + totalTokens: 16, + latencyMs: 120, + success: true, + errorCode: null, + createdAt: "2026-04-06T08:00:00.000Z" + } + ], + page: 2, + pageSize: 1, + total: 2 + }); + }); +}); diff --git a/apps/api/test/astrbot-provider.spec.ts b/apps/api/test/astrbot-provider.spec.ts new file mode 100644 index 0000000..3ccecc9 --- /dev/null +++ b/apps/api/test/astrbot-provider.spec.ts @@ -0,0 +1,73 @@ +import { AiChannel } from "../generated/prisma/client"; +import { AstrbotProvider } from "../src/ai/providers/astrbot.provider"; + +describe("AstrbotProvider", () => { + const originalFetch = global.fetch; + + afterEach(() => { + global.fetch = originalFetch; + jest.restoreAllMocks(); + }); + + it("should not forward binding label fields as astrbot selection parameters", async () => { + const provider = new AstrbotProvider(); + const fetchMock = jest.fn(async (_input: unknown, init?: RequestInit) => { + expect(init?.method).toBe("POST"); + const payload = JSON.parse(String(init?.body ?? "{}")) as Record; + + expect(payload).toMatchObject({ + username: "user_1", + session_id: "session_1", + message: "你好", + enable_streaming: false, + selected_model: "deepseek-chat" + }); + expect(payload).not.toHaveProperty("selected_provider"); + expect(payload).not.toHaveProperty("config_id"); + expect(payload).not.toHaveProperty("config_name"); + + return new Response( + [ + 'data: {"type":"session_id","session_id":"session_1"}', + "", + 'data: {"type":"plain","data":"收到","streaming":false,"chain_type":null}', + "", + 'data: {"type":"end","data":"","streaming":false}', + "" + ].join("\n"), + { + status: 200, + headers: { + "content-type": "text/event-stream" + } + } + ); + }); + + global.fetch = fetchMock as typeof global.fetch; + + const result = await provider.execute( + { + channel: AiChannel.ASTRBOT, + source: "binding", + sourceId: "binding_1", + providerName: "astrbot-main", + model: "deepseek-chat", + configId: "default", + configName: "默认配置", + endpoint: "http://127.0.0.1:6185", + apiKey: "abk_secret" + }, + { + userId: "user_1", + message: "你好", + sessionId: "session_1" + } + ); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(result.content).toBe("收到"); + expect(result.sessionId).toBe("session_1"); + expect(result.providerName).toBe("astrbot-main"); + }); +}); diff --git a/apps/api/test/auth.spec.ts b/apps/api/test/auth.spec.ts new file mode 100644 index 0000000..8f1f1ae --- /dev/null +++ b/apps/api/test/auth.spec.ts @@ -0,0 +1,355 @@ +import { UnauthorizedException } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { JwtService } from "@nestjs/jwt"; +import { Test, TestingModule } from "@nestjs/testing"; +import { AuthMailService } from "../src/auth/auth-mail.service"; +import { AuthService } from "../src/auth/auth.service"; +import { PrismaService } from "../src/prisma/prisma.service"; +import { DataEncryptionService } from "../src/security/data-encryption.service"; + +type UserRecord = { + id: string; + email: string; + emailHash: string; + nickname: string | null; + avatarUrl: string | null; +}; + +type RefreshTokenRecord = { + id: string; + userId: string; + tokenHash: string; + expiresAt: Date; + revokedAt: Date | null; + createdAt: Date; +}; + +type UserSecurityRecord = { + userId: string; + twoFactorEnabled: boolean; + twoFactorSecret: string | null; +}; + +class InMemoryAuthPrismaService { + private userIdSequence = 1; + private refreshTokenIdSequence = 1; + private users: UserRecord[] = []; + private refreshTokens: RefreshTokenRecord[] = []; + private userSecurities: UserSecurityRecord[] = []; + + readonly user = { + upsert: async (args: { + where: { + emailHash: string; + }; + update: Record; + create: { + email: string; + emailHash: string; + }; + select: { + id: true; + email: true; + }; + }) => { + const existingUser = this.users.find((user) => user.emailHash === args.where.emailHash); + if (existingUser) { + return { + id: existingUser.id, + email: existingUser.email + }; + } + + const createdUser: UserRecord = { + id: `user_${this.userIdSequence++}`, + email: args.create.email, + emailHash: args.create.emailHash, + nickname: null, + avatarUrl: null + }; + this.users.push(createdUser); + + return { + id: createdUser.id, + email: createdUser.email + }; + } + }; + + readonly refreshToken = { + create: async (args: { + data: { + userId: string; + tokenHash: string; + expiresAt: Date; + }; + }) => { + const refreshToken: RefreshTokenRecord = { + id: `refresh_${this.refreshTokenIdSequence++}`, + userId: args.data.userId, + tokenHash: args.data.tokenHash, + expiresAt: args.data.expiresAt, + revokedAt: null, + createdAt: new Date() + }; + this.refreshTokens.push(refreshToken); + return refreshToken; + }, + + findUnique: async (args: { + where: { + tokenHash: string; + }; + include: { + user: { + select: { + id: true; + email: true; + }; + }; + }; + }) => { + const refreshToken = this.refreshTokens.find( + (item) => item.tokenHash === args.where.tokenHash + ); + if (!refreshToken) { + return null; + } + + const user = this.users.find((item) => item.id === refreshToken.userId); + if (!user) { + throw new Error("user not found"); + } + + return { + ...refreshToken, + user: { + id: user.id, + email: user.email + } + }; + }, + + update: async (args: { + where: { + id: string; + }; + data: { + revokedAt: Date; + }; + }) => { + const refreshToken = this.refreshTokens.find((item) => item.id === args.where.id); + if (!refreshToken) { + throw new Error("refresh token not found"); + } + + refreshToken.revokedAt = args.data.revokedAt; + return refreshToken; + }, + + updateMany: async (args: { + where: { + tokenHash: string; + revokedAt: null; + }; + data: { + revokedAt: Date; + }; + }) => { + let count = 0; + for (const refreshToken of this.refreshTokens) { + if (refreshToken.tokenHash !== args.where.tokenHash || refreshToken.revokedAt !== null) { + continue; + } + + refreshToken.revokedAt = args.data.revokedAt; + count += 1; + } + + return { count }; + } + }; + + readonly userSecurity = { + upsert: async (args: { + where: { + userId: string; + }; + update: { + twoFactorSecret: string; + twoFactorEnabled: boolean; + }; + create: { + userId: string; + twoFactorSecret: string; + twoFactorEnabled: boolean; + }; + }) => { + const existingSecurity = this.userSecurities.find( + (item) => item.userId === args.where.userId + ); + if (existingSecurity) { + existingSecurity.twoFactorSecret = args.update.twoFactorSecret; + existingSecurity.twoFactorEnabled = args.update.twoFactorEnabled; + return existingSecurity; + } + + const createdSecurity: UserSecurityRecord = { + userId: args.create.userId, + twoFactorSecret: args.create.twoFactorSecret, + twoFactorEnabled: args.create.twoFactorEnabled + }; + this.userSecurities.push(createdSecurity); + return createdSecurity; + }, + + findUnique: async (args: { + where: { + userId: string; + }; + select: { + twoFactorSecret: true; + }; + }) => { + const security = this.userSecurities.find((item) => item.userId === args.where.userId); + if (!security) { + return null; + } + + return { + twoFactorSecret: security.twoFactorSecret + }; + }, + + update: async (args: { + where: { + userId: string; + }; + data: { + twoFactorEnabled: boolean; + }; + }) => { + const security = this.userSecurities.find((item) => item.userId === args.where.userId); + if (!security) { + throw new Error("user security not found"); + } + + security.twoFactorEnabled = args.data.twoFactorEnabled; + return security; + } + }; + + getUsers(): UserRecord[] { + return [...this.users]; + } +} + +class MockAuthMailService { + readonly sentMessages: Array<{ + email: string; + code: string; + ttlSeconds: number; + }> = []; + + async sendLoginCode(email: string, code: string, ttlSeconds: number): Promise { + this.sentMessages.push({ + email, + code, + ttlSeconds + }); + } +} + +describe("AuthService", () => { + let authService: AuthService; + let prismaService: InMemoryAuthPrismaService; + let authMailService: MockAuthMailService; + + beforeEach(async () => { + prismaService = new InMemoryAuthPrismaService(); + authMailService = new MockAuthMailService(); + + const moduleRef: TestingModule = await Test.createTestingModule({ + providers: [ + AuthService, + DataEncryptionService, + { + provide: PrismaService, + useValue: prismaService + }, + { + provide: AuthMailService, + useValue: authMailService + }, + { + provide: JwtService, + useValue: { + signAsync: async (payload: Record) => + `signed-${String(payload["sub"])}-${String(payload["email"])}` + } + }, + { + provide: ConfigService, + useValue: { + get: (key: string) => { + switch (key) { + case "AUTH_EMAIL_CODE_TTL_SECONDS": + return "300"; + case "AUTH_ACCESS_EXPIRES_IN_SECONDS": + return "900"; + case "AUTH_REFRESH_EXPIRES_IN_SECONDS": + return "2592000"; + case "AUTH_TOTP_ISSUER": + return "TodoList"; + case "DATA_ENCRYPTION_SECRET": + return "test-data-encryption-secret"; + default: + return undefined; + } + } + } + } + ] + }).compile(); + + authService = moduleRef.get(AuthService); + }); + + it("should encrypt user email in database while keeping login flow available", async () => { + await authService.sendEmailCode("User@Example.com"); + expect(authMailService.sentMessages).toHaveLength(1); + expect(authMailService.sentMessages[0]?.email).toBe("user@example.com"); + + const loginResult = await authService.loginWithEmailCode( + "USER@example.com", + authMailService.sentMessages[0]?.code ?? "" + ); + + expect(loginResult.user.email).toBe("user@example.com"); + expect(loginResult.accessToken).toContain("user@example.com"); + + const storedUser = prismaService.getUsers()[0]; + expect(storedUser?.email).not.toBe("user@example.com"); + expect(storedUser?.emailHash).toMatch(/^[a-f0-9]{64}$/); + }); + + it("should decrypt user email when refreshing token", async () => { + await authService.sendEmailCode("refresh@example.com"); + const loginResult = await authService.loginWithEmailCode( + "refresh@example.com", + authMailService.sentMessages[0]?.code ?? "" + ); + + const refreshResult = await authService.refreshTokens(loginResult.refreshToken); + expect(refreshResult.user.email).toBe("refresh@example.com"); + expect(refreshResult.accessToken).toContain("refresh@example.com"); + }); + + it("should reject invalid verification code", async () => { + await authService.sendEmailCode("invalid@example.com"); + + await expect( + authService.loginWithEmailCode("invalid@example.com", "000000") + ).rejects.toBeInstanceOf(UnauthorizedException); + }); +}); diff --git a/apps/api/test/openai-compatible-provider.spec.ts b/apps/api/test/openai-compatible-provider.spec.ts new file mode 100644 index 0000000..7654669 --- /dev/null +++ b/apps/api/test/openai-compatible-provider.spec.ts @@ -0,0 +1,80 @@ +import { AiChannel } from "../generated/prisma/client"; +import { OpenAiCompatibleProvider } from "../src/ai/providers/openai-compatible.provider"; + +describe("OpenAiCompatibleProvider", () => { + const originalFetch = global.fetch; + + afterEach(() => { + global.fetch = originalFetch; + jest.restoreAllMocks(); + }); + + it("should read text from responses style payload when chat content is empty", async () => { + const provider = new OpenAiCompatibleProvider(); + const fetchMock = jest.fn(async (_input: unknown, init?: RequestInit) => { + expect(init?.method).toBe("POST"); + + return new Response( + JSON.stringify({ + id: "resp_123", + object: "response", + model: "gpt-5.4", + output: [ + { + id: "msg_123", + type: "message", + role: "assistant", + content: [ + { + type: "output_text", + text: "今天优先先完成截止时间最近的任务。" + } + ] + } + ], + usage: { + prompt_tokens: 15, + completion_tokens: 9, + total_tokens: 24 + } + }), + { + status: 200, + headers: { + "content-type": "application/json" + } + } + ); + }); + + global.fetch = fetchMock as typeof global.fetch; + + const result = await provider.execute( + { + channel: AiChannel.USER_KEY, + source: "binding", + sourceId: "binding_user_key_1", + providerName: "airouter", + model: "gpt-5.4", + configId: null, + configName: null, + endpoint: "https://api.airouter.io/v1", + apiKey: "sk_test" + }, + { + userId: "user_1", + message: "帮我安排今天的任务", + sessionId: null + } + ); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(result.content).toBe("今天优先先完成截止时间最近的任务。"); + expect(result.model).toBe("gpt-5.4"); + expect(result.usage).toEqual({ + promptTokens: 15, + completionTokens: 9, + totalTokens: 24 + }); + }); +}); diff --git a/apps/api/test/sync-push.spec.ts b/apps/api/test/sync-push.spec.ts new file mode 100644 index 0000000..3c75f9b --- /dev/null +++ b/apps/api/test/sync-push.spec.ts @@ -0,0 +1,439 @@ +import request from "supertest"; +import { INestApplication, ValidationPipe } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { Test, TestingModule } from "@nestjs/testing"; +import { PrismaService } from "../src/prisma/prisma.service"; +import { DataEncryptionService } from "../src/security/data-encryption.service"; +import { SyncController } from "../src/sync/sync.controller"; +import { SyncService } from "../src/sync/sync.service"; + +type SyncOperationRecord = { + id: string; + opId: string; + userId: string; + deviceId: string; + entityType: string; + entityId: string; + action: string; + payload: string | null; + clientTs: Date; + serverTs: Date; +}; + +type SyncOperationSelect = { + opId?: true; + entityId?: true; + entityType?: true; + action?: true; + payload?: true; + clientTs?: true; + deviceId?: true; + serverTs?: true; +}; + +type SyncOperationFindManyArgs = { + where: { + userId: string; + opId?: { + in: string[]; + }; + OR?: Array< + | { + serverTs: { + gt: Date; + }; + } + | { + serverTs: Date; + opId: { + gt: string; + }; + } + >; + }; + select: SyncOperationSelect; + orderBy?: Array<{ + serverTs?: "asc" | "desc"; + opId?: "asc" | "desc"; + }>; + take?: number; +}; + +type SyncOperationCreateArgs = { + data: { + opId: string; + userId: string; + deviceId: string; + entityType: string; + entityId: string; + action: string; + payload?: string; + clientTs: Date; + }; + select: { + opId: true; + serverTs: true; + }; +}; + +class InMemoryPrismaService { + private syncOperationIdSequence = 1; + private syncOperations: SyncOperationRecord[] = []; + + readonly syncOperation = { + findMany: async (args: SyncOperationFindManyArgs) => { + let items = this.syncOperations.filter((item) => item.userId === args.where.userId); + + if (args.where.opId?.in) { + items = items.filter((item) => args.where.opId?.in.includes(item.opId)); + } + + if (args.where.OR && args.where.OR.length > 0) { + items = items.filter((item) => + args.where.OR?.some((condition) => { + if ("gt" in condition.serverTs) { + return item.serverTs.getTime() > condition.serverTs.gt.getTime(); + } + + if ("opId" in condition) { + return ( + item.serverTs.getTime() === condition.serverTs.getTime() && + item.opId > condition.opId.gt + ); + } + + return false; + }) + ); + } + + if (args.orderBy && args.orderBy.length > 0) { + items = [...items].sort((left, right) => { + for (const orderRule of args.orderBy ?? []) { + if (orderRule.serverTs) { + const diff = left.serverTs.getTime() - right.serverTs.getTime(); + if (diff !== 0) { + return orderRule.serverTs === "asc" ? diff : -diff; + } + } + + if (orderRule.opId) { + const diff = left.opId.localeCompare(right.opId); + if (diff !== 0) { + return orderRule.opId === "asc" ? diff : -diff; + } + } + } + + return 0; + }); + } + + const limitedItems = args.take ? items.slice(0, args.take) : items; + + return limitedItems.map((item) => this.pickSelectedFields(item, args.select)); + }, + + create: async (args: SyncOperationCreateArgs) => { + const createdOperation: SyncOperationRecord = { + id: `sync_${this.syncOperationIdSequence++}`, + opId: args.data.opId, + userId: args.data.userId, + deviceId: args.data.deviceId, + entityType: args.data.entityType, + entityId: args.data.entityId, + action: args.data.action, + payload: args.data.payload ?? null, + clientTs: args.data.clientTs, + serverTs: new Date() + }; + + this.syncOperations.push(createdOperation); + + return { + opId: createdOperation.opId, + serverTs: createdOperation.serverTs + }; + } + }; + + getOperationCount(): number { + return this.syncOperations.length; + } + + getRawOperationById(opId: string): SyncOperationRecord | undefined { + return this.syncOperations.find((operation) => operation.opId === opId); + } + + seedOperations(records: Array>): void { + for (const record of records) { + this.syncOperations.push({ + ...record, + id: `sync_${this.syncOperationIdSequence++}` + }); + } + } + + private pickSelectedFields( + item: SyncOperationRecord, + select: SyncOperationSelect + ): Partial { + const result: Record = {}; + + for (const key of Object.keys(select) as Array) { + if (!select[key]) { + continue; + } + + const recordKey = key as keyof SyncOperationRecord; + result[recordKey] = item[recordKey]; + } + + return result as Partial; + } +} + +describe("SyncController (integration)", () => { + let app: INestApplication; + let prismaService: InMemoryPrismaService; + + beforeAll(async () => { + prismaService = new InMemoryPrismaService(); + + const moduleRef: TestingModule = await Test.createTestingModule({ + controllers: [SyncController], + providers: [ + SyncService, + DataEncryptionService, + { provide: PrismaService, useValue: prismaService }, + { + provide: ConfigService, + useValue: { + get: (key: string) => + key === "DATA_ENCRYPTION_SECRET" ? "test-data-encryption-secret" : undefined + } + } + ] + }).compile(); + + app = moduleRef.createNestApplication(); + app.useGlobalPipes( + new ValidationPipe({ + transform: true, + whitelist: true, + forbidNonWhitelisted: true + }) + ); + await app.init(); + }); + + afterAll(async () => { + await app.close(); + }); + + it("should accept operations once and mark repeated push as duplicate", async () => { + const payload = { + operations: [ + { + opId: "op-create-1", + entityType: "TASK", + entityId: "task-1", + action: "CREATE", + payload: '{"title":"任务一"}', + clientTs: 1712419200000, + deviceId: "device-a" + }, + { + opId: "op-update-1", + entityType: "TASK", + entityId: "task-1", + action: "UPDATE", + payload: '{"title":"任务一-更新"}', + clientTs: 1712419201000, + deviceId: "device-a" + } + ] + }; + + const firstResponse = await request(app.getHttpServer()) + .post("/sync/push") + .set("x-user-id", "user-1") + .send(payload) + .expect(201); + + expect(firstResponse.body.acceptedCount).toBe(2); + expect(firstResponse.body.duplicateCount).toBe(0); + expect(firstResponse.body.failedCount).toBe(0); + expect(firstResponse.body.results).toEqual([ + expect.objectContaining({ + opId: "op-create-1", + status: "accepted" + }), + expect.objectContaining({ + opId: "op-update-1", + status: "accepted" + }) + ]); + expect(prismaService.getOperationCount()).toBe(2); + expect(prismaService.getRawOperationById("op-create-1")?.payload).not.toBe( + '{"title":"浠诲姟涓€"}' + ); + + const secondResponse = await request(app.getHttpServer()) + .post("/sync/push") + .set("x-user-id", "user-1") + .send(payload) + .expect(201); + + expect(secondResponse.body.acceptedCount).toBe(0); + expect(secondResponse.body.duplicateCount).toBe(2); + expect(secondResponse.body.failedCount).toBe(0); + expect(secondResponse.body.results).toEqual([ + expect.objectContaining({ + opId: "op-create-1", + status: "duplicate", + reason: "already_synced" + }), + expect.objectContaining({ + opId: "op-update-1", + status: "duplicate", + reason: "already_synced" + }) + ]); + expect(prismaService.getOperationCount()).toBe(2); + }); + + it("should mark duplicated op ids in the same batch as duplicate", async () => { + const response = await request(app.getHttpServer()) + .post("/sync/push") + .set("x-user-id", "user-2") + .send({ + operations: [ + { + opId: "op-dup-1", + entityType: "TASK", + entityId: "task-2", + action: "CREATE", + payload: '{"title":"任务二"}', + clientTs: 1712419300000, + deviceId: "device-b" + }, + { + opId: "op-dup-1", + entityType: "TASK", + entityId: "task-2", + action: "UPDATE", + payload: '{"title":"任务二-重复"}', + clientTs: 1712419301000, + deviceId: "device-b" + } + ] + }) + .expect(201); + + expect(response.body.acceptedCount).toBe(1); + expect(response.body.duplicateCount).toBe(1); + expect(response.body.failedCount).toBe(0); + expect(response.body.results[0]).toEqual( + expect.objectContaining({ + opId: "op-dup-1", + status: "accepted" + }) + ); + expect(response.body.results[1]).toEqual( + expect.objectContaining({ + opId: "op-dup-1", + status: "duplicate", + reason: "same_batch_duplicate" + }) + ); + expect(prismaService.getOperationCount()).toBe(3); + }); + + it("should pull operations incrementally with a stable cursor", async () => { + prismaService.seedOperations([ + { + opId: "pull-op-1", + userId: "user-pull", + deviceId: "device-c", + entityType: "TASK", + entityId: "task-10", + action: "CREATE", + payload: '{"title":"任务甲"}', + clientTs: new Date("2026-04-06T10:00:00.000Z"), + serverTs: new Date("2026-04-06T10:10:00.000Z") + }, + { + opId: "pull-op-2", + userId: "user-pull", + deviceId: "device-c", + entityType: "TASK", + entityId: "task-10", + action: "UPDATE", + payload: '{"title":"任务甲-更新"}', + clientTs: new Date("2026-04-06T10:01:00.000Z"), + serverTs: new Date("2026-04-06T10:10:00.000Z") + }, + { + opId: "pull-op-3", + userId: "user-pull", + deviceId: "device-c", + entityType: "TASK", + entityId: "task-11", + action: "CREATE", + payload: '{"title":"任务乙"}', + clientTs: new Date("2026-04-06T10:02:00.000Z"), + serverTs: new Date("2026-04-06T10:11:00.000Z") + }, + { + opId: "pull-op-other-user", + userId: "user-other", + deviceId: "device-z", + entityType: "TASK", + entityId: "task-99", + action: "CREATE", + payload: '{"title":"其他用户任务"}', + clientTs: new Date("2026-04-06T10:03:00.000Z"), + serverTs: new Date("2026-04-06T10:12:00.000Z") + } + ]); + + const firstResponse = await request(app.getHttpServer()) + .get("/sync/pull") + .set("x-user-id", "user-pull") + .query({ limit: 2 }) + .expect(200); + + expect(firstResponse.body.items.map((item: { opId: string }) => item.opId)).toEqual([ + "pull-op-1", + "pull-op-2" + ]); + expect(firstResponse.body.hasMore).toBe(true); + expect(firstResponse.body.nextCursor).toEqual(expect.any(String)); + + const secondResponse = await request(app.getHttpServer()) + .get("/sync/pull") + .set("x-user-id", "user-pull") + .query({ + limit: 2, + cursor: firstResponse.body.nextCursor + }) + .expect(200); + + expect(secondResponse.body.items.map((item: { opId: string }) => item.opId)).toEqual([ + "pull-op-3" + ]); + expect(secondResponse.body.hasMore).toBe(false); + expect(secondResponse.body.nextCursor).toEqual(expect.any(String)); + }); + + it("should reject invalid cursor payload", async () => { + await request(app.getHttpServer()) + .get("/sync/pull") + .set("x-user-id", "user-invalid-cursor") + .query({ + cursor: "not-a-valid-cursor" + }) + .expect(400); + }); +}); diff --git a/apps/api/test/task.spec.ts b/apps/api/test/task.spec.ts new file mode 100644 index 0000000..98b25cd --- /dev/null +++ b/apps/api/test/task.spec.ts @@ -0,0 +1,481 @@ +import request from "supertest"; +import { INestApplication, ValidationPipe } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { Test, TestingModule } from "@nestjs/testing"; +import { PrismaService } from "../src/prisma/prisma.service"; +import { DataEncryptionService } from "../src/security/data-encryption.service"; +import { TaskController } from "../src/task/task.controller"; +import { TaskService } from "../src/task/task.service"; +import { TaskPriority, TaskStatus } from "../generated/prisma/client"; + +type TaskRecord = { + id: string; + userId: string; + title: string; + contentJson: unknown | null; + contentText: string | null; + priority: TaskPriority; + status: TaskStatus; + ddl: Date | null; + completedAt: Date | null; + version: number; + createdAt: Date; + updatedAt: Date; +}; + +type TagRecord = { + id: string; + userId: string; + name: string; +}; + +type TaskTagRecord = { + taskId: string; + tagId: string; +}; + +type ListWhereInput = { + userId?: string; + status?: TaskStatus; + priority?: TaskPriority; + taskTags?: { + some: { + tag: { + name: { + in: string[]; + }; + }; + }; + }; + OR?: Array<{ + title?: { + contains: string; + mode?: "insensitive"; + }; + contentText?: { + contains: string; + mode?: "insensitive"; + }; + }>; +}; + +class InMemoryPrismaService { + private taskIdSequence = 1; + private tagIdSequence = 1; + private tasks: TaskRecord[] = []; + private tags: TagRecord[] = []; + private taskTags: TaskTagRecord[] = []; + + readonly task = { + findMany: async (args: { + where?: ListWhereInput; + orderBy?: { createdAt?: "asc" | "desc"; updatedAt?: "asc" | "desc"; ddl?: "asc" | "desc" }; + skip?: number; + take?: number; + }) => { + const where = args.where; + const skip = args.skip ?? 0; + const take = args.take ?? 20; + let filtered = [...this.tasks]; + + if (where?.userId) { + filtered = filtered.filter((task) => task.userId === where.userId); + } + if (where?.status) { + filtered = filtered.filter((task) => task.status === where.status); + } + if (where?.priority) { + filtered = filtered.filter((task) => task.priority === where.priority); + } + if (where?.taskTags?.some.tag.name.in) { + const expectedTags = new Set(where.taskTags.some.tag.name.in); + filtered = filtered.filter((task) => { + const taskTagNames = this.getTaskTagNames(task.id); + return taskTagNames.some((tagName) => expectedTags.has(tagName)); + }); + } + if (where?.OR && where.OR.length > 0) { + filtered = filtered.filter((task) => + where.OR!.some((orCondition) => { + if (orCondition.title?.contains) { + return task.title.toLowerCase().includes(orCondition.title.contains.toLowerCase()); + } + if (orCondition.contentText?.contains) { + return ( + task.contentText + ?.toLowerCase() + .includes(orCondition.contentText.contains.toLowerCase()) ?? false + ); + } + return false; + }) + ); + } + + if (args.orderBy) { + const [orderField, orderDirection] = Object.entries(args.orderBy)[0] as [ + "createdAt" | "updatedAt" | "ddl", + "asc" | "desc" + ]; + filtered.sort((left, right) => { + const leftValue = left[orderField]; + const rightValue = right[orderField]; + + if (leftValue === null && rightValue === null) { + return 0; + } + if (leftValue === null) { + return 1; + } + if (rightValue === null) { + return -1; + } + + const diff = leftValue.getTime() - rightValue.getTime(); + return orderDirection === "asc" ? diff : -diff; + }); + } + + return filtered.slice(skip, skip + take).map((task) => this.toTaskWithTags(task)); + }, + + count: async (args: { where?: ListWhereInput }) => { + const results = await this.task.findMany({ + where: args.where, + skip: 0, + take: Number.MAX_SAFE_INTEGER + }); + return results.length; + }, + + findFirst: async (args: { + where: { + id?: string; + userId?: string; + }; + select?: { + id?: boolean; + status?: boolean; + }; + }) => { + const task = this.tasks.find( + (item) => + (args.where.id === undefined || item.id === args.where.id) && + (args.where.userId === undefined || item.userId === args.where.userId) + ); + if (!task) { + return null; + } + + if (args.select) { + return { + id: args.select.id ? task.id : undefined, + status: args.select.status ? task.status : undefined + }; + } + + return this.toTaskWithTags(task); + }, + + create: async (args: { + data: { + userId: string; + title: string; + contentJson?: unknown; + contentText: string | null; + priority: TaskPriority; + status: TaskStatus; + ddl: Date | null; + completedAt: Date | null; + }; + }) => { + const now = new Date(); + const task: TaskRecord = { + id: `task_${this.taskIdSequence++}`, + userId: args.data.userId, + title: args.data.title, + contentJson: args.data.contentJson ?? null, + contentText: args.data.contentText, + priority: args.data.priority, + status: args.data.status, + ddl: args.data.ddl, + completedAt: args.data.completedAt, + version: 1, + createdAt: now, + updatedAt: now + }; + this.tasks.push(task); + return task; + }, + + update: async (args: { + where: { + id: string; + }; + data: { + title?: string; + contentJson?: unknown; + contentText?: string | null; + priority?: TaskPriority; + status?: TaskStatus; + ddl?: Date | null; + completedAt?: Date | null; + version?: { + increment: number; + }; + }; + }) => { + const task = this.tasks.find((item) => item.id === args.where.id); + if (!task) { + throw new Error("task not found"); + } + + if (args.data.title !== undefined) { + task.title = args.data.title; + } + if (args.data.contentJson !== undefined) { + task.contentJson = args.data.contentJson; + } + if (args.data.contentText !== undefined) { + task.contentText = args.data.contentText; + } + if (args.data.priority !== undefined) { + task.priority = args.data.priority; + } + if (args.data.status !== undefined) { + task.status = args.data.status; + } + if (args.data.ddl !== undefined) { + task.ddl = args.data.ddl; + } + if (args.data.completedAt !== undefined) { + task.completedAt = args.data.completedAt; + } + if (args.data.version !== undefined) { + task.version += args.data.version.increment; + } + task.updatedAt = new Date(); + + return task; + }, + + deleteMany: async (args: { + where: { + id: string; + userId: string; + }; + }) => { + const beforeCount = this.tasks.length; + this.tasks = this.tasks.filter( + (task) => !(task.id === args.where.id && task.userId === args.where.userId) + ); + this.taskTags = this.taskTags.filter((taskTag) => taskTag.taskId !== args.where.id); + return { + count: beforeCount - this.tasks.length + }; + }, + + findUniqueOrThrow: async (args: { + where: { + id: string; + }; + }) => { + const task = this.tasks.find((item) => item.id === args.where.id); + if (!task) { + throw new Error("task not found"); + } + + return this.toTaskWithTags(task); + } + }; + + readonly tag = { + upsert: async (args: { + where: { + userId_name: { + userId: string; + name: string; + }; + }; + create: { + userId: string; + name: string; + }; + }) => { + const existing = this.tags.find( + (tag) => + tag.userId === args.where.userId_name.userId && tag.name === args.where.userId_name.name + ); + if (existing) { + return existing; + } + + const createdTag: TagRecord = { + id: `tag_${this.tagIdSequence++}`, + userId: args.create.userId, + name: args.create.name + }; + this.tags.push(createdTag); + return createdTag; + } + }; + + readonly taskTag = { + deleteMany: async (args: { + where: { + taskId: string; + }; + }) => { + const beforeCount = this.taskTags.length; + this.taskTags = this.taskTags.filter((taskTag) => taskTag.taskId !== args.where.taskId); + return { + count: beforeCount - this.taskTags.length + }; + }, + + createMany: async (args: { + data: Array<{ + taskId: string; + tagId: string; + }>; + }) => { + for (const row of args.data) { + const existing = this.taskTags.find( + (taskTag) => taskTag.taskId === row.taskId && taskTag.tagId === row.tagId + ); + if (!existing) { + this.taskTags.push(row); + } + } + return { + count: args.data.length + }; + } + }; + + async $transaction(runner: (tx: InMemoryPrismaService) => Promise): Promise { + return runner(this); + } + + getRawTaskById(taskId: string): TaskRecord | undefined { + return this.tasks.find((task) => task.id === taskId); + } + + private toTaskWithTags( + task: TaskRecord + ): TaskRecord & { taskTags: Array<{ tag: { name: string } }> } { + return { + ...task, + taskTags: this.taskTags + .filter((taskTag) => taskTag.taskId === task.id) + .map((taskTag) => this.tags.find((tag) => tag.id === taskTag.tagId)) + .filter((tag): tag is TagRecord => tag !== undefined) + .map((tag) => ({ + tag: { + name: tag.name + } + })) + }; + } + + private getTaskTagNames(taskId: string): string[] { + return this.taskTags + .filter((taskTag) => taskTag.taskId === taskId) + .map((taskTag) => this.tags.find((tag) => tag.id === taskTag.tagId)) + .filter((tag): tag is TagRecord => tag !== undefined) + .map((tag) => tag.name); + } +} + +describe("TaskController (integration)", () => { + let app: INestApplication; + const prismaService = new InMemoryPrismaService(); + + beforeAll(async () => { + const moduleRef: TestingModule = await Test.createTestingModule({ + controllers: [TaskController], + providers: [ + TaskService, + DataEncryptionService, + { provide: PrismaService, useValue: prismaService as unknown as PrismaService }, + { + provide: ConfigService, + useValue: { + get: (key: string) => + key === "DATA_ENCRYPTION_SECRET" ? "test-data-encryption-secret" : undefined + } + } + ] + }).compile(); + + app = moduleRef.createNestApplication(); + app.useGlobalPipes( + new ValidationPipe({ + transform: true, + whitelist: true, + forbidNonWhitelisted: true + }) + ); + + await app.init(); + }); + + afterAll(async () => { + await app.close(); + }); + + it("should create, query, update and delete a task", async () => { + const createResponse = await request(app.getHttpServer()) + .post("/tasks") + .set("x-user-id", "user_1") + .send({ + title: "准备周会", + contentText: "整理本周进度", + priority: "HIGH", + tagNames: ["工作", "会议"] + }) + .expect(201); + + expect(createResponse.body.id).toBeDefined(); + expect(createResponse.body.tags).toEqual(["工作", "会议"]); + const taskId = createResponse.body.id as string; + const rawCreatedTask = prismaService.getRawTaskById(taskId); + expect(rawCreatedTask?.title).not.toBe("准备周会"); + expect(rawCreatedTask?.contentText).not.toBe("整理本周进度"); + + const listResponse = await request(app.getHttpServer()) + .get("/tasks") + .set("x-user-id", "user_1") + .query({ tags: "会议" }) + .expect(200); + + expect(listResponse.body.total).toBe(1); + expect(listResponse.body.items[0].id).toBe(taskId); + + const updateResponse = await request(app.getHttpServer()) + .patch(`/tasks/${taskId}`) + .set("x-user-id", "user_1") + .send({ + status: "DONE" + }) + .expect(200); + + expect(updateResponse.body.status).toBe("DONE"); + expect(updateResponse.body.completedAt).toBeTruthy(); + expect(updateResponse.body.version).toBe(2); + + await request(app.getHttpServer()) + .delete(`/tasks/${taskId}`) + .set("x-user-id", "user_1") + .expect(200) + .expect({ + success: true + }); + + const listAfterDeleteResponse = await request(app.getHttpServer()) + .get("/tasks") + .set("x-user-id", "user_1") + .expect(200); + expect(listAfterDeleteResponse.body.total).toBe(0); + }); +}); diff --git a/apps/api/tsconfig.build.json b/apps/api/tsconfig.build.json new file mode 100644 index 0000000..10b2c04 --- /dev/null +++ b/apps/api/tsconfig.build.json @@ -0,0 +1,5 @@ +{ + "$schema": "https://json.schemastore.org/tsconfig", + "extends": "./tsconfig.json", + "exclude": ["node_modules", "dist", "**/*.spec.ts"] +} diff --git a/apps/api/tsconfig.json b/apps/api/tsconfig.json new file mode 100644 index 0000000..f196337 --- /dev/null +++ b/apps/api/tsconfig.json @@ -0,0 +1,10 @@ +{ + "$schema": "https://json.schemastore.org/tsconfig", + "extends": "../../packages/tsconfig/nest-app.json", + "compilerOptions": { + "rootDir": ".", + "outDir": "dist" + }, + "include": ["src/**/*.ts", "scripts/**/*.ts", "generated/prisma/**/*.ts"], + "exclude": ["dist", "node_modules"] +} diff --git a/apps/api/tsconfig.spec.json b/apps/api/tsconfig.spec.json new file mode 100644 index 0000000..e83c972 --- /dev/null +++ b/apps/api/tsconfig.spec.json @@ -0,0 +1,9 @@ +{ + "$schema": "https://json.schemastore.org/tsconfig", + "extends": "./tsconfig.json", + "compilerOptions": { + "types": ["node", "jest"] + }, + "include": ["src/**/*.ts", "generated/prisma/**/*.ts", "test/**/*.ts"], + "exclude": ["dist", "node_modules"] +} diff --git a/apps/web/.gitignore b/apps/web/.gitignore new file mode 100644 index 0000000..a547bf3 --- /dev/null +++ b/apps/web/.gitignore @@ -0,0 +1,24 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/apps/web/README.md b/apps/web/README.md new file mode 100644 index 0000000..15e3b1d --- /dev/null +++ b/apps/web/README.md @@ -0,0 +1,57 @@ +# TodoList Web 前端 + +这是 TodoList 的用户端前端应用(SPA + PWA),基于 `React + TypeScript + Vite`。 + +## 技术栈 + +- React +- TypeScript +- Vite +- Tailwind CSS +- shadcn/ui + +## 本地开发 + +在仓库根目录执行: + +```bash +pnpm install +pnpm --filter web dev +``` + +默认开发地址: + +- `http://localhost:5173` + +## 后端接口地址 + +前端默认请求: + +- `http://localhost:3000` + +如需自定义,请在运行前设置环境变量: + +```bash +VITE_API_BASE_URL=http://localhost:3000 +``` + +## 构建与预览 + +```bash +pnpm --filter web build +pnpm --filter web preview +``` + +## 当前功能进度(阶段性) + +- 邮箱验证码登录页面 +- OAuth 回调页面 +- 会话本地缓存与启动恢复 +- 基础工作台页面骨架 + +## 目录说明 + +- `src/pages`:页面组件 +- `src/components`:通用 UI 组件 +- `src/services`:接口请求与会话处理 +- `src/lib`:工具函数 diff --git a/apps/web/components.json b/apps/web/components.json new file mode 100644 index 0000000..f5ebd86 --- /dev/null +++ b/apps/web/components.json @@ -0,0 +1,25 @@ +{ + "$schema": "https://ui.shadcn.com/schema.json", + "style": "base-nova", + "rsc": false, + "tsx": true, + "tailwind": { + "config": "tailwind.config.js", + "css": "src/index.css", + "baseColor": "neutral", + "cssVariables": true, + "prefix": "" + }, + "iconLibrary": "lucide", + "rtl": false, + "aliases": { + "components": "@/components", + "utils": "@/lib/utils", + "ui": "@/components/ui", + "lib": "@/lib", + "hooks": "@/hooks" + }, + "menuColor": "default", + "menuAccent": "subtle", + "registries": {} +} diff --git a/apps/web/eslint.config.js b/apps/web/eslint.config.js new file mode 100644 index 0000000..99368a6 --- /dev/null +++ b/apps/web/eslint.config.js @@ -0,0 +1,23 @@ +import js from "@eslint/js"; +import globals from "globals"; +import reactHooks from "eslint-plugin-react-hooks"; +import reactRefresh from "eslint-plugin-react-refresh"; +import tseslint from "typescript-eslint"; +import { defineConfig, globalIgnores } from "eslint/config"; + +export default defineConfig([ + globalIgnores(["dist"]), + { + files: ["**/*.{ts,tsx}"], + extends: [ + js.configs.recommended, + tseslint.configs.recommended, + reactHooks.configs.flat.recommended, + reactRefresh.configs.vite + ], + languageOptions: { + ecmaVersion: 2020, + globals: globals.browser + } + } +]); diff --git a/apps/web/index.html b/apps/web/index.html new file mode 100644 index 0000000..da4e35e --- /dev/null +++ b/apps/web/index.html @@ -0,0 +1,14 @@ + + + + + + + + TodoList + + +
+ + + diff --git a/apps/web/package.json b/apps/web/package.json new file mode 100644 index 0000000..efbfcf0 --- /dev/null +++ b/apps/web/package.json @@ -0,0 +1,51 @@ +{ + "name": "web", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "tsc -b && vite build", + "lint": "eslint .", + "preview": "vite preview" + }, + "dependencies": { + "@base-ui/react": "^1.3.0", + "@fontsource-variable/geist": "^5.2.8", + "@tiptap/core": "^3.22.2", + "@tiptap/extension-image": "^3.22.2", + "@tiptap/extension-link": "^3.22.2", + "@tiptap/extension-youtube": "^3.22.2", + "@tiptap/react": "^3.22.2", + "@tiptap/starter-kit": "^3.22.2", + "browser-image-compression": "^2.0.2", + "class-variance-authority": "^0.7.1", + "clsx": "^2.1.1", + "dexie": "^4.4.2", + "dexie-react-hooks": "^4.4.0", + "lucide-react": "^1.7.0", + "react": "^19.2.4", + "react-dom": "^19.2.4", + "react-router-dom": "^7.14.0", + "shadcn": "^4.1.2", + "tailwind-merge": "^3.5.0", + "tw-animate-css": "^1.4.0" + }, + "devDependencies": { + "@eslint/js": "^9.39.4", + "@types/node": "^24.12.0", + "@types/react": "^19.2.14", + "@types/react-dom": "^19.2.3", + "@vitejs/plugin-react": "^6.0.1", + "autoprefixer": "^10.4.27", + "eslint": "^9.39.4", + "eslint-plugin-react-hooks": "^7.0.1", + "eslint-plugin-react-refresh": "^0.5.2", + "globals": "^17.4.0", + "postcss": "^8.5.8", + "tailwindcss": "^3.4.17", + "typescript": "~5.9.3", + "typescript-eslint": "^8.57.0", + "vite": "^8.0.1" + } +} diff --git a/apps/web/postcss.config.js b/apps/web/postcss.config.js new file mode 100644 index 0000000..ba80730 --- /dev/null +++ b/apps/web/postcss.config.js @@ -0,0 +1,6 @@ +export default { + plugins: { + tailwindcss: {}, + autoprefixer: {} + } +}; diff --git a/apps/web/public/favicon.png b/apps/web/public/favicon.png new file mode 100644 index 0000000..0f8c608 Binary files /dev/null and b/apps/web/public/favicon.png differ diff --git a/apps/web/public/favicon.svg b/apps/web/public/favicon.svg new file mode 100644 index 0000000..6893eb1 --- /dev/null +++ b/apps/web/public/favicon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/apps/web/public/icons.svg b/apps/web/public/icons.svg new file mode 100644 index 0000000..e952219 --- /dev/null +++ b/apps/web/public/icons.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/apps/web/src/App.css b/apps/web/src/App.css new file mode 100644 index 0000000..f90339d --- /dev/null +++ b/apps/web/src/App.css @@ -0,0 +1,184 @@ +.counter { + font-size: 16px; + padding: 5px 10px; + border-radius: 5px; + color: var(--accent); + background: var(--accent-bg); + border: 2px solid transparent; + transition: border-color 0.3s; + margin-bottom: 24px; + + &:hover { + border-color: var(--accent-border); + } + &:focus-visible { + outline: 2px solid var(--accent); + outline-offset: 2px; + } +} + +.hero { + position: relative; + + .base, + .framework, + .vite { + inset-inline: 0; + margin: 0 auto; + } + + .base { + width: 170px; + position: relative; + z-index: 0; + } + + .framework, + .vite { + position: absolute; + } + + .framework { + z-index: 1; + top: 34px; + height: 28px; + transform: perspective(2000px) rotateZ(300deg) rotateX(44deg) rotateY(39deg) + scale(1.4); + } + + .vite { + z-index: 0; + top: 107px; + height: 26px; + width: auto; + transform: perspective(2000px) rotateZ(300deg) rotateX(40deg) rotateY(39deg) + scale(0.8); + } +} + +#center { + display: flex; + flex-direction: column; + gap: 25px; + place-content: center; + place-items: center; + flex-grow: 1; + + @media (max-width: 1024px) { + padding: 32px 20px 24px; + gap: 18px; + } +} + +#next-steps { + display: flex; + border-top: 1px solid var(--border); + text-align: left; + + & > div { + flex: 1 1 0; + padding: 32px; + @media (max-width: 1024px) { + padding: 24px 20px; + } + } + + .icon { + margin-bottom: 16px; + width: 22px; + height: 22px; + } + + @media (max-width: 1024px) { + flex-direction: column; + text-align: center; + } +} + +#docs { + border-right: 1px solid var(--border); + + @media (max-width: 1024px) { + border-right: none; + border-bottom: 1px solid var(--border); + } +} + +#next-steps ul { + list-style: none; + padding: 0; + display: flex; + gap: 8px; + margin: 32px 0 0; + + .logo { + height: 18px; + } + + a { + color: var(--text-h); + font-size: 16px; + border-radius: 6px; + background: var(--social-bg); + display: flex; + padding: 6px 12px; + align-items: center; + gap: 8px; + text-decoration: none; + transition: box-shadow 0.3s; + + &:hover { + box-shadow: var(--shadow); + } + .button-icon { + height: 18px; + width: 18px; + } + } + + @media (max-width: 1024px) { + margin-top: 20px; + flex-wrap: wrap; + justify-content: center; + + li { + flex: 1 1 calc(50% - 8px); + } + + a { + width: 100%; + justify-content: center; + box-sizing: border-box; + } + } +} + +#spacer { + height: 88px; + border-top: 1px solid var(--border); + @media (max-width: 1024px) { + height: 48px; + } +} + +.ticks { + position: relative; + width: 100%; + + &::before, + &::after { + content: ''; + position: absolute; + top: -4.5px; + border: 5px solid transparent; + } + + &::before { + left: 0; + border-left-color: var(--border); + } + &::after { + right: 0; + border-right-color: var(--border); + } +} diff --git a/apps/web/src/App.tsx b/apps/web/src/App.tsx new file mode 100644 index 0000000..dd51c3b --- /dev/null +++ b/apps/web/src/App.tsx @@ -0,0 +1,386 @@ +import { useEffect, useState } from "react"; +import type { LucideIcon } from "lucide-react"; +import { + Bell, + ChevronLeft, + ChevronRight, + LayoutDashboard, + ListTodo, + LogOut, + Menu, + Moon, + Settings, + Sparkles, + Sun, + X +} from "lucide-react"; +import { Navigate, Route, Routes, useLocation, useNavigate } from "react-router-dom"; +import { Button } from "@/components/ui/button"; +import { cn } from "@/lib/utils"; +import { AiChatPage } from "@/pages/ai-chat-page"; +import { EmailLoginPage } from "@/pages/email-login-page"; +import { OAuthCallbackPage } from "@/pages/oauth-callback-page"; +import { PlaceholderPage } from "@/pages/placeholder-page"; +import { SettingsPage } from "@/pages/settings-page"; +import { TodoShellPage } from "@/pages/todo-shell-page"; +import { revokeRefreshToken, type EmailLoginResult } from "@/services/auth-api"; +import { + clearSession, + loadSession, + saveSession, + type WebSession +} from "@/services/session-storage"; +import { + applyThemeMode, + loadThemeMode, + saveThemeMode, + type ThemeMode +} from "@/services/theme-storage"; + +type SidebarItem = { + key: string; + label: string; + icon: LucideIcon; + path: string; +}; + +const SIDEBAR_ITEMS: SidebarItem[] = [ + { key: "dashboard", label: "概览面板", icon: LayoutDashboard, path: "/dashboard" }, + { key: "todo", label: "待办事项", icon: ListTodo, path: "/todo" }, + { key: "ai", label: "AI 助手", icon: Sparkles, path: "/ai" }, + { key: "notice", label: "提醒中心", icon: Bell, path: "/notice" }, + { key: "settings", label: "系统设置", icon: Settings, path: "/settings" } +]; + +const READY_SIDEBAR_KEYS = new Set(["todo", "ai", "settings"]); + +function toWebSession(payload: EmailLoginResult): WebSession { + return { + accessToken: payload.accessToken, + refreshToken: payload.refreshToken, + user: { + id: payload.user.id, + email: payload.user.email + } + }; +} + +function App() { + const [session, setSession] = useState(() => loadSession()); + const [loggingOut, setLoggingOut] = useState(false); + const [themeMode, setThemeMode] = useState(() => loadThemeMode()); + const [sidebarCollapsed, setSidebarCollapsed] = useState(false); + const [mobileSidebarOpen, setMobileSidebarOpen] = useState(false); + const navigate = useNavigate(); + const location = useLocation(); + + const isAuthPage = + location.pathname === "/login/email" || location.pathname.startsWith("/auth/callback/"); + + useEffect(() => { + applyThemeMode(themeMode); + saveThemeMode(themeMode); + }, [themeMode]); + + async function handleLogout(): Promise { + if (!session || loggingOut) { + return; + } + + try { + setLoggingOut(true); + await revokeRefreshToken(session.refreshToken); + } catch { + // 无论接口成功与否,都要清理本地会话,避免页面卡在登录态。 + } finally { + clearSession(); + setSession(null); + setLoggingOut(false); + setMobileSidebarOpen(false); + navigate("/login/email", { replace: true }); + } + } + + function handleToggleTheme(): void { + setThemeMode((currentTheme) => (currentTheme === "dark" ? "light" : "dark")); + } + + function handleLoginSuccess(payload: EmailLoginResult): void { + const nextSession = toWebSession(payload); + saveSession(nextSession); + setSession(nextSession); + setMobileSidebarOpen(false); + navigate("/todo", { replace: true }); + } + + function handleBootstrapSession(nextSession: WebSession): void { + setSession(nextSession); + setMobileSidebarOpen(false); + } + + function renderSidebarContent(options: { collapsed: boolean; mobile: boolean }) { + const { collapsed, mobile } = options; + + return ( +
+ {mobile ? ( +
+ +
+ ) : null} + +
+ +
+ +
+ + + +
+
+ ); + } + + if (isAuthPage) { + return ( +
+
+
+ + } + /> + } + /> + } + /> + +
+
+
+ ); + } + + return ( +
+
+
+
+ + TodoList + TodoList +
+ + {session ? session.user.email : "未登录"} + +
+
+ + {mobileSidebarOpen ? ( + + + +
+
+
+ + } + /> + + ) : ( + + ) + } + /> + + ) : ( + + ) + } + /> + + ) : ( + + ) + } + /> + + ) : ( + + ) + } + /> + + ) : ( + + ) + } + /> + } + /> + +
+
+
+
+ + ); +} + +export default App; diff --git a/apps/web/src/assets/hero.png b/apps/web/src/assets/hero.png new file mode 100644 index 0000000..cc51a3d Binary files /dev/null and b/apps/web/src/assets/hero.png differ diff --git a/apps/web/src/assets/react.svg b/apps/web/src/assets/react.svg new file mode 100644 index 0000000..6c87de9 --- /dev/null +++ b/apps/web/src/assets/react.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/apps/web/src/assets/vite.svg b/apps/web/src/assets/vite.svg new file mode 100644 index 0000000..5101b67 --- /dev/null +++ b/apps/web/src/assets/vite.svg @@ -0,0 +1 @@ +Vite diff --git a/apps/web/src/components/ai/ai-shared.ts b/apps/web/src/components/ai/ai-shared.ts new file mode 100644 index 0000000..c0fa0e0 --- /dev/null +++ b/apps/web/src/components/ai/ai-shared.ts @@ -0,0 +1,72 @@ +import type { UpsertWebAiBindingInput, WebAiBindingSummary, WebAiChannel } from "@/services/ai-api"; + +export type AiBindingFormState = { + providerName: string; + model: string; + endpoint: string; + apiKey: string; + configId: string; + configName: string; + isEnabled: boolean; +}; + +export const CHANNEL_ORDER: WebAiChannel[] = ["USER_KEY", "ASTRBOT", "PUBLIC_POOL"]; + +export const CHANNEL_META: Record< + WebAiChannel, + { + title: string; + description: string; + accentClassName: string; + } +> = { + USER_KEY: { + title: "自备厂商", + description: "用户自行接入 OpenAI-Compatible 服务", + accentClassName: "from-sky-500/15 via-transparent to-sky-500/5" + }, + ASTRBOT: { + title: "AstrBot", + description: "复用你在 AstrBot 中维护的模型配置", + accentClassName: "from-amber-500/15 via-transparent to-amber-500/5" + }, + PUBLIC_POOL: { + title: "公共 AI", + description: "使用管理员开放的站点公共通道", + accentClassName: "from-emerald-500/15 via-transparent to-emerald-500/5" + } +}; + +export function createAiBindingFormState(binding?: WebAiBindingSummary | null): AiBindingFormState { + return { + providerName: binding?.providerName ?? "", + model: binding?.model ?? "", + endpoint: binding?.endpoint ?? "", + apiKey: "", + configId: binding?.configId ?? "", + configName: binding?.configName ?? "", + isEnabled: binding?.isEnabled ?? true + }; +} + +export function trimAiOptionalValue(value: string): string | undefined { + const normalized = value.trim(); + return normalized.length > 0 ? normalized : undefined; +} + +export function buildAiBindingPayload( + channel: Exclude, + formState: AiBindingFormState, + currentBinding: WebAiBindingSummary | null +): UpsertWebAiBindingInput { + return { + channel, + providerName: trimAiOptionalValue(formState.providerName), + model: trimAiOptionalValue(formState.model), + endpoint: trimAiOptionalValue(formState.endpoint), + configId: trimAiOptionalValue(formState.configId), + configName: trimAiOptionalValue(formState.configName), + apiKey: trimAiOptionalValue(formState.apiKey) ?? undefined, + isEnabled: formState.isEnabled ?? currentBinding?.isEnabled ?? true + }; +} diff --git a/apps/web/src/components/editor/resizable-media-node-view.tsx b/apps/web/src/components/editor/resizable-media-node-view.tsx new file mode 100644 index 0000000..e0c0874 --- /dev/null +++ b/apps/web/src/components/editor/resizable-media-node-view.tsx @@ -0,0 +1,283 @@ +import { useEffect, useRef, useState } from "react"; +import { NodeViewWrapper, type NodeViewProps } from "@tiptap/react"; +import { cn } from "@/lib/utils"; + +type MediaAlign = "left" | "center" | "right"; +type MediaKind = "image" | "video" | "youtube"; +type ResizeSide = "left" | "right"; + +type ResizableMediaNodeViewProps = NodeViewProps & { + mediaKind: MediaKind; +}; + +type HandleDescriptor = { + key: string; + side: ResizeSide; + className: string; +}; + +const HANDLE_DESCRIPTORS: HandleDescriptor[] = [ + { + key: "top-left", + side: "left", + className: "-left-1.5 -top-1.5 cursor-ew-resize" + }, + { + key: "bottom-left", + side: "left", + className: "-bottom-1.5 -left-1.5 cursor-ew-resize" + }, + { + key: "top-right", + side: "right", + className: "-right-1.5 -top-1.5 cursor-ew-resize" + }, + { + key: "bottom-right", + side: "right", + className: "-bottom-1.5 -right-1.5 cursor-ew-resize" + } +]; + +function clamp(value: number, min: number, max: number): number { + return Math.min(Math.max(value, min), max); +} + +function readWidthPercent(value: unknown): number { + const numericValue = typeof value === "number" ? value : Number(value); + + if (Number.isNaN(numericValue)) { + return 100; + } + + return clamp(numericValue, 25, 100); +} + +function readAlign(value: unknown): MediaAlign { + if (value === "left" || value === "right" || value === "center") { + return value; + } + + return "center"; +} + +function resolveAlignClass(align: MediaAlign): string { + if (align === "left") { + return "mr-auto"; + } + + if (align === "right") { + return "ml-auto"; + } + + return "mx-auto"; +} + +function isStringValue(value: unknown): value is string { + return typeof value === "string" && value.trim().length > 0; +} + +export function ResizableMediaNodeView({ + editor, + getPos, + mediaKind, + node, + selected, + updateAttributes +}: ResizableMediaNodeViewProps) { + const [isResizing, setIsResizing] = useState(false); + const mediaFrameRef = useRef(null); + const cleanupResizeRef = useRef<(() => void) | null>(null); + + const widthPercent = readWidthPercent(node.attrs.widthPercent); + const align = readAlign(node.attrs.align); + const src = isStringValue(node.attrs.src) ? node.attrs.src : ""; + const alt = isStringValue(node.attrs.alt) ? node.attrs.alt : ""; + const title = isStringValue(node.attrs.title) ? node.attrs.title : ""; + const showControls = selected || isResizing; + + useEffect(() => { + return () => { + cleanupResizeRef.current?.(); + }; + }, []); + + function selectCurrentNode(): void { + const position = getPos(); + + if (typeof position !== "number") { + return; + } + + editor.chain().focus().setNodeSelection(position).run(); + } + + function applyAlign(nextAlign: MediaAlign): void { + selectCurrentNode(); + updateAttributes({ align: nextAlign }); + } + + function startResize(side: ResizeSide) { + return (event: React.PointerEvent): void => { + event.preventDefault(); + event.stopPropagation(); + + selectCurrentNode(); + + const mediaFrame = mediaFrameRef.current; + const editorRoot = mediaFrame?.closest(".ProseMirror") as HTMLElement | null; + + if (!mediaFrame || !editorRoot) { + return; + } + + const startX = event.clientX; + const startWidth = mediaFrame.getBoundingClientRect().width; + const maxWidth = Math.max(editorRoot.clientWidth - 24, 240); + + const handlePointerMove = (moveEvent: PointerEvent): void => { + const delta = moveEvent.clientX - startX; + const resizedWidth = side === "right" ? startWidth + delta : startWidth - delta; + const nextWidth = clamp(resizedWidth, 180, maxWidth); + const nextWidthPercent = clamp((nextWidth / maxWidth) * 100, 25, 100); + + updateAttributes({ + widthPercent: Math.round(nextWidthPercent) + }); + }; + + const handlePointerUp = (): void => { + cleanupResizeRef.current?.(); + cleanupResizeRef.current = null; + setIsResizing(false); + }; + + cleanupResizeRef.current = () => { + window.removeEventListener("pointermove", handlePointerMove); + window.removeEventListener("pointerup", handlePointerUp); + }; + + window.addEventListener("pointermove", handlePointerMove); + window.addEventListener("pointerup", handlePointerUp, { once: true }); + setIsResizing(true); + }; + } + + function renderMediaContent() { + if (mediaKind === "image") { + return ( + {alt} + ); + } + + if (mediaKind === "youtube") { + return ( +
+