From 180f7a9baa128f5985dbf99b6cdd1f4c734a4d43 Mon Sep 17 00:00:00 2001 From: Yaosanqi137 Date: Mon, 6 Apr 2026 11:44:05 +0800 Subject: [PATCH 01/21] feat(api-ai): add provider registry and routing fallback --- .../src/ai/ai-provider-registry.service.ts | 28 ++ apps/api/src/ai/ai.controller.ts | 41 ++ apps/api/src/ai/ai.module.ts | 14 + apps/api/src/ai/ai.service.ts | 430 ++++++++++++++++++ apps/api/src/ai/ai.types.ts | 52 +++ apps/api/src/ai/dto/ai-chat.dto.ts | 17 + .../ai/dto/upsert-ai-provider-binding.dto.ts | 45 ++ apps/api/src/ai/providers/astrbot.provider.ts | 197 ++++++++ .../providers/openai-compatible.provider.ts | 203 +++++++++ apps/api/src/app.module.ts | 4 +- apps/api/test/ai.spec.ts | 395 ++++++++++++++++ 11 files changed, 1425 insertions(+), 1 deletion(-) create mode 100644 apps/api/src/ai/ai-provider-registry.service.ts create mode 100644 apps/api/src/ai/ai.controller.ts create mode 100644 apps/api/src/ai/ai.module.ts create mode 100644 apps/api/src/ai/ai.service.ts create mode 100644 apps/api/src/ai/ai.types.ts create mode 100644 apps/api/src/ai/dto/ai-chat.dto.ts create mode 100644 apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts create mode 100644 apps/api/src/ai/providers/astrbot.provider.ts create mode 100644 apps/api/src/ai/providers/openai-compatible.provider.ts create mode 100644 apps/api/test/ai.spec.ts diff --git a/apps/api/src/ai/ai-provider-registry.service.ts b/apps/api/src/ai/ai-provider-registry.service.ts new file mode 100644 index 0000000..54387a9 --- /dev/null +++ b/apps/api/src/ai/ai-provider-registry.service.ts @@ -0,0 +1,28 @@ +import { Injectable } from "@nestjs/common"; +import { AiChannel } from "../../generated/prisma/client"; +import { AstrbotProvider } from "./providers/astrbot.provider"; +import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider"; +import { AiChannelExecutor } from "./ai.types"; + +@Injectable() +export class AiProviderRegistryService { + private readonly executors = new Map(); + + constructor( + openAiCompatibleProvider: OpenAiCompatibleProvider, + astrbotProvider: AstrbotProvider + ) { + this.executors.set(AiChannel.USER_KEY, openAiCompatibleProvider); + this.executors.set(AiChannel.PUBLIC_POOL, openAiCompatibleProvider); + this.executors.set(AiChannel.ASTRBOT, astrbotProvider); + } + + getExecutor(channel: AiChannel): AiChannelExecutor { + const executor = this.executors.get(channel); + if (!executor) { + throw new Error(`未找到 ${channel} 对应的 AI 通道执行器`); + } + + return executor; + } +} diff --git a/apps/api/src/ai/ai.controller.ts b/apps/api/src/ai/ai.controller.ts new file mode 100644 index 0000000..3ca3ff9 --- /dev/null +++ b/apps/api/src/ai/ai.controller.ts @@ -0,0 +1,41 @@ +import { Body, Controller, Get, Headers, Post, UnauthorizedException } from "@nestjs/common"; +import { AiChatDto } from "./dto/ai-chat.dto"; +import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; +import { AiChatResponse, AiService, ListAiBindingsResponse } from "./ai.service"; + +@Controller("ai") +export class AiController { + constructor(private readonly aiService: AiService) {} + + @Get("bindings") + async listBindings( + @Headers("x-user-id") userIdHeader: string | string[] | undefined + ): Promise { + return this.aiService.listBindings(this.resolveUserId(userIdHeader)); + } + + @Post("bindings") + async upsertBinding( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: UpsertAiProviderBindingDto + ) { + return this.aiService.upsertBinding(this.resolveUserId(userIdHeader), body); + } + + @Post("chat") + async chat( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: AiChatDto + ): Promise { + return this.aiService.chat(this.resolveUserId(userIdHeader), body); + } + + private resolveUserId(userIdHeader: string | string[] | undefined): string { + const userId = Array.isArray(userIdHeader) ? userIdHeader[0] : userIdHeader; + if (!userId) { + throw new UnauthorizedException("缺少用户上下文"); + } + + return userId; + } +} diff --git a/apps/api/src/ai/ai.module.ts b/apps/api/src/ai/ai.module.ts new file mode 100644 index 0000000..2655525 --- /dev/null +++ b/apps/api/src/ai/ai.module.ts @@ -0,0 +1,14 @@ +import { Module } from "@nestjs/common"; +import { PrismaModule } from "../prisma/prisma.module"; +import { AiController } from "./ai.controller"; +import { AiProviderRegistryService } from "./ai-provider-registry.service"; +import { AiService } from "./ai.service"; +import { AstrbotProvider } from "./providers/astrbot.provider"; +import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider"; + +@Module({ + imports: [PrismaModule], + controllers: [AiController], + providers: [AiService, AiProviderRegistryService, OpenAiCompatibleProvider, AstrbotProvider] +}) +export class AiModule {} diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts new file mode 100644 index 0000000..c4b4825 --- /dev/null +++ b/apps/api/src/ai/ai.service.ts @@ -0,0 +1,430 @@ +import { + BadGatewayException, + BadRequestException, + Injectable, + Logger, + NotFoundException +} from "@nestjs/common"; +import { + AiChannel, + AiProviderBinding, + AiPublicPoolConfig, + Prisma +} from "../../generated/prisma/client"; +import { PrismaService } from "../prisma/prisma.service"; +import { AiProviderRegistryService } from "./ai-provider-registry.service"; +import { AiChatDto } from "./dto/ai-chat.dto"; +import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; +import { AiResolvedRouteCandidate, AiRouteAttempt, AiRouteFailureError } from "./ai.types"; + +type AiBindingSummary = { + id: string; + channel: AiChannel; + providerName: string; + model: string | null; + endpoint: string | null; + isDefault: boolean; + isEnabled: boolean; + hasApiKey: boolean; + maskedApiKey: string | null; + updatedAt: string; +}; + +type AiRoutePlanEntry = + | { + kind: "candidate"; + candidate: AiResolvedRouteCandidate; + } + | { + kind: "skip"; + attempt: AiRouteAttempt; + }; + +export type ListAiBindingsResponse = { + routeOrder: AiChannel[]; + bindings: AiBindingSummary[]; + publicPool: { + enabled: boolean; + providerName: string | null; + model: string | null; + endpoint: string | null; + hasApiKey: boolean; + } | null; +}; + +export type AiChatResponse = { + channel: AiChannel; + providerName: string; + model: string | null; + content: string; + sessionId: string | null; + attempts: AiRouteAttempt[]; +}; + +@Injectable() +export class AiService { + private readonly logger = new Logger(AiService.name); + + constructor( + private readonly prismaService: PrismaService, + private readonly aiProviderRegistryService: AiProviderRegistryService + ) {} + + async listBindings(userId: string): Promise { + const [bindings, publicPool] = await Promise.all([ + this.prismaService.aiProviderBinding.findMany({ + where: { + userId + }, + orderBy: [{ channel: "asc" }, { isDefault: "desc" }, { updatedAt: "desc" }] + }), + this.prismaService.aiPublicPoolConfig.findFirst({ + orderBy: { + updatedAt: "desc" + } + }) + ]); + + return { + routeOrder: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL], + bindings: bindings.map((binding) => this.serializeBinding(binding)), + publicPool: publicPool + ? { + enabled: publicPool.enabled, + providerName: publicPool.providerName, + model: publicPool.model, + endpoint: publicPool.endpoint, + hasApiKey: Boolean(publicPool.encryptedApiKey) + } + : null + }; + } + + async upsertBinding(userId: string, dto: UpsertAiProviderBindingDto): Promise { + if (dto.channel === AiChannel.PUBLIC_POOL) { + throw new BadRequestException("公共 AI 通道只能由管理员配置"); + } + + const result = await this.prismaService.$transaction(async (tx) => { + if (dto.isDefault) { + const where: Prisma.AiProviderBindingWhereInput = { + userId, + channel: dto.channel + }; + + if (dto.id) { + where.id = { + not: dto.id + }; + } + + await tx.aiProviderBinding.updateMany({ + where, + data: { + isDefault: false + } + }); + } + + if (!dto.id) { + return tx.aiProviderBinding.create({ + data: { + userId, + channel: dto.channel, + providerName: dto.providerName.trim(), + model: this.normalizeOptionalString(dto.model), + endpoint: this.normalizeOptionalString(dto.endpoint), + encryptedApiKey: this.normalizeOptionalString(dto.apiKey), + isDefault: dto.isDefault ?? false, + isEnabled: dto.isEnabled ?? true + } + }); + } + + const existingBinding = await tx.aiProviderBinding.findFirst({ + where: { + id: dto.id, + userId + } + }); + + if (!existingBinding) { + throw new NotFoundException("AI 通道配置不存在"); + } + + const updateData: Prisma.AiProviderBindingUpdateInput = { + channel: dto.channel, + providerName: dto.providerName.trim(), + model: this.normalizeOptionalString(dto.model), + isDefault: dto.isDefault ?? existingBinding.isDefault, + isEnabled: dto.isEnabled ?? existingBinding.isEnabled + }; + + if (dto.endpoint !== undefined) { + updateData.endpoint = this.normalizeOptionalString(dto.endpoint); + } + + if (dto.apiKey !== undefined) { + updateData.encryptedApiKey = this.normalizeOptionalString(dto.apiKey); + } + + return tx.aiProviderBinding.update({ + where: { + id: dto.id + }, + data: updateData + }); + }); + + return this.serializeBinding(result); + } + + async chat(userId: string, dto: AiChatDto): Promise { + const attempts: AiRouteAttempt[] = []; + const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null); + + for (const entry of plan) { + if (entry.kind === "skip") { + attempts.push(entry.attempt); + continue; + } + + const executor = this.aiProviderRegistryService.getExecutor(entry.candidate.channel); + + try { + const result = await executor.execute(entry.candidate, { + userId, + message: dto.message, + sessionId: dto.sessionId ?? null + }); + + attempts.push({ + channel: result.channel, + providerName: result.providerName, + model: result.model, + status: "success", + reasonCode: null, + reasonMessage: null + }); + + return { + channel: result.channel, + providerName: result.providerName, + model: result.model, + content: result.content, + sessionId: result.sessionId, + attempts + }; + } catch (error) { + const failureAttempt = this.toFailureAttempt(entry.candidate, error); + attempts.push(failureAttempt); + this.logger.warn( + `AI 通道降级:channel=${failureAttempt.channel} provider=${failureAttempt.providerName ?? "unknown"} code=${failureAttempt.reasonCode ?? "UNKNOWN"} message=${failureAttempt.reasonMessage ?? "unknown"}` + ); + } + } + + throw new BadGatewayException({ + message: "当前没有可用的 AI 通道,请稍后重试", + attempts + }); + } + + private async buildRoutePlan( + userId: string, + bindingId: string | null + ): Promise { + const plan: AiRoutePlanEntry[] = []; + const consumedChannels = new Set(); + + if (bindingId) { + const pinnedBinding = await this.prismaService.aiProviderBinding.findFirst({ + where: { + id: bindingId, + userId, + isEnabled: true + } + }); + + if (!pinnedBinding) { + throw new NotFoundException("指定的 AI 通道配置不存在或已禁用"); + } + + plan.push({ + kind: "candidate", + candidate: this.toBindingCandidate(pinnedBinding) + }); + consumedChannels.add(pinnedBinding.channel); + } + + for (const channel of [AiChannel.USER_KEY, AiChannel.ASTRBOT]) { + if (consumedChannels.has(channel)) { + continue; + } + + const binding = await this.findPreferredBinding(userId, channel); + if (binding) { + plan.push({ + kind: "candidate", + candidate: this.toBindingCandidate(binding) + }); + continue; + } + + plan.push({ + kind: "skip", + attempt: { + channel, + providerName: null, + model: null, + status: "skipped", + reasonCode: "CHANNEL_NOT_CONFIGURED", + reasonMessage: + channel === AiChannel.USER_KEY + ? "当前用户未配置可用的自备 Key 通道" + : "当前用户未配置可用的 AstrBot 通道" + } + }); + } + + const publicPool = await this.findEnabledPublicPool(); + if (publicPool) { + plan.push({ + kind: "candidate", + candidate: this.toPublicPoolCandidate(publicPool) + }); + } else { + plan.push({ + kind: "skip", + attempt: { + channel: AiChannel.PUBLIC_POOL, + providerName: null, + model: null, + status: "skipped", + reasonCode: "PUBLIC_POOL_DISABLED", + reasonMessage: "公共 AI 通道未开启" + } + }); + } + + return plan; + } + + private async findPreferredBinding( + userId: string, + channel: AiChannel + ): Promise { + return this.prismaService.aiProviderBinding.findFirst({ + where: { + userId, + channel, + isEnabled: true + }, + orderBy: [{ isDefault: "desc" }, { updatedAt: "desc" }] + }); + } + + private async findEnabledPublicPool(): Promise { + return this.prismaService.aiPublicPoolConfig.findFirst({ + where: { + enabled: true + }, + orderBy: { + updatedAt: "desc" + } + }); + } + + private toBindingCandidate(binding: AiProviderBinding): AiResolvedRouteCandidate { + return { + channel: binding.channel, + source: "binding", + sourceId: binding.id, + providerName: binding.providerName, + model: binding.model, + endpoint: binding.endpoint, + apiKey: binding.encryptedApiKey + }; + } + + private toPublicPoolCandidate(publicPool: AiPublicPoolConfig): AiResolvedRouteCandidate { + return { + channel: AiChannel.PUBLIC_POOL, + source: "public_pool", + sourceId: publicPool.id, + providerName: publicPool.providerName ?? "public-pool", + model: publicPool.model, + endpoint: publicPool.endpoint, + apiKey: publicPool.encryptedApiKey + }; + } + + private serializeBinding(binding: AiProviderBinding): AiBindingSummary { + return { + id: binding.id, + channel: binding.channel, + providerName: binding.providerName, + model: binding.model, + endpoint: binding.endpoint, + isDefault: binding.isDefault, + isEnabled: binding.isEnabled, + hasApiKey: Boolean(binding.encryptedApiKey), + maskedApiKey: this.maskSecret(binding.encryptedApiKey), + updatedAt: binding.updatedAt.toISOString() + }; + } + + private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt { + if (error instanceof AiRouteFailureError) { + return { + channel: error.channel, + providerName: error.providerName, + model: candidate.model, + status: "failed", + reasonCode: error.code, + reasonMessage: error.message + }; + } + + if (error instanceof Error) { + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + status: "failed", + reasonCode: "UNKNOWN_ERROR", + reasonMessage: error.message + }; + } + + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + status: "failed", + reasonCode: "UNKNOWN_ERROR", + reasonMessage: "未知错误" + }; + } + + private normalizeOptionalString(value: string | undefined): string | null { + if (value === undefined) { + return null; + } + + const normalizedValue = value.trim(); + return normalizedValue.length > 0 ? normalizedValue : null; + } + + private maskSecret(secret: string | null): string | null { + if (!secret) { + return null; + } + + if (secret.length <= 6) { + return "*".repeat(secret.length); + } + + return `${secret.slice(0, 4)}***${secret.slice(-2)}`; + } +} diff --git a/apps/api/src/ai/ai.types.ts b/apps/api/src/ai/ai.types.ts new file mode 100644 index 0000000..e576c61 --- /dev/null +++ b/apps/api/src/ai/ai.types.ts @@ -0,0 +1,52 @@ +import { AiChannel } from "../../generated/prisma/client"; + +export type AiResolvedRouteCandidate = { + channel: AiChannel; + source: "binding" | "public_pool"; + sourceId: string | null; + providerName: string; + model: string | null; + endpoint: string | null; + apiKey: string | null; +}; + +export type AiChatInput = { + userId: string; + message: string; + sessionId: string | null; +}; + +export type AiChatResult = { + channel: AiChannel; + providerName: string; + model: string | null; + content: string; + sessionId: string | null; + raw: unknown; +}; + +export type AiRouteAttempt = { + channel: AiChannel; + providerName: string | null; + model: string | null; + status: "skipped" | "failed" | "success"; + reasonCode: string | null; + reasonMessage: string | null; +}; + +export class AiRouteFailureError extends Error { + constructor( + public readonly channel: AiChannel, + public readonly providerName: string, + public readonly code: string, + message: string + ) { + super(message); + this.name = "AiRouteFailureError"; + Object.setPrototypeOf(this, new.target.prototype); + } +} + +export interface AiChannelExecutor { + execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise; +} diff --git a/apps/api/src/ai/dto/ai-chat.dto.ts b/apps/api/src/ai/dto/ai-chat.dto.ts new file mode 100644 index 0000000..a89692a --- /dev/null +++ b/apps/api/src/ai/dto/ai-chat.dto.ts @@ -0,0 +1,17 @@ +import { IsOptional, IsString, MinLength } from "class-validator"; + +export class AiChatDto { + @IsString() + @MinLength(1) + message!: string; + + @IsOptional() + @IsString() + @MinLength(1) + sessionId?: string; + + @IsOptional() + @IsString() + @MinLength(1) + bindingId?: string; +} diff --git a/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts new file mode 100644 index 0000000..4bffff0 --- /dev/null +++ b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts @@ -0,0 +1,45 @@ +import { AiChannel } from "../../../generated/prisma/client"; +import { IsBoolean, IsEnum, IsOptional, IsString, IsUrl, MinLength } from "class-validator"; + +export class UpsertAiProviderBindingDto { + @IsOptional() + @IsString() + @MinLength(1) + id?: string; + + @IsEnum(AiChannel) + channel!: AiChannel; + + @IsString() + @MinLength(1) + providerName!: string; + + @IsOptional() + @IsString() + @MinLength(1) + model?: string; + + @IsOptional() + @IsUrl( + { + require_tld: false + }, + { + message: "endpoint 必须是合法的 URL" + } + ) + endpoint?: string; + + @IsOptional() + @IsString() + @MinLength(1) + apiKey?: string; + + @IsOptional() + @IsBoolean() + isDefault?: boolean; + + @IsOptional() + @IsBoolean() + isEnabled?: boolean; +} diff --git a/apps/api/src/ai/providers/astrbot.provider.ts b/apps/api/src/ai/providers/astrbot.provider.ts new file mode 100644 index 0000000..419139d --- /dev/null +++ b/apps/api/src/ai/providers/astrbot.provider.ts @@ -0,0 +1,197 @@ +import { Injectable } from "@nestjs/common"; +import { + AiChannelExecutor, + AiChatInput, + AiChatResult, + AiResolvedRouteCandidate, + AiRouteFailureError +} from "../ai.types"; + +@Injectable() +export class AstrbotProvider implements AiChannelExecutor { + async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise { + if (!candidate.endpoint) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_ENDPOINT", + "缺少 AstrBot 服务地址配置" + ); + } + + if (!candidate.apiKey) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_API_KEY", + "缺少 AstrBot API Key 配置" + ); + } + + const requestUrl = this.buildRequestUrl(candidate.endpoint); + + let response: Response; + try { + response = await fetch(requestUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${candidate.apiKey}` + }, + body: JSON.stringify({ + username: input.userId, + session_id: input.sessionId ?? undefined, + message: input.message, + enable_streaming: false, + selected_provider: candidate.providerName || undefined, + selected_model: candidate.model ?? undefined + }), + signal: AbortSignal.timeout(30000) + }); + } catch (error) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "UPSTREAM_UNREACHABLE", + this.toErrorMessage(error, "AstrBot 服务请求失败") + ); + } + + const rawText = await response.text(); + if (!response.ok) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + `UPSTREAM_HTTP_${response.status}`, + this.extractHttpErrorMessage(rawText, response.status) + ); + } + + const events = this.parseSseEvents(rawText); + let content = ""; + let sessionId = input.sessionId; + + for (const event of events) { + const type = this.readString(event["type"]); + if (type === "session_id") { + sessionId = this.readString(event["session_id"]) ?? sessionId; + continue; + } + + if (type === "error") { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + this.readString(event["code"]) ?? "ASTRBOT_ERROR", + this.readString(event["data"]) ?? "AstrBot 返回错误" + ); + } + + if (type !== "plain") { + continue; + } + + const chainType = this.readString(event["chain_type"]); + if ( + chainType === "reasoning" || + chainType === "tool_call" || + chainType === "tool_call_result" + ) { + continue; + } + + const data = this.readString(event["data"]); + if (!data) { + continue; + } + + if (event["streaming"] === true) { + content += data; + continue; + } + + content = data; + } + + if (!content.trim()) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "EMPTY_RESPONSE", + "AstrBot 没有返回有效内容" + ); + } + + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + content, + sessionId, + raw: events + }; + } + + private buildRequestUrl(endpoint: string): string { + const normalizedEndpoint = endpoint.replace(/\/+$/, ""); + if (normalizedEndpoint.endsWith("/api/v1/chat")) { + return normalizedEndpoint; + } + if (normalizedEndpoint.endsWith("/api/v1")) { + return `${normalizedEndpoint}/chat`; + } + if (normalizedEndpoint.endsWith("/api")) { + return `${normalizedEndpoint}/v1/chat`; + } + return `${normalizedEndpoint}/api/v1/chat`; + } + + private parseSseEvents(rawText: string): Array> { + return rawText + .split(/\r?\n\r?\n/) + .map((block) => + block + .split(/\r?\n/) + .filter((line) => line.startsWith("data:")) + .map((line) => line.slice(5).trim()) + .join("\n") + ) + .filter((payload) => payload.length > 0) + .map((payload) => { + try { + return JSON.parse(payload) as Record; + } catch { + return null; + } + }) + .filter((item): item is Record => item !== null); + } + + private extractHttpErrorMessage(rawText: string, statusCode: number): string { + try { + const payload = JSON.parse(rawText) as Record; + if (typeof payload["message"] === "string") { + return payload["message"]; + } + if (typeof payload["data"] === "string") { + return payload["data"]; + } + } catch { + return `AstrBot 服务调用失败,状态码 ${statusCode}`; + } + + return `AstrBot 服务调用失败,状态码 ${statusCode}`; + } + + private readString(value: unknown): string | null { + return typeof value === "string" ? value : null; + } + + private toErrorMessage(error: unknown, fallback: string): string { + if (error instanceof Error && error.message) { + return error.message; + } + + return fallback; + } +} diff --git a/apps/api/src/ai/providers/openai-compatible.provider.ts b/apps/api/src/ai/providers/openai-compatible.provider.ts new file mode 100644 index 0000000..2ca1723 --- /dev/null +++ b/apps/api/src/ai/providers/openai-compatible.provider.ts @@ -0,0 +1,203 @@ +import { Injectable } from "@nestjs/common"; +import { + AiChannelExecutor, + AiChatInput, + AiChatResult, + AiResolvedRouteCandidate, + AiRouteFailureError +} from "../ai.types"; + +@Injectable() +export class OpenAiCompatibleProvider implements AiChannelExecutor { + async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise { + if (!candidate.endpoint) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_ENDPOINT", + "缺少 AI 服务地址配置" + ); + } + + if (!candidate.apiKey) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_API_KEY", + "缺少 AI 服务密钥配置" + ); + } + + if (!candidate.model) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_MODEL", + "缺少 AI 模型配置" + ); + } + + const requestUrl = this.buildRequestUrl(candidate.endpoint); + + let response: Response; + try { + response = await fetch(requestUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${candidate.apiKey}` + }, + body: JSON.stringify({ + model: candidate.model, + messages: [ + { + role: "user", + content: input.message + } + ], + stream: false + }), + signal: AbortSignal.timeout(30000) + }); + } catch (error) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "UPSTREAM_UNREACHABLE", + this.toErrorMessage(error, "AI 服务请求失败") + ); + } + + let payload: unknown; + try { + payload = await response.json(); + } catch (error) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "INVALID_RESPONSE", + this.toErrorMessage(error, "AI 服务返回了无法解析的数据") + ); + } + + if (!response.ok) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + `UPSTREAM_HTTP_${response.status}`, + this.extractErrorMessage(payload, `AI 服务调用失败,状态码 ${response.status}`) + ); + } + + const content = this.extractAssistantText(payload); + if (!content.trim()) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "EMPTY_RESPONSE", + "AI 服务没有返回有效内容" + ); + } + + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: this.extractModel(payload) ?? candidate.model, + content, + sessionId: input.sessionId, + raw: payload + }; + } + + private buildRequestUrl(endpoint: string): string { + const normalizedEndpoint = endpoint.replace(/\/+$/, ""); + if (normalizedEndpoint.endsWith("/chat/completions")) { + return normalizedEndpoint; + } + if (normalizedEndpoint.endsWith("/v1")) { + return `${normalizedEndpoint}/chat/completions`; + } + return `${normalizedEndpoint}/v1/chat/completions`; + } + + private extractAssistantText(payload: unknown): string { + if (!this.isRecord(payload)) { + return ""; + } + + const choices = payload["choices"]; + if (!Array.isArray(choices) || choices.length === 0) { + return ""; + } + + const firstChoice = choices[0]; + if (!this.isRecord(firstChoice)) { + return ""; + } + + const message = firstChoice["message"]; + if (!this.isRecord(message)) { + return ""; + } + + return this.extractMessageContent(message["content"]); + } + + private extractMessageContent(content: unknown): string { + if (typeof content === "string") { + return content; + } + + if (!Array.isArray(content)) { + return ""; + } + + return content + .map((item) => { + if (!this.isRecord(item)) { + return ""; + } + + if (typeof item["text"] === "string") { + return item["text"]; + } + + return ""; + }) + .filter((item) => item.length > 0) + .join("\n"); + } + + private extractModel(payload: unknown): string | null { + if (!this.isRecord(payload) || typeof payload["model"] !== "string") { + return null; + } + + return payload["model"]; + } + + private extractErrorMessage(payload: unknown, fallback: string): string { + if (!this.isRecord(payload)) { + return fallback; + } + + const error = payload["error"]; + if (!this.isRecord(error) || typeof error["message"] !== "string") { + return fallback; + } + + return error["message"]; + } + + private isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; + } + + private toErrorMessage(error: unknown, fallback: string): string { + if (error instanceof Error && error.message) { + return error.message; + } + + return fallback; + } +} diff --git a/apps/api/src/app.module.ts b/apps/api/src/app.module.ts index 8c2a78f..db0e7c4 100644 --- a/apps/api/src/app.module.ts +++ b/apps/api/src/app.module.ts @@ -1,5 +1,6 @@ import { Module } from "@nestjs/common"; import { ConfigModule } from "@nestjs/config"; +import { AiModule } from "./ai/ai.module"; import { AttachmentModule } from "./attachment/attachment.module"; import { AuthModule } from "./auth/auth.module"; import { PrismaModule } from "./prisma/prisma.module"; @@ -16,7 +17,8 @@ import { TaskModule } from "./task/task.module"; AuthModule, TaskModule, AttachmentModule, - SyncModule + SyncModule, + AiModule ] }) export class AppModule {} diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts new file mode 100644 index 0000000..bbef854 --- /dev/null +++ b/apps/api/test/ai.spec.ts @@ -0,0 +1,395 @@ +import request from "supertest"; +import { INestApplication, ValidationPipe } from "@nestjs/common"; +import { Test, TestingModule } from "@nestjs/testing"; +import { AiChannel, AiProviderBinding, AiPublicPoolConfig } from "../generated/prisma/client"; +import { AiController } from "../src/ai/ai.controller"; +import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service"; +import { AiService } from "../src/ai/ai.service"; +import { AiChannelExecutor, AiRouteFailureError } from "../src/ai/ai.types"; +import { PrismaService } from "../src/prisma/prisma.service"; + +class InMemoryAiPrismaService { + private bindingIdSequence = 1; + private publicPoolIdSequence = 1; + private bindings: AiProviderBinding[] = []; + private publicPools: AiPublicPoolConfig[] = []; + + readonly aiProviderBinding = { + findMany: async (args: { + where: { + userId: string; + }; + }) => { + return this.bindings + .filter((binding) => binding.userId === args.where.userId) + .sort((left, right) => right.updatedAt.getTime() - left.updatedAt.getTime()); + }, + + findFirst: async (args: { + where: { + id?: string; + userId?: string; + channel?: AiChannel; + isEnabled?: boolean; + }; + }) => { + return ( + this.bindings + .filter((binding) => { + if (args.where.id !== undefined && binding.id !== args.where.id) { + return false; + } + if (args.where.userId !== undefined && binding.userId !== args.where.userId) { + return false; + } + if (args.where.channel !== undefined && binding.channel !== args.where.channel) { + return false; + } + if (args.where.isEnabled !== undefined && binding.isEnabled !== args.where.isEnabled) { + return false; + } + return true; + }) + .sort((left, right) => { + if (left.isDefault !== right.isDefault) { + return Number(right.isDefault) - Number(left.isDefault); + } + return right.updatedAt.getTime() - left.updatedAt.getTime(); + })[0] ?? null + ); + }, + + create: async (args: { + data: { + userId: string; + channel: AiChannel; + providerName: string; + model: string | null; + endpoint: string | null; + encryptedApiKey: string | null; + isDefault: boolean; + isEnabled: boolean; + }; + }) => { + const now = new Date(); + const binding: AiProviderBinding = { + id: `binding_${this.bindingIdSequence++}`, + userId: args.data.userId, + channel: args.data.channel, + providerName: args.data.providerName, + model: args.data.model, + encryptedApiKey: args.data.encryptedApiKey, + endpoint: args.data.endpoint, + isDefault: args.data.isDefault, + isEnabled: args.data.isEnabled, + createdAt: now, + updatedAt: now + }; + + this.bindings.push(binding); + return binding; + }, + + update: async (args: { + where: { + id: string; + }; + data: Partial; + }) => { + const binding = this.bindings.find((item) => item.id === args.where.id); + if (!binding) { + throw new Error("binding not found"); + } + + Object.assign(binding, args.data, { updatedAt: new Date() }); + return binding; + }, + + updateMany: async (args: { + where: { + userId?: string; + channel?: AiChannel; + id?: { + not: string; + }; + }; + data: { + isDefault?: boolean; + }; + }) => { + let count = 0; + for (const binding of this.bindings) { + if (args.where.userId !== undefined && binding.userId !== args.where.userId) { + continue; + } + if (args.where.channel !== undefined && binding.channel !== args.where.channel) { + continue; + } + if (args.where.id?.not !== undefined && binding.id === args.where.id.not) { + continue; + } + + if (args.data.isDefault !== undefined) { + binding.isDefault = args.data.isDefault; + binding.updatedAt = new Date(); + } + count += 1; + } + + return { count }; + } + }; + + readonly aiPublicPoolConfig = { + findFirst: async (args?: { + where?: { + enabled?: boolean; + }; + }) => { + const items = this.publicPools + .filter((item) => + args?.where?.enabled === undefined ? true : item.enabled === args.where.enabled + ) + .sort((left, right) => right.updatedAt.getTime() - left.updatedAt.getTime()); + + return items[0] ?? null; + } + }; + + async $transaction(callback: (tx: InMemoryAiPrismaService) => Promise): Promise { + return callback(this); + } + + seedBinding(binding: Omit): void { + const now = new Date(); + this.bindings.push({ + ...binding, + createdAt: now, + updatedAt: now + }); + } + + seedPublicPool(publicPool: Omit): void { + const now = new Date(); + this.publicPools.push({ + id: `pool_${this.publicPoolIdSequence++}`, + createdAt: now, + updatedAt: now, + ...publicPool + }); + } +} + +class StaticExecutor implements AiChannelExecutor { + constructor( + private readonly resolver: (channel: AiChannel) => { + content?: string; + code?: string; + message?: string; + } + ) {} + + async execute(candidate: { channel: AiChannel; providerName: string; model: string | null }) { + const result = this.resolver(candidate.channel); + if (result.code) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + result.code, + result.message ?? "执行失败" + ); + } + + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + content: result.content ?? "", + sessionId: "session_ai", + raw: null + }; + } +} + +describe("AiController (integration)", () => { + let app: INestApplication; + let prismaService: InMemoryAiPrismaService; + + beforeEach(async () => { + prismaService = new InMemoryAiPrismaService(); + + const openAiExecutor = new StaticExecutor((channel) => + channel === AiChannel.USER_KEY + ? { + code: "UPSTREAM_UNREACHABLE", + message: "用户自备 Key 渠道暂时不可用" + } + : { + content: "公共 AI 已接管" + } + ); + const astrbotExecutor = new StaticExecutor(() => ({ + content: "AstrBot 已接管" + })); + + const moduleRef: TestingModule = await Test.createTestingModule({ + controllers: [AiController], + providers: [ + AiService, + { + provide: PrismaService, + useValue: prismaService + }, + { + provide: AiProviderRegistryService, + useValue: { + getExecutor: (channel: AiChannel) => + channel === AiChannel.ASTRBOT ? astrbotExecutor : openAiExecutor + } + } + ] + }).compile(); + + app = moduleRef.createNestApplication(); + app.useGlobalPipes( + new ValidationPipe({ + transform: true, + whitelist: true, + forbidNonWhitelisted: true + }) + ); + await app.init(); + }); + + afterEach(async () => { + await app.close(); + }); + + it("should create and list ai bindings", async () => { + await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.ASTRBOT, + providerName: "astrbot-main", + model: "deepseek-chat", + endpoint: "http://127.0.0.1:6185", + apiKey: "abk_secret_1234", + isDefault: true, + isEnabled: true + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .get("/ai/bindings") + .set("x-user-id", "user_1") + .expect(200); + + expect(response.body.routeOrder).toEqual([ + AiChannel.USER_KEY, + AiChannel.ASTRBOT, + AiChannel.PUBLIC_POOL + ]); + expect(response.body.bindings).toHaveLength(1); + expect(response.body.bindings[0]).toMatchObject({ + channel: AiChannel.ASTRBOT, + providerName: "astrbot-main", + model: "deepseek-chat", + hasApiKey: true, + maskedApiKey: "abk_***34", + isDefault: true + }); + }); + + it("should fallback from user key to astrbot", async () => { + prismaService.seedBinding({ + id: "binding_user_key", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + encryptedApiKey: "sk-user", + endpoint: "https://api.example.com", + isDefault: true, + isEnabled: true + }); + prismaService.seedBinding({ + id: "binding_astrbot", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "astrbot-main", + model: "deepseek-chat", + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "帮我安排今天的任务" + }) + .expect(201); + + expect(response.body.channel).toBe(AiChannel.ASTRBOT); + expect(response.body.content).toBe("AstrBot 已接管"); + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + status: "failed", + reasonCode: "UPSTREAM_UNREACHABLE", + reasonMessage: "用户自备 Key 渠道暂时不可用" + }, + { + channel: AiChannel.ASTRBOT, + providerName: "astrbot-main", + model: "deepseek-chat", + status: "success", + reasonCode: null, + reasonMessage: null + } + ]); + }); + + it("should return skipped attempts when no channel is available", async () => { + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "帮我总结今天的安排" + }) + .expect(502); + + expect(response.body.message).toBe("当前没有可用的 AI 通道,请稍后重试"); + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: null, + model: null, + status: "skipped", + reasonCode: "CHANNEL_NOT_CONFIGURED", + reasonMessage: "当前用户未配置可用的自备 Key 通道" + }, + { + channel: AiChannel.ASTRBOT, + providerName: null, + model: null, + status: "skipped", + reasonCode: "CHANNEL_NOT_CONFIGURED", + reasonMessage: "当前用户未配置可用的 AstrBot 通道" + }, + { + channel: AiChannel.PUBLIC_POOL, + providerName: null, + model: null, + status: "skipped", + reasonCode: "PUBLIC_POOL_DISABLED", + reasonMessage: "公共 AI 通道未开启" + } + ]); + }); +}); From 2bce9a59c69b13e41bf21466c8d3a43e8969b0c5 Mon Sep 17 00:00:00 2001 From: Yaosanqi137 Date: Mon, 6 Apr 2026 12:19:13 +0800 Subject: [PATCH 02/21] fix(api-ai): stop astrbot stream on end event --- apps/api/src/ai/providers/astrbot.provider.ts | 53 +++++++++++- apps/api/test/astrbot-provider.spec.ts | 80 +++++++++++++++++++ 2 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 apps/api/test/astrbot-provider.spec.ts diff --git a/apps/api/src/ai/providers/astrbot.provider.ts b/apps/api/src/ai/providers/astrbot.provider.ts index 419139d..a413a3b 100644 --- a/apps/api/src/ai/providers/astrbot.provider.ts +++ b/apps/api/src/ai/providers/astrbot.provider.ts @@ -57,8 +57,8 @@ export class AstrbotProvider implements AiChannelExecutor { ); } - const rawText = await response.text(); if (!response.ok) { + const rawText = await response.text(); throw new AiRouteFailureError( candidate.channel, candidate.providerName, @@ -67,7 +67,7 @@ export class AstrbotProvider implements AiChannelExecutor { ); } - const events = this.parseSseEvents(rawText); + const events = await this.readSseEvents(response); let content = ""; let sessionId = input.sessionId; @@ -167,6 +167,55 @@ export class AstrbotProvider implements AiChannelExecutor { .filter((item): item is Record => item !== null); } + private async readSseEvents(response: Response): Promise>> { + if (!response.body) { + return this.parseSseEvents(await response.text()); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + const events: Array> = []; + let buffer = ""; + let reachedEndEvent = false; + + try { + while (!reachedEndEvent) { + const { done, value } = await reader.read(); + if (done) { + break; + } + + buffer += decoder.decode(value, { stream: true }); + const segments = buffer.split(/\r?\n\r?\n/); + buffer = segments.pop() ?? ""; + + for (const segment of segments) { + const parsedEvents = this.parseSseEvents(segment); + for (const event of parsedEvents) { + events.push(event); + if (this.readString(event["type"]) === "end") { + reachedEndEvent = true; + break; + } + } + + if (reachedEndEvent) { + break; + } + } + } + + const tail = `${buffer}${decoder.decode()}`; + if (tail.trim().length > 0) { + events.push(...this.parseSseEvents(tail)); + } + } finally { + await reader.cancel(); + } + + return events; + } + private extractHttpErrorMessage(rawText: string, statusCode: number): string { try { const payload = JSON.parse(rawText) as Record; diff --git a/apps/api/test/astrbot-provider.spec.ts b/apps/api/test/astrbot-provider.spec.ts new file mode 100644 index 0000000..6190c4a --- /dev/null +++ b/apps/api/test/astrbot-provider.spec.ts @@ -0,0 +1,80 @@ +import { AiChannel } from "../generated/prisma/client"; +import { AstrbotProvider } from "../src/ai/providers/astrbot.provider"; + +describe("AstrbotProvider", () => { + afterEach(() => { + jest.restoreAllMocks(); + }); + + it("should stop reading once the end event arrives", async () => { + const encoder = new TextEncoder(); + let pullCount = 0; + + const stream = new ReadableStream({ + pull(controller) { + pullCount += 1; + if (pullCount === 1) { + controller.enqueue( + encoder.encode('data: {"type":"session_id","data":null,"session_id":"session_1"}\n\n') + ); + return; + } + + if (pullCount === 2) { + controller.enqueue( + encoder.encode( + 'data: {"type":"plain","data":"TodoList AstrBot 已连接","streaming":false,"chain_type":null}\n\n' + ) + ); + return; + } + + if (pullCount === 3) { + controller.enqueue( + encoder.encode('data: {"type":"end","data":"","streaming":false}\n\n') + ); + return; + } + + return new Promise(() => undefined); + } + }); + + jest.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(stream, { + status: 200, + headers: { + "Content-Type": "text/event-stream" + } + }) + ); + + const provider = new AstrbotProvider(); + + const result = await Promise.race([ + provider.execute( + { + channel: AiChannel.ASTRBOT, + source: "binding", + sourceId: "binding_1", + providerName: "", + model: null, + endpoint: "http://127.0.0.1:6185", + apiKey: "abk_test" + }, + { + userId: "user_1", + message: "ping", + sessionId: null + } + ), + new Promise((_, reject) => { + setTimeout(() => reject(new Error("provider timeout")), 1000); + }) + ]); + + expect(result.content).toBe("TodoList AstrBot 已连接"); + expect(result.sessionId).toBe("session_1"); + expect(pullCount).toBeGreaterThanOrEqual(3); + }); +}); From 2ca790abf90149e3462dba83f542c7dd50f89d5d Mon Sep 17 00:00:00 2001 From: Yaosanqi137 Date: Mon, 6 Apr 2026 12:33:10 +0800 Subject: [PATCH 03/21] feat(api-ai): support astrbot config selection --- apps/api/prisma/schema.prisma | 2 + apps/api/src/ai/ai.service.ts | 41 +++++++++++++- apps/api/src/ai/ai.types.ts | 2 + .../ai/dto/upsert-ai-provider-binding.dto.ts | 13 ++++- apps/api/src/ai/providers/astrbot.provider.ts | 19 ++++--- apps/api/test/ai.spec.ts | 53 ++++++++++++++++--- apps/api/test/astrbot-provider.spec.ts | 2 + 7 files changed, 114 insertions(+), 18 deletions(-) diff --git a/apps/api/prisma/schema.prisma b/apps/api/prisma/schema.prisma index 838df48..facef68 100644 --- a/apps/api/prisma/schema.prisma +++ b/apps/api/prisma/schema.prisma @@ -273,6 +273,8 @@ model AiProviderBinding { channel AiChannel providerName String model String? + configId String? + configName String? encryptedApiKey String? endpoint String? isDefault Boolean @default(false) diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts index c4b4825..b5bf24e 100644 --- a/apps/api/src/ai/ai.service.ts +++ b/apps/api/src/ai/ai.service.ts @@ -22,6 +22,8 @@ type AiBindingSummary = { channel: AiChannel; providerName: string; model: string | null; + configId: string | null; + configName: string | null; endpoint: string | null; isDefault: boolean; isEnabled: boolean; @@ -105,6 +107,8 @@ export class AiService { throw new BadRequestException("公共 AI 通道只能由管理员配置"); } + this.validateBindingInput(dto); + const result = await this.prismaService.$transaction(async (tx) => { if (dto.isDefault) { const where: Prisma.AiProviderBindingWhereInput = { @@ -131,8 +135,10 @@ export class AiService { data: { userId, channel: dto.channel, - providerName: dto.providerName.trim(), + providerName: this.normalizeProviderName(dto.providerName), model: this.normalizeOptionalString(dto.model), + configId: this.normalizeOptionalString(dto.configId), + configName: this.normalizeOptionalString(dto.configName), endpoint: this.normalizeOptionalString(dto.endpoint), encryptedApiKey: this.normalizeOptionalString(dto.apiKey), isDefault: dto.isDefault ?? false, @@ -154,8 +160,10 @@ export class AiService { const updateData: Prisma.AiProviderBindingUpdateInput = { channel: dto.channel, - providerName: dto.providerName.trim(), + providerName: this.normalizeProviderName(dto.providerName), model: this.normalizeOptionalString(dto.model), + configId: this.normalizeOptionalString(dto.configId), + configName: this.normalizeOptionalString(dto.configName), isDefault: dto.isDefault ?? existingBinding.isDefault, isEnabled: dto.isEnabled ?? existingBinding.isEnabled }; @@ -342,6 +350,8 @@ export class AiService { sourceId: binding.id, providerName: binding.providerName, model: binding.model, + configId: binding.configId, + configName: binding.configName, endpoint: binding.endpoint, apiKey: binding.encryptedApiKey }; @@ -354,6 +364,8 @@ export class AiService { sourceId: publicPool.id, providerName: publicPool.providerName ?? "public-pool", model: publicPool.model, + configId: null, + configName: null, endpoint: publicPool.endpoint, apiKey: publicPool.encryptedApiKey }; @@ -365,6 +377,8 @@ export class AiService { channel: binding.channel, providerName: binding.providerName, model: binding.model, + configId: binding.configId, + configName: binding.configName, endpoint: binding.endpoint, isDefault: binding.isDefault, isEnabled: binding.isEnabled, @@ -416,6 +430,29 @@ export class AiService { return normalizedValue.length > 0 ? normalizedValue : null; } + private normalizeProviderName(value: string | undefined): string { + return this.normalizeOptionalString(value) ?? ""; + } + + private validateBindingInput(dto: UpsertAiProviderBindingDto): void { + const providerName = this.normalizeOptionalString(dto.providerName); + const configId = this.normalizeOptionalString(dto.configId); + const configName = this.normalizeOptionalString(dto.configName); + + if (dto.channel === AiChannel.ASTRBOT) { + if (!providerName && !configId && !configName) { + throw new BadRequestException( + "AstrBot 通道至少需要 providerName、configId、configName 三者之一" + ); + } + return; + } + + if (!providerName) { + throw new BadRequestException("当前通道必须提供 providerName"); + } + } + private maskSecret(secret: string | null): string | null { if (!secret) { return null; diff --git a/apps/api/src/ai/ai.types.ts b/apps/api/src/ai/ai.types.ts index e576c61..5c52915 100644 --- a/apps/api/src/ai/ai.types.ts +++ b/apps/api/src/ai/ai.types.ts @@ -6,6 +6,8 @@ export type AiResolvedRouteCandidate = { sourceId: string | null; providerName: string; model: string | null; + configId: string | null; + configName: string | null; endpoint: string | null; apiKey: string | null; }; diff --git a/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts index 4bffff0..144f86d 100644 --- a/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts +++ b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts @@ -10,15 +10,26 @@ export class UpsertAiProviderBindingDto { @IsEnum(AiChannel) channel!: AiChannel; + @IsOptional() @IsString() @MinLength(1) - providerName!: string; + providerName?: string; @IsOptional() @IsString() @MinLength(1) model?: string; + @IsOptional() + @IsString() + @MinLength(1) + configId?: string; + + @IsOptional() + @IsString() + @MinLength(1) + configName?: string; + @IsOptional() @IsUrl( { diff --git a/apps/api/src/ai/providers/astrbot.provider.ts b/apps/api/src/ai/providers/astrbot.provider.ts index a413a3b..a82cb99 100644 --- a/apps/api/src/ai/providers/astrbot.provider.ts +++ b/apps/api/src/ai/providers/astrbot.provider.ts @@ -10,10 +10,13 @@ import { @Injectable() export class AstrbotProvider implements AiChannelExecutor { async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise { + const routeLabel = + candidate.providerName || candidate.configName || candidate.configId || "astrbot"; + if (!candidate.endpoint) { throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + routeLabel, "MISSING_ENDPOINT", "缺少 AstrBot 服务地址配置" ); @@ -22,7 +25,7 @@ export class AstrbotProvider implements AiChannelExecutor { if (!candidate.apiKey) { throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + routeLabel, "MISSING_API_KEY", "缺少 AstrBot API Key 配置" ); @@ -43,6 +46,8 @@ export class AstrbotProvider implements AiChannelExecutor { session_id: input.sessionId ?? undefined, message: input.message, enable_streaming: false, + config_id: candidate.configId ?? undefined, + config_name: candidate.configName ?? undefined, selected_provider: candidate.providerName || undefined, selected_model: candidate.model ?? undefined }), @@ -51,7 +56,7 @@ export class AstrbotProvider implements AiChannelExecutor { } catch (error) { throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + routeLabel, "UPSTREAM_UNREACHABLE", this.toErrorMessage(error, "AstrBot 服务请求失败") ); @@ -61,7 +66,7 @@ export class AstrbotProvider implements AiChannelExecutor { const rawText = await response.text(); throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + routeLabel, `UPSTREAM_HTTP_${response.status}`, this.extractHttpErrorMessage(rawText, response.status) ); @@ -81,7 +86,7 @@ export class AstrbotProvider implements AiChannelExecutor { if (type === "error") { throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + routeLabel, this.readString(event["code"]) ?? "ASTRBOT_ERROR", this.readString(event["data"]) ?? "AstrBot 返回错误" ); @@ -116,7 +121,7 @@ export class AstrbotProvider implements AiChannelExecutor { if (!content.trim()) { throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + routeLabel, "EMPTY_RESPONSE", "AstrBot 没有返回有效内容" ); @@ -124,7 +129,7 @@ export class AstrbotProvider implements AiChannelExecutor { return { channel: candidate.channel, - providerName: candidate.providerName, + providerName: routeLabel, model: candidate.model, content, sessionId, diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts index bbef854..113d037 100644 --- a/apps/api/test/ai.spec.ts +++ b/apps/api/test/ai.spec.ts @@ -5,7 +5,11 @@ import { AiChannel, AiProviderBinding, AiPublicPoolConfig } from "../generated/p import { AiController } from "../src/ai/ai.controller"; import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service"; import { AiService } from "../src/ai/ai.service"; -import { AiChannelExecutor, AiRouteFailureError } from "../src/ai/ai.types"; +import { + AiChannelExecutor, + AiResolvedRouteCandidate, + AiRouteFailureError +} from "../src/ai/ai.types"; import { PrismaService } from "../src/prisma/prisma.service"; class InMemoryAiPrismaService { @@ -65,6 +69,8 @@ class InMemoryAiPrismaService { channel: AiChannel; providerName: string; model: string | null; + configId: string | null; + configName: string | null; endpoint: string | null; encryptedApiKey: string | null; isDefault: boolean; @@ -78,6 +84,8 @@ class InMemoryAiPrismaService { channel: args.data.channel, providerName: args.data.providerName, model: args.data.model, + configId: args.data.configId, + configName: args.data.configName, encryptedApiKey: args.data.encryptedApiKey, endpoint: args.data.endpoint, isDefault: args.data.isDefault, @@ -189,12 +197,12 @@ class StaticExecutor implements AiChannelExecutor { } ) {} - async execute(candidate: { channel: AiChannel; providerName: string; model: string | null }) { + async execute(candidate: AiResolvedRouteCandidate) { const result = this.resolver(candidate.channel); if (result.code) { throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + candidate.providerName || candidate.configName || candidate.configId || "unknown", result.code, result.message ?? "执行失败" ); @@ -202,7 +210,7 @@ class StaticExecutor implements AiChannelExecutor { return { channel: candidate.channel, - providerName: candidate.providerName, + providerName: candidate.providerName || candidate.configName || candidate.configId || "", model: candidate.model, content: result.content ?? "", sessionId: "session_ai", @@ -273,6 +281,7 @@ describe("AiController (integration)", () => { channel: AiChannel.ASTRBOT, providerName: "astrbot-main", model: "deepseek-chat", + configId: "default", endpoint: "http://127.0.0.1:6185", apiKey: "abk_secret_1234", isDefault: true, @@ -295,6 +304,8 @@ describe("AiController (integration)", () => { channel: AiChannel.ASTRBOT, providerName: "astrbot-main", model: "deepseek-chat", + configId: "default", + configName: null, hasApiKey: true, maskedApiKey: "abk_***34", isDefault: true @@ -308,6 +319,8 @@ describe("AiController (integration)", () => { channel: AiChannel.USER_KEY, providerName: "openai", model: "gpt-4o-mini", + configId: null, + configName: null, encryptedApiKey: "sk-user", endpoint: "https://api.example.com", isDefault: true, @@ -317,8 +330,10 @@ describe("AiController (integration)", () => { id: "binding_astrbot", userId: "user_1", channel: AiChannel.ASTRBOT, - providerName: "astrbot-main", - model: "deepseek-chat", + providerName: "", + model: null, + configId: "default", + configName: null, encryptedApiKey: "abk_astrbot", endpoint: "http://127.0.0.1:6185", isDefault: true, @@ -346,8 +361,8 @@ describe("AiController (integration)", () => { }, { channel: AiChannel.ASTRBOT, - providerName: "astrbot-main", - model: "deepseek-chat", + providerName: "default", + model: null, status: "success", reasonCode: null, reasonMessage: null @@ -355,6 +370,28 @@ describe("AiController (integration)", () => { ]); }); + it("should allow astrbot binding with config id only", async () => { + const response = await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.ASTRBOT, + configId: "default", + endpoint: "http://127.0.0.1:6185", + apiKey: "abk_secret_1234", + isDefault: true, + isEnabled: true + }) + .expect(201); + + expect(response.body).toMatchObject({ + channel: AiChannel.ASTRBOT, + providerName: "", + configId: "default", + configName: null + }); + }); + it("should return skipped attempts when no channel is available", async () => { const response = await request(app.getHttpServer()) .post("/ai/chat") diff --git a/apps/api/test/astrbot-provider.spec.ts b/apps/api/test/astrbot-provider.spec.ts index 6190c4a..a8fcee8 100644 --- a/apps/api/test/astrbot-provider.spec.ts +++ b/apps/api/test/astrbot-provider.spec.ts @@ -59,6 +59,8 @@ describe("AstrbotProvider", () => { sourceId: "binding_1", providerName: "", model: null, + configId: "default", + configName: null, endpoint: "http://127.0.0.1:6185", apiKey: "abk_test" }, From 45177e9fade6042a5b4cd62034c661467ec48a77 Mon Sep 17 00:00:00 2001 From: Yaosanqi137 Date: Mon, 6 Apr 2026 12:42:56 +0800 Subject: [PATCH 04/21] feat(api-ai): persist usage logs --- apps/api/src/ai/ai.service.ts | 61 ++++++++++++++++++- apps/api/src/ai/ai.types.ts | 7 +++ apps/api/src/ai/providers/astrbot.provider.ts | 36 +++++++++++ .../providers/openai-compatible.provider.ts | 30 +++++++++ apps/api/test/ai.spec.ts | 57 +++++++++++++++++ apps/api/test/astrbot-provider.spec.ts | 16 ++++- 6 files changed, 205 insertions(+), 2 deletions(-) diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts index b5bf24e..2269d2d 100644 --- a/apps/api/src/ai/ai.service.ts +++ b/apps/api/src/ai/ai.service.ts @@ -15,7 +15,12 @@ import { PrismaService } from "../prisma/prisma.service"; import { AiProviderRegistryService } from "./ai-provider-registry.service"; import { AiChatDto } from "./dto/ai-chat.dto"; import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; -import { AiResolvedRouteCandidate, AiRouteAttempt, AiRouteFailureError } from "./ai.types"; +import { + AiResolvedRouteCandidate, + AiRouteAttempt, + AiRouteFailureError, + AiUsageMetrics +} from "./ai.types"; type AiBindingSummary = { id: string; @@ -198,6 +203,7 @@ export class AiService { } const executor = this.aiProviderRegistryService.getExecutor(entry.candidate.channel); + const startedAt = Date.now(); try { const result = await executor.execute(entry.candidate, { @@ -205,6 +211,7 @@ export class AiService { message: dto.message, sessionId: dto.sessionId ?? null }); + const latencyMs = Date.now() - startedAt; attempts.push({ channel: result.channel, @@ -214,6 +221,16 @@ export class AiService { reasonCode: null, reasonMessage: null }); + await this.recordUsageLog({ + userId, + channel: result.channel, + providerName: result.providerName, + model: result.model, + usage: result.usage, + latencyMs, + success: true, + errorCode: null + }); return { channel: result.channel, @@ -224,8 +241,19 @@ export class AiService { attempts }; } catch (error) { + const latencyMs = Date.now() - startedAt; const failureAttempt = this.toFailureAttempt(entry.candidate, error); attempts.push(failureAttempt); + await this.recordUsageLog({ + userId, + channel: failureAttempt.channel, + providerName: failureAttempt.providerName, + model: failureAttempt.model, + usage: null, + latencyMs, + success: false, + errorCode: failureAttempt.reasonCode + }); this.logger.warn( `AI 通道降级:channel=${failureAttempt.channel} provider=${failureAttempt.providerName ?? "unknown"} code=${failureAttempt.reasonCode ?? "UNKNOWN"} message=${failureAttempt.reasonMessage ?? "unknown"}` ); @@ -464,4 +492,35 @@ export class AiService { return `${secret.slice(0, 4)}***${secret.slice(-2)}`; } + + private async recordUsageLog(input: { + userId: string; + channel: AiChannel; + providerName: string | null; + model: string | null; + usage: AiUsageMetrics | null; + latencyMs: number; + success: boolean; + errorCode: string | null; + }): Promise { + try { + await this.prismaService.aiUsageLog.create({ + data: { + userId: input.userId, + channel: input.channel, + providerName: input.providerName, + model: input.model, + promptTokens: input.usage?.promptTokens ?? 0, + completionTokens: input.usage?.completionTokens ?? 0, + totalTokens: input.usage?.totalTokens ?? 0, + latencyMs: input.latencyMs, + success: input.success, + errorCode: input.errorCode + } + }); + } catch (error) { + const message = error instanceof Error ? error.message : "未知错误"; + this.logger.warn(`写入 AI 使用日志失败:${message}`); + } + } } diff --git a/apps/api/src/ai/ai.types.ts b/apps/api/src/ai/ai.types.ts index 5c52915..ccb8088 100644 --- a/apps/api/src/ai/ai.types.ts +++ b/apps/api/src/ai/ai.types.ts @@ -24,9 +24,16 @@ export type AiChatResult = { model: string | null; content: string; sessionId: string | null; + usage: AiUsageMetrics | null; raw: unknown; }; +export type AiUsageMetrics = { + promptTokens: number; + completionTokens: number; + totalTokens: number; +}; + export type AiRouteAttempt = { channel: AiChannel; providerName: string | null; diff --git a/apps/api/src/ai/providers/astrbot.provider.ts b/apps/api/src/ai/providers/astrbot.provider.ts index a82cb99..4f9936e 100644 --- a/apps/api/src/ai/providers/astrbot.provider.ts +++ b/apps/api/src/ai/providers/astrbot.provider.ts @@ -133,6 +133,7 @@ export class AstrbotProvider implements AiChannelExecutor { model: candidate.model, content, sessionId, + usage: this.extractUsage(events), raw: events }; } @@ -248,4 +249,39 @@ export class AstrbotProvider implements AiChannelExecutor { return fallback; } + + private extractUsage(events: Array>): AiChatResult["usage"] { + for (const event of events) { + if (this.readString(event["type"]) !== "agent_stats") { + continue; + } + + const data = this.asRecord(event["data"]); + const tokenUsage = this.asRecord(data?.["token_usage"]); + if (!tokenUsage) { + continue; + } + + const promptTokens = + (this.readNumber(tokenUsage["input_other"]) ?? 0) + + (this.readNumber(tokenUsage["input_cached"]) ?? 0); + const completionTokens = this.readNumber(tokenUsage["output"]) ?? 0; + + return { + promptTokens, + completionTokens, + totalTokens: promptTokens + completionTokens + }; + } + + return null; + } + + private asRecord(value: unknown): Record | null { + return typeof value === "object" && value !== null ? (value as Record) : null; + } + + private readNumber(value: unknown): number | null { + return typeof value === "number" && Number.isFinite(value) ? value : null; + } } diff --git a/apps/api/src/ai/providers/openai-compatible.provider.ts b/apps/api/src/ai/providers/openai-compatible.provider.ts index 2ca1723..0c52099 100644 --- a/apps/api/src/ai/providers/openai-compatible.provider.ts +++ b/apps/api/src/ai/providers/openai-compatible.provider.ts @@ -105,6 +105,7 @@ export class OpenAiCompatibleProvider implements AiChannelExecutor { model: this.extractModel(payload) ?? candidate.model, content, sessionId: input.sessionId, + usage: this.extractUsage(payload), raw: payload }; } @@ -176,6 +177,31 @@ export class OpenAiCompatibleProvider implements AiChannelExecutor { return payload["model"]; } + private extractUsage(payload: unknown): AiChatResult["usage"] { + if (!this.isRecord(payload)) { + return null; + } + + const usage = payload["usage"]; + if (!this.isRecord(usage)) { + return null; + } + + const promptTokens = this.readNumber(usage["prompt_tokens"]); + const completionTokens = this.readNumber(usage["completion_tokens"]); + const totalTokens = this.readNumber(usage["total_tokens"]); + + if (promptTokens === null && completionTokens === null && totalTokens === null) { + return null; + } + + return { + promptTokens: promptTokens ?? 0, + completionTokens: completionTokens ?? 0, + totalTokens: totalTokens ?? (promptTokens ?? 0) + (completionTokens ?? 0) + }; + } + private extractErrorMessage(payload: unknown, fallback: string): string { if (!this.isRecord(payload)) { return fallback; @@ -200,4 +226,8 @@ export class OpenAiCompatibleProvider implements AiChannelExecutor { return fallback; } + + private readNumber(value: unknown): number | null { + return typeof value === "number" && Number.isFinite(value) ? value : null; + } } diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts index 113d037..85705ef 100644 --- a/apps/api/test/ai.spec.ts +++ b/apps/api/test/ai.spec.ts @@ -12,11 +12,25 @@ import { } from "../src/ai/ai.types"; import { PrismaService } from "../src/prisma/prisma.service"; +type AiUsageLogRecord = { + userId: string | null; + channel: AiChannel; + providerName: string | null; + model: string | null; + promptTokens: number; + completionTokens: number; + totalTokens: number; + latencyMs: number | null; + success: boolean; + errorCode: string | null; +}; + class InMemoryAiPrismaService { private bindingIdSequence = 1; private publicPoolIdSequence = 1; private bindings: AiProviderBinding[] = []; private publicPools: AiPublicPoolConfig[] = []; + private usageLogs: AiUsageLogRecord[] = []; readonly aiProviderBinding = { findMany: async (args: { @@ -164,6 +178,13 @@ class InMemoryAiPrismaService { } }; + readonly aiUsageLog = { + create: async (args: { data: AiUsageLogRecord }) => { + this.usageLogs.push(args.data); + return args.data; + } + }; + async $transaction(callback: (tx: InMemoryAiPrismaService) => Promise): Promise { return callback(this); } @@ -186,6 +207,10 @@ class InMemoryAiPrismaService { ...publicPool }); } + + getUsageLogs(): AiUsageLogRecord[] { + return [...this.usageLogs]; + } } class StaticExecutor implements AiChannelExecutor { @@ -214,6 +239,11 @@ class StaticExecutor implements AiChannelExecutor { model: candidate.model, content: result.content ?? "", sessionId: "session_ai", + usage: { + promptTokens: 12, + completionTokens: 8, + totalTokens: 20 + }, raw: null }; } @@ -368,6 +398,32 @@ describe("AiController (integration)", () => { reasonMessage: null } ]); + expect(prismaService.getUsageLogs()).toEqual([ + { + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + promptTokens: 0, + completionTokens: 0, + totalTokens: 0, + latencyMs: expect.any(Number), + success: false, + errorCode: "UPSTREAM_UNREACHABLE" + }, + { + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: null, + promptTokens: 12, + completionTokens: 8, + totalTokens: 20, + latencyMs: expect.any(Number), + success: true, + errorCode: null + } + ]); }); it("should allow astrbot binding with config id only", async () => { @@ -428,5 +484,6 @@ describe("AiController (integration)", () => { reasonMessage: "公共 AI 通道未开启" } ]); + expect(prismaService.getUsageLogs()).toEqual([]); }); }); diff --git a/apps/api/test/astrbot-provider.spec.ts b/apps/api/test/astrbot-provider.spec.ts index a8fcee8..98f3862 100644 --- a/apps/api/test/astrbot-provider.spec.ts +++ b/apps/api/test/astrbot-provider.spec.ts @@ -30,6 +30,15 @@ describe("AstrbotProvider", () => { } if (pullCount === 3) { + controller.enqueue( + encoder.encode( + 'data: {"type":"agent_stats","data":{"token_usage":{"input_other":12,"input_cached":30,"output":8}}}\n\n' + ) + ); + return; + } + + if (pullCount === 4) { controller.enqueue( encoder.encode('data: {"type":"end","data":"","streaming":false}\n\n') ); @@ -77,6 +86,11 @@ describe("AstrbotProvider", () => { expect(result.content).toBe("TodoList AstrBot 已连接"); expect(result.sessionId).toBe("session_1"); - expect(pullCount).toBeGreaterThanOrEqual(3); + expect(result.usage).toEqual({ + promptTokens: 42, + completionTokens: 8, + totalTokens: 50 + }); + expect(pullCount).toBeGreaterThanOrEqual(4); }); }); From 4578116a30d3d9da6a5c57c2ebc3a572503a43ff Mon Sep 17 00:00:00 2001 From: Yaosanqi137 Date: Mon, 6 Apr 2026 12:57:14 +0800 Subject: [PATCH 05/21] feat(api-ai): inject unfinished task summary --- apps/api/src/ai/ai.service.ts | 150 +++++++++++++++++++++++++++++++++- apps/api/test/ai.spec.ts | 113 ++++++++++++++++++++++++- 2 files changed, 258 insertions(+), 5 deletions(-) diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts index 2269d2d..efee87e 100644 --- a/apps/api/src/ai/ai.service.ts +++ b/apps/api/src/ai/ai.service.ts @@ -9,7 +9,9 @@ import { AiChannel, AiProviderBinding, AiPublicPoolConfig, - Prisma + Prisma, + TaskPriority, + TaskStatus } from "../../generated/prisma/client"; import { PrismaService } from "../prisma/prisma.service"; import { AiProviderRegistryService } from "./ai-provider-registry.service"; @@ -71,6 +73,8 @@ export type AiChatResponse = { @Injectable() export class AiService { private readonly logger = new Logger(AiService.name); + private readonly maxContextTasks = 6; + private readonly maxContextContentLength = 80; constructor( private readonly prismaService: PrismaService, @@ -195,6 +199,7 @@ export class AiService { async chat(userId: string, dto: AiChatDto): Promise { const attempts: AiRouteAttempt[] = []; const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null); + const promptMessage = await this.buildPromptMessage(userId, dto.message); for (const entry of plan) { if (entry.kind === "skip") { @@ -208,7 +213,7 @@ export class AiService { try { const result = await executor.execute(entry.candidate, { userId, - message: dto.message, + message: promptMessage, sessionId: dto.sessionId ?? null }); const latencyMs = Date.now() - startedAt; @@ -416,6 +421,85 @@ export class AiService { }; } + private async buildPromptMessage(userId: string, userMessage: string): Promise { + const taskSummary = await this.buildTaskContextSummary(userId); + if (!taskSummary) { + return userMessage; + } + + return [ + "你是 TodoList 的 AI 助手,请优先结合用户当前未完成任务给出安排建议。", + "以下是系统整理的未完成任务摘要:", + taskSummary, + "如果用户的问题与任务无关,也可以正常回答;如果相关,请优先考虑优先级、截止时间与执行顺序。", + `用户当前问题:${userMessage}` + ].join("\n\n"); + } + + private async buildTaskContextSummary(userId: string): Promise { + const tasks = await this.prismaService.task.findMany({ + where: { + userId, + status: { + in: [TaskStatus.TODO, TaskStatus.IN_PROGRESS] + } + }, + select: { + title: true, + priority: true, + status: true, + ddl: true, + contentText: true, + updatedAt: true + }, + take: 20 + }); + + if (tasks.length === 0) { + return null; + } + + const sortedTasks = [...tasks].sort((left, right) => { + const priorityDiff = + this.getPriorityWeight(right.priority) - this.getPriorityWeight(left.priority); + if (priorityDiff !== 0) { + return priorityDiff; + } + + const leftDdl = left.ddl?.getTime() ?? Number.POSITIVE_INFINITY; + const rightDdl = right.ddl?.getTime() ?? Number.POSITIVE_INFINITY; + if (leftDdl !== rightDdl) { + return leftDdl - rightDdl; + } + + return right.updatedAt.getTime() - left.updatedAt.getTime(); + }); + + const visibleTasks = sortedTasks.slice(0, this.maxContextTasks); + const lines = visibleTasks.map((task, index) => { + const parts = [ + `${index + 1}. ${task.title}`, + `优先级:${this.getPriorityLabel(task.priority)}`, + `状态:${this.getStatusLabel(task.status)}`, + `DDL:${task.ddl ? task.ddl.toISOString() : "未设置"}` + ]; + + const contentSnippet = this.getContentSnippet(task.contentText); + if (contentSnippet) { + parts.push(`内容摘要:${contentSnippet}`); + } + + return parts.join(" | "); + }); + + const omittedCount = sortedTasks.length - visibleTasks.length; + if (omittedCount > 0) { + lines.push(`其余 ${omittedCount} 项未完成任务已省略。`); + } + + return [`共 ${sortedTasks.length} 项未完成任务。`, ...lines].join("\n"); + } + private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt { if (error instanceof AiRouteFailureError) { return { @@ -493,6 +577,68 @@ export class AiService { return `${secret.slice(0, 4)}***${secret.slice(-2)}`; } + private getPriorityWeight(priority: TaskPriority): number { + switch (priority) { + case TaskPriority.URGENT: + return 4; + case TaskPriority.HIGH: + return 3; + case TaskPriority.MEDIUM: + return 2; + case TaskPriority.LOW: + return 1; + default: + return 0; + } + } + + private getPriorityLabel(priority: TaskPriority): string { + switch (priority) { + case TaskPriority.URGENT: + return "紧急"; + case TaskPriority.HIGH: + return "高"; + case TaskPriority.MEDIUM: + return "中"; + case TaskPriority.LOW: + return "低"; + default: + return String(priority); + } + } + + private getStatusLabel(status: TaskStatus): string { + switch (status) { + case TaskStatus.TODO: + return "待开始"; + case TaskStatus.IN_PROGRESS: + return "进行中"; + case TaskStatus.DONE: + return "已完成"; + case TaskStatus.ARCHIVED: + return "已归档"; + default: + return String(status); + } + } + + private getContentSnippet(contentText: string | null): string | null { + if (!contentText) { + return null; + } + + const normalizedContent = contentText.replace(/\s+/g, " ").trim(); + if (normalizedContent.length === 0) { + return null; + } + + if (normalizedContent.length <= this.maxContextContentLength) { + return normalizedContent; + } + + return `${normalizedContent.slice(0, this.maxContextContentLength)}...`; + } + private async recordUsageLog(input: { userId: string; channel: AiChannel; diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts index 85705ef..38f1b7f 100644 --- a/apps/api/test/ai.spec.ts +++ b/apps/api/test/ai.spec.ts @@ -1,11 +1,18 @@ import request from "supertest"; import { INestApplication, ValidationPipe } from "@nestjs/common"; import { Test, TestingModule } from "@nestjs/testing"; -import { AiChannel, AiProviderBinding, AiPublicPoolConfig } from "../generated/prisma/client"; +import { + AiChannel, + AiProviderBinding, + AiPublicPoolConfig, + TaskPriority, + TaskStatus +} from "../generated/prisma/client"; import { AiController } from "../src/ai/ai.controller"; import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service"; import { AiService } from "../src/ai/ai.service"; import { + AiChatInput, AiChannelExecutor, AiResolvedRouteCandidate, AiRouteFailureError @@ -25,12 +32,23 @@ type AiUsageLogRecord = { errorCode: string | null; }; +type AiTaskRecord = { + userId: string; + title: string; + priority: TaskPriority; + status: TaskStatus; + ddl: Date | null; + contentText: string | null; + updatedAt: Date; +}; + class InMemoryAiPrismaService { private bindingIdSequence = 1; private publicPoolIdSequence = 1; private bindings: AiProviderBinding[] = []; private publicPools: AiPublicPoolConfig[] = []; private usageLogs: AiUsageLogRecord[] = []; + private tasks: AiTaskRecord[] = []; readonly aiProviderBinding = { findMany: async (args: { @@ -185,6 +203,31 @@ class InMemoryAiPrismaService { } }; + readonly task = { + findMany: async (args: { + where: { + userId: string; + status: { + in: TaskStatus[]; + }; + }; + take?: number; + }) => { + const filteredTasks = this.tasks.filter( + (task) => task.userId === args.where.userId && args.where.status.in.includes(task.status) + ); + + return filteredTasks.slice(0, args.take ?? filteredTasks.length).map((task) => ({ + title: task.title, + priority: task.priority, + status: task.status, + ddl: task.ddl, + contentText: task.contentText, + updatedAt: task.updatedAt + })); + } + }; + async $transaction(callback: (tx: InMemoryAiPrismaService) => Promise): Promise { return callback(this); } @@ -211,9 +254,18 @@ class InMemoryAiPrismaService { getUsageLogs(): AiUsageLogRecord[] { return [...this.usageLogs]; } + + seedTask(task: AiTaskRecord): void { + this.tasks.push(task); + } } class StaticExecutor implements AiChannelExecutor { + readonly inputs: Array<{ + candidate: AiResolvedRouteCandidate; + message: string; + }> = []; + constructor( private readonly resolver: (channel: AiChannel) => { content?: string; @@ -222,7 +274,12 @@ class StaticExecutor implements AiChannelExecutor { } ) {} - async execute(candidate: AiResolvedRouteCandidate) { + async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput) { + this.inputs.push({ + candidate, + message: input.message + }); + const result = this.resolver(candidate.channel); if (result.code) { throw new AiRouteFailureError( @@ -252,6 +309,7 @@ class StaticExecutor implements AiChannelExecutor { describe("AiController (integration)", () => { let app: INestApplication; let prismaService: InMemoryAiPrismaService; + let astrbotExecutor: StaticExecutor; beforeEach(async () => { prismaService = new InMemoryAiPrismaService(); @@ -266,7 +324,7 @@ describe("AiController (integration)", () => { content: "公共 AI 已接管" } ); - const astrbotExecutor = new StaticExecutor(() => ({ + astrbotExecutor = new StaticExecutor(() => ({ content: "AstrBot 已接管" })); @@ -448,6 +506,55 @@ describe("AiController (integration)", () => { }); }); + it("should inject unfinished task summary into ai prompt", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_context", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + prismaService.seedTask({ + userId: "user_1", + title: "今晚提交周报", + priority: TaskPriority.URGENT, + status: TaskStatus.IN_PROGRESS, + ddl: new Date("2026-04-06T12:00:00.000Z"), + contentText: "需要汇总 AI 路由、AstrBot 接入和同步模块进度", + updatedAt: new Date("2026-04-06T08:00:00.000Z") + }); + prismaService.seedTask({ + userId: "user_1", + title: "整理已完成事项", + priority: TaskPriority.LOW, + status: TaskStatus.DONE, + ddl: null, + contentText: "这条任务不应该出现在上下文里", + updatedAt: new Date("2026-04-06T07:00:00.000Z") + }); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "帮我安排今天剩余任务" + }) + .expect(201); + + expect(astrbotExecutor.inputs).toHaveLength(1); + expect(astrbotExecutor.inputs[0]?.message).toContain("以下是系统整理的未完成任务摘要"); + expect(astrbotExecutor.inputs[0]?.message).toContain("今晚提交周报"); + expect(astrbotExecutor.inputs[0]?.message).toContain("优先级:紧急"); + expect(astrbotExecutor.inputs[0]?.message).not.toContain("整理已完成事项"); + expect(astrbotExecutor.inputs[0]?.message).toContain("用户当前问题:帮我安排今天剩余任务"); + }); + it("should return skipped attempts when no channel is available", async () => { const response = await request(app.getHttpServer()) .post("/ai/chat") From 5c956c195be94cb7768e12b99a140a4bd96bf6d9 Mon Sep 17 00:00:00 2001 From: Yaosanqi137 Date: Mon, 6 Apr 2026 13:08:36 +0800 Subject: [PATCH 06/21] feat(api-ai): add usage log query endpoint --- apps/api/src/ai/ai.controller.ts | 18 +- apps/api/src/ai/ai.service.ts | 80 ++++++++ .../ai/dto/list-ai-usage-logs-query.dto.ts | 48 +++++ apps/api/test/ai.spec.ts | 173 +++++++++++++++++- 4 files changed, 312 insertions(+), 7 deletions(-) create mode 100644 apps/api/src/ai/dto/list-ai-usage-logs-query.dto.ts diff --git a/apps/api/src/ai/ai.controller.ts b/apps/api/src/ai/ai.controller.ts index 3ca3ff9..f9f0c16 100644 --- a/apps/api/src/ai/ai.controller.ts +++ b/apps/api/src/ai/ai.controller.ts @@ -1,7 +1,13 @@ -import { Body, Controller, Get, Headers, Post, UnauthorizedException } from "@nestjs/common"; +import { Body, Controller, Get, Headers, Post, Query, UnauthorizedException } from "@nestjs/common"; import { AiChatDto } from "./dto/ai-chat.dto"; +import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto"; import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; -import { AiChatResponse, AiService, ListAiBindingsResponse } from "./ai.service"; +import { + AiChatResponse, + AiService, + ListAiBindingsResponse, + ListAiUsageLogsResponse +} from "./ai.service"; @Controller("ai") export class AiController { @@ -14,6 +20,14 @@ export class AiController { return this.aiService.listBindings(this.resolveUserId(userIdHeader)); } + @Get("usage-logs") + async listUsageLogs( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Query() query: ListAiUsageLogsQueryDto + ): Promise { + return this.aiService.listUsageLogs(this.resolveUserId(userIdHeader), query); + } + @Post("bindings") async upsertBinding( @Headers("x-user-id") userIdHeader: string | string[] | undefined, diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts index efee87e..dfbbd9a 100644 --- a/apps/api/src/ai/ai.service.ts +++ b/apps/api/src/ai/ai.service.ts @@ -7,6 +7,7 @@ } from "@nestjs/common"; import { AiChannel, + AiUsageLog, AiProviderBinding, AiPublicPoolConfig, Prisma, @@ -16,6 +17,7 @@ import { import { PrismaService } from "../prisma/prisma.service"; import { AiProviderRegistryService } from "./ai-provider-registry.service"; import { AiChatDto } from "./dto/ai-chat.dto"; +import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto"; import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; import { AiResolvedRouteCandidate, @@ -61,6 +63,27 @@ export type ListAiBindingsResponse = { } | null; }; +type AiUsageLogSummary = { + id: string; + channel: AiChannel; + providerName: string | null; + model: string | null; + promptTokens: number; + completionTokens: number; + totalTokens: number; + latencyMs: number | null; + success: boolean; + errorCode: string | null; + createdAt: string; +}; + +export type ListAiUsageLogsResponse = { + items: AiUsageLogSummary[]; + page: number; + pageSize: number; + total: number; +}; + export type AiChatResponse = { channel: AiChannel; providerName: string; @@ -111,6 +134,47 @@ export class AiService { }; } + async listUsageLogs( + userId: string, + query: ListAiUsageLogsQueryDto + ): Promise { + const page = query.page ?? 1; + const pageSize = query.pageSize ?? 20; + const skip = (page - 1) * pageSize; + const where: Prisma.AiUsageLogWhereInput = { + userId + }; + + if (query.channel) { + where.channel = query.channel; + } + + if (query.success !== undefined) { + where.success = query.success; + } + + const [items, total] = await Promise.all([ + this.prismaService.aiUsageLog.findMany({ + where, + orderBy: { + createdAt: "desc" + }, + skip, + take: pageSize + }), + this.prismaService.aiUsageLog.count({ + where + }) + ]); + + return { + items: items.map((item) => this.serializeUsageLog(item)), + page, + pageSize, + total + }; + } + async upsertBinding(userId: string, dto: UpsertAiProviderBindingDto): Promise { if (dto.channel === AiChannel.PUBLIC_POOL) { throw new BadRequestException("公共 AI 通道只能由管理员配置"); @@ -421,6 +485,22 @@ export class AiService { }; } + private serializeUsageLog(log: AiUsageLog): AiUsageLogSummary { + return { + id: log.id, + channel: log.channel, + providerName: log.providerName, + model: log.model, + promptTokens: log.promptTokens, + completionTokens: log.completionTokens, + totalTokens: log.totalTokens, + latencyMs: log.latencyMs, + success: log.success, + errorCode: log.errorCode, + createdAt: log.createdAt.toISOString() + }; + } + private async buildPromptMessage(userId: string, userMessage: string): Promise { const taskSummary = await this.buildTaskContextSummary(userId); if (!taskSummary) { diff --git a/apps/api/src/ai/dto/list-ai-usage-logs-query.dto.ts b/apps/api/src/ai/dto/list-ai-usage-logs-query.dto.ts new file mode 100644 index 0000000..49aa2e0 --- /dev/null +++ b/apps/api/src/ai/dto/list-ai-usage-logs-query.dto.ts @@ -0,0 +1,48 @@ +import { Transform, Type } from "class-transformer"; +import { IsBoolean, IsEnum, IsInt, IsOptional, Max, Min } from "class-validator"; +import { AiChannel } from "../../../generated/prisma/client"; + +function normalizeBoolean(value: unknown): boolean | undefined { + if (typeof value === "boolean") { + return value; + } + + if (typeof value !== "string") { + return undefined; + } + + const normalized = value.trim().toLowerCase(); + if (normalized === "true" || normalized === "1") { + return true; + } + + if (normalized === "false" || normalized === "0") { + return false; + } + + return undefined; +} + +export class ListAiUsageLogsQueryDto { + @Type(() => Number) + @IsOptional() + @IsInt() + @Min(1) + page?: number; + + @Type(() => Number) + @IsOptional() + @IsInt() + @Min(1) + @Max(100) + pageSize?: number; + + @IsOptional() + @IsEnum(AiChannel) + channel?: AiChannel; + + @Transform(({ value }) => normalizeBoolean(value)) + @IsOptional() + @IsBoolean() + success?: boolean; +} diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts index 38f1b7f..4fd9b2f 100644 --- a/apps/api/test/ai.spec.ts +++ b/apps/api/test/ai.spec.ts @@ -3,6 +3,7 @@ import { INestApplication, ValidationPipe } from "@nestjs/common"; import { Test, TestingModule } from "@nestjs/testing"; import { AiChannel, + AiUsageLog, AiProviderBinding, AiPublicPoolConfig, TaskPriority, @@ -20,6 +21,7 @@ import { import { PrismaService } from "../src/prisma/prisma.service"; type AiUsageLogRecord = { + id: string; userId: string | null; channel: AiChannel; providerName: string | null; @@ -30,6 +32,7 @@ type AiUsageLogRecord = { latencyMs: number | null; success: boolean; errorCode: string | null; + createdAt: Date; }; type AiTaskRecord = { @@ -45,6 +48,7 @@ type AiTaskRecord = { class InMemoryAiPrismaService { private bindingIdSequence = 1; private publicPoolIdSequence = 1; + private usageLogIdSequence = 1; private bindings: AiProviderBinding[] = []; private publicPools: AiPublicPoolConfig[] = []; private usageLogs: AiUsageLogRecord[] = []; @@ -197,9 +201,47 @@ class InMemoryAiPrismaService { }; readonly aiUsageLog = { - create: async (args: { data: AiUsageLogRecord }) => { - this.usageLogs.push(args.data); - return args.data; + create: async (args: { data: Omit }) => { + const usageLog: AiUsageLogRecord = { + id: `usage_log_${this.usageLogIdSequence++}`, + createdAt: new Date(), + ...args.data + }; + + this.usageLogs.push(usageLog); + return usageLog; + }, + + findMany: async (args: { + where?: { + userId?: string; + channel?: AiChannel; + success?: boolean; + }; + orderBy?: { + createdAt: "asc" | "desc"; + }; + skip?: number; + take?: number; + }) => { + const filteredLogs = this.filterUsageLogs(args.where); + const sortedLogs = [...filteredLogs].sort((left, right) => { + const direction = args.orderBy?.createdAt === "asc" ? 1 : -1; + return (left.createdAt.getTime() - right.createdAt.getTime()) * direction; + }); + const start = args.skip ?? 0; + const end = args.take === undefined ? undefined : start + args.take; + return sortedLogs.slice(start, end); + }, + + count: async (args?: { + where?: { + userId?: string; + channel?: AiChannel; + success?: boolean; + }; + }) => { + return this.filterUsageLogs(args?.where).length; } }; @@ -258,6 +300,33 @@ class InMemoryAiPrismaService { seedTask(task: AiTaskRecord): void { this.tasks.push(task); } + + seedUsageLog(log: Omit & { id?: string }): void { + this.usageLogs.push({ + id: log.id ?? `usage_log_${this.usageLogIdSequence++}`, + ...log + }); + } + + private filterUsageLogs(where?: { + userId?: string; + channel?: AiChannel; + success?: boolean; + }): AiUsageLogRecord[] { + return this.usageLogs.filter((log) => { + if (where?.userId !== undefined && log.userId !== where.userId) { + return false; + } + if (where?.channel !== undefined && log.channel !== where.channel) { + return false; + } + if (where?.success !== undefined && log.success !== where.success) { + return false; + } + + return true; + }); + } } class StaticExecutor implements AiChannelExecutor { @@ -458,6 +527,7 @@ describe("AiController (integration)", () => { ]); expect(prismaService.getUsageLogs()).toEqual([ { + id: expect.any(String), userId: "user_1", channel: AiChannel.USER_KEY, providerName: "openai", @@ -467,9 +537,11 @@ describe("AiController (integration)", () => { totalTokens: 0, latencyMs: expect.any(Number), success: false, - errorCode: "UPSTREAM_UNREACHABLE" + errorCode: "UPSTREAM_UNREACHABLE", + createdAt: expect.any(Date) }, { + id: expect.any(String), userId: "user_1", channel: AiChannel.ASTRBOT, providerName: "default", @@ -479,7 +551,8 @@ describe("AiController (integration)", () => { totalTokens: 20, latencyMs: expect.any(Number), success: true, - errorCode: null + errorCode: null, + createdAt: expect.any(Date) } ]); }); @@ -593,4 +666,94 @@ describe("AiController (integration)", () => { ]); expect(prismaService.getUsageLogs()).toEqual([]); }); + it("should list usage logs with pagination and filters", async () => { + prismaService.seedUsageLog({ + id: "usage_log_1", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: "deepseek-chat", + promptTokens: 10, + completionTokens: 6, + totalTokens: 16, + latencyMs: 120, + success: true, + errorCode: null, + createdAt: new Date("2026-04-06T08:00:00.000Z") + }); + prismaService.seedUsageLog({ + id: "usage_log_2", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: "deepseek-chat", + promptTokens: 14, + completionTokens: 9, + totalTokens: 23, + latencyMs: 100, + success: true, + errorCode: null, + createdAt: new Date("2026-04-06T09:00:00.000Z") + }); + prismaService.seedUsageLog({ + id: "usage_log_3", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + promptTokens: 20, + completionTokens: 12, + totalTokens: 32, + latencyMs: 210, + success: false, + errorCode: "UPSTREAM_UNREACHABLE", + createdAt: new Date("2026-04-06T10:00:00.000Z") + }); + prismaService.seedUsageLog({ + id: "usage_log_4", + userId: "user_2", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: "deepseek-chat", + promptTokens: 18, + completionTokens: 11, + totalTokens: 29, + latencyMs: 90, + success: true, + errorCode: null, + createdAt: new Date("2026-04-06T11:00:00.000Z") + }); + + const response = await request(app.getHttpServer()) + .get("/ai/usage-logs") + .set("x-user-id", "user_1") + .query({ + page: 2, + pageSize: 1, + channel: AiChannel.ASTRBOT, + success: true + }) + .expect(200); + + expect(response.body).toEqual({ + items: [ + { + id: "usage_log_1", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: "deepseek-chat", + promptTokens: 10, + completionTokens: 6, + totalTokens: 16, + latencyMs: 120, + success: true, + errorCode: null, + createdAt: "2026-04-06T08:00:00.000Z" + } + ], + page: 2, + pageSize: 1, + total: 2 + }); + }); }); From d0ba58118450467678fe2423f69db4bd5b2b7354 Mon Sep 17 00:00:00 2001 From: Yaosanqi137 Date: Mon, 6 Apr 2026 13:36:28 +0800 Subject: [PATCH 07/21] feat(api-ai): scope private bindings by user channel --- apps/api/src/ai/ai.service.ts | 142 +++++++----------- apps/api/src/ai/dto/ai-chat.dto.ts | 8 +- .../ai/dto/upsert-ai-provider-binding.dto.ts | 13 +- apps/api/test/ai.spec.ts | 100 +++++++++++- 4 files changed, 157 insertions(+), 106 deletions(-) diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts index dfbbd9a..8dae914 100644 --- a/apps/api/src/ai/ai.service.ts +++ b/apps/api/src/ai/ai.service.ts @@ -1,10 +1,4 @@ -import { - BadGatewayException, - BadRequestException, - Injectable, - Logger, - NotFoundException -} from "@nestjs/common"; +import { BadGatewayException, BadRequestException, Injectable, Logger } from "@nestjs/common"; import { AiChannel, AiUsageLog, @@ -34,7 +28,6 @@ type AiBindingSummary = { configId: string | null; configName: string | null; endpoint: string | null; - isDefault: boolean; isEnabled: boolean; hasApiKey: boolean; maskedApiKey: string | null; @@ -110,7 +103,7 @@ export class AiService { where: { userId }, - orderBy: [{ channel: "asc" }, { isDefault: "desc" }, { updatedAt: "desc" }] + orderBy: [{ updatedAt: "desc" }] }), this.prismaService.aiPublicPoolConfig.findFirst({ orderBy: { @@ -119,9 +112,11 @@ export class AiService { }) ]); + const latestBindings = this.pickLatestBindingsByChannel(bindings); + return { routeOrder: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL], - bindings: bindings.map((binding) => this.serializeBinding(binding)), + bindings: latestBindings.map((binding) => this.serializeBinding(binding)), publicPool: publicPool ? { enabled: publicPool.enabled, @@ -183,27 +178,17 @@ export class AiService { this.validateBindingInput(dto); const result = await this.prismaService.$transaction(async (tx) => { - if (dto.isDefault) { - const where: Prisma.AiProviderBindingWhereInput = { + const existingBinding = await tx.aiProviderBinding.findFirst({ + where: { userId, channel: dto.channel - }; - - if (dto.id) { - where.id = { - not: dto.id - }; + }, + orderBy: { + updatedAt: "desc" } + }); - await tx.aiProviderBinding.updateMany({ - where, - data: { - isDefault: false - } - }); - } - - if (!dto.id) { + if (!existingBinding) { return tx.aiProviderBinding.create({ data: { userId, @@ -214,30 +199,17 @@ export class AiService { configName: this.normalizeOptionalString(dto.configName), endpoint: this.normalizeOptionalString(dto.endpoint), encryptedApiKey: this.normalizeOptionalString(dto.apiKey), - isDefault: dto.isDefault ?? false, isEnabled: dto.isEnabled ?? true } }); } - const existingBinding = await tx.aiProviderBinding.findFirst({ - where: { - id: dto.id, - userId - } - }); - - if (!existingBinding) { - throw new NotFoundException("AI 通道配置不存在"); - } - const updateData: Prisma.AiProviderBindingUpdateInput = { channel: dto.channel, providerName: this.normalizeProviderName(dto.providerName), model: this.normalizeOptionalString(dto.model), configId: this.normalizeOptionalString(dto.configId), configName: this.normalizeOptionalString(dto.configName), - isDefault: dto.isDefault ?? existingBinding.isDefault, isEnabled: dto.isEnabled ?? existingBinding.isEnabled }; @@ -251,7 +223,7 @@ export class AiService { return tx.aiProviderBinding.update({ where: { - id: dto.id + id: existingBinding.id }, data: updateData }); @@ -262,7 +234,7 @@ export class AiService { async chat(userId: string, dto: AiChatDto): Promise { const attempts: AiRouteAttempt[] = []; - const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null); + const plan = await this.buildRoutePlan(userId, dto.channel ?? null); const promptMessage = await this.buildPromptMessage(userId, dto.message); for (const entry of plan) { @@ -337,33 +309,34 @@ export class AiService { private async buildRoutePlan( userId: string, - bindingId: string | null + selectedChannel: AiChannel | null ): Promise { const plan: AiRoutePlanEntry[] = []; - const consumedChannels = new Set(); + const targetChannels = selectedChannel + ? [selectedChannel] + : [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL]; - if (bindingId) { - const pinnedBinding = await this.prismaService.aiProviderBinding.findFirst({ - where: { - id: bindingId, - userId, - isEnabled: true + for (const channel of targetChannels) { + if (channel === AiChannel.PUBLIC_POOL) { + const publicPool = await this.findEnabledPublicPool(); + if (publicPool) { + plan.push({ + kind: "candidate", + candidate: this.toPublicPoolCandidate(publicPool) + }); + } else { + plan.push({ + kind: "skip", + attempt: { + channel: AiChannel.PUBLIC_POOL, + providerName: null, + model: null, + status: "skipped", + reasonCode: "PUBLIC_POOL_DISABLED", + reasonMessage: "公共 AI 通道未开启" + } + }); } - }); - - if (!pinnedBinding) { - throw new NotFoundException("指定的 AI 通道配置不存在或已禁用"); - } - - plan.push({ - kind: "candidate", - candidate: this.toBindingCandidate(pinnedBinding) - }); - consumedChannels.add(pinnedBinding.channel); - } - - for (const channel of [AiChannel.USER_KEY, AiChannel.ASTRBOT]) { - if (consumedChannels.has(channel)) { continue; } @@ -392,26 +365,6 @@ export class AiService { }); } - const publicPool = await this.findEnabledPublicPool(); - if (publicPool) { - plan.push({ - kind: "candidate", - candidate: this.toPublicPoolCandidate(publicPool) - }); - } else { - plan.push({ - kind: "skip", - attempt: { - channel: AiChannel.PUBLIC_POOL, - providerName: null, - model: null, - status: "skipped", - reasonCode: "PUBLIC_POOL_DISABLED", - reasonMessage: "公共 AI 通道未开启" - } - }); - } - return plan; } @@ -425,7 +378,9 @@ export class AiService { channel, isEnabled: true }, - orderBy: [{ isDefault: "desc" }, { updatedAt: "desc" }] + orderBy: { + updatedAt: "desc" + } }); } @@ -477,7 +432,6 @@ export class AiService { configId: binding.configId, configName: binding.configName, endpoint: binding.endpoint, - isDefault: binding.isDefault, isEnabled: binding.isEnabled, hasApiKey: Boolean(binding.encryptedApiKey), maskedApiKey: this.maskSecret(binding.encryptedApiKey), @@ -485,6 +439,20 @@ export class AiService { }; } + private pickLatestBindingsByChannel(bindings: AiProviderBinding[]): AiProviderBinding[] { + const bindingMap = new Map(); + + for (const binding of bindings) { + if (!bindingMap.has(binding.channel)) { + bindingMap.set(binding.channel, binding); + } + } + + return [AiChannel.USER_KEY, AiChannel.ASTRBOT] + .map((channel) => bindingMap.get(channel) ?? null) + .filter((binding): binding is AiProviderBinding => binding !== null); + } + private serializeUsageLog(log: AiUsageLog): AiUsageLogSummary { return { id: log.id, diff --git a/apps/api/src/ai/dto/ai-chat.dto.ts b/apps/api/src/ai/dto/ai-chat.dto.ts index a89692a..013b697 100644 --- a/apps/api/src/ai/dto/ai-chat.dto.ts +++ b/apps/api/src/ai/dto/ai-chat.dto.ts @@ -1,4 +1,5 @@ -import { IsOptional, IsString, MinLength } from "class-validator"; +import { IsEnum, IsOptional, IsString, MinLength } from "class-validator"; +import { AiChannel } from "../../../generated/prisma/client"; export class AiChatDto { @IsString() @@ -11,7 +12,6 @@ export class AiChatDto { sessionId?: string; @IsOptional() - @IsString() - @MinLength(1) - bindingId?: string; + @IsEnum(AiChannel) + channel?: AiChannel; } diff --git a/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts index 144f86d..b821bcc 100644 --- a/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts +++ b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts @@ -1,12 +1,7 @@ -import { AiChannel } from "../../../generated/prisma/client"; +import { AiChannel } from "../../../generated/prisma/client"; import { IsBoolean, IsEnum, IsOptional, IsString, IsUrl, MinLength } from "class-validator"; export class UpsertAiProviderBindingDto { - @IsOptional() - @IsString() - @MinLength(1) - id?: string; - @IsEnum(AiChannel) channel!: AiChannel; @@ -36,7 +31,7 @@ export class UpsertAiProviderBindingDto { require_tld: false }, { - message: "endpoint 必须是合法的 URL" + message: "endpoint \u5fc5\u987b\u662f\u5408\u6cd5\u7684 URL" } ) endpoint?: string; @@ -46,10 +41,6 @@ export class UpsertAiProviderBindingDto { @MinLength(1) apiKey?: string; - @IsOptional() - @IsBoolean() - isDefault?: boolean; - @IsOptional() @IsBoolean() isEnabled?: boolean; diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts index 4fd9b2f..d344a9e 100644 --- a/apps/api/test/ai.spec.ts +++ b/apps/api/test/ai.spec.ts @@ -441,7 +441,6 @@ describe("AiController (integration)", () => { configId: "default", endpoint: "http://127.0.0.1:6185", apiKey: "abk_secret_1234", - isDefault: true, isEnabled: true }) .expect(201); @@ -465,10 +464,54 @@ describe("AiController (integration)", () => { configName: null, hasApiKey: true, maskedApiKey: "abk_***34", - isDefault: true + isEnabled: true }); }); + it("should upsert one binding per user channel", async () => { + await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + endpoint: "https://api.example.com", + apiKey: "sk-first", + isEnabled: true + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "google", + model: "gemini-2.5-flash", + endpoint: "https://generativelanguage.googleapis.com", + apiKey: "sk-second", + isEnabled: false + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .get("/ai/bindings") + .set("x-user-id", "user_1") + .expect(200); + + expect(response.body.bindings).toEqual([ + expect.objectContaining({ + channel: AiChannel.USER_KEY, + providerName: "google", + model: "gemini-2.5-flash", + endpoint: "https://generativelanguage.googleapis.com", + isEnabled: false, + maskedApiKey: "sk-s***nd" + }) + ]); + }); + it("should fallback from user key to astrbot", async () => { prismaService.seedBinding({ id: "binding_user_key", @@ -566,7 +609,6 @@ describe("AiController (integration)", () => { configId: "default", endpoint: "http://127.0.0.1:6185", apiKey: "abk_secret_1234", - isDefault: true, isEnabled: true }) .expect(201); @@ -575,10 +617,60 @@ describe("AiController (integration)", () => { channel: AiChannel.ASTRBOT, providerName: "", configId: "default", - configName: null + configName: null, + isEnabled: true }); }); + it("should use selected channel without automatic fallback", async () => { + prismaService.seedBinding({ + id: "binding_user_key_selected", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + configId: null, + configName: null, + encryptedApiKey: "sk-user", + endpoint: "https://api.example.com", + isDefault: false, + isEnabled: true + }); + prismaService.seedBinding({ + id: "binding_astrbot_selected", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: false, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "只使用自备渠道", + channel: AiChannel.USER_KEY + }) + .expect(502); + + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + status: "failed", + reasonCode: "UPSTREAM_UNREACHABLE", + reasonMessage: "用户自备 Key 渠道暂时不可用" + } + ]); + }); + it("should inject unfinished task summary into ai prompt", async () => { prismaService.seedBinding({ id: "binding_astrbot_context", From ea23f6264cc6f64c7e39171f929edcd48fc63787 Mon Sep 17 00:00:00 2001 From: Yaosanqi137 Date: Mon, 6 Apr 2026 13:51:44 +0800 Subject: [PATCH 08/21] feat(web-ai): add channel-aware assistant panel --- apps/web/src/App.tsx | 10 +- .../src/components/ai/ai-assistant-panel.tsx | 811 ++++++++++++++++++ apps/web/src/pages/todo-shell-page.tsx | 5 +- apps/web/src/services/ai-api.ts | 155 ++++ 4 files changed, 977 insertions(+), 4 deletions(-) create mode 100644 apps/web/src/components/ai/ai-assistant-panel.tsx create mode 100644 apps/web/src/services/ai-api.ts diff --git a/apps/web/src/App.tsx b/apps/web/src/App.tsx index c77f482..4c6384e 100644 --- a/apps/web/src/App.tsx +++ b/apps/web/src/App.tsx @@ -48,6 +48,8 @@ const SIDEBAR_ITEMS: SidebarItem[] = [ { key: "settings", label: "系统设置", icon: Settings } ]; +const READY_SIDEBAR_KEYS = new Set(["todo", "ai"]); + function toWebSession(payload: EmailLoginResult): WebSession { return { accessToken: payload.accessToken, @@ -151,9 +153,11 @@ function App() { {item.label} - - 即将上线 - + {READY_SIDEBAR_KEYS.has(item.key) ? null : ( + + 即将上线 + + )} )} diff --git a/apps/web/src/components/ai/ai-assistant-panel.tsx b/apps/web/src/components/ai/ai-assistant-panel.tsx new file mode 100644 index 0000000..7939ab8 --- /dev/null +++ b/apps/web/src/components/ai/ai-assistant-panel.tsx @@ -0,0 +1,811 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { + Bot, + CheckCircle2, + CircleAlert, + Globe2, + KeyRound, + LoaderCircle, + PlugZap, + RefreshCw, + SendHorizontal, + Settings2, + Sparkles +} from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { cn } from "@/lib/utils"; +import { + chatWithAi, + listAiBindings, + upsertAiBinding, + type UpsertWebAiBindingInput, + type WebAiBindingSummary, + type WebAiBindingsResponse, + type WebAiChannel, + WebAiApiError +} from "@/services/ai-api"; +import type { WebSession } from "@/services/session-storage"; + +type AiAssistantPanelProps = { + session: WebSession; +}; + +type AiBindingFormState = { + providerName: string; + model: string; + endpoint: string; + apiKey: string; + configId: string; + configName: string; + isEnabled: boolean; +}; + +type AiMessageRecord = { + id: string; + role: "user" | "assistant" | "system"; + content: string; + meta?: string; +}; + +type PanelNotice = { + tone: "success" | "error"; + message: string; +}; + +const CHANNEL_ORDER: WebAiChannel[] = ["USER_KEY", "ASTRBOT", "PUBLIC_POOL"]; + +const CHANNEL_META: Record< + WebAiChannel, + { + title: string; + description: string; + icon: typeof KeyRound; + accentClassName: string; + } +> = { + USER_KEY: { + title: "自备厂商", + description: "用户自己接入厂商接口", + icon: KeyRound, + accentClassName: "from-sky-500/15 via-transparent to-sky-500/5" + }, + ASTRBOT: { + title: "AstrBot", + description: "复用 AstrBot 内已接入模型", + icon: PlugZap, + accentClassName: "from-amber-500/15 via-transparent to-amber-500/5" + }, + PUBLIC_POOL: { + title: "公共 AI", + description: "使用站点管理员开放的公共通道", + icon: Globe2, + accentClassName: "from-emerald-500/15 via-transparent to-emerald-500/5" + } +}; + +function createFormState(binding?: WebAiBindingSummary | null): AiBindingFormState { + return { + providerName: binding?.providerName ?? "", + model: binding?.model ?? "", + endpoint: binding?.endpoint ?? "", + apiKey: "", + configId: binding?.configId ?? "", + configName: binding?.configName ?? "", + isEnabled: binding?.isEnabled ?? true + }; +} + +function createEmptyMessages(): Record { + return { + USER_KEY: [], + ASTRBOT: [], + PUBLIC_POOL: [] + }; +} + +function createEmptySessionIds(): Partial> { + return {}; +} + +function formatTimeLabel(date = new Date()): string { + return date.toLocaleTimeString("zh-CN", { + hour: "2-digit", + minute: "2-digit" + }); +} + +function trimOptionalValue(value: string): string | undefined { + const normalized = value.trim(); + return normalized.length > 0 ? normalized : undefined; +} + +function buildBindingPayload( + channel: Exclude, + formState: AiBindingFormState, + currentBinding: WebAiBindingSummary | null +): UpsertWebAiBindingInput { + return { + channel, + providerName: trimOptionalValue(formState.providerName), + model: trimOptionalValue(formState.model), + endpoint: trimOptionalValue(formState.endpoint), + configId: trimOptionalValue(formState.configId), + configName: trimOptionalValue(formState.configName), + apiKey: trimOptionalValue(formState.apiKey) ?? undefined, + isEnabled: formState.isEnabled ?? currentBinding?.isEnabled ?? true + }; +} + +function appendMessage( + records: Record, + channel: WebAiChannel, + message: AiMessageRecord +): Record { + return { + ...records, + [channel]: [...records[channel], message] + }; +} + +export function AiAssistantPanel({ session }: AiAssistantPanelProps) { + const [bindingsResponse, setBindingsResponse] = useState(null); + const [loadingBindings, setLoadingBindings] = useState(true); + const [refreshingBindings, setRefreshingBindings] = useState(false); + const [activeChannel, setActiveChannel] = useState("USER_KEY"); + const [settingsOpen, setSettingsOpen] = useState(true); + const [userKeyForm, setUserKeyForm] = useState(() => createFormState()); + const [astrbotForm, setAstrbotForm] = useState(() => createFormState()); + const [savingChannel, setSavingChannel] = useState(null); + const [panelNotice, setPanelNotice] = useState(null); + const [messagesByChannel, setMessagesByChannel] = useState< + Record + >(() => createEmptyMessages()); + const [sessionIds, setSessionIds] = useState>>(() => + createEmptySessionIds() + ); + const [draftMessage, setDraftMessage] = useState(""); + const [sending, setSending] = useState(false); + const messagesEndRef = useRef(null); + + const bindingMap = useMemo(() => { + const map = new Map(); + for (const binding of bindingsResponse?.bindings ?? []) { + map.set(binding.channel, binding); + } + return map; + }, [bindingsResponse]); + + const currentBinding = + activeChannel === "PUBLIC_POOL" ? null : (bindingMap.get(activeChannel) ?? null); + const currentMessages = messagesByChannel[activeChannel]; + const publicPool = bindingsResponse?.publicPool ?? null; + + const loadBindings = useCallback( + async (mode: "initial" | "refresh" = "refresh"): Promise => { + if (mode === "initial") { + setLoadingBindings(true); + } else { + setRefreshingBindings(true); + } + + try { + const response = await listAiBindings(session); + setBindingsResponse(response); + setUserKeyForm( + createFormState(response.bindings.find((item) => item.channel === "USER_KEY")) + ); + setAstrbotForm( + createFormState(response.bindings.find((item) => item.channel === "ASTRBOT")) + ); + } catch (error) { + setPanelNotice({ + tone: "error", + message: error instanceof Error ? error.message : "AI ??????" + }); + } finally { + setLoadingBindings(false); + setRefreshingBindings(false); + } + }, + [session] + ); + + useEffect(() => { + void loadBindings("initial"); + }, [loadBindings]); + + useEffect(() => { + messagesEndRef.current?.scrollIntoView({ + block: "end", + behavior: "smooth" + }); + }, [activeChannel, currentMessages.length]); + + useEffect(() => { + if (!panelNotice) { + return; + } + + const timer = window.setTimeout(() => { + setPanelNotice(null); + }, 2800); + + return () => { + window.clearTimeout(timer); + }; + }, [panelNotice]); + + const sendBlockedReason = useMemo(() => { + if (activeChannel === "PUBLIC_POOL") { + if (!publicPool?.enabled) { + return "管理员尚未开放公共 AI。"; + } + + return null; + } + + if (!currentBinding) { + return activeChannel === "USER_KEY" ? "请先保存自备厂商配置。" : "请先保存 AstrBot 配置。"; + } + + if (!currentBinding.isEnabled) { + return "当前渠道已关闭,请先启用后再发起对话。"; + } + + return null; + }, [activeChannel, currentBinding, publicPool]); + + const channelStatusText = useMemo(() => { + if (activeChannel === "PUBLIC_POOL") { + return publicPool?.enabled ? "管理员已开放" : "当前不可用"; + } + + if (!currentBinding) { + return "尚未配置"; + } + + return currentBinding.isEnabled ? "已配置并启用" : "已配置但停用"; + }, [activeChannel, currentBinding, publicPool]); + + async function handleSaveChannel(channel: Exclude): Promise { + const formState = channel === "USER_KEY" ? userKeyForm : astrbotForm; + const binding = bindingMap.get(channel) ?? null; + + try { + setSavingChannel(channel); + await upsertAiBinding(session, buildBindingPayload(channel, formState, binding)); + setPanelNotice({ + tone: "success", + message: channel === "USER_KEY" ? "自备厂商配置已保存。" : "AstrBot 配置已保存。" + }); + if (channel === "USER_KEY") { + setUserKeyForm((current) => ({ + ...current, + apiKey: "" + })); + } else { + setAstrbotForm((current) => ({ + ...current, + apiKey: "" + })); + } + await loadBindings("refresh"); + } catch (error) { + setPanelNotice({ + tone: "error", + message: error instanceof Error ? error.message : "AI 配置保存失败" + }); + } finally { + setSavingChannel(null); + } + } + + async function handleSendMessage(): Promise { + const message = draftMessage.trim(); + if (!message || sendBlockedReason || sending) { + return; + } + + const channel = activeChannel; + setSending(true); + setDraftMessage(""); + setMessagesByChannel((current) => + appendMessage(current, channel, { + id: crypto.randomUUID(), + role: "user", + content: message, + meta: formatTimeLabel() + }) + ); + + try { + const response = await chatWithAi(session, { + channel, + message, + sessionId: sessionIds[channel] + }); + + setSessionIds((current) => ({ + ...current, + [channel]: response.sessionId ?? current[channel] + })); + setMessagesByChannel((current) => + appendMessage(current, channel, { + id: crypto.randomUUID(), + role: "assistant", + content: response.content, + meta: `${CHANNEL_META[response.channel].title} · ${response.providerName}${response.model ? ` · ${response.model}` : ""}` + }) + ); + } catch (error) { + const apiError = + error instanceof WebAiApiError + ? error + : new WebAiApiError(error instanceof Error ? error.message : "AI 请求失败"); + const firstFailedAttempt = apiError.attempts?.find((item) => item.reasonMessage); + const content = + firstFailedAttempt?.reasonMessage && firstFailedAttempt.reasonMessage !== apiError.message + ? `${apiError.message}\n${firstFailedAttempt.reasonMessage}` + : apiError.message; + + setMessagesByChannel((current) => + appendMessage(current, channel, { + id: crypto.randomUUID(), + role: "system", + content, + meta: "调用失败" + }) + ); + } finally { + setSending(false); + } + } + + function renderChannelButton(channel: WebAiChannel) { + const channelMeta = CHANNEL_META[channel]; + const ChannelIcon = channelMeta.icon; + const selected = activeChannel === channel; + const binding = channel === "PUBLIC_POOL" ? null : (bindingMap.get(channel) ?? null); + const enabled = + channel === "PUBLIC_POOL" ? Boolean(publicPool?.enabled) : Boolean(binding?.isEnabled); + const statusLabel = + channel === "PUBLIC_POOL" + ? publicPool?.enabled + ? "可使用" + : "未开放" + : binding + ? enabled + ? "已启用" + : "已停用" + : "未配置"; + + return ( + + ); + } + + function renderNotice() { + if (!panelNotice) { + return null; + } + + return ( +
+ {panelNotice.tone === "success" ? ( + + ) : ( + + )} + {panelNotice.message} +
+ ); + } + + function renderPrivateConfigForm(channel: Exclude) { + const formState = channel === "USER_KEY" ? userKeyForm : astrbotForm; + const setFormState = channel === "USER_KEY" ? setUserKeyForm : setAstrbotForm; + const binding = bindingMap.get(channel) ?? null; + + return ( +
+
+ + + +
+ + + + {channel === "ASTRBOT" ? ( +
+ + + +
+ ) : null} + + + + + +
+

+ {channel === "USER_KEY" + ? "当前自备厂商通道按用户单独保存,适合个人独享密钥。" + : "AstrBot 通道按用户单独保存,可直接复用你在 AstrBot 中维护的模型配置。"} +

+ +
+
+ ); + } + + function renderPublicPoolCard() { + return ( +
+
+
+
+
+ {publicPool?.providerName || "公共 AI"} +
+
+ {publicPool?.model ? `默认模型:${publicPool.model}` : "管理员尚未设置默认模型"} +
+
+ + {publicPool?.enabled ? "可用" : "不可用"} + +
+
+ 公共 AI 由管理后台统一维护,普通用户仅可选择使用,不可查看或修改密钥。 +
+
+
+ ); + } + + return ( +
+
+
+
+
+ + AI 助手 +
+

+ 三路通道,按用户独立配置 +

+

+ 你可以随时切换 AstrBot、自备厂商与公共 AI 进行问答和任务统筹。 +

+
+ +
+
+ +
+ {renderNotice()} + +
+ {CHANNEL_ORDER.map((channel) => renderChannelButton(channel))} +
+ +
+
+
+ + + +
+
+ {CHANNEL_META[activeChannel].title} +
+
{channelStatusText}
+
+
+ +
+
+ + {settingsOpen ? ( +
+ {loadingBindings ? ( +
+ + 正在加载 AI 配置... +
+ ) : activeChannel === "PUBLIC_POOL" ? ( + renderPublicPoolCard() + ) : ( + renderPrivateConfigForm(activeChannel) + )} +
+ ) : null} + +
+
+
对话记录
+
+ 当前渠道:{CHANNEL_META[activeChannel].title} +
+
+ +
+ {currentMessages.length === 0 ? ( +
+
还没有对话记录。
+
+ 发送一句话试试看,例如“帮我根据当前未完成任务安排今天下午的执行顺序”。 +
+
+ ) : ( + currentMessages.map((message) => ( +
+
{message.content}
+ {message.meta ? ( +
+ {message.meta} +
+ ) : null} +
+ )) + )} +
+
+ +
+ {sendBlockedReason ? ( +
+ {sendBlockedReason} +
+ ) : null} + +