diff --git a/apps/api/.env.example b/apps/api/.env.example index 971b829..c8045bf 100644 --- a/apps/api/.env.example +++ b/apps/api/.env.example @@ -63,3 +63,11 @@ MAIL_SMTP_PASS="replace-with-smtp-password" # 发件人显示名称与地址 MAIL_FROM_NAME="TodoList" MAIL_FROM_ADDRESS="no-reply@example.com" + +# [数据加密] 服务端敏感数据加密主密钥 +# 用于加密 AI 配置、任务内容、同步 payload、附件元数据等数据库字段 +# 请使用高强度随机字符串,生产环境务必单独保管 +DATA_ENCRYPTION_SECRET="replace-with-a-long-random-secret" + +# [对象存储加密] 服务端对象加密策略,默认使用 AES256;如需关闭可填写 NONE +S3_SERVER_SIDE_ENCRYPTION="AES256" diff --git a/apps/api/package.json b/apps/api/package.json index f2d2dff..cac4457 100644 --- a/apps/api/package.json +++ b/apps/api/package.json @@ -3,12 +3,13 @@ "version": "0.1.0", "description": "TodoList API service", "scripts": { - "prisma:generate": "prisma generate", + "prisma:generate": "node -e \"require('node:fs').rmSync('generated/prisma', { recursive: true, force: true })\" && prisma generate", "prisma:format": "prisma format", "prisma:validate": "prisma validate", "prebuild": "pnpm run prisma:generate", "pretypecheck": "pnpm run prisma:generate", "pretest": "pnpm run prisma:generate", + "data:reencrypt": "node -e \"require('node:fs').rmSync('.tmp-compile', { recursive: true, force: true })\" && tsc -p tsconfig.json --outDir .tmp-compile --noEmit false && node .tmp-compile/scripts/reencrypt-sensitive-data.js && node -e \"require('node:fs').rmSync('.tmp-compile', { recursive: true, force: true })\"", "start": "node dist/main.js", "start:dev": "ts-node-dev --respawn --transpile-only src/main.ts", "build": "tsc -p tsconfig.build.json", diff --git a/apps/api/prisma/schema.prisma b/apps/api/prisma/schema.prisma index 838df48..b6cfa31 100644 --- a/apps/api/prisma/schema.prisma +++ b/apps/api/prisma/schema.prisma @@ -63,7 +63,8 @@ enum NotificationStatus { model User { id String @id @default(cuid()) - email String @unique + email String + emailHash String? @unique nickname String? avatarUrl String? status UserStatus @default(ACTIVE) @@ -97,11 +98,13 @@ model AuthIdentity { provider AuthProvider providerUserId String email String? + emailHash String? createdAt DateTime @default(now()) updatedAt DateTime @updatedAt user User @relation(fields: [userId], references: [id], onDelete: Cascade) @@unique([provider, providerUserId]) + @@index([emailHash]) @@index([userId]) @@map("auth_identities") } @@ -273,6 +276,8 @@ model AiProviderBinding { channel AiChannel providerName String model String? + configId String? + configName String? encryptedApiKey String? endpoint String? isDefault Boolean @default(false) diff --git a/apps/api/scripts/reencrypt-sensitive-data.ts b/apps/api/scripts/reencrypt-sensitive-data.ts new file mode 100644 index 0000000..af39197 --- /dev/null +++ b/apps/api/scripts/reencrypt-sensitive-data.ts @@ -0,0 +1,418 @@ +import "dotenv/config"; +import { PrismaPg } from "@prisma/adapter-pg"; +import { ConfigService } from "@nestjs/config"; +import { Prisma, PrismaClient } from "../generated/prisma/client"; +import { DataEncryptionService } from "../src/security/data-encryption.service"; + +type MigrationCounter = Record< + | "users" + | "authIdentities" + | "aiBindings" + | "publicPools" + | "aiUsageLogs" + | "tasks" + | "attachments" + | "syncOperations", + number +>; + +function createEncryptionService(): DataEncryptionService { + const configService = { + get: (key: string) => process.env[key] + } as ConfigService; + + return new DataEncryptionService(configService); +} + +function encryptStringIfNeeded( + value: string | null, + dataEncryptionService: DataEncryptionService +): string | null | undefined { + if (value === null || dataEncryptionService.isEncryptedString(value)) { + return undefined; + } + + return dataEncryptionService.encryptString(value) ?? null; +} + +function assignRequiredEncryptedString, K extends keyof T>( + target: T, + key: K, + value: string | null | undefined +): void { + if (typeof value === "string") { + target[key] = value as T[K]; + } +} + +function assignOptionalEncryptedString, K extends keyof T>( + target: T, + key: K, + value: string | null | undefined +): void { + if (value !== undefined) { + target[key] = value as T[K]; + } +} + +function encryptJsonIfNeeded( + value: Prisma.JsonValue | null, + dataEncryptionService: DataEncryptionService +): Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput | undefined { + if (value === null) { + return undefined; + } + + if (typeof value === "string" && dataEncryptionService.isEncryptedString(value)) { + return undefined; + } + + return (dataEncryptionService.encryptJson(value as Prisma.InputJsonValue) ?? Prisma.JsonNull) as + | Prisma.InputJsonValue + | Prisma.NullableJsonNullValueInput; +} + +function resolvePlainString( + value: string | null, + dataEncryptionService: DataEncryptionService +): string | null { + if (value === null) { + return null; + } + + return dataEncryptionService.isEncryptedString(value) + ? (dataEncryptionService.decryptString(value) ?? null) + : value; +} + +async function main(): Promise { + if (!process.env["DATABASE_URL"]) { + throw new Error("缺少 DATABASE_URL,无法执行敏感数据迁移"); + } + + if (!process.env["DATA_ENCRYPTION_SECRET"]) { + throw new Error("缺少 DATA_ENCRYPTION_SECRET,无法执行敏感数据迁移"); + } + + const prisma = new PrismaClient({ + adapter: new PrismaPg({ + connectionString: process.env["DATABASE_URL"] + }) + }); + const dataEncryptionService = createEncryptionService(); + const counter: MigrationCounter = { + users: 0, + authIdentities: 0, + aiBindings: 0, + publicPools: 0, + aiUsageLogs: 0, + tasks: 0, + attachments: 0, + syncOperations: 0 + }; + + try { + const users = await prisma.user.findMany({ + select: { + id: true, + email: true, + emailHash: true, + nickname: true, + avatarUrl: true + } + }); + + for (const user of users) { + const normalizedEmail = resolvePlainString(user.email, dataEncryptionService)?.toLowerCase(); + if (!normalizedEmail) { + continue; + } + const nextEmailHash = dataEncryptionService.createLookupHash("user.email", normalizedEmail); + const data: Prisma.UserUpdateInput = {}; + const email = encryptStringIfNeeded(user.email, dataEncryptionService); + const nickname = encryptStringIfNeeded(user.nickname, dataEncryptionService); + const avatarUrl = encryptStringIfNeeded(user.avatarUrl, dataEncryptionService); + + assignRequiredEncryptedString(data, "email", email); + if (user.emailHash !== nextEmailHash) { + data.emailHash = nextEmailHash; + } + assignOptionalEncryptedString(data, "nickname", nickname); + assignOptionalEncryptedString(data, "avatarUrl", avatarUrl); + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.user.update({ + where: { + id: user.id + }, + data + }); + counter.users += 1; + } + + const authIdentities = await prisma.authIdentity.findMany({ + select: { + id: true, + email: true, + emailHash: true + } + }); + + for (const authIdentity of authIdentities) { + const data: Prisma.AuthIdentityUpdateInput = {}; + const email = encryptStringIfNeeded(authIdentity.email, dataEncryptionService); + const normalizedIdentityEmail = resolvePlainString(authIdentity.email, dataEncryptionService); + const nextEmailHash = + normalizedIdentityEmail === null + ? null + : dataEncryptionService.createLookupHash( + "auth_identity.email", + normalizedIdentityEmail.toLowerCase() + ); + + assignOptionalEncryptedString(data, "email", email); + if (authIdentity.emailHash !== nextEmailHash) { + data.emailHash = nextEmailHash; + } + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.authIdentity.update({ + where: { + id: authIdentity.id + }, + data + }); + counter.authIdentities += 1; + } + + const aiBindings = await prisma.aiProviderBinding.findMany({ + select: { + id: true, + providerName: true, + model: true, + configId: true, + configName: true, + endpoint: true, + encryptedApiKey: true + } + }); + + for (const binding of aiBindings) { + const data: Prisma.AiProviderBindingUpdateInput = {}; + const providerName = encryptStringIfNeeded(binding.providerName, dataEncryptionService); + const model = encryptStringIfNeeded(binding.model, dataEncryptionService); + const configId = encryptStringIfNeeded(binding.configId, dataEncryptionService); + const configName = encryptStringIfNeeded(binding.configName, dataEncryptionService); + const endpoint = encryptStringIfNeeded(binding.endpoint, dataEncryptionService); + const encryptedApiKey = encryptStringIfNeeded(binding.encryptedApiKey, dataEncryptionService); + + assignRequiredEncryptedString(data, "providerName", providerName); + assignOptionalEncryptedString(data, "model", model); + assignOptionalEncryptedString(data, "configId", configId); + assignOptionalEncryptedString(data, "configName", configName); + assignOptionalEncryptedString(data, "endpoint", endpoint); + assignOptionalEncryptedString(data, "encryptedApiKey", encryptedApiKey); + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.aiProviderBinding.update({ + where: { + id: binding.id + }, + data + }); + counter.aiBindings += 1; + } + + const publicPools = await prisma.aiPublicPoolConfig.findMany({ + select: { + id: true, + providerName: true, + model: true, + endpoint: true, + encryptedApiKey: true + } + }); + + for (const publicPool of publicPools) { + const data: Prisma.AiPublicPoolConfigUpdateInput = {}; + const providerName = encryptStringIfNeeded(publicPool.providerName, dataEncryptionService); + const model = encryptStringIfNeeded(publicPool.model, dataEncryptionService); + const endpoint = encryptStringIfNeeded(publicPool.endpoint, dataEncryptionService); + const encryptedApiKey = encryptStringIfNeeded( + publicPool.encryptedApiKey, + dataEncryptionService + ); + + assignOptionalEncryptedString(data, "providerName", providerName); + assignOptionalEncryptedString(data, "model", model); + assignOptionalEncryptedString(data, "endpoint", endpoint); + assignOptionalEncryptedString(data, "encryptedApiKey", encryptedApiKey); + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.aiPublicPoolConfig.update({ + where: { + id: publicPool.id + }, + data + }); + counter.publicPools += 1; + } + + const aiUsageLogs = await prisma.aiUsageLog.findMany({ + select: { + id: true, + providerName: true, + model: true + } + }); + + for (const aiUsageLog of aiUsageLogs) { + const data: Prisma.AiUsageLogUpdateInput = {}; + const providerName = encryptStringIfNeeded(aiUsageLog.providerName, dataEncryptionService); + const model = encryptStringIfNeeded(aiUsageLog.model, dataEncryptionService); + + assignOptionalEncryptedString(data, "providerName", providerName); + assignOptionalEncryptedString(data, "model", model); + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.aiUsageLog.update({ + where: { + id: aiUsageLog.id + }, + data + }); + counter.aiUsageLogs += 1; + } + + const tasks = await prisma.task.findMany({ + select: { + id: true, + title: true, + contentJson: true, + contentText: true + } + }); + + for (const task of tasks) { + const data: Prisma.TaskUpdateInput = {}; + const title = encryptStringIfNeeded(task.title, dataEncryptionService); + const contentJson = encryptJsonIfNeeded(task.contentJson, dataEncryptionService); + const contentText = encryptStringIfNeeded(task.contentText, dataEncryptionService); + + assignRequiredEncryptedString(data, "title", title); + if (contentJson !== undefined) { + data.contentJson = contentJson; + } + assignOptionalEncryptedString(data, "contentText", contentText); + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.task.update({ + where: { + id: task.id + }, + data + }); + counter.tasks += 1; + } + + const attachments = await prisma.attachment.findMany({ + select: { + id: true, + url: true, + fileName: true, + checksum: true + } + }); + + for (const attachment of attachments) { + const data: Prisma.AttachmentUpdateInput = {}; + const url = encryptStringIfNeeded(attachment.url, dataEncryptionService); + const fileName = encryptStringIfNeeded(attachment.fileName, dataEncryptionService); + const checksum = encryptStringIfNeeded(attachment.checksum, dataEncryptionService); + + assignRequiredEncryptedString(data, "url", url); + assignOptionalEncryptedString(data, "fileName", fileName); + assignOptionalEncryptedString(data, "checksum", checksum); + + if (Object.keys(data).length === 0) { + continue; + } + + await prisma.attachment.update({ + where: { + id: attachment.id + }, + data + }); + counter.attachments += 1; + } + + const syncOperations = await prisma.syncOperation.findMany({ + select: { + id: true, + payload: true + } + }); + + for (const operation of syncOperations) { + if (operation.payload === null) { + continue; + } + + let nextPayload: string | null = null; + if (typeof operation.payload === "string") { + if (dataEncryptionService.isEncryptedString(operation.payload)) { + continue; + } + + nextPayload = dataEncryptionService.encryptString(operation.payload) ?? null; + } else { + nextPayload = + dataEncryptionService.encryptString(JSON.stringify(operation.payload)) ?? null; + } + + if (nextPayload === null) { + continue; + } + + await prisma.syncOperation.update({ + where: { + id: operation.id + }, + data: { + payload: nextPayload + } + }); + counter.syncOperations += 1; + } + + console.log("敏感数据迁移完成"); + console.log(JSON.stringify(counter, null, 2)); + } finally { + await prisma.$disconnect(); + } +} + +void main().catch((error: unknown) => { + const message = error instanceof Error ? error.message : "未知错误"; + console.error(`敏感数据迁移失败:${message}`); + process.exitCode = 1; +}); diff --git a/apps/api/src/ai/ai-provider-registry.service.ts b/apps/api/src/ai/ai-provider-registry.service.ts new file mode 100644 index 0000000..54387a9 --- /dev/null +++ b/apps/api/src/ai/ai-provider-registry.service.ts @@ -0,0 +1,28 @@ +import { Injectable } from "@nestjs/common"; +import { AiChannel } from "../../generated/prisma/client"; +import { AstrbotProvider } from "./providers/astrbot.provider"; +import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider"; +import { AiChannelExecutor } from "./ai.types"; + +@Injectable() +export class AiProviderRegistryService { + private readonly executors = new Map(); + + constructor( + openAiCompatibleProvider: OpenAiCompatibleProvider, + astrbotProvider: AstrbotProvider + ) { + this.executors.set(AiChannel.USER_KEY, openAiCompatibleProvider); + this.executors.set(AiChannel.PUBLIC_POOL, openAiCompatibleProvider); + this.executors.set(AiChannel.ASTRBOT, astrbotProvider); + } + + getExecutor(channel: AiChannel): AiChannelExecutor { + const executor = this.executors.get(channel); + if (!executor) { + throw new Error(`未找到 ${channel} 对应的 AI 通道执行器`); + } + + return executor; + } +} diff --git a/apps/api/src/ai/ai-rate-limit.service.ts b/apps/api/src/ai/ai-rate-limit.service.ts new file mode 100644 index 0000000..e84b2de --- /dev/null +++ b/apps/api/src/ai/ai-rate-limit.service.ts @@ -0,0 +1,123 @@ +import { Injectable } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; + +type AiRateLimitBucket = { + count: number; + resetAt: number; +}; + +export type AiRateLimitResult = + | { + allowed: true; + } + | { + allowed: false; + reason: "USER" | "IP"; + retryAfterMs: number; + limit: number; + windowMs: number; + }; + +@Injectable() +export class AiRateLimitService { + private readonly userBuckets = new Map(); + private readonly ipBuckets = new Map(); + private readonly windowMs: number; + private readonly userLimit: number; + private readonly ipLimit: number; + + constructor(private readonly configService: ConfigService) { + this.windowMs = this.readPositiveInt("AI_RATE_LIMIT_WINDOW_MS", 60_000); + this.userLimit = this.readPositiveInt("AI_RATE_LIMIT_USER_MAX", 20); + this.ipLimit = this.readPositiveInt("AI_RATE_LIMIT_IP_MAX", 60); + } + + consume(userId: string, clientIp: string | null): AiRateLimitResult { + const now = Date.now(); + const userBucket = this.getBucket(this.userBuckets, userId, now); + if (userBucket.count >= this.userLimit) { + return { + allowed: false, + reason: "USER", + retryAfterMs: Math.max(0, userBucket.resetAt - now), + limit: this.userLimit, + windowMs: this.windowMs + }; + } + + const normalizedIp = this.normalizeIp(clientIp); + const ipBucket = normalizedIp ? this.getBucket(this.ipBuckets, normalizedIp, now) : null; + if (ipBucket && ipBucket.count >= this.ipLimit) { + return { + allowed: false, + reason: "IP", + retryAfterMs: Math.max(0, ipBucket.resetAt - now), + limit: this.ipLimit, + windowMs: this.windowMs + }; + } + + userBucket.count += 1; + if (ipBucket) { + ipBucket.count += 1; + } + + this.cleanupExpiredBuckets(this.userBuckets, now); + this.cleanupExpiredBuckets(this.ipBuckets, now); + + return { + allowed: true + }; + } + + private getBucket( + buckets: Map, + key: string, + now: number + ): AiRateLimitBucket { + const currentBucket = buckets.get(key); + if (!currentBucket || now >= currentBucket.resetAt) { + const nextBucket: AiRateLimitBucket = { + count: 0, + resetAt: now + this.windowMs + }; + buckets.set(key, nextBucket); + return nextBucket; + } + + return currentBucket; + } + + private cleanupExpiredBuckets(buckets: Map, now: number): void { + if (buckets.size <= 256) { + return; + } + + for (const [key, bucket] of buckets.entries()) { + if (now >= bucket.resetAt) { + buckets.delete(key); + } + } + } + + private normalizeIp(clientIp: string | null): string | null { + if (!clientIp) { + return null; + } + + const normalizedIp = clientIp.trim(); + return normalizedIp.length > 0 ? normalizedIp : null; + } + + private readPositiveInt(key: string, fallbackValue: number): number { + const rawValue = this.configService.get(key); + const parsedValue = + typeof rawValue === "number" ? rawValue : Number.parseInt(String(rawValue ?? ""), 10); + + if (!Number.isFinite(parsedValue) || parsedValue <= 0) { + return fallbackValue; + } + + return parsedValue; + } +} diff --git a/apps/api/src/ai/ai.controller.ts b/apps/api/src/ai/ai.controller.ts new file mode 100644 index 0000000..8756e08 --- /dev/null +++ b/apps/api/src/ai/ai.controller.ts @@ -0,0 +1,74 @@ +import { + Body, + Controller, + Get, + Headers, + Ip, + Post, + Query, + UnauthorizedException +} from "@nestjs/common"; +import { AiChatDto } from "./dto/ai-chat.dto"; +import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto"; +import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; +import { + AiChatResponse, + AiService, + ListAiBindingsResponse, + ListAiUsageLogsResponse, + TestAiBindingResponse +} from "./ai.service"; + +@Controller("ai") +export class AiController { + constructor(private readonly aiService: AiService) {} + + @Get("bindings") + async listBindings( + @Headers("x-user-id") userIdHeader: string | string[] | undefined + ): Promise { + return this.aiService.listBindings(this.resolveUserId(userIdHeader)); + } + + @Get("usage-logs") + async listUsageLogs( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Query() query: ListAiUsageLogsQueryDto + ): Promise { + return this.aiService.listUsageLogs(this.resolveUserId(userIdHeader), query); + } + + @Post("bindings") + async upsertBinding( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: UpsertAiProviderBindingDto + ) { + return this.aiService.upsertBinding(this.resolveUserId(userIdHeader), body); + } + + @Post("bindings/test") + async testBinding( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: UpsertAiProviderBindingDto + ): Promise { + return this.aiService.testBinding(this.resolveUserId(userIdHeader), body); + } + + @Post("chat") + async chat( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Ip() clientIp: string, + @Body() body: AiChatDto + ): Promise { + return this.aiService.chat(this.resolveUserId(userIdHeader), body, clientIp); + } + + private resolveUserId(userIdHeader: string | string[] | undefined): string { + const userId = Array.isArray(userIdHeader) ? userIdHeader[0] : userIdHeader; + if (!userId) { + throw new UnauthorizedException("缺少用户上下文"); + } + + return userId; + } +} diff --git a/apps/api/src/ai/ai.module.ts b/apps/api/src/ai/ai.module.ts new file mode 100644 index 0000000..a17544a --- /dev/null +++ b/apps/api/src/ai/ai.module.ts @@ -0,0 +1,21 @@ +import { Module } from "@nestjs/common"; +import { PrismaModule } from "../prisma/prisma.module"; +import { AiRateLimitService } from "./ai-rate-limit.service"; +import { AiController } from "./ai.controller"; +import { AiProviderRegistryService } from "./ai-provider-registry.service"; +import { AiService } from "./ai.service"; +import { AstrbotProvider } from "./providers/astrbot.provider"; +import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider"; + +@Module({ + imports: [PrismaModule], + controllers: [AiController], + providers: [ + AiService, + AiRateLimitService, + AiProviderRegistryService, + OpenAiCompatibleProvider, + AstrbotProvider + ] +}) +export class AiModule {} diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts new file mode 100644 index 0000000..25cbf61 --- /dev/null +++ b/apps/api/src/ai/ai.service.ts @@ -0,0 +1,988 @@ +import { + BadGatewayException, + BadRequestException, + HttpException, + HttpStatus, + Injectable, + Logger +} from "@nestjs/common"; +import { + AiChannel, + AiUsageLog, + AiProviderBinding, + AiPublicPoolConfig, + Prisma, + TaskPriority, + TaskStatus +} from "../../generated/prisma/client"; +import { PrismaService } from "../prisma/prisma.service"; +import { DataEncryptionService } from "../security/data-encryption.service"; +import { AiRateLimitService } from "./ai-rate-limit.service"; +import { AiProviderRegistryService } from "./ai-provider-registry.service"; +import { AiChatDto } from "./dto/ai-chat.dto"; +import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto"; +import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; +import { + AiResolvedRouteCandidate, + AiRouteAttempt, + AiRouteFailureError, + AiUsageMetrics +} from "./ai.types"; + +type AiBindingSummary = { + id: string; + channel: AiChannel; + providerName: string; + model: string | null; + configId: string | null; + configName: string | null; + endpoint: string | null; + isEnabled: boolean; + hasApiKey: boolean; + maskedApiKey: string | null; + updatedAt: string; +}; + +type AiRoutePlanEntry = + | { + kind: "candidate"; + candidate: AiResolvedRouteCandidate; + } + | { + kind: "skip"; + attempt: AiRouteAttempt; + }; + +export type ListAiBindingsResponse = { + routeOrder: AiChannel[]; + bindings: AiBindingSummary[]; + publicPool: { + enabled: boolean; + providerName: string | null; + model: string | null; + hasApiKey: boolean; + } | null; +}; + +type AiUsageLogSummary = { + id: string; + channel: AiChannel; + providerName: string | null; + model: string | null; + promptTokens: number; + completionTokens: number; + totalTokens: number; + latencyMs: number | null; + success: boolean; + errorCode: string | null; + createdAt: string; +}; + +type AiContextTaskItem = { + id: string; + title: string; + priority: TaskPriority; + status: TaskStatus; + ddl: Date | null; + contentText: string | null; + updatedAt: Date; +}; + +export type ListAiUsageLogsResponse = { + items: AiUsageLogSummary[]; + page: number; + pageSize: number; + total: number; +}; + +export type AiChatResponse = { + channel: AiChannel; + providerName: string; + model: string | null; + content: string; + sessionId: string | null; + attempts: AiRouteAttempt[]; +}; + +export type TestAiBindingResponse = + | { + success: true; + channel: AiChannel; + providerName: string; + model: string | null; + contentPreview: string; + } + | { + success: false; + channel: AiChannel; + providerName: string; + model: string | null; + code: string; + message: string; + }; + +@Injectable() +export class AiService { + private readonly logger = new Logger(AiService.name); + private readonly maxContextTasks = 6; + private readonly maxContextContentLength = 80; + + constructor( + private readonly prismaService: PrismaService, + private readonly aiProviderRegistryService: AiProviderRegistryService, + private readonly dataEncryptionService: DataEncryptionService, + private readonly aiRateLimitService: AiRateLimitService + ) {} + + async listBindings(userId: string): Promise { + const [bindings, publicPool] = await Promise.all([ + this.prismaService.aiProviderBinding.findMany({ + where: { + userId + }, + orderBy: [{ updatedAt: "desc" }] + }), + this.prismaService.aiPublicPoolConfig.findFirst({ + orderBy: { + updatedAt: "desc" + } + }) + ]); + + const latestBindings = this.pickLatestBindingsByChannel(bindings); + + return { + routeOrder: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL], + bindings: latestBindings.map((binding) => this.serializeBinding(binding)), + publicPool: publicPool + ? { + enabled: publicPool.enabled, + providerName: this.readDecryptedString(publicPool.providerName), + model: this.readDecryptedString(publicPool.model), + hasApiKey: Boolean(publicPool.encryptedApiKey) + } + : null + }; + } + + async listUsageLogs( + userId: string, + query: ListAiUsageLogsQueryDto + ): Promise { + const page = query.page ?? 1; + const pageSize = query.pageSize ?? 20; + const skip = (page - 1) * pageSize; + const where: Prisma.AiUsageLogWhereInput = { + userId + }; + + if (query.channel) { + where.channel = query.channel; + } + + if (query.success !== undefined) { + where.success = query.success; + } + + const [items, total] = await Promise.all([ + this.prismaService.aiUsageLog.findMany({ + where, + orderBy: { + createdAt: "desc" + }, + skip, + take: pageSize + }), + this.prismaService.aiUsageLog.count({ + where + }) + ]); + + return { + items: items.map((item) => this.serializeUsageLog(item)), + page, + pageSize, + total + }; + } + + async upsertBinding(userId: string, dto: UpsertAiProviderBindingDto): Promise { + if (dto.channel === AiChannel.PUBLIC_POOL) { + throw new BadRequestException("公共 AI 通道只能由管理员配置"); + } + + this.validateBindingInput(dto); + + const result = await this.prismaService.$transaction(async (tx) => { + const existingBinding = await tx.aiProviderBinding.findFirst({ + where: { + userId, + channel: dto.channel + }, + orderBy: { + updatedAt: "desc" + } + }); + + if (!existingBinding) { + return tx.aiProviderBinding.create({ + data: { + userId, + channel: dto.channel, + providerName: this.encryptRequiredString(this.normalizeProviderName(dto.providerName)), + model: this.encryptOptionalString(dto.model), + configId: this.encryptOptionalString(dto.configId), + configName: this.encryptOptionalString(dto.configName), + endpoint: this.encryptOptionalString(dto.endpoint), + encryptedApiKey: this.encryptOptionalString(dto.apiKey), + isEnabled: dto.isEnabled ?? true + } + }); + } + + const updateData: Prisma.AiProviderBindingUpdateInput = { + channel: dto.channel, + providerName: this.encryptRequiredString(this.normalizeProviderName(dto.providerName)), + model: this.encryptOptionalString(dto.model), + configId: this.encryptOptionalString(dto.configId), + configName: this.encryptOptionalString(dto.configName), + isEnabled: dto.isEnabled ?? existingBinding.isEnabled + }; + + if (dto.endpoint !== undefined) { + updateData.endpoint = this.encryptOptionalString(dto.endpoint); + } + + if (dto.apiKey !== undefined) { + updateData.encryptedApiKey = this.encryptOptionalString(dto.apiKey); + } + + return tx.aiProviderBinding.update({ + where: { + id: existingBinding.id + }, + data: updateData + }); + }); + + return this.serializeBinding(result); + } + + async testBinding( + userId: string, + dto: UpsertAiProviderBindingDto + ): Promise { + if (dto.channel === AiChannel.PUBLIC_POOL) { + throw new BadRequestException("公共 AI 通道不能由用户自行测试"); + } + + const candidate = await this.buildTestCandidate(userId, dto); + const executor = this.aiProviderRegistryService.getExecutor(candidate.channel); + + try { + const result = await executor.execute(candidate, { + userId, + message: "请只回复“连接成功”,不要添加其他内容。", + sessionId: null + }); + + return { + success: true, + channel: result.channel, + providerName: result.providerName, + model: result.model, + contentPreview: this.limitPreviewText(result.content) + }; + } catch (error) { + if (error instanceof AiRouteFailureError) { + return { + success: false, + channel: error.channel, + providerName: error.providerName, + model: candidate.model, + code: error.code, + message: error.message + }; + } + + if (error instanceof Error) { + return { + success: false, + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + code: "UNKNOWN_ERROR", + message: error.message + }; + } + + return { + success: false, + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + code: "UNKNOWN_ERROR", + message: "未知错误" + }; + } + } + + async chat( + userId: string, + dto: AiChatDto, + clientIp: string | null = null + ): Promise { + const rateLimitResult = this.aiRateLimitService.consume(userId, clientIp); + if (!rateLimitResult.allowed) { + throw new HttpException( + { + message: "AI 请求过于频繁,请稍后再试", + code: "AI_RATE_LIMITED", + dimension: rateLimitResult.reason === "USER" ? "user" : "ip", + retryAfterMs: rateLimitResult.retryAfterMs, + limit: rateLimitResult.limit, + windowMs: rateLimitResult.windowMs + }, + HttpStatus.TOO_MANY_REQUESTS + ); + } + + const attempts: AiRouteAttempt[] = []; + const plan = await this.buildRoutePlan(userId, dto.channel ?? null); + const promptMessage = await this.buildPromptMessage(userId, dto.message, dto.localTasks ?? []); + + for (const entry of plan) { + if (entry.kind === "skip") { + attempts.push(entry.attempt); + continue; + } + + const executor = this.aiProviderRegistryService.getExecutor(entry.candidate.channel); + const startedAt = Date.now(); + + try { + const result = await executor.execute(entry.candidate, { + userId, + message: promptMessage, + sessionId: dto.sessionId ?? null + }); + const latencyMs = Date.now() - startedAt; + + attempts.push({ + channel: result.channel, + providerName: result.providerName, + model: result.model, + status: "success", + reasonCode: null, + reasonMessage: null + }); + await this.recordUsageLog({ + userId, + channel: result.channel, + providerName: result.providerName, + model: result.model, + usage: result.usage, + latencyMs, + success: true, + errorCode: null + }); + + return { + channel: result.channel, + providerName: result.providerName, + model: result.model, + content: result.content, + sessionId: result.sessionId, + attempts + }; + } catch (error) { + const latencyMs = Date.now() - startedAt; + const failureAttempt = this.toFailureAttempt(entry.candidate, error); + attempts.push(failureAttempt); + await this.recordUsageLog({ + userId, + channel: failureAttempt.channel, + providerName: failureAttempt.providerName, + model: failureAttempt.model, + usage: null, + latencyMs, + success: false, + errorCode: failureAttempt.reasonCode + }); + this.logger.warn( + `AI 通道降级:channel=${failureAttempt.channel} provider=${failureAttempt.providerName ?? "unknown"} code=${failureAttempt.reasonCode ?? "UNKNOWN"} message=${failureAttempt.reasonMessage ?? "unknown"}` + ); + } + } + + throw new BadGatewayException({ + message: "当前没有可用的 AI 通道,请稍后重试", + attempts + }); + } + + private async buildRoutePlan( + userId: string, + selectedChannel: AiChannel | null + ): Promise { + const plan: AiRoutePlanEntry[] = []; + const targetChannels = selectedChannel + ? [selectedChannel] + : [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL]; + + for (const channel of targetChannels) { + if (channel === AiChannel.PUBLIC_POOL) { + const publicPool = await this.findEnabledPublicPool(); + if (publicPool) { + plan.push({ + kind: "candidate", + candidate: this.toPublicPoolCandidate(publicPool) + }); + } else { + plan.push({ + kind: "skip", + attempt: { + channel: AiChannel.PUBLIC_POOL, + providerName: null, + model: null, + status: "skipped", + reasonCode: "PUBLIC_POOL_DISABLED", + reasonMessage: "公共 AI 通道未开启" + } + }); + } + continue; + } + + const binding = await this.findPreferredBinding(userId, channel); + if (binding) { + plan.push({ + kind: "candidate", + candidate: this.toBindingCandidate(binding) + }); + continue; + } + + plan.push({ + kind: "skip", + attempt: { + channel, + providerName: null, + model: null, + status: "skipped", + reasonCode: "CHANNEL_NOT_CONFIGURED", + reasonMessage: + channel === AiChannel.USER_KEY + ? "当前用户未配置可用的自备 Key 通道" + : "当前用户未配置可用的 AstrBot 通道" + } + }); + } + + return plan; + } + + private async findPreferredBinding( + userId: string, + channel: AiChannel + ): Promise { + return this.prismaService.aiProviderBinding.findFirst({ + where: { + userId, + channel, + isEnabled: true + }, + orderBy: { + updatedAt: "desc" + } + }); + } + + private async findEnabledPublicPool(): Promise { + return this.prismaService.aiPublicPoolConfig.findFirst({ + where: { + enabled: true + }, + orderBy: { + updatedAt: "desc" + } + }); + } + + private async buildTestCandidate( + userId: string, + dto: UpsertAiProviderBindingDto + ): Promise { + const existingBinding = await this.prismaService.aiProviderBinding.findFirst({ + where: { + userId, + channel: dto.channel + }, + orderBy: { + updatedAt: "desc" + } + }); + + const mergedDto: UpsertAiProviderBindingDto = { + channel: dto.channel, + providerName: + dto.providerName ?? this.readDecryptedString(existingBinding?.providerName ?? null) ?? "", + model: dto.model ?? this.readDecryptedString(existingBinding?.model ?? null) ?? undefined, + configId: + dto.configId ?? this.readDecryptedString(existingBinding?.configId ?? null) ?? undefined, + configName: + dto.configName ?? + this.readDecryptedString(existingBinding?.configName ?? null) ?? + undefined, + endpoint: + dto.endpoint ?? this.readDecryptedString(existingBinding?.endpoint ?? null) ?? undefined, + apiKey: + dto.apiKey ?? + this.readDecryptedString(existingBinding?.encryptedApiKey ?? null) ?? + undefined, + isEnabled: dto.isEnabled ?? existingBinding?.isEnabled ?? true + }; + + this.validateBindingInput(mergedDto); + + return { + channel: mergedDto.channel, + source: existingBinding ? "binding" : "binding", + sourceId: existingBinding?.id ?? null, + providerName: this.normalizeProviderName(mergedDto.providerName), + model: this.normalizeOptionalString(mergedDto.model), + configId: this.normalizeOptionalString(mergedDto.configId), + configName: this.normalizeOptionalString(mergedDto.configName), + endpoint: this.normalizeOptionalString(mergedDto.endpoint), + apiKey: this.normalizeOptionalString(mergedDto.apiKey) + }; + } + + private toBindingCandidate(binding: AiProviderBinding): AiResolvedRouteCandidate { + return { + channel: binding.channel, + source: "binding", + sourceId: binding.id, + providerName: this.readDecryptedString(binding.providerName) ?? "", + model: this.readDecryptedString(binding.model), + configId: this.readDecryptedString(binding.configId), + configName: this.readDecryptedString(binding.configName), + endpoint: this.readDecryptedString(binding.endpoint), + apiKey: this.readDecryptedString(binding.encryptedApiKey) + }; + } + + private toPublicPoolCandidate(publicPool: AiPublicPoolConfig): AiResolvedRouteCandidate { + return { + channel: AiChannel.PUBLIC_POOL, + source: "public_pool", + sourceId: publicPool.id, + providerName: this.readDecryptedString(publicPool.providerName) ?? "public-pool", + model: this.readDecryptedString(publicPool.model), + configId: null, + configName: null, + endpoint: this.readDecryptedString(publicPool.endpoint), + apiKey: this.readDecryptedString(publicPool.encryptedApiKey) + }; + } + + private serializeBinding(binding: AiProviderBinding): AiBindingSummary { + const decryptedProviderName = this.readDecryptedString(binding.providerName) ?? ""; + const decryptedModel = this.readDecryptedString(binding.model); + const decryptedConfigId = this.readDecryptedString(binding.configId); + const decryptedConfigName = this.readDecryptedString(binding.configName); + const decryptedEndpoint = this.readDecryptedString(binding.endpoint); + const decryptedApiKey = this.readDecryptedString(binding.encryptedApiKey); + + return { + id: binding.id, + channel: binding.channel, + providerName: decryptedProviderName, + model: decryptedModel, + configId: decryptedConfigId, + configName: decryptedConfigName, + endpoint: decryptedEndpoint, + isEnabled: binding.isEnabled, + hasApiKey: Boolean(binding.encryptedApiKey), + maskedApiKey: this.maskSecret(decryptedApiKey), + updatedAt: binding.updatedAt.toISOString() + }; + } + + private pickLatestBindingsByChannel(bindings: AiProviderBinding[]): AiProviderBinding[] { + const bindingMap = new Map(); + + for (const binding of bindings) { + if (!bindingMap.has(binding.channel)) { + bindingMap.set(binding.channel, binding); + } + } + + return [AiChannel.USER_KEY, AiChannel.ASTRBOT] + .map((channel) => bindingMap.get(channel) ?? null) + .filter((binding): binding is AiProviderBinding => binding !== null); + } + + private serializeUsageLog(log: AiUsageLog): AiUsageLogSummary { + return { + id: log.id, + channel: log.channel, + providerName: this.readDecryptedString(log.providerName), + model: this.readDecryptedString(log.model), + promptTokens: log.promptTokens, + completionTokens: log.completionTokens, + totalTokens: log.totalTokens, + latencyMs: log.latencyMs, + success: log.success, + errorCode: log.errorCode, + createdAt: log.createdAt.toISOString() + }; + } + + private async buildPromptMessage( + userId: string, + userMessage: string, + localTasks: NonNullable + ): Promise { + const taskSummary = await this.buildTaskContextSummary(userId, localTasks); + if (!taskSummary) { + return userMessage; + } + + return [ + "你是 TodoList 的 AI 助手,需要结合用户当前待办提供任务统筹建议。", + "以下是系统整理的未完成任务摘要:", + taskSummary, + "请优先根据这些任务的紧急度、截止时间和执行顺序回答,并给出明确可执行的建议。", + `用户当前问题:${userMessage}` + ].join("\n\n"); + } + + private async buildTaskContextSummary( + userId: string, + localTasks: NonNullable + ): Promise { + const tasks = await this.prismaService.task.findMany({ + where: { + userId, + status: { + in: [TaskStatus.TODO, TaskStatus.IN_PROGRESS] + } + }, + select: { + id: true, + title: true, + priority: true, + status: true, + ddl: true, + contentText: true, + updatedAt: true + }, + take: 20 + }); + + const sortedTasks = this.sortContextTasks(this.mergeContextTasks(tasks, localTasks)); + if (sortedTasks.length === 0) { + return null; + } + + const visibleTasks = sortedTasks.slice(0, this.maxContextTasks); + const lines = visibleTasks.map((task, index) => { + const parts = [ + `${index + 1}. ${task.title}`, + `优先级:${this.getPriorityLabel(task.priority)}`, + `状态:${this.getStatusLabel(task.status)}`, + `DDL:${task.ddl ? task.ddl.toISOString() : "未设置"}` + ]; + + const contentSnippet = this.getContentSnippet(task.contentText); + if (contentSnippet) { + parts.push(`内容摘要:${contentSnippet}`); + } + + return parts.join(" | "); + }); + + const omittedCount = sortedTasks.length - visibleTasks.length; + if (omittedCount > 0) { + lines.push(`另有 ${omittedCount} 条任务已省略。`); + } + + return [`共 ${sortedTasks.length} 条未完成任务。`, ...lines].join("\n"); + } + + private mergeContextTasks( + databaseTasks: Array<{ + id: string; + title: string; + priority: TaskPriority; + status: TaskStatus; + ddl: Date | null; + contentText: string | null; + updatedAt: Date; + }>, + localTasks: NonNullable + ): AiContextTaskItem[] { + const taskMap = new Map(); + + for (const task of databaseTasks) { + taskMap.set(task.id, { + id: task.id, + title: this.readDecryptedString(task.title) ?? "未命名任务", + priority: task.priority, + status: task.status, + ddl: task.ddl, + contentText: this.readDecryptedString(task.contentText), + updatedAt: task.updatedAt + }); + } + + for (const task of localTasks) { + if (task.status !== TaskStatus.TODO && task.status !== TaskStatus.IN_PROGRESS) { + continue; + } + + const currentTask = taskMap.get(task.id); + const nextTask: AiContextTaskItem = { + id: task.id, + title: task.title.trim().length > 0 ? task.title.trim() : "未命名任务", + priority: task.priority, + status: task.status, + ddl: typeof task.ddlAt === "number" ? new Date(task.ddlAt) : null, + contentText: + typeof task.contentText === "string" && task.contentText.trim().length > 0 + ? task.contentText + : null, + updatedAt: new Date(task.updatedAt) + }; + + if (!currentTask || nextTask.updatedAt.getTime() >= currentTask.updatedAt.getTime()) { + taskMap.set(task.id, nextTask); + } + } + + return [...taskMap.values()].filter( + (task) => task.status === TaskStatus.TODO || task.status === TaskStatus.IN_PROGRESS + ); + } + + private sortContextTasks(tasks: AiContextTaskItem[]): AiContextTaskItem[] { + return [...tasks].sort((left, right) => { + const priorityDiff = + this.getPriorityWeight(right.priority) - this.getPriorityWeight(left.priority); + if (priorityDiff !== 0) { + return priorityDiff; + } + + const leftDdl = left.ddl?.getTime() ?? Number.POSITIVE_INFINITY; + const rightDdl = right.ddl?.getTime() ?? Number.POSITIVE_INFINITY; + if (leftDdl !== rightDdl) { + return leftDdl - rightDdl; + } + + return right.updatedAt.getTime() - left.updatedAt.getTime(); + }); + } + + private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt { + if (error instanceof AiRouteFailureError) { + return { + channel: error.channel, + providerName: error.providerName, + model: candidate.model, + status: "failed", + reasonCode: error.code, + reasonMessage: error.message + }; + } + + if (error instanceof Error) { + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + status: "failed", + reasonCode: "UNKNOWN_ERROR", + reasonMessage: error.message + }; + } + + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + status: "failed", + reasonCode: "UNKNOWN_ERROR", + reasonMessage: "未知错误" + }; + } + + private normalizeOptionalString(value: string | undefined): string | null { + if (value === undefined) { + return null; + } + + const normalizedValue = value.trim(); + return normalizedValue.length > 0 ? normalizedValue : null; + } + + private normalizeProviderName(value: string | undefined): string { + return this.normalizeOptionalString(value) ?? ""; + } + + private encryptOptionalString(value: string | undefined): string | null | undefined { + const normalizedValue = this.normalizeOptionalString(value); + return this.dataEncryptionService.encryptString(normalizedValue); + } + + private encryptRequiredString(value: string): string { + const encryptedValue = this.dataEncryptionService.encryptString(value); + if (!encryptedValue) { + throw new BadRequestException("敏感配置加密失败"); + } + + return encryptedValue; + } + + private readDecryptedString(value: string | null): string | null { + const decryptedValue = this.dataEncryptionService.decryptString(value); + return typeof decryptedValue === "string" ? decryptedValue : null; + } + + private validateBindingInput(dto: UpsertAiProviderBindingDto): void { + const providerName = this.normalizeOptionalString(dto.providerName); + const configId = this.normalizeOptionalString(dto.configId); + const configName = this.normalizeOptionalString(dto.configName); + + if (dto.channel === AiChannel.ASTRBOT) { + if (!providerName && !configId && !configName) { + throw new BadRequestException( + "AstrBot 通道至少需要 providerName、configId、configName 三者之一" + ); + } + return; + } + + if (!providerName) { + throw new BadRequestException("当前通道必须提供 providerName"); + } + } + + private maskSecret(secret: string | null): string | null { + if (!secret) { + return null; + } + + if (secret.length <= 6) { + return "*".repeat(secret.length); + } + + return `${secret.slice(0, 4)}***${secret.slice(-2)}`; + } + + private limitPreviewText(content: string): string { + const normalizedContent = content.replace(/\s+/g, " ").trim(); + if (normalizedContent.length <= 60) { + return normalizedContent; + } + + return `${normalizedContent.slice(0, 60)}...`; + } + + private getPriorityWeight(priority: TaskPriority): number { + switch (priority) { + case TaskPriority.URGENT: + return 4; + case TaskPriority.HIGH: + return 3; + case TaskPriority.MEDIUM: + return 2; + case TaskPriority.LOW: + return 1; + default: + return 0; + } + } + + private getPriorityLabel(priority: TaskPriority): string { + switch (priority) { + case TaskPriority.URGENT: + return "紧急"; + case TaskPriority.HIGH: + return "高"; + case TaskPriority.MEDIUM: + return "中"; + case TaskPriority.LOW: + return "低"; + default: + return String(priority); + } + } + + private getStatusLabel(status: TaskStatus): string { + switch (status) { + case TaskStatus.TODO: + return "待开始"; + case TaskStatus.IN_PROGRESS: + return "进行中"; + case TaskStatus.DONE: + return "已完成"; + case TaskStatus.ARCHIVED: + return "已归档"; + default: + return String(status); + } + } + + private getContentSnippet(contentText: string | null): string | null { + if (!contentText) { + return null; + } + + const normalizedContent = contentText.replace(/\s+/g, " ").trim(); + if (normalizedContent.length === 0) { + return null; + } + + if (normalizedContent.length <= this.maxContextContentLength) { + return normalizedContent; + } + + return `${normalizedContent.slice(0, this.maxContextContentLength)}...`; + } + + private async recordUsageLog(input: { + userId: string; + channel: AiChannel; + providerName: string | null; + model: string | null; + usage: AiUsageMetrics | null; + latencyMs: number; + success: boolean; + errorCode: string | null; + }): Promise { + try { + await this.prismaService.aiUsageLog.create({ + data: { + userId: input.userId, + channel: input.channel, + providerName: + input.providerName === null + ? null + : this.dataEncryptionService.encryptString(input.providerName), + model: + input.model === null ? null : this.dataEncryptionService.encryptString(input.model), + promptTokens: input.usage?.promptTokens ?? 0, + completionTokens: input.usage?.completionTokens ?? 0, + totalTokens: input.usage?.totalTokens ?? 0, + latencyMs: input.latencyMs, + success: input.success, + errorCode: input.errorCode + } + }); + } catch (error) { + const message = error instanceof Error ? error.message : "未知错误"; + this.logger.warn(`写入 AI 使用日志失败:${message}`); + } + } +} diff --git a/apps/api/src/ai/ai.types.ts b/apps/api/src/ai/ai.types.ts new file mode 100644 index 0000000..ccb8088 --- /dev/null +++ b/apps/api/src/ai/ai.types.ts @@ -0,0 +1,61 @@ +import { AiChannel } from "../../generated/prisma/client"; + +export type AiResolvedRouteCandidate = { + channel: AiChannel; + source: "binding" | "public_pool"; + sourceId: string | null; + providerName: string; + model: string | null; + configId: string | null; + configName: string | null; + endpoint: string | null; + apiKey: string | null; +}; + +export type AiChatInput = { + userId: string; + message: string; + sessionId: string | null; +}; + +export type AiChatResult = { + channel: AiChannel; + providerName: string; + model: string | null; + content: string; + sessionId: string | null; + usage: AiUsageMetrics | null; + raw: unknown; +}; + +export type AiUsageMetrics = { + promptTokens: number; + completionTokens: number; + totalTokens: number; +}; + +export type AiRouteAttempt = { + channel: AiChannel; + providerName: string | null; + model: string | null; + status: "skipped" | "failed" | "success"; + reasonCode: string | null; + reasonMessage: string | null; +}; + +export class AiRouteFailureError extends Error { + constructor( + public readonly channel: AiChannel, + public readonly providerName: string, + public readonly code: string, + message: string + ) { + super(message); + this.name = "AiRouteFailureError"; + Object.setPrototypeOf(this, new.target.prototype); + } +} + +export interface AiChannelExecutor { + execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise; +} diff --git a/apps/api/src/ai/dto/ai-chat.dto.ts b/apps/api/src/ai/dto/ai-chat.dto.ts new file mode 100644 index 0000000..e2613ae --- /dev/null +++ b/apps/api/src/ai/dto/ai-chat.dto.ts @@ -0,0 +1,60 @@ +import { Type } from "class-transformer"; +import { + IsArray, + IsEnum, + IsInt, + IsOptional, + IsString, + MinLength, + ValidateNested +} from "class-validator"; +import { AiChannel } from "../../../generated/prisma/client"; +import { TaskPriority, TaskStatus } from "../../../generated/prisma/client"; + +export class LocalTaskContextItemDto { + @IsString() + @MinLength(1) + id!: string; + + @IsString() + @MinLength(1) + title!: string; + + @IsEnum(TaskPriority) + priority!: TaskPriority; + + @IsEnum(TaskStatus) + status!: TaskStatus; + + @IsOptional() + @IsInt() + ddlAt?: number | null; + + @IsOptional() + @IsString() + contentText?: string | null; + + @IsInt() + updatedAt!: number; +} + +export class AiChatDto { + @IsString() + @MinLength(1) + message!: string; + + @IsOptional() + @IsString() + @MinLength(1) + sessionId?: string; + + @IsOptional() + @IsEnum(AiChannel) + channel?: AiChannel; + + @IsOptional() + @IsArray() + @ValidateNested({ each: true }) + @Type(() => LocalTaskContextItemDto) + localTasks?: LocalTaskContextItemDto[]; +} diff --git a/apps/api/src/ai/dto/list-ai-usage-logs-query.dto.ts b/apps/api/src/ai/dto/list-ai-usage-logs-query.dto.ts new file mode 100644 index 0000000..49aa2e0 --- /dev/null +++ b/apps/api/src/ai/dto/list-ai-usage-logs-query.dto.ts @@ -0,0 +1,48 @@ +import { Transform, Type } from "class-transformer"; +import { IsBoolean, IsEnum, IsInt, IsOptional, Max, Min } from "class-validator"; +import { AiChannel } from "../../../generated/prisma/client"; + +function normalizeBoolean(value: unknown): boolean | undefined { + if (typeof value === "boolean") { + return value; + } + + if (typeof value !== "string") { + return undefined; + } + + const normalized = value.trim().toLowerCase(); + if (normalized === "true" || normalized === "1") { + return true; + } + + if (normalized === "false" || normalized === "0") { + return false; + } + + return undefined; +} + +export class ListAiUsageLogsQueryDto { + @Type(() => Number) + @IsOptional() + @IsInt() + @Min(1) + page?: number; + + @Type(() => Number) + @IsOptional() + @IsInt() + @Min(1) + @Max(100) + pageSize?: number; + + @IsOptional() + @IsEnum(AiChannel) + channel?: AiChannel; + + @Transform(({ value }) => normalizeBoolean(value)) + @IsOptional() + @IsBoolean() + success?: boolean; +} diff --git a/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts new file mode 100644 index 0000000..b821bcc --- /dev/null +++ b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts @@ -0,0 +1,47 @@ +import { AiChannel } from "../../../generated/prisma/client"; +import { IsBoolean, IsEnum, IsOptional, IsString, IsUrl, MinLength } from "class-validator"; + +export class UpsertAiProviderBindingDto { + @IsEnum(AiChannel) + channel!: AiChannel; + + @IsOptional() + @IsString() + @MinLength(1) + providerName?: string; + + @IsOptional() + @IsString() + @MinLength(1) + model?: string; + + @IsOptional() + @IsString() + @MinLength(1) + configId?: string; + + @IsOptional() + @IsString() + @MinLength(1) + configName?: string; + + @IsOptional() + @IsUrl( + { + require_tld: false + }, + { + message: "endpoint \u5fc5\u987b\u662f\u5408\u6cd5\u7684 URL" + } + ) + endpoint?: string; + + @IsOptional() + @IsString() + @MinLength(1) + apiKey?: string; + + @IsOptional() + @IsBoolean() + isEnabled?: boolean; +} diff --git a/apps/api/src/ai/providers/astrbot.provider.ts b/apps/api/src/ai/providers/astrbot.provider.ts new file mode 100644 index 0000000..3d28cbb --- /dev/null +++ b/apps/api/src/ai/providers/astrbot.provider.ts @@ -0,0 +1,284 @@ +import { Injectable } from "@nestjs/common"; +import { + AiChannelExecutor, + AiChatInput, + AiChatResult, + AiResolvedRouteCandidate, + AiRouteFailureError +} from "../ai.types"; + +@Injectable() +export class AstrbotProvider implements AiChannelExecutor { + async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise { + const routeLabel = + candidate.providerName || candidate.configName || candidate.configId || "astrbot"; + + if (!candidate.endpoint) { + throw new AiRouteFailureError( + candidate.channel, + routeLabel, + "MISSING_ENDPOINT", + "缺少 AstrBot 服务地址配置" + ); + } + + if (!candidate.apiKey) { + throw new AiRouteFailureError( + candidate.channel, + routeLabel, + "MISSING_API_KEY", + "缺少 AstrBot API Key 配置" + ); + } + + const requestUrl = this.buildRequestUrl(candidate.endpoint); + + let response: Response; + try { + response = await fetch(requestUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${candidate.apiKey}` + }, + body: JSON.stringify({ + username: input.userId, + session_id: input.sessionId ?? undefined, + message: input.message, + enable_streaming: false, + selected_model: candidate.model ?? undefined + }), + signal: AbortSignal.timeout(30000) + }); + } catch (error) { + throw new AiRouteFailureError( + candidate.channel, + routeLabel, + "UPSTREAM_UNREACHABLE", + this.toErrorMessage(error, "AstrBot 服务请求失败") + ); + } + + if (!response.ok) { + const rawText = await response.text(); + throw new AiRouteFailureError( + candidate.channel, + routeLabel, + `UPSTREAM_HTTP_${response.status}`, + this.extractHttpErrorMessage(rawText, response.status) + ); + } + + const events = await this.readSseEvents(response); + let content = ""; + let sessionId = input.sessionId; + + for (const event of events) { + const type = this.readString(event["type"]); + if (type === "session_id") { + sessionId = this.readString(event["session_id"]) ?? sessionId; + continue; + } + + if (type === "error") { + throw new AiRouteFailureError( + candidate.channel, + routeLabel, + this.readString(event["code"]) ?? "ASTRBOT_ERROR", + this.readString(event["data"]) ?? "AstrBot 返回错误" + ); + } + + if (type !== "plain") { + continue; + } + + const chainType = this.readString(event["chain_type"]); + if ( + chainType === "reasoning" || + chainType === "tool_call" || + chainType === "tool_call_result" + ) { + continue; + } + + const data = this.readString(event["data"]); + if (!data) { + continue; + } + + if (event["streaming"] === true) { + content += data; + continue; + } + + content = data; + } + + if (!content.trim()) { + throw new AiRouteFailureError( + candidate.channel, + routeLabel, + "EMPTY_RESPONSE", + "AstrBot 没有返回有效内容" + ); + } + + return { + channel: candidate.channel, + providerName: routeLabel, + model: candidate.model, + content, + sessionId, + usage: this.extractUsage(events), + raw: events + }; + } + + private buildRequestUrl(endpoint: string): string { + const normalizedEndpoint = endpoint.replace(/\/+$/, ""); + if (normalizedEndpoint.endsWith("/api/v1/chat")) { + return normalizedEndpoint; + } + if (normalizedEndpoint.endsWith("/api/v1")) { + return `${normalizedEndpoint}/chat`; + } + if (normalizedEndpoint.endsWith("/api")) { + return `${normalizedEndpoint}/v1/chat`; + } + return `${normalizedEndpoint}/api/v1/chat`; + } + + private parseSseEvents(rawText: string): Array> { + return rawText + .split(/\r?\n\r?\n/) + .map((block) => + block + .split(/\r?\n/) + .filter((line) => line.startsWith("data:")) + .map((line) => line.slice(5).trim()) + .join("\n") + ) + .filter((payload) => payload.length > 0) + .map((payload) => { + try { + return JSON.parse(payload) as Record; + } catch { + return null; + } + }) + .filter((item): item is Record => item !== null); + } + + private async readSseEvents(response: Response): Promise>> { + if (!response.body) { + return this.parseSseEvents(await response.text()); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + const events: Array> = []; + let buffer = ""; + let reachedEndEvent = false; + + try { + while (!reachedEndEvent) { + const { done, value } = await reader.read(); + if (done) { + break; + } + + buffer += decoder.decode(value, { stream: true }); + const segments = buffer.split(/\r?\n\r?\n/); + buffer = segments.pop() ?? ""; + + for (const segment of segments) { + const parsedEvents = this.parseSseEvents(segment); + for (const event of parsedEvents) { + events.push(event); + if (this.readString(event["type"]) === "end") { + reachedEndEvent = true; + break; + } + } + + if (reachedEndEvent) { + break; + } + } + } + + const tail = `${buffer}${decoder.decode()}`; + if (tail.trim().length > 0) { + events.push(...this.parseSseEvents(tail)); + } + } finally { + await reader.cancel(); + } + + return events; + } + + private extractHttpErrorMessage(rawText: string, statusCode: number): string { + try { + const payload = JSON.parse(rawText) as Record; + if (typeof payload["message"] === "string") { + return payload["message"]; + } + if (typeof payload["data"] === "string") { + return payload["data"]; + } + } catch { + return `AstrBot 服务调用失败,状态码 ${statusCode}`; + } + + return `AstrBot 服务调用失败,状态码 ${statusCode}`; + } + + private readString(value: unknown): string | null { + return typeof value === "string" ? value : null; + } + + private toErrorMessage(error: unknown, fallback: string): string { + if (error instanceof Error && error.message) { + return error.message; + } + + return fallback; + } + + private extractUsage(events: Array>): AiChatResult["usage"] { + for (const event of events) { + if (this.readString(event["type"]) !== "agent_stats") { + continue; + } + + const data = this.asRecord(event["data"]); + const tokenUsage = this.asRecord(data?.["token_usage"]); + if (!tokenUsage) { + continue; + } + + const promptTokens = + (this.readNumber(tokenUsage["input_other"]) ?? 0) + + (this.readNumber(tokenUsage["input_cached"]) ?? 0); + const completionTokens = this.readNumber(tokenUsage["output"]) ?? 0; + + return { + promptTokens, + completionTokens, + totalTokens: promptTokens + completionTokens + }; + } + + return null; + } + + private asRecord(value: unknown): Record | null { + return typeof value === "object" && value !== null ? (value as Record) : null; + } + + private readNumber(value: unknown): number | null { + return typeof value === "number" && Number.isFinite(value) ? value : null; + } +} diff --git a/apps/api/src/ai/providers/openai-compatible.provider.ts b/apps/api/src/ai/providers/openai-compatible.provider.ts new file mode 100644 index 0000000..1ba4eff --- /dev/null +++ b/apps/api/src/ai/providers/openai-compatible.provider.ts @@ -0,0 +1,300 @@ +import { Injectable } from "@nestjs/common"; +import { + AiChannelExecutor, + AiChatInput, + AiChatResult, + AiResolvedRouteCandidate, + AiRouteFailureError +} from "../ai.types"; + +@Injectable() +export class OpenAiCompatibleProvider implements AiChannelExecutor { + async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise { + if (!candidate.endpoint) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_ENDPOINT", + "缺少 AI 服务地址配置" + ); + } + + if (!candidate.apiKey) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_API_KEY", + "缺少 AI 服务密钥配置" + ); + } + + if (!candidate.model) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_MODEL", + "缺少 AI 模型配置" + ); + } + + const requestUrl = this.buildRequestUrl(candidate.endpoint); + + let response: Response; + try { + response = await fetch(requestUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${candidate.apiKey}` + }, + body: JSON.stringify({ + model: candidate.model, + messages: [ + { + role: "user", + content: input.message + } + ], + stream: false + }), + signal: AbortSignal.timeout(30000) + }); + } catch (error) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "UPSTREAM_UNREACHABLE", + this.toErrorMessage(error, "AI 服务请求失败") + ); + } + + let payload: unknown; + try { + payload = await response.json(); + } catch (error) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "INVALID_RESPONSE", + this.toErrorMessage(error, "AI 服务返回了无法解析的数据") + ); + } + + if (!response.ok) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + `UPSTREAM_HTTP_${response.status}`, + this.extractErrorMessage(payload, `AI 服务调用失败,状态码 ${response.status}`) + ); + } + + const content = this.extractAssistantText(payload); + if (!content.trim()) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "EMPTY_RESPONSE", + "AI 服务没有返回有效内容" + ); + } + + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: this.extractModel(payload) ?? candidate.model, + content, + sessionId: input.sessionId, + usage: this.extractUsage(payload), + raw: payload + }; + } + + private buildRequestUrl(endpoint: string): string { + const normalizedEndpoint = endpoint.replace(/\/+$/, ""); + if (normalizedEndpoint.endsWith("/chat/completions")) { + return normalizedEndpoint; + } + if (normalizedEndpoint.endsWith("/v1")) { + return `${normalizedEndpoint}/chat/completions`; + } + return `${normalizedEndpoint}/v1/chat/completions`; + } + + private extractAssistantText(payload: unknown): string { + const chatCompletionText = this.extractChatCompletionText(payload); + if (chatCompletionText) { + return chatCompletionText; + } + + const responsesText = this.extractResponsesApiText(payload); + if (responsesText) { + return responsesText; + } + + return ""; + } + + private extractChatCompletionText(payload: unknown): string { + if (!this.isRecord(payload)) { + return ""; + } + + const choices = payload["choices"]; + if (!Array.isArray(choices) || choices.length === 0) { + return ""; + } + + const firstChoice = choices[0]; + if (!this.isRecord(firstChoice)) { + return ""; + } + + const message = firstChoice["message"]; + if (this.isRecord(message)) { + const messageContent = this.extractMessageContent(message["content"]); + if (messageContent) { + return messageContent; + } + } + + if (typeof firstChoice["text"] === "string") { + return firstChoice["text"]; + } + + return ""; + } + + private extractResponsesApiText(payload: unknown): string { + if (!this.isRecord(payload)) { + return ""; + } + + if (typeof payload["output_text"] === "string") { + return payload["output_text"]; + } + + const output = payload["output"]; + if (!Array.isArray(output)) { + return ""; + } + + return output + .map((item) => { + if (!this.isRecord(item)) { + return ""; + } + + if (typeof item["text"] === "string") { + return item["text"]; + } + + return this.extractMessageContent(item["content"]); + }) + .filter((item) => item.length > 0) + .join("\n") + .trim(); + } + + private extractMessageContent(content: unknown): string { + if (typeof content === "string") { + return content; + } + + if (!Array.isArray(content)) { + return ""; + } + + return content + .map((item) => this.extractContentPartText(item)) + .filter((item) => item.length > 0) + .join("\n") + .trim(); + } + + private extractContentPartText(item: unknown): string { + if (!this.isRecord(item)) { + return ""; + } + + if (typeof item["text"] === "string") { + return item["text"]; + } + + if (this.isRecord(item["text"]) && typeof item["text"]["value"] === "string") { + return item["text"]["value"]; + } + + if (typeof item["content"] === "string") { + return item["content"]; + } + + if (this.isRecord(item["content"]) && typeof item["content"]["text"] === "string") { + return item["content"]["text"]; + } + + return ""; + } + + private extractModel(payload: unknown): string | null { + if (!this.isRecord(payload) || typeof payload["model"] !== "string") { + return null; + } + + return payload["model"]; + } + + private extractUsage(payload: unknown): AiChatResult["usage"] { + if (!this.isRecord(payload)) { + return null; + } + + const usage = payload["usage"]; + if (!this.isRecord(usage)) { + return null; + } + + const promptTokens = this.readNumber(usage["prompt_tokens"]); + const completionTokens = this.readNumber(usage["completion_tokens"]); + const totalTokens = this.readNumber(usage["total_tokens"]); + + if (promptTokens === null && completionTokens === null && totalTokens === null) { + return null; + } + + return { + promptTokens: promptTokens ?? 0, + completionTokens: completionTokens ?? 0, + totalTokens: totalTokens ?? (promptTokens ?? 0) + (completionTokens ?? 0) + }; + } + + private extractErrorMessage(payload: unknown, fallback: string): string { + if (!this.isRecord(payload)) { + return fallback; + } + + const error = payload["error"]; + if (!this.isRecord(error) || typeof error["message"] !== "string") { + return fallback; + } + + return error["message"]; + } + + private isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; + } + + private toErrorMessage(error: unknown, fallback: string): string { + if (error instanceof Error && error.message) { + return error.message; + } + + return fallback; + } + + private readNumber(value: unknown): number | null { + return typeof value === "number" && Number.isFinite(value) ? value : null; + } +} diff --git a/apps/api/src/app.module.ts b/apps/api/src/app.module.ts index 8c2a78f..aae9297 100644 --- a/apps/api/src/app.module.ts +++ b/apps/api/src/app.module.ts @@ -1,8 +1,11 @@ import { Module } from "@nestjs/common"; import { ConfigModule } from "@nestjs/config"; +import { resolve } from "node:path"; +import { AiModule } from "./ai/ai.module"; import { AttachmentModule } from "./attachment/attachment.module"; import { AuthModule } from "./auth/auth.module"; import { PrismaModule } from "./prisma/prisma.module"; +import { SecurityModule } from "./security/security.module"; import { SyncModule } from "./sync/sync.module"; import { TaskModule } from "./task/task.module"; @@ -10,13 +13,15 @@ import { TaskModule } from "./task/task.module"; imports: [ ConfigModule.forRoot({ isGlobal: true, - envFilePath: ".env" + envFilePath: [resolve(__dirname, "../.env"), ".env"] }), PrismaModule, + SecurityModule, AuthModule, TaskModule, AttachmentModule, - SyncModule + SyncModule, + AiModule ] }) export class AppModule {} diff --git a/apps/api/src/attachment/attachment.service.ts b/apps/api/src/attachment/attachment.service.ts index 031cd03..cb15f7e 100644 --- a/apps/api/src/attachment/attachment.service.ts +++ b/apps/api/src/attachment/attachment.service.ts @@ -1,10 +1,16 @@ import { randomUUID } from "node:crypto"; -import { Injectable, NotFoundException, PayloadTooLargeException } from "@nestjs/common"; +import { + Injectable, + InternalServerErrorException, + NotFoundException, + PayloadTooLargeException +} from "@nestjs/common"; import { ConfigService } from "@nestjs/config"; import { PutObjectCommand, S3Client } from "@aws-sdk/client-s3"; import { getSignedUrl } from "@aws-sdk/s3-request-presigner"; import { AttachmentType } from "../../generated/prisma/client"; import { PrismaService } from "../prisma/prisma.service"; +import { DataEncryptionService } from "../security/data-encryption.service"; import { CompleteAttachmentDto } from "./dto/complete-attachment.dto"; import { PresignAttachmentDto } from "./dto/presign-attachment.dto"; @@ -25,9 +31,7 @@ export type PresignAttachmentResponse = { usedBytes: string; remainingBytes: string; }; - headers: { - "Content-Type": string; - }; + headers: Record; }; export type AttachmentResponse = { @@ -52,7 +56,8 @@ export class AttachmentService { constructor( private readonly configService: ConfigService, - private readonly prismaService: PrismaService + private readonly prismaService: PrismaService, + private readonly dataEncryptionService: DataEncryptionService ) {} async presignAttachment( @@ -67,15 +72,17 @@ export class AttachmentService { } const bucket = this.getDefaultBucket(); - const objectKey = this.generateObjectKey(userId, body.fileName); + const objectKey = this.generateObjectKey(body.fileName); const objectUrl = this.resolveObjectUrl(bucket, objectKey); const expiresInSeconds = this.getPresignExpiresInSeconds(); + const serverSideEncryption = this.getServerSideEncryptionMode(); const command = new PutObjectCommand({ Bucket: bucket, Key: objectKey, ContentType: body.mimeType, - ContentLength: body.fileSize + ContentLength: body.fileSize, + ServerSideEncryption: serverSideEncryption }); const uploadUrl = await getSignedUrl(this.getS3Client(), command, { @@ -94,9 +101,7 @@ export class AttachmentService { usedBytes: quotaInfo.usedBytes.toString(), remainingBytes: (quotaInfo.totalBytes - quotaInfo.usedBytes).toString() }, - headers: { - "Content-Type": body.mimeType - } + headers: this.buildUploadHeaders(body.mimeType, serverSideEncryption) }; } @@ -139,14 +144,14 @@ export class AttachmentService { userId, taskId: body.taskId ?? null, type: body.type ?? this.resolveAttachmentType(body.mimeType), - url: objectUrl, + url: this.encryptRequiredString(objectUrl), mimeType: body.mimeType, - fileName: body.fileName, + fileName: this.encryptNullableString(body.fileName), fileSize: body.fileSize, width: body.width ?? null, height: body.height ?? null, durationMs: body.durationMs ?? null, - checksum: body.checksum ?? null + checksum: this.encryptNullableString(body.checksum) } }); }); @@ -155,14 +160,14 @@ export class AttachmentService { id: attachment.id, taskId: attachment.taskId, type: attachment.type, - url: attachment.url, + url: this.readDecryptedString(attachment.url) ?? objectUrl, mimeType: attachment.mimeType, - fileName: attachment.fileName, + fileName: this.readDecryptedString(attachment.fileName), fileSize: attachment.fileSize, width: attachment.width, height: attachment.height, durationMs: attachment.durationMs, - checksum: attachment.checksum, + checksum: this.readDecryptedString(attachment.checksum), createdAt: attachment.createdAt.toISOString(), updatedAt: attachment.updatedAt.toISOString() }; @@ -204,10 +209,9 @@ export class AttachmentService { return Math.min(configValue, 604800); } - private generateObjectKey(userId: string, fileName: string): string { - const safeFileName = fileName.replace(/[^\w.-]+/g, "_"); + private generateObjectKey(fileName: string): string { const datePrefix = new Date().toISOString().slice(0, 10); - return `${userId}/${datePrefix}/${randomUUID()}-${safeFileName}`; + return `attachments/${datePrefix}/${randomUUID()}${this.extractFileExtension(fileName)}`; } private resolveObjectUrl(bucket: string, objectKey: string): string { @@ -232,6 +236,37 @@ export class AttachmentService { return AttachmentType.FILE; } + private buildUploadHeaders( + mimeType: string, + serverSideEncryption: "AES256" | undefined + ): Record { + const headers: Record = { + "Content-Type": mimeType + }; + + if (serverSideEncryption) { + headers["x-amz-server-side-encryption"] = serverSideEncryption; + } + + return headers; + } + + private getServerSideEncryptionMode(): "AES256" | undefined { + const configValue = + this.configService.get("S3_SERVER_SIDE_ENCRYPTION")?.trim().toUpperCase() ?? "AES256"; + + if (configValue === "NONE" || configValue === "DISABLED") { + return undefined; + } + + return "AES256"; + } + + private extractFileExtension(fileName: string): string { + const match = /\.[a-zA-Z0-9]{1,16}$/.exec(fileName); + return match?.[0]?.toLowerCase() ?? ""; + } + private async ensureTaskOwnership(userId: string, taskId: string): Promise { const task = await this.prismaService.task.findFirst({ where: { @@ -279,4 +314,22 @@ export class AttachmentService { throw new PayloadTooLargeException("存储配额不足"); } } + + private encryptRequiredString(value: string): string { + const encryptedValue = this.dataEncryptionService.encryptString(value); + if (!encryptedValue) { + throw new InternalServerErrorException("附件元数据加密失败"); + } + + return encryptedValue; + } + + private encryptNullableString(value: string | null | undefined): string | null | undefined { + return this.dataEncryptionService.encryptString(value); + } + + private readDecryptedString(value: string | null): string | null { + const decryptedValue = this.dataEncryptionService.decryptString(value); + return typeof decryptedValue === "string" ? decryptedValue : null; + } } diff --git a/apps/api/src/auth/auth.service.ts b/apps/api/src/auth/auth.service.ts index 52bac56..8554a5c 100644 --- a/apps/api/src/auth/auth.service.ts +++ b/apps/api/src/auth/auth.service.ts @@ -5,6 +5,7 @@ import { randomUUID } from "node:crypto"; import { authenticator } from "@otplib/preset-default"; import { AuthMailService } from "./auth-mail.service"; import { PrismaService } from "../prisma/prisma.service"; +import { DataEncryptionService } from "../security/data-encryption.service"; type EmailCodeEntry = { code: string; @@ -33,7 +34,8 @@ export class AuthService { private readonly configService: ConfigService, private readonly jwtService: JwtService, private readonly authMailService: AuthMailService, - private readonly prismaService: PrismaService + private readonly prismaService: PrismaService, + private readonly dataEncryptionService: DataEncryptionService ) {} async sendEmailCode(email: string): Promise<{ success: boolean; expiresInSeconds: number }> { @@ -118,7 +120,10 @@ export class AuthService { } }); - return this.issueTokens(entry.user); + return this.issueTokens({ + id: entry.user.id, + email: this.readRequiredEmail(entry.user.email) + }); } async revokeRefreshToken(refreshToken: string): Promise<{ success: boolean }> { @@ -205,19 +210,27 @@ export class AuthService { } private async getOrCreateUser(email: string): Promise { - return this.prismaService.user.upsert({ + const normalizedEmail = email.toLowerCase(); + const emailHash = this.dataEncryptionService.createLookupHash("user.email", normalizedEmail); + const user = await this.prismaService.user.upsert({ where: { - email + emailHash }, update: {}, create: { - email + email: this.encryptRequiredString(normalizedEmail), + emailHash }, select: { id: true, email: true } }); + + return { + id: user.id, + email: this.readRequiredEmail(user.email) + }; } private generateCode(): string { @@ -254,4 +267,22 @@ export class AuthService { user }; } + + private encryptRequiredString(value: string): string { + const encryptedValue = this.dataEncryptionService.encryptString(value); + if (!encryptedValue) { + throw new UnauthorizedException("用户敏感字段加密失败"); + } + + return encryptedValue; + } + + private readRequiredEmail(value: string): string { + const decryptedValue = this.dataEncryptionService.decryptString(value); + if (typeof decryptedValue !== "string" || decryptedValue.length === 0) { + throw new UnauthorizedException("用户邮箱解密失败"); + } + + return decryptedValue; + } } diff --git a/apps/api/src/security/data-encryption.service.ts b/apps/api/src/security/data-encryption.service.ts new file mode 100644 index 0000000..ece7ceb --- /dev/null +++ b/apps/api/src/security/data-encryption.service.ts @@ -0,0 +1,155 @@ +import { Injectable, InternalServerErrorException } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { Prisma } from "../../generated/prisma/client"; +import { createCipheriv, createDecipheriv, createHash, createHmac, randomBytes } from "node:crypto"; + +const ENCRYPTION_PREFIX = "encv1"; +const ENCRYPTION_ALGORITHM = "aes-256-gcm"; +const ENCRYPTION_IV_LENGTH = 12; + +@Injectable() +export class DataEncryptionService { + constructor(private readonly configService: ConfigService) {} + + isConfigured(): boolean { + return Boolean(this.configService.get("DATA_ENCRYPTION_SECRET")); + } + + isEncryptedString(value: string): boolean { + return value.startsWith(`${ENCRYPTION_PREFIX}:`); + } + + encryptString(value: string | null | undefined): string | null | undefined { + if (value === undefined) { + return undefined; + } + + if (value === null) { + return null; + } + + const key = this.resolveKey(); + const iv = randomBytes(ENCRYPTION_IV_LENGTH); + const cipher = createCipheriv(ENCRYPTION_ALGORITHM, key, iv); + const encrypted = Buffer.concat([cipher.update(value, "utf8"), cipher.final()]); + const authTag = cipher.getAuthTag(); + + return [ + ENCRYPTION_PREFIX, + iv.toString("base64url"), + authTag.toString("base64url"), + encrypted.toString("base64url") + ].join(":"); + } + + decryptString(value: string | null | undefined): string | null | undefined { + if (value === undefined) { + return undefined; + } + + if (value === null || !this.isEncryptedPayload(value)) { + return value; + } + + const [prefix, ivText, authTagText, encryptedText] = value.split(":"); + if (prefix !== ENCRYPTION_PREFIX || !ivText || !authTagText || encryptedText === undefined) { + throw new InternalServerErrorException("加密数据格式无效"); + } + + try { + const key = this.resolveKey(); + const decipher = createDecipheriv( + ENCRYPTION_ALGORITHM, + key, + Buffer.from(ivText, "base64url") + ); + decipher.setAuthTag(Buffer.from(authTagText, "base64url")); + const decrypted = Buffer.concat([ + decipher.update(Buffer.from(encryptedText, "base64url")), + decipher.final() + ]); + + return decrypted.toString("utf8"); + } catch { + throw new InternalServerErrorException("加密数据解密失败"); + } + } + + encryptJson( + value: Prisma.InputJsonValue | null | undefined + ): Prisma.InputJsonValue | null | undefined { + if (value === undefined) { + return undefined; + } + + if (value === null) { + return null; + } + + return this.encryptString(JSON.stringify(value)); + } + + decryptJson(value: Prisma.JsonValue | null): Prisma.JsonValue | null { + if (value === null) { + return null; + } + + if (typeof value !== "string" || !this.isEncryptedPayload(value)) { + return value; + } + + const decrypted = this.decryptString(value); + if (typeof decrypted !== "string") { + throw new InternalServerErrorException("加密数据解密失败"); + } + + try { + return JSON.parse(decrypted) as Prisma.JsonValue; + } catch { + throw new InternalServerErrorException("加密 JSON 数据损坏"); + } + } + + decryptPayload(value: Prisma.JsonValue | null): string | null { + if (value === null) { + return null; + } + + if (typeof value === "string") { + return this.decryptString(value) ?? null; + } + + return JSON.stringify(value); + } + + createLookupHash(scope: string, value: string): string { + const normalizedScope = scope.trim().toLowerCase(); + if (!normalizedScope) { + throw new InternalServerErrorException("缺少盲索引作用域"); + } + + const secret = this.configService.get("DATA_ENCRYPTION_SECRET"); + if (!secret) { + throw new InternalServerErrorException("服务端未配置 DATA_ENCRYPTION_SECRET,无法生成盲索引"); + } + + return createHmac("sha256", `lookup:${normalizedScope}:${secret}`) + .update(value, "utf8") + .digest("hex"); + } + + private isEncryptedPayload(value: string): boolean { + return this.isEncryptedString(value); + } + + private resolveKey(): Buffer { + const secret = this.configService.get("DATA_ENCRYPTION_SECRET"); + if (!secret) { + throw new InternalServerErrorException( + "服务端未配置 DATA_ENCRYPTION_SECRET,无法写入加密数据" + ); + } + + return createHash("sha256").update(secret, "utf8").digest(); + } +} diff --git a/apps/api/src/security/security.module.ts b/apps/api/src/security/security.module.ts new file mode 100644 index 0000000..8373141 --- /dev/null +++ b/apps/api/src/security/security.module.ts @@ -0,0 +1,9 @@ +import { Global, Module } from "@nestjs/common"; +import { DataEncryptionService } from "./data-encryption.service"; + +@Global() +@Module({ + providers: [DataEncryptionService], + exports: [DataEncryptionService] +}) +export class SecurityModule {} diff --git a/apps/api/src/sync/sync.service.ts b/apps/api/src/sync/sync.service.ts index 9bab5e2..cfd0f49 100644 --- a/apps/api/src/sync/sync.service.ts +++ b/apps/api/src/sync/sync.service.ts @@ -1,6 +1,7 @@ import { BadRequestException, Injectable } from "@nestjs/common"; import { Prisma } from "../../generated/prisma/client"; import { PrismaService } from "../prisma/prisma.service"; +import { DataEncryptionService } from "../security/data-encryption.service"; import { SyncPullQueryDto } from "./dto/sync-pull.dto"; import { SyncPushDto, SyncPushOperationDto } from "./dto/sync-push.dto"; @@ -60,7 +61,10 @@ export type SyncPullResponse = { @Injectable() export class SyncService { - constructor(private readonly prismaService: PrismaService) {} + constructor( + private readonly prismaService: PrismaService, + private readonly dataEncryptionService: DataEncryptionService + ) {} async pullOperations(userId: string, query: SyncPullQueryDto): Promise { const limit = query.limit ?? 100; @@ -137,7 +141,7 @@ export class SyncService { entityType: operation.entityType, entityId: operation.entityId, action: operation.action, - payload: operation.payload, + payload: this.dataEncryptionService.encryptString(operation.payload) ?? undefined, clientTs: new Date(operation.clientTs) }, select: { @@ -252,15 +256,7 @@ export class SyncService { } private serializePayload(payload: Prisma.JsonValue | null): string | null { - if (payload === null) { - return null; - } - - if (typeof payload === "string") { - return payload; - } - - return JSON.stringify(payload); + return this.dataEncryptionService.decryptPayload(payload); } private parseCursor(cursor: string | undefined): SyncPullCursorState | null { diff --git a/apps/api/src/task/task.service.ts b/apps/api/src/task/task.service.ts index f754003..deb1f76 100644 --- a/apps/api/src/task/task.service.ts +++ b/apps/api/src/task/task.service.ts @@ -1,6 +1,7 @@ -import { Injectable, NotFoundException } from "@nestjs/common"; +import { Injectable, InternalServerErrorException, NotFoundException } from "@nestjs/common"; import { Prisma, TaskPriority, TaskStatus } from "../../generated/prisma/client"; import { PrismaService } from "../prisma/prisma.service"; +import { DataEncryptionService } from "../security/data-encryption.service"; import { CreateTaskDto } from "./dto/create-task.dto"; import { ListTasksQueryDto, TaskSortBy, TaskSortOrder } from "./dto/list-tasks-query.dto"; import { UpdateTaskDto } from "./dto/update-task.dto"; @@ -43,16 +44,48 @@ export type ListTasksResponse = { @Injectable() export class TaskService { - constructor(private readonly prismaService: PrismaService) {} + constructor( + private readonly prismaService: PrismaService, + private readonly dataEncryptionService: DataEncryptionService + ) {} async listTasks(userId: string, query: ListTasksQueryDto): Promise { const page = query.page ?? 1; const pageSize = query.pageSize ?? 20; const skip = (page - 1) * pageSize; + const keyword = query.keyword?.trim() ?? ""; - const where = this.buildWhereInput(userId, query); + const where = this.buildWhereInput(userId, query, keyword.length === 0); const orderBy = this.buildOrderByInput(query); + if (keyword.length > 0) { + const items = await this.prismaService.task.findMany({ + where, + orderBy, + include: { + taskTags: { + include: { + tag: { + select: { + name: true + } + } + } + } + } + }); + + const serializedItems = items.map((item: TaskEntity) => this.serializeTask(item)); + const filteredItems = serializedItems.filter((item) => this.matchesKeyword(item, keyword)); + + return { + items: filteredItems.slice(skip, skip + pageSize), + page, + pageSize, + total: filteredItems.length + }; + } + const [items, total] = await Promise.all([ this.prismaService.task.findMany({ where, @@ -112,15 +145,18 @@ export class TaskService { const tagNames = this.normalizeTagNames(body.tagNames); const nextStatus = body.status ?? TaskStatus.TODO; const contentJson = - body.contentJson !== undefined ? (body.contentJson as Prisma.InputJsonValue) : undefined; + body.contentJson !== undefined + ? ((this.dataEncryptionService.encryptJson(body.contentJson as Prisma.InputJsonValue) ?? + Prisma.JsonNull) as Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput) + : undefined; const task = await this.prismaService.$transaction(async (tx) => { const createdTask = await tx.task.create({ data: { userId, - title: body.title, + title: this.encryptRequiredString(body.title), contentJson, - contentText: body.contentText ?? null, + contentText: this.encryptNullableString(body.contentText), priority: body.priority ?? TaskPriority.MEDIUM, status: nextStatus, ddl: body.ddl ? new Date(body.ddl) : null, @@ -172,13 +208,15 @@ export class TaskService { }; if (body.title !== undefined) { - data.title = body.title; + data.title = this.encryptRequiredString(body.title); } if (body.contentJson !== undefined) { - data.contentJson = body.contentJson as Prisma.InputJsonValue; + data.contentJson = (this.dataEncryptionService.encryptJson( + body.contentJson as Prisma.InputJsonValue + ) ?? Prisma.JsonNull) as Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput; } if (body.contentText !== undefined) { - data.contentText = body.contentText; + data.contentText = this.encryptNullableString(body.contentText); } if (body.priority !== undefined) { data.priority = body.priority; @@ -242,7 +280,11 @@ export class TaskService { return { success: true }; } - private buildWhereInput(userId: string, query: ListTasksQueryDto): Prisma.TaskWhereInput { + private buildWhereInput( + userId: string, + query: ListTasksQueryDto, + includeKeyword: boolean + ): Prisma.TaskWhereInput { const where: Prisma.TaskWhereInput = { userId }; @@ -267,7 +309,7 @@ export class TaskService { }; } - if (query.keyword !== undefined && query.keyword.length > 0) { + if (includeKeyword && query.keyword !== undefined && query.keyword.length > 0) { where.OR = [ { title: { @@ -374,9 +416,9 @@ export class TaskService { private serializeTask(task: TaskEntity): TaskResponse { return { id: task.id, - title: task.title, - contentJson: task.contentJson, - contentText: task.contentText, + title: this.readDecryptedString(task.title) ?? "未命名任务", + contentJson: this.dataEncryptionService.decryptJson(task.contentJson), + contentText: this.readDecryptedString(task.contentText), priority: task.priority, status: task.status, ddl: task.ddl?.toISOString() ?? null, @@ -387,4 +429,30 @@ export class TaskService { updatedAt: task.updatedAt.toISOString() }; } + + private encryptRequiredString(value: string): string { + const encryptedValue = this.dataEncryptionService.encryptString(value); + if (!encryptedValue) { + throw new InternalServerErrorException("任务字段加密失败"); + } + + return encryptedValue; + } + + private encryptNullableString(value: string | null | undefined): string | null | undefined { + return this.dataEncryptionService.encryptString(value); + } + + private readDecryptedString(value: string | null): string | null { + const decryptedValue = this.dataEncryptionService.decryptString(value); + return typeof decryptedValue === "string" ? decryptedValue : null; + } + + private matchesKeyword(task: TaskResponse, keyword: string): boolean { + const lowerKeyword = keyword.toLocaleLowerCase(); + return ( + task.title.toLocaleLowerCase().includes(lowerKeyword) || + task.contentText?.toLocaleLowerCase().includes(lowerKeyword) === true + ); + } } diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts new file mode 100644 index 0000000..7dc70e1 --- /dev/null +++ b/apps/api/test/ai.spec.ts @@ -0,0 +1,1250 @@ +import request from "supertest"; +import { INestApplication, ValidationPipe } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { Test, TestingModule } from "@nestjs/testing"; +import { + AiChannel, + AiUsageLog, + AiProviderBinding, + AiPublicPoolConfig, + TaskPriority, + TaskStatus +} from "../generated/prisma/client"; +import { AiController } from "../src/ai/ai.controller"; +import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service"; +import { AiRateLimitService } from "../src/ai/ai-rate-limit.service"; +import { AiService } from "../src/ai/ai.service"; +import { + AiChatInput, + AiChannelExecutor, + AiResolvedRouteCandidate, + AiRouteFailureError +} from "../src/ai/ai.types"; +import { PrismaService } from "../src/prisma/prisma.service"; +import { DataEncryptionService } from "../src/security/data-encryption.service"; + +type AiUsageLogRecord = { + id: string; + userId: string | null; + channel: AiChannel; + providerName: string | null; + model: string | null; + promptTokens: number; + completionTokens: number; + totalTokens: number; + latencyMs: number | null; + success: boolean; + errorCode: string | null; + createdAt: Date; +}; + +type AiTaskRecord = { + id: string; + userId: string; + title: string; + priority: TaskPriority; + status: TaskStatus; + ddl: Date | null; + contentText: string | null; + updatedAt: Date; +}; + +class InMemoryAiPrismaService { + private bindingIdSequence = 1; + private publicPoolIdSequence = 1; + private usageLogIdSequence = 1; + private bindings: AiProviderBinding[] = []; + private publicPools: AiPublicPoolConfig[] = []; + private usageLogs: AiUsageLogRecord[] = []; + private tasks: AiTaskRecord[] = []; + + readonly aiProviderBinding = { + findMany: async (args: { + where: { + userId: string; + }; + }) => { + return this.bindings + .filter((binding) => binding.userId === args.where.userId) + .sort((left, right) => right.updatedAt.getTime() - left.updatedAt.getTime()); + }, + + findFirst: async (args: { + where: { + id?: string; + userId?: string; + channel?: AiChannel; + isEnabled?: boolean; + }; + }) => { + return ( + this.bindings + .filter((binding) => { + if (args.where.id !== undefined && binding.id !== args.where.id) { + return false; + } + if (args.where.userId !== undefined && binding.userId !== args.where.userId) { + return false; + } + if (args.where.channel !== undefined && binding.channel !== args.where.channel) { + return false; + } + if (args.where.isEnabled !== undefined && binding.isEnabled !== args.where.isEnabled) { + return false; + } + return true; + }) + .sort((left, right) => { + if (left.isDefault !== right.isDefault) { + return Number(right.isDefault) - Number(left.isDefault); + } + return right.updatedAt.getTime() - left.updatedAt.getTime(); + })[0] ?? null + ); + }, + + create: async (args: { + data: { + userId: string; + channel: AiChannel; + providerName: string; + model: string | null; + configId: string | null; + configName: string | null; + endpoint: string | null; + encryptedApiKey: string | null; + isDefault: boolean; + isEnabled: boolean; + }; + }) => { + const now = new Date(); + const binding: AiProviderBinding = { + id: `binding_${this.bindingIdSequence++}`, + userId: args.data.userId, + channel: args.data.channel, + providerName: args.data.providerName, + model: args.data.model, + configId: args.data.configId, + configName: args.data.configName, + encryptedApiKey: args.data.encryptedApiKey, + endpoint: args.data.endpoint, + isDefault: args.data.isDefault, + isEnabled: args.data.isEnabled, + createdAt: now, + updatedAt: now + }; + + this.bindings.push(binding); + return binding; + }, + + update: async (args: { + where: { + id: string; + }; + data: Partial; + }) => { + const binding = this.bindings.find((item) => item.id === args.where.id); + if (!binding) { + throw new Error("binding not found"); + } + + Object.assign(binding, args.data, { updatedAt: new Date() }); + return binding; + }, + + updateMany: async (args: { + where: { + userId?: string; + channel?: AiChannel; + id?: { + not: string; + }; + }; + data: { + isDefault?: boolean; + }; + }) => { + let count = 0; + for (const binding of this.bindings) { + if (args.where.userId !== undefined && binding.userId !== args.where.userId) { + continue; + } + if (args.where.channel !== undefined && binding.channel !== args.where.channel) { + continue; + } + if (args.where.id?.not !== undefined && binding.id === args.where.id.not) { + continue; + } + + if (args.data.isDefault !== undefined) { + binding.isDefault = args.data.isDefault; + binding.updatedAt = new Date(); + } + count += 1; + } + + return { count }; + } + }; + + readonly aiPublicPoolConfig = { + findFirst: async (args?: { + where?: { + enabled?: boolean; + }; + }) => { + const items = this.publicPools + .filter((item) => + args?.where?.enabled === undefined ? true : item.enabled === args.where.enabled + ) + .sort((left, right) => right.updatedAt.getTime() - left.updatedAt.getTime()); + + return items[0] ?? null; + } + }; + + readonly aiUsageLog = { + create: async (args: { data: Omit }) => { + const usageLog: AiUsageLogRecord = { + id: `usage_log_${this.usageLogIdSequence++}`, + createdAt: new Date(), + ...args.data + }; + + this.usageLogs.push(usageLog); + return usageLog; + }, + + findMany: async (args: { + where?: { + userId?: string; + channel?: AiChannel; + success?: boolean; + }; + orderBy?: { + createdAt: "asc" | "desc"; + }; + skip?: number; + take?: number; + }) => { + const filteredLogs = this.filterUsageLogs(args.where); + const sortedLogs = [...filteredLogs].sort((left, right) => { + const direction = args.orderBy?.createdAt === "asc" ? 1 : -1; + return (left.createdAt.getTime() - right.createdAt.getTime()) * direction; + }); + const start = args.skip ?? 0; + const end = args.take === undefined ? undefined : start + args.take; + return sortedLogs.slice(start, end); + }, + + count: async (args?: { + where?: { + userId?: string; + channel?: AiChannel; + success?: boolean; + }; + }) => { + return this.filterUsageLogs(args?.where).length; + } + }; + + readonly task = { + findMany: async (args: { + where: { + userId: string; + status: { + in: TaskStatus[]; + }; + }; + take?: number; + }) => { + const filteredTasks = this.tasks.filter( + (task) => task.userId === args.where.userId && args.where.status.in.includes(task.status) + ); + + return filteredTasks.slice(0, args.take ?? filteredTasks.length).map((task) => ({ + id: task.id, + title: task.title, + priority: task.priority, + status: task.status, + ddl: task.ddl, + contentText: task.contentText, + updatedAt: task.updatedAt + })); + } + }; + + async $transaction(callback: (tx: InMemoryAiPrismaService) => Promise): Promise { + return callback(this); + } + + seedBinding(binding: Omit): void { + const now = new Date(); + this.bindings.push({ + ...binding, + createdAt: now, + updatedAt: now + }); + } + + seedPublicPool(publicPool: Omit): void { + const now = new Date(); + this.publicPools.push({ + id: `pool_${this.publicPoolIdSequence++}`, + createdAt: now, + updatedAt: now, + ...publicPool + }); + } + + getUsageLogs(): AiUsageLogRecord[] { + return [...this.usageLogs]; + } + + getBindings(): AiProviderBinding[] { + return [...this.bindings]; + } + + seedTask(task: AiTaskRecord): void { + this.tasks.push(task); + } + + seedUsageLog(log: Omit & { id?: string }): void { + this.usageLogs.push({ + id: log.id ?? `usage_log_${this.usageLogIdSequence++}`, + ...log + }); + } + + private filterUsageLogs(where?: { + userId?: string; + channel?: AiChannel; + success?: boolean; + }): AiUsageLogRecord[] { + return this.usageLogs.filter((log) => { + if (where?.userId !== undefined && log.userId !== where.userId) { + return false; + } + if (where?.channel !== undefined && log.channel !== where.channel) { + return false; + } + if (where?.success !== undefined && log.success !== where.success) { + return false; + } + + return true; + }); + } +} + +class StaticExecutor implements AiChannelExecutor { + readonly inputs: Array<{ + candidate: AiResolvedRouteCandidate; + message: string; + }> = []; + + constructor( + private readonly resolver: (channel: AiChannel) => { + content?: string; + code?: string; + message?: string; + } + ) {} + + async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput) { + this.inputs.push({ + candidate, + message: input.message + }); + + const result = this.resolver(candidate.channel); + if (result.code) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName || candidate.configName || candidate.configId || "unknown", + result.code, + result.message ?? "执行失败" + ); + } + + return { + channel: candidate.channel, + providerName: candidate.providerName || candidate.configName || candidate.configId || "", + model: candidate.model, + content: result.content ?? "", + sessionId: "session_ai", + usage: { + promptTokens: 12, + completionTokens: 8, + totalTokens: 20 + }, + raw: null + }; + } +} + +describe("AiController (integration)", () => { + let app: INestApplication; + let prismaService: InMemoryAiPrismaService; + let astrbotExecutor: StaticExecutor; + let openAiExecutor: StaticExecutor; + + beforeEach(async () => { + prismaService = new InMemoryAiPrismaService(); + + openAiExecutor = new StaticExecutor((channel) => + channel === AiChannel.USER_KEY + ? { + code: "UPSTREAM_UNREACHABLE", + message: "用户自备 Key 渠道暂时不可用" + } + : { + content: "公共 AI 已接管" + } + ); + astrbotExecutor = new StaticExecutor(() => ({ + content: "AstrBot 已接管" + })); + + const moduleRef: TestingModule = await Test.createTestingModule({ + controllers: [AiController], + providers: [ + AiService, + AiRateLimitService, + DataEncryptionService, + { + provide: PrismaService, + useValue: prismaService + }, + { + provide: ConfigService, + useValue: { + get: (key: string) => { + if (key === "DATA_ENCRYPTION_SECRET") { + return "test-data-encryption-secret"; + } + if (key === "AI_RATE_LIMIT_WINDOW_MS") { + return 60_000; + } + if (key === "AI_RATE_LIMIT_USER_MAX") { + return 2; + } + if (key === "AI_RATE_LIMIT_IP_MAX") { + return 3; + } + + return undefined; + } + } + }, + { + provide: AiProviderRegistryService, + useValue: { + getExecutor: (channel: AiChannel) => + channel === AiChannel.ASTRBOT ? astrbotExecutor : openAiExecutor + } + } + ] + }).compile(); + + app = moduleRef.createNestApplication(); + app.useGlobalPipes( + new ValidationPipe({ + transform: true, + whitelist: true, + forbidNonWhitelisted: true + }) + ); + await app.init(); + }); + + afterEach(async () => { + await app.close(); + }); + + it("should create and list ai bindings", async () => { + await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.ASTRBOT, + providerName: "astrbot-main", + model: "deepseek-chat", + configId: "default", + endpoint: "http://127.0.0.1:6185", + apiKey: "abk_secret_1234", + isEnabled: true + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .get("/ai/bindings") + .set("x-user-id", "user_1") + .expect(200); + + expect(response.body.routeOrder).toEqual([ + AiChannel.USER_KEY, + AiChannel.ASTRBOT, + AiChannel.PUBLIC_POOL + ]); + expect(response.body.bindings).toHaveLength(1); + expect(response.body.bindings[0]).toMatchObject({ + channel: AiChannel.ASTRBOT, + providerName: "astrbot-main", + model: "deepseek-chat", + configId: "default", + configName: null, + hasApiKey: true, + maskedApiKey: "abk_***34", + isEnabled: true + }); + + const storedBinding = prismaService.getBindings()[0]; + expect(storedBinding?.providerName).not.toBe("astrbot-main"); + expect(storedBinding?.endpoint).not.toBe("http://127.0.0.1:6185"); + expect(storedBinding?.encryptedApiKey).not.toBe("abk_secret_1234"); + }); + + it("should hide public pool endpoint from user bindings response", async () => { + prismaService.seedPublicPool({ + enabled: true, + providerName: "public-openai", + model: "gpt-4o-mini", + encryptedApiKey: "sk-public", + endpoint: "https://internal.example.com/v1", + rpmLimit: 60, + dailyTokenLimit: 100000 + }); + + const response = await request(app.getHttpServer()) + .get("/ai/bindings") + .set("x-user-id", "user_1") + .expect(200); + + expect(response.body.publicPool).toEqual({ + enabled: true, + providerName: "public-openai", + model: "gpt-4o-mini", + hasApiKey: true + }); + }); + + it("should upsert one binding per user channel", async () => { + await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + endpoint: "https://api.example.com", + apiKey: "sk-first", + isEnabled: true + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "google", + model: "gemini-2.5-flash", + endpoint: "https://generativelanguage.googleapis.com", + apiKey: "sk-second", + isEnabled: false + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .get("/ai/bindings") + .set("x-user-id", "user_1") + .expect(200); + + expect(response.body.bindings).toEqual([ + expect.objectContaining({ + channel: AiChannel.USER_KEY, + providerName: "google", + model: "gemini-2.5-flash", + endpoint: "https://generativelanguage.googleapis.com", + isEnabled: false, + maskedApiKey: "sk-s***nd" + }) + ]); + }); + + it("should fallback from user key to astrbot", async () => { + prismaService.seedBinding({ + id: "binding_user_key", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + configId: null, + configName: null, + encryptedApiKey: "sk-user", + endpoint: "https://api.example.com", + isDefault: true, + isEnabled: true + }); + prismaService.seedBinding({ + id: "binding_astrbot", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "帮我安排今天的任务" + }) + .expect(201); + + expect(response.body.channel).toBe(AiChannel.ASTRBOT); + expect(response.body.content).toBe("AstrBot 已接管"); + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + status: "failed", + reasonCode: "UPSTREAM_UNREACHABLE", + reasonMessage: "用户自备 Key 渠道暂时不可用" + }, + { + channel: AiChannel.ASTRBOT, + providerName: "default", + model: null, + status: "success", + reasonCode: null, + reasonMessage: null + } + ]); + expect(prismaService.getUsageLogs()).toEqual([ + expect.objectContaining({ + id: expect.any(String), + userId: "user_1", + channel: AiChannel.USER_KEY, + promptTokens: 0, + completionTokens: 0, + totalTokens: 0, + latencyMs: expect.any(Number), + success: false, + errorCode: "UPSTREAM_UNREACHABLE", + createdAt: expect.any(Date) + }), + expect.objectContaining({ + id: expect.any(String), + userId: "user_1", + channel: AiChannel.ASTRBOT, + promptTokens: 12, + completionTokens: 8, + totalTokens: 20, + latencyMs: expect.any(Number), + success: true, + errorCode: null, + createdAt: expect.any(Date) + }) + ]); + expect(prismaService.getUsageLogs()[0]?.providerName).not.toBe("openai"); + expect(prismaService.getUsageLogs()[0]?.model).not.toBe("gpt-4o-mini"); + }); + + it("should allow astrbot binding with config id only", async () => { + const response = await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.ASTRBOT, + configId: "default", + endpoint: "http://127.0.0.1:6185", + apiKey: "abk_secret_1234", + isEnabled: true + }) + .expect(201); + + expect(response.body).toMatchObject({ + channel: AiChannel.ASTRBOT, + providerName: "", + configId: "default", + configName: null, + isEnabled: true + }); + }); + + it("should test binding with stored secret when api key is omitted", async () => { + prismaService.seedBinding({ + id: "binding_user_key_test_existing_secret", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + configId: null, + configName: null, + encryptedApiKey: "sk-existing", + endpoint: "https://api.example.com", + isDefault: false, + isEnabled: true + }); + + const executeSpy = jest.spyOn(openAiExecutor, "execute").mockResolvedValue({ + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + content: "连接成功", + sessionId: "session_binding_test", + usage: { + promptTokens: 1, + completionTokens: 1, + totalTokens: 2 + }, + raw: null + }); + + const response = await request(app.getHttpServer()) + .post("/ai/bindings/test") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + endpoint: "https://api.example.com" + }) + .expect(201); + + expect(response.body).toEqual({ + success: true, + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + contentPreview: "连接成功" + }); + expect(executeSpy).toHaveBeenCalledWith( + expect.objectContaining({ + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + endpoint: "https://api.example.com", + apiKey: "sk-existing" + }), + expect.objectContaining({ + userId: "user_1" + }) + ); + }); + + it("should return structured failure result when binding test fails", async () => { + prismaService.seedBinding({ + id: "binding_user_key_test_failure", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-5.4", + configId: null, + configName: null, + encryptedApiKey: "sk-existing", + endpoint: "https://api.example.com", + isDefault: false, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/bindings/test") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-5.4", + endpoint: "https://api.example.com" + }) + .expect(201); + + expect(response.body).toEqual({ + success: false, + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-5.4", + code: "UPSTREAM_UNREACHABLE", + message: "用户自备 Key 渠道暂时不可用" + }); + }); + + it("should use selected channel without automatic fallback", async () => { + prismaService.seedBinding({ + id: "binding_user_key_selected", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + configId: null, + configName: null, + encryptedApiKey: "sk-user", + endpoint: "https://api.example.com", + isDefault: false, + isEnabled: true + }); + prismaService.seedBinding({ + id: "binding_astrbot_selected", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: false, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "只使用自备渠道", + channel: AiChannel.USER_KEY + }) + .expect(502); + + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + status: "failed", + reasonCode: "UPSTREAM_UNREACHABLE", + reasonMessage: "用户自备 Key 渠道暂时不可用" + } + ]); + }); + + it("should inject unfinished task summary into ai prompt", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_context", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + prismaService.seedTask({ + id: "task_weekly_report", + userId: "user_1", + title: "今晚提交周报", + priority: TaskPriority.URGENT, + status: TaskStatus.IN_PROGRESS, + ddl: new Date("2026-04-06T12:00:00.000Z"), + contentText: "需要汇总 AI 路由、AstrBot 接入和同步模块进度", + updatedAt: new Date("2026-04-06T08:00:00.000Z") + }); + prismaService.seedTask({ + id: "task_done_item", + userId: "user_1", + title: "整理已完成事项", + priority: TaskPriority.LOW, + status: TaskStatus.DONE, + ddl: null, + contentText: "这条任务不应该出现在上下文里", + updatedAt: new Date("2026-04-06T07:00:00.000Z") + }); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "帮我安排今天剩余任务" + }) + .expect(201); + + expect(astrbotExecutor.inputs).toHaveLength(1); + expect(astrbotExecutor.inputs[0]?.message).toContain("以下是系统整理的未完成任务摘要"); + expect(astrbotExecutor.inputs[0]?.message).toContain("今晚提交周报"); + expect(astrbotExecutor.inputs[0]?.message).toContain("优先级:紧急"); + expect(astrbotExecutor.inputs[0]?.message).not.toContain("整理已完成事项"); + expect(astrbotExecutor.inputs[0]?.message).toContain("用户当前问题:帮我安排今天剩余任务"); + }); + + it("should inject local unfinished tasks into ai prompt when database is empty", async () => { + prismaService.seedBinding({ + id: "binding_user_key_local_context", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + configId: null, + configName: null, + encryptedApiKey: "sk-user", + endpoint: "https://api.example.com", + isDefault: true, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "结合我的 TodoList 帮我排优先级", + channel: AiChannel.USER_KEY, + localTasks: [ + { + id: "local_task_1", + title: "准备明天答辩材料", + priority: TaskPriority.URGENT, + status: TaskStatus.IN_PROGRESS, + ddlAt: new Date("2026-04-07T13:00:00.000Z").getTime(), + contentText: "需要补齐演示文稿和总结页", + updatedAt: new Date("2026-04-07T09:00:00.000Z").getTime() + } + ] + }) + .expect(502); + + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + status: "failed", + reasonCode: "UPSTREAM_UNREACHABLE", + reasonMessage: "用户自备 Key 渠道暂时不可用" + } + ]); + expect(openAiExecutor.inputs).toHaveLength(1); + expect(openAiExecutor.inputs[0]?.message).toContain("准备明天答辩材料"); + expect(openAiExecutor.inputs[0]?.message).toContain("优先级:紧急"); + expect(openAiExecutor.inputs[0]?.message).toContain("内容摘要:需要补齐演示文稿和总结页"); + expect(openAiExecutor.inputs[0]?.message).toContain( + "用户当前问题:结合我的 TodoList 帮我排优先级" + ); + expect(astrbotExecutor.inputs).toHaveLength(0); + }); + + it("should prefer newer local task snapshot over older database task", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_local_override", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + prismaService.seedTask({ + id: "task_same_id", + userId: "user_1", + title: "旧标题", + priority: TaskPriority.LOW, + status: TaskStatus.TODO, + ddl: new Date("2026-04-08T10:00:00.000Z"), + contentText: "旧内容", + updatedAt: new Date("2026-04-07T08:00:00.000Z") + }); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "看看我最新要做什么", + channel: AiChannel.ASTRBOT, + localTasks: [ + { + id: "task_same_id", + title: "新标题", + priority: TaskPriority.HIGH, + status: TaskStatus.IN_PROGRESS, + ddlAt: new Date("2026-04-07T15:00:00.000Z").getTime(), + contentText: "新内容", + updatedAt: new Date("2026-04-07T12:00:00.000Z").getTime() + } + ] + }) + .expect(201); + + expect(astrbotExecutor.inputs.at(-1)?.message).toContain("新标题"); + expect(astrbotExecutor.inputs.at(-1)?.message).toContain("优先级:高"); + expect(astrbotExecutor.inputs.at(-1)?.message).toContain("内容摘要:新内容"); + expect(astrbotExecutor.inputs.at(-1)?.message).not.toContain("旧标题"); + expect(astrbotExecutor.inputs.at(-1)?.message).not.toContain("旧内容"); + }); + + it("should return skipped attempts when no channel is available", async () => { + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "帮我总结今天的安排" + }) + .expect(502); + + expect(response.body.message).toBe("当前没有可用的 AI 通道,请稍后重试"); + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: null, + model: null, + status: "skipped", + reasonCode: "CHANNEL_NOT_CONFIGURED", + reasonMessage: "当前用户未配置可用的自备 Key 通道" + }, + { + channel: AiChannel.ASTRBOT, + providerName: null, + model: null, + status: "skipped", + reasonCode: "CHANNEL_NOT_CONFIGURED", + reasonMessage: "当前用户未配置可用的 AstrBot 通道" + }, + { + channel: AiChannel.PUBLIC_POOL, + providerName: null, + model: null, + status: "skipped", + reasonCode: "PUBLIC_POOL_DISABLED", + reasonMessage: "公共 AI 通道未开启" + } + ]); + expect(prismaService.getUsageLogs()).toEqual([]); + }); + + it("should rate limit ai chat by user in the same window", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_rate_limit_user", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", "203.0.113.10") + .send({ + message: "第一条" + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", "203.0.113.10") + .send({ + message: "第二条" + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", "203.0.113.10") + .send({ + message: "第三条" + }) + .expect(429); + + expect(response.body).toMatchObject({ + message: "AI 请求过于频繁,请稍后再试", + code: "AI_RATE_LIMITED", + dimension: "user", + limit: 2, + windowMs: 60000 + }); + expect(response.body.retryAfterMs).toEqual(expect.any(Number)); + expect(astrbotExecutor.inputs).toHaveLength(2); + expect(prismaService.getUsageLogs()).toHaveLength(2); + }); + + it("should rate limit ai chat by ip across different users", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_rate_limit_ip_user_1", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + prismaService.seedBinding({ + id: "binding_astrbot_rate_limit_ip_user_2", + userId: "user_2", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + + const sharedIp = "198.51.100.7"; + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", sharedIp) + .send({ + message: "用户一第一条" + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_2") + .set("x-forwarded-for", sharedIp) + .send({ + message: "用户二第一条" + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", sharedIp) + .send({ + message: "用户一第二条" + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_2") + .set("x-forwarded-for", sharedIp) + .send({ + message: "用户二第二条" + }) + .expect(429); + + expect(response.body).toMatchObject({ + message: "AI 请求过于频繁,请稍后再试", + code: "AI_RATE_LIMITED", + dimension: "ip", + limit: 3, + windowMs: 60000 + }); + expect(response.body.retryAfterMs).toEqual(expect.any(Number)); + expect(astrbotExecutor.inputs).toHaveLength(3); + expect(prismaService.getUsageLogs()).toHaveLength(3); + }); + + it("should list usage logs with pagination and filters", async () => { + prismaService.seedUsageLog({ + id: "usage_log_1", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: "deepseek-chat", + promptTokens: 10, + completionTokens: 6, + totalTokens: 16, + latencyMs: 120, + success: true, + errorCode: null, + createdAt: new Date("2026-04-06T08:00:00.000Z") + }); + prismaService.seedUsageLog({ + id: "usage_log_2", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: "deepseek-chat", + promptTokens: 14, + completionTokens: 9, + totalTokens: 23, + latencyMs: 100, + success: true, + errorCode: null, + createdAt: new Date("2026-04-06T09:00:00.000Z") + }); + prismaService.seedUsageLog({ + id: "usage_log_3", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + promptTokens: 20, + completionTokens: 12, + totalTokens: 32, + latencyMs: 210, + success: false, + errorCode: "UPSTREAM_UNREACHABLE", + createdAt: new Date("2026-04-06T10:00:00.000Z") + }); + prismaService.seedUsageLog({ + id: "usage_log_4", + userId: "user_2", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: "deepseek-chat", + promptTokens: 18, + completionTokens: 11, + totalTokens: 29, + latencyMs: 90, + success: true, + errorCode: null, + createdAt: new Date("2026-04-06T11:00:00.000Z") + }); + + const response = await request(app.getHttpServer()) + .get("/ai/usage-logs") + .set("x-user-id", "user_1") + .query({ + page: 2, + pageSize: 1, + channel: AiChannel.ASTRBOT, + success: true + }) + .expect(200); + + expect(response.body).toEqual({ + items: [ + { + id: "usage_log_1", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: "deepseek-chat", + promptTokens: 10, + completionTokens: 6, + totalTokens: 16, + latencyMs: 120, + success: true, + errorCode: null, + createdAt: "2026-04-06T08:00:00.000Z" + } + ], + page: 2, + pageSize: 1, + total: 2 + }); + }); +}); diff --git a/apps/api/test/astrbot-provider.spec.ts b/apps/api/test/astrbot-provider.spec.ts new file mode 100644 index 0000000..3ccecc9 --- /dev/null +++ b/apps/api/test/astrbot-provider.spec.ts @@ -0,0 +1,73 @@ +import { AiChannel } from "../generated/prisma/client"; +import { AstrbotProvider } from "../src/ai/providers/astrbot.provider"; + +describe("AstrbotProvider", () => { + const originalFetch = global.fetch; + + afterEach(() => { + global.fetch = originalFetch; + jest.restoreAllMocks(); + }); + + it("should not forward binding label fields as astrbot selection parameters", async () => { + const provider = new AstrbotProvider(); + const fetchMock = jest.fn(async (_input: unknown, init?: RequestInit) => { + expect(init?.method).toBe("POST"); + const payload = JSON.parse(String(init?.body ?? "{}")) as Record; + + expect(payload).toMatchObject({ + username: "user_1", + session_id: "session_1", + message: "你好", + enable_streaming: false, + selected_model: "deepseek-chat" + }); + expect(payload).not.toHaveProperty("selected_provider"); + expect(payload).not.toHaveProperty("config_id"); + expect(payload).not.toHaveProperty("config_name"); + + return new Response( + [ + 'data: {"type":"session_id","session_id":"session_1"}', + "", + 'data: {"type":"plain","data":"收到","streaming":false,"chain_type":null}', + "", + 'data: {"type":"end","data":"","streaming":false}', + "" + ].join("\n"), + { + status: 200, + headers: { + "content-type": "text/event-stream" + } + } + ); + }); + + global.fetch = fetchMock as typeof global.fetch; + + const result = await provider.execute( + { + channel: AiChannel.ASTRBOT, + source: "binding", + sourceId: "binding_1", + providerName: "astrbot-main", + model: "deepseek-chat", + configId: "default", + configName: "默认配置", + endpoint: "http://127.0.0.1:6185", + apiKey: "abk_secret" + }, + { + userId: "user_1", + message: "你好", + sessionId: "session_1" + } + ); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(result.content).toBe("收到"); + expect(result.sessionId).toBe("session_1"); + expect(result.providerName).toBe("astrbot-main"); + }); +}); diff --git a/apps/api/test/auth.spec.ts b/apps/api/test/auth.spec.ts new file mode 100644 index 0000000..8f1f1ae --- /dev/null +++ b/apps/api/test/auth.spec.ts @@ -0,0 +1,355 @@ +import { UnauthorizedException } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; +import { JwtService } from "@nestjs/jwt"; +import { Test, TestingModule } from "@nestjs/testing"; +import { AuthMailService } from "../src/auth/auth-mail.service"; +import { AuthService } from "../src/auth/auth.service"; +import { PrismaService } from "../src/prisma/prisma.service"; +import { DataEncryptionService } from "../src/security/data-encryption.service"; + +type UserRecord = { + id: string; + email: string; + emailHash: string; + nickname: string | null; + avatarUrl: string | null; +}; + +type RefreshTokenRecord = { + id: string; + userId: string; + tokenHash: string; + expiresAt: Date; + revokedAt: Date | null; + createdAt: Date; +}; + +type UserSecurityRecord = { + userId: string; + twoFactorEnabled: boolean; + twoFactorSecret: string | null; +}; + +class InMemoryAuthPrismaService { + private userIdSequence = 1; + private refreshTokenIdSequence = 1; + private users: UserRecord[] = []; + private refreshTokens: RefreshTokenRecord[] = []; + private userSecurities: UserSecurityRecord[] = []; + + readonly user = { + upsert: async (args: { + where: { + emailHash: string; + }; + update: Record; + create: { + email: string; + emailHash: string; + }; + select: { + id: true; + email: true; + }; + }) => { + const existingUser = this.users.find((user) => user.emailHash === args.where.emailHash); + if (existingUser) { + return { + id: existingUser.id, + email: existingUser.email + }; + } + + const createdUser: UserRecord = { + id: `user_${this.userIdSequence++}`, + email: args.create.email, + emailHash: args.create.emailHash, + nickname: null, + avatarUrl: null + }; + this.users.push(createdUser); + + return { + id: createdUser.id, + email: createdUser.email + }; + } + }; + + readonly refreshToken = { + create: async (args: { + data: { + userId: string; + tokenHash: string; + expiresAt: Date; + }; + }) => { + const refreshToken: RefreshTokenRecord = { + id: `refresh_${this.refreshTokenIdSequence++}`, + userId: args.data.userId, + tokenHash: args.data.tokenHash, + expiresAt: args.data.expiresAt, + revokedAt: null, + createdAt: new Date() + }; + this.refreshTokens.push(refreshToken); + return refreshToken; + }, + + findUnique: async (args: { + where: { + tokenHash: string; + }; + include: { + user: { + select: { + id: true; + email: true; + }; + }; + }; + }) => { + const refreshToken = this.refreshTokens.find( + (item) => item.tokenHash === args.where.tokenHash + ); + if (!refreshToken) { + return null; + } + + const user = this.users.find((item) => item.id === refreshToken.userId); + if (!user) { + throw new Error("user not found"); + } + + return { + ...refreshToken, + user: { + id: user.id, + email: user.email + } + }; + }, + + update: async (args: { + where: { + id: string; + }; + data: { + revokedAt: Date; + }; + }) => { + const refreshToken = this.refreshTokens.find((item) => item.id === args.where.id); + if (!refreshToken) { + throw new Error("refresh token not found"); + } + + refreshToken.revokedAt = args.data.revokedAt; + return refreshToken; + }, + + updateMany: async (args: { + where: { + tokenHash: string; + revokedAt: null; + }; + data: { + revokedAt: Date; + }; + }) => { + let count = 0; + for (const refreshToken of this.refreshTokens) { + if (refreshToken.tokenHash !== args.where.tokenHash || refreshToken.revokedAt !== null) { + continue; + } + + refreshToken.revokedAt = args.data.revokedAt; + count += 1; + } + + return { count }; + } + }; + + readonly userSecurity = { + upsert: async (args: { + where: { + userId: string; + }; + update: { + twoFactorSecret: string; + twoFactorEnabled: boolean; + }; + create: { + userId: string; + twoFactorSecret: string; + twoFactorEnabled: boolean; + }; + }) => { + const existingSecurity = this.userSecurities.find( + (item) => item.userId === args.where.userId + ); + if (existingSecurity) { + existingSecurity.twoFactorSecret = args.update.twoFactorSecret; + existingSecurity.twoFactorEnabled = args.update.twoFactorEnabled; + return existingSecurity; + } + + const createdSecurity: UserSecurityRecord = { + userId: args.create.userId, + twoFactorSecret: args.create.twoFactorSecret, + twoFactorEnabled: args.create.twoFactorEnabled + }; + this.userSecurities.push(createdSecurity); + return createdSecurity; + }, + + findUnique: async (args: { + where: { + userId: string; + }; + select: { + twoFactorSecret: true; + }; + }) => { + const security = this.userSecurities.find((item) => item.userId === args.where.userId); + if (!security) { + return null; + } + + return { + twoFactorSecret: security.twoFactorSecret + }; + }, + + update: async (args: { + where: { + userId: string; + }; + data: { + twoFactorEnabled: boolean; + }; + }) => { + const security = this.userSecurities.find((item) => item.userId === args.where.userId); + if (!security) { + throw new Error("user security not found"); + } + + security.twoFactorEnabled = args.data.twoFactorEnabled; + return security; + } + }; + + getUsers(): UserRecord[] { + return [...this.users]; + } +} + +class MockAuthMailService { + readonly sentMessages: Array<{ + email: string; + code: string; + ttlSeconds: number; + }> = []; + + async sendLoginCode(email: string, code: string, ttlSeconds: number): Promise { + this.sentMessages.push({ + email, + code, + ttlSeconds + }); + } +} + +describe("AuthService", () => { + let authService: AuthService; + let prismaService: InMemoryAuthPrismaService; + let authMailService: MockAuthMailService; + + beforeEach(async () => { + prismaService = new InMemoryAuthPrismaService(); + authMailService = new MockAuthMailService(); + + const moduleRef: TestingModule = await Test.createTestingModule({ + providers: [ + AuthService, + DataEncryptionService, + { + provide: PrismaService, + useValue: prismaService + }, + { + provide: AuthMailService, + useValue: authMailService + }, + { + provide: JwtService, + useValue: { + signAsync: async (payload: Record) => + `signed-${String(payload["sub"])}-${String(payload["email"])}` + } + }, + { + provide: ConfigService, + useValue: { + get: (key: string) => { + switch (key) { + case "AUTH_EMAIL_CODE_TTL_SECONDS": + return "300"; + case "AUTH_ACCESS_EXPIRES_IN_SECONDS": + return "900"; + case "AUTH_REFRESH_EXPIRES_IN_SECONDS": + return "2592000"; + case "AUTH_TOTP_ISSUER": + return "TodoList"; + case "DATA_ENCRYPTION_SECRET": + return "test-data-encryption-secret"; + default: + return undefined; + } + } + } + } + ] + }).compile(); + + authService = moduleRef.get(AuthService); + }); + + it("should encrypt user email in database while keeping login flow available", async () => { + await authService.sendEmailCode("User@Example.com"); + expect(authMailService.sentMessages).toHaveLength(1); + expect(authMailService.sentMessages[0]?.email).toBe("user@example.com"); + + const loginResult = await authService.loginWithEmailCode( + "USER@example.com", + authMailService.sentMessages[0]?.code ?? "" + ); + + expect(loginResult.user.email).toBe("user@example.com"); + expect(loginResult.accessToken).toContain("user@example.com"); + + const storedUser = prismaService.getUsers()[0]; + expect(storedUser?.email).not.toBe("user@example.com"); + expect(storedUser?.emailHash).toMatch(/^[a-f0-9]{64}$/); + }); + + it("should decrypt user email when refreshing token", async () => { + await authService.sendEmailCode("refresh@example.com"); + const loginResult = await authService.loginWithEmailCode( + "refresh@example.com", + authMailService.sentMessages[0]?.code ?? "" + ); + + const refreshResult = await authService.refreshTokens(loginResult.refreshToken); + expect(refreshResult.user.email).toBe("refresh@example.com"); + expect(refreshResult.accessToken).toContain("refresh@example.com"); + }); + + it("should reject invalid verification code", async () => { + await authService.sendEmailCode("invalid@example.com"); + + await expect( + authService.loginWithEmailCode("invalid@example.com", "000000") + ).rejects.toBeInstanceOf(UnauthorizedException); + }); +}); diff --git a/apps/api/test/openai-compatible-provider.spec.ts b/apps/api/test/openai-compatible-provider.spec.ts new file mode 100644 index 0000000..7654669 --- /dev/null +++ b/apps/api/test/openai-compatible-provider.spec.ts @@ -0,0 +1,80 @@ +import { AiChannel } from "../generated/prisma/client"; +import { OpenAiCompatibleProvider } from "../src/ai/providers/openai-compatible.provider"; + +describe("OpenAiCompatibleProvider", () => { + const originalFetch = global.fetch; + + afterEach(() => { + global.fetch = originalFetch; + jest.restoreAllMocks(); + }); + + it("should read text from responses style payload when chat content is empty", async () => { + const provider = new OpenAiCompatibleProvider(); + const fetchMock = jest.fn(async (_input: unknown, init?: RequestInit) => { + expect(init?.method).toBe("POST"); + + return new Response( + JSON.stringify({ + id: "resp_123", + object: "response", + model: "gpt-5.4", + output: [ + { + id: "msg_123", + type: "message", + role: "assistant", + content: [ + { + type: "output_text", + text: "今天优先先完成截止时间最近的任务。" + } + ] + } + ], + usage: { + prompt_tokens: 15, + completion_tokens: 9, + total_tokens: 24 + } + }), + { + status: 200, + headers: { + "content-type": "application/json" + } + } + ); + }); + + global.fetch = fetchMock as typeof global.fetch; + + const result = await provider.execute( + { + channel: AiChannel.USER_KEY, + source: "binding", + sourceId: "binding_user_key_1", + providerName: "airouter", + model: "gpt-5.4", + configId: null, + configName: null, + endpoint: "https://api.airouter.io/v1", + apiKey: "sk_test" + }, + { + userId: "user_1", + message: "帮我安排今天的任务", + sessionId: null + } + ); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(result.content).toBe("今天优先先完成截止时间最近的任务。"); + expect(result.model).toBe("gpt-5.4"); + expect(result.usage).toEqual({ + promptTokens: 15, + completionTokens: 9, + totalTokens: 24 + }); + }); +}); diff --git a/apps/api/test/sync-push.spec.ts b/apps/api/test/sync-push.spec.ts index dfbacba..3c75f9b 100644 --- a/apps/api/test/sync-push.spec.ts +++ b/apps/api/test/sync-push.spec.ts @@ -1,7 +1,9 @@ import request from "supertest"; import { INestApplication, ValidationPipe } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; import { Test, TestingModule } from "@nestjs/testing"; import { PrismaService } from "../src/prisma/prisma.service"; +import { DataEncryptionService } from "../src/security/data-encryption.service"; import { SyncController } from "../src/sync/sync.controller"; import { SyncService } from "../src/sync/sync.service"; @@ -159,6 +161,10 @@ class InMemoryPrismaService { return this.syncOperations.length; } + getRawOperationById(opId: string): SyncOperationRecord | undefined { + return this.syncOperations.find((operation) => operation.opId === opId); + } + seedOperations(records: Array>): void { for (const record of records) { this.syncOperations.push({ @@ -196,7 +202,18 @@ describe("SyncController (integration)", () => { const moduleRef: TestingModule = await Test.createTestingModule({ controllers: [SyncController], - providers: [SyncService, { provide: PrismaService, useValue: prismaService }] + providers: [ + SyncService, + DataEncryptionService, + { provide: PrismaService, useValue: prismaService }, + { + provide: ConfigService, + useValue: { + get: (key: string) => + key === "DATA_ENCRYPTION_SECRET" ? "test-data-encryption-secret" : undefined + } + } + ] }).compile(); app = moduleRef.createNestApplication(); @@ -258,6 +275,9 @@ describe("SyncController (integration)", () => { }) ]); expect(prismaService.getOperationCount()).toBe(2); + expect(prismaService.getRawOperationById("op-create-1")?.payload).not.toBe( + '{"title":"浠诲姟涓€"}' + ); const secondResponse = await request(app.getHttpServer()) .post("/sync/push") diff --git a/apps/api/test/task.spec.ts b/apps/api/test/task.spec.ts index fcc6799..98b25cd 100644 --- a/apps/api/test/task.spec.ts +++ b/apps/api/test/task.spec.ts @@ -1,7 +1,9 @@ import request from "supertest"; import { INestApplication, ValidationPipe } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; import { Test, TestingModule } from "@nestjs/testing"; import { PrismaService } from "../src/prisma/prisma.service"; +import { DataEncryptionService } from "../src/security/data-encryption.service"; import { TaskController } from "../src/task/task.controller"; import { TaskService } from "../src/task/task.service"; import { TaskPriority, TaskStatus } from "../generated/prisma/client"; @@ -355,6 +357,10 @@ class InMemoryPrismaService { return runner(this); } + getRawTaskById(taskId: string): TaskRecord | undefined { + return this.tasks.find((task) => task.id === taskId); + } + private toTaskWithTags( task: TaskRecord ): TaskRecord & { taskTags: Array<{ tag: { name: string } }> } { @@ -390,7 +396,15 @@ describe("TaskController (integration)", () => { controllers: [TaskController], providers: [ TaskService, - { provide: PrismaService, useValue: prismaService as unknown as PrismaService } + DataEncryptionService, + { provide: PrismaService, useValue: prismaService as unknown as PrismaService }, + { + provide: ConfigService, + useValue: { + get: (key: string) => + key === "DATA_ENCRYPTION_SECRET" ? "test-data-encryption-secret" : undefined + } + } ] }).compile(); @@ -425,6 +439,9 @@ describe("TaskController (integration)", () => { expect(createResponse.body.id).toBeDefined(); expect(createResponse.body.tags).toEqual(["工作", "会议"]); const taskId = createResponse.body.id as string; + const rawCreatedTask = prismaService.getRawTaskById(taskId); + expect(rawCreatedTask?.title).not.toBe("准备周会"); + expect(rawCreatedTask?.contentText).not.toBe("整理本周进度"); const listResponse = await request(app.getHttpServer()) .get("/tasks") diff --git a/apps/api/tsconfig.json b/apps/api/tsconfig.json index 17d29bc..f196337 100644 --- a/apps/api/tsconfig.json +++ b/apps/api/tsconfig.json @@ -5,6 +5,6 @@ "rootDir": ".", "outDir": "dist" }, - "include": ["src/**/*.ts", "generated/prisma/**/*.ts"], + "include": ["src/**/*.ts", "scripts/**/*.ts", "generated/prisma/**/*.ts"], "exclude": ["dist", "node_modules"] } diff --git a/apps/web/src/App.tsx b/apps/web/src/App.tsx index c77f482..dd51c3b 100644 --- a/apps/web/src/App.tsx +++ b/apps/web/src/App.tsx @@ -17,8 +17,11 @@ import { import { Navigate, Route, Routes, useLocation, useNavigate } from "react-router-dom"; import { Button } from "@/components/ui/button"; import { cn } from "@/lib/utils"; +import { AiChatPage } from "@/pages/ai-chat-page"; import { EmailLoginPage } from "@/pages/email-login-page"; import { OAuthCallbackPage } from "@/pages/oauth-callback-page"; +import { PlaceholderPage } from "@/pages/placeholder-page"; +import { SettingsPage } from "@/pages/settings-page"; import { TodoShellPage } from "@/pages/todo-shell-page"; import { revokeRefreshToken, type EmailLoginResult } from "@/services/auth-api"; import { @@ -38,16 +41,19 @@ type SidebarItem = { key: string; label: string; icon: LucideIcon; + path: string; }; const SIDEBAR_ITEMS: SidebarItem[] = [ - { key: "dashboard", label: "概览面板", icon: LayoutDashboard }, - { key: "todo", label: "待办事项", icon: ListTodo }, - { key: "ai", label: "AI 建议", icon: Sparkles }, - { key: "notice", label: "提醒中心", icon: Bell }, - { key: "settings", label: "系统设置", icon: Settings } + { key: "dashboard", label: "概览面板", icon: LayoutDashboard, path: "/dashboard" }, + { key: "todo", label: "待办事项", icon: ListTodo, path: "/todo" }, + { key: "ai", label: "AI 助手", icon: Sparkles, path: "/ai" }, + { key: "notice", label: "提醒中心", icon: Bell, path: "/notice" }, + { key: "settings", label: "系统设置", icon: Settings, path: "/settings" } ]; +const READY_SIDEBAR_KEYS = new Set(["todo", "ai", "settings"]); + function toWebSession(payload: EmailLoginResult): WebSession { return { accessToken: payload.accessToken, @@ -104,7 +110,7 @@ function App() { saveSession(nextSession); setSession(nextSession); setMobileSidebarOpen(false); - navigate("/", { replace: true }); + navigate("/todo", { replace: true }); } function handleBootstrapSession(nextSession: WebSession): void { @@ -136,14 +142,21 @@ function App() {