diff --git a/apps/api/src/ai/ai-provider-registry.service.ts b/apps/api/src/ai/ai-provider-registry.service.ts new file mode 100644 index 0000000..54387a9 --- /dev/null +++ b/apps/api/src/ai/ai-provider-registry.service.ts @@ -0,0 +1,28 @@ +import { Injectable } from "@nestjs/common"; +import { AiChannel } from "../../generated/prisma/client"; +import { AstrbotProvider } from "./providers/astrbot.provider"; +import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider"; +import { AiChannelExecutor } from "./ai.types"; + +@Injectable() +export class AiProviderRegistryService { + private readonly executors = new Map(); + + constructor( + openAiCompatibleProvider: OpenAiCompatibleProvider, + astrbotProvider: AstrbotProvider + ) { + this.executors.set(AiChannel.USER_KEY, openAiCompatibleProvider); + this.executors.set(AiChannel.PUBLIC_POOL, openAiCompatibleProvider); + this.executors.set(AiChannel.ASTRBOT, astrbotProvider); + } + + getExecutor(channel: AiChannel): AiChannelExecutor { + const executor = this.executors.get(channel); + if (!executor) { + throw new Error(`未找到 ${channel} 对应的 AI 通道执行器`); + } + + return executor; + } +} diff --git a/apps/api/src/ai/ai.controller.ts b/apps/api/src/ai/ai.controller.ts new file mode 100644 index 0000000..3ca3ff9 --- /dev/null +++ b/apps/api/src/ai/ai.controller.ts @@ -0,0 +1,41 @@ +import { Body, Controller, Get, Headers, Post, UnauthorizedException } from "@nestjs/common"; +import { AiChatDto } from "./dto/ai-chat.dto"; +import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; +import { AiChatResponse, AiService, ListAiBindingsResponse } from "./ai.service"; + +@Controller("ai") +export class AiController { + constructor(private readonly aiService: AiService) {} + + @Get("bindings") + async listBindings( + @Headers("x-user-id") userIdHeader: string | string[] | undefined + ): Promise { + return this.aiService.listBindings(this.resolveUserId(userIdHeader)); + } + + @Post("bindings") + async upsertBinding( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: UpsertAiProviderBindingDto + ) { + return this.aiService.upsertBinding(this.resolveUserId(userIdHeader), body); + } + + @Post("chat") + async chat( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: AiChatDto + ): Promise { + return this.aiService.chat(this.resolveUserId(userIdHeader), body); + } + + private resolveUserId(userIdHeader: string | string[] | undefined): string { + const userId = Array.isArray(userIdHeader) ? userIdHeader[0] : userIdHeader; + if (!userId) { + throw new UnauthorizedException("缺少用户上下文"); + } + + return userId; + } +} diff --git a/apps/api/src/ai/ai.module.ts b/apps/api/src/ai/ai.module.ts new file mode 100644 index 0000000..2655525 --- /dev/null +++ b/apps/api/src/ai/ai.module.ts @@ -0,0 +1,14 @@ +import { Module } from "@nestjs/common"; +import { PrismaModule } from "../prisma/prisma.module"; +import { AiController } from "./ai.controller"; +import { AiProviderRegistryService } from "./ai-provider-registry.service"; +import { AiService } from "./ai.service"; +import { AstrbotProvider } from "./providers/astrbot.provider"; +import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider"; + +@Module({ + imports: [PrismaModule], + controllers: [AiController], + providers: [AiService, AiProviderRegistryService, OpenAiCompatibleProvider, AstrbotProvider] +}) +export class AiModule {} diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts new file mode 100644 index 0000000..c4b4825 --- /dev/null +++ b/apps/api/src/ai/ai.service.ts @@ -0,0 +1,430 @@ +import { + BadGatewayException, + BadRequestException, + Injectable, + Logger, + NotFoundException +} from "@nestjs/common"; +import { + AiChannel, + AiProviderBinding, + AiPublicPoolConfig, + Prisma +} from "../../generated/prisma/client"; +import { PrismaService } from "../prisma/prisma.service"; +import { AiProviderRegistryService } from "./ai-provider-registry.service"; +import { AiChatDto } from "./dto/ai-chat.dto"; +import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; +import { AiResolvedRouteCandidate, AiRouteAttempt, AiRouteFailureError } from "./ai.types"; + +type AiBindingSummary = { + id: string; + channel: AiChannel; + providerName: string; + model: string | null; + endpoint: string | null; + isDefault: boolean; + isEnabled: boolean; + hasApiKey: boolean; + maskedApiKey: string | null; + updatedAt: string; +}; + +type AiRoutePlanEntry = + | { + kind: "candidate"; + candidate: AiResolvedRouteCandidate; + } + | { + kind: "skip"; + attempt: AiRouteAttempt; + }; + +export type ListAiBindingsResponse = { + routeOrder: AiChannel[]; + bindings: AiBindingSummary[]; + publicPool: { + enabled: boolean; + providerName: string | null; + model: string | null; + endpoint: string | null; + hasApiKey: boolean; + } | null; +}; + +export type AiChatResponse = { + channel: AiChannel; + providerName: string; + model: string | null; + content: string; + sessionId: string | null; + attempts: AiRouteAttempt[]; +}; + +@Injectable() +export class AiService { + private readonly logger = new Logger(AiService.name); + + constructor( + private readonly prismaService: PrismaService, + private readonly aiProviderRegistryService: AiProviderRegistryService + ) {} + + async listBindings(userId: string): Promise { + const [bindings, publicPool] = await Promise.all([ + this.prismaService.aiProviderBinding.findMany({ + where: { + userId + }, + orderBy: [{ channel: "asc" }, { isDefault: "desc" }, { updatedAt: "desc" }] + }), + this.prismaService.aiPublicPoolConfig.findFirst({ + orderBy: { + updatedAt: "desc" + } + }) + ]); + + return { + routeOrder: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL], + bindings: bindings.map((binding) => this.serializeBinding(binding)), + publicPool: publicPool + ? { + enabled: publicPool.enabled, + providerName: publicPool.providerName, + model: publicPool.model, + endpoint: publicPool.endpoint, + hasApiKey: Boolean(publicPool.encryptedApiKey) + } + : null + }; + } + + async upsertBinding(userId: string, dto: UpsertAiProviderBindingDto): Promise { + if (dto.channel === AiChannel.PUBLIC_POOL) { + throw new BadRequestException("公共 AI 通道只能由管理员配置"); + } + + const result = await this.prismaService.$transaction(async (tx) => { + if (dto.isDefault) { + const where: Prisma.AiProviderBindingWhereInput = { + userId, + channel: dto.channel + }; + + if (dto.id) { + where.id = { + not: dto.id + }; + } + + await tx.aiProviderBinding.updateMany({ + where, + data: { + isDefault: false + } + }); + } + + if (!dto.id) { + return tx.aiProviderBinding.create({ + data: { + userId, + channel: dto.channel, + providerName: dto.providerName.trim(), + model: this.normalizeOptionalString(dto.model), + endpoint: this.normalizeOptionalString(dto.endpoint), + encryptedApiKey: this.normalizeOptionalString(dto.apiKey), + isDefault: dto.isDefault ?? false, + isEnabled: dto.isEnabled ?? true + } + }); + } + + const existingBinding = await tx.aiProviderBinding.findFirst({ + where: { + id: dto.id, + userId + } + }); + + if (!existingBinding) { + throw new NotFoundException("AI 通道配置不存在"); + } + + const updateData: Prisma.AiProviderBindingUpdateInput = { + channel: dto.channel, + providerName: dto.providerName.trim(), + model: this.normalizeOptionalString(dto.model), + isDefault: dto.isDefault ?? existingBinding.isDefault, + isEnabled: dto.isEnabled ?? existingBinding.isEnabled + }; + + if (dto.endpoint !== undefined) { + updateData.endpoint = this.normalizeOptionalString(dto.endpoint); + } + + if (dto.apiKey !== undefined) { + updateData.encryptedApiKey = this.normalizeOptionalString(dto.apiKey); + } + + return tx.aiProviderBinding.update({ + where: { + id: dto.id + }, + data: updateData + }); + }); + + return this.serializeBinding(result); + } + + async chat(userId: string, dto: AiChatDto): Promise { + const attempts: AiRouteAttempt[] = []; + const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null); + + for (const entry of plan) { + if (entry.kind === "skip") { + attempts.push(entry.attempt); + continue; + } + + const executor = this.aiProviderRegistryService.getExecutor(entry.candidate.channel); + + try { + const result = await executor.execute(entry.candidate, { + userId, + message: dto.message, + sessionId: dto.sessionId ?? null + }); + + attempts.push({ + channel: result.channel, + providerName: result.providerName, + model: result.model, + status: "success", + reasonCode: null, + reasonMessage: null + }); + + return { + channel: result.channel, + providerName: result.providerName, + model: result.model, + content: result.content, + sessionId: result.sessionId, + attempts + }; + } catch (error) { + const failureAttempt = this.toFailureAttempt(entry.candidate, error); + attempts.push(failureAttempt); + this.logger.warn( + `AI 通道降级:channel=${failureAttempt.channel} provider=${failureAttempt.providerName ?? "unknown"} code=${failureAttempt.reasonCode ?? "UNKNOWN"} message=${failureAttempt.reasonMessage ?? "unknown"}` + ); + } + } + + throw new BadGatewayException({ + message: "当前没有可用的 AI 通道,请稍后重试", + attempts + }); + } + + private async buildRoutePlan( + userId: string, + bindingId: string | null + ): Promise { + const plan: AiRoutePlanEntry[] = []; + const consumedChannels = new Set(); + + if (bindingId) { + const pinnedBinding = await this.prismaService.aiProviderBinding.findFirst({ + where: { + id: bindingId, + userId, + isEnabled: true + } + }); + + if (!pinnedBinding) { + throw new NotFoundException("指定的 AI 通道配置不存在或已禁用"); + } + + plan.push({ + kind: "candidate", + candidate: this.toBindingCandidate(pinnedBinding) + }); + consumedChannels.add(pinnedBinding.channel); + } + + for (const channel of [AiChannel.USER_KEY, AiChannel.ASTRBOT]) { + if (consumedChannels.has(channel)) { + continue; + } + + const binding = await this.findPreferredBinding(userId, channel); + if (binding) { + plan.push({ + kind: "candidate", + candidate: this.toBindingCandidate(binding) + }); + continue; + } + + plan.push({ + kind: "skip", + attempt: { + channel, + providerName: null, + model: null, + status: "skipped", + reasonCode: "CHANNEL_NOT_CONFIGURED", + reasonMessage: + channel === AiChannel.USER_KEY + ? "当前用户未配置可用的自备 Key 通道" + : "当前用户未配置可用的 AstrBot 通道" + } + }); + } + + const publicPool = await this.findEnabledPublicPool(); + if (publicPool) { + plan.push({ + kind: "candidate", + candidate: this.toPublicPoolCandidate(publicPool) + }); + } else { + plan.push({ + kind: "skip", + attempt: { + channel: AiChannel.PUBLIC_POOL, + providerName: null, + model: null, + status: "skipped", + reasonCode: "PUBLIC_POOL_DISABLED", + reasonMessage: "公共 AI 通道未开启" + } + }); + } + + return plan; + } + + private async findPreferredBinding( + userId: string, + channel: AiChannel + ): Promise { + return this.prismaService.aiProviderBinding.findFirst({ + where: { + userId, + channel, + isEnabled: true + }, + orderBy: [{ isDefault: "desc" }, { updatedAt: "desc" }] + }); + } + + private async findEnabledPublicPool(): Promise { + return this.prismaService.aiPublicPoolConfig.findFirst({ + where: { + enabled: true + }, + orderBy: { + updatedAt: "desc" + } + }); + } + + private toBindingCandidate(binding: AiProviderBinding): AiResolvedRouteCandidate { + return { + channel: binding.channel, + source: "binding", + sourceId: binding.id, + providerName: binding.providerName, + model: binding.model, + endpoint: binding.endpoint, + apiKey: binding.encryptedApiKey + }; + } + + private toPublicPoolCandidate(publicPool: AiPublicPoolConfig): AiResolvedRouteCandidate { + return { + channel: AiChannel.PUBLIC_POOL, + source: "public_pool", + sourceId: publicPool.id, + providerName: publicPool.providerName ?? "public-pool", + model: publicPool.model, + endpoint: publicPool.endpoint, + apiKey: publicPool.encryptedApiKey + }; + } + + private serializeBinding(binding: AiProviderBinding): AiBindingSummary { + return { + id: binding.id, + channel: binding.channel, + providerName: binding.providerName, + model: binding.model, + endpoint: binding.endpoint, + isDefault: binding.isDefault, + isEnabled: binding.isEnabled, + hasApiKey: Boolean(binding.encryptedApiKey), + maskedApiKey: this.maskSecret(binding.encryptedApiKey), + updatedAt: binding.updatedAt.toISOString() + }; + } + + private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt { + if (error instanceof AiRouteFailureError) { + return { + channel: error.channel, + providerName: error.providerName, + model: candidate.model, + status: "failed", + reasonCode: error.code, + reasonMessage: error.message + }; + } + + if (error instanceof Error) { + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + status: "failed", + reasonCode: "UNKNOWN_ERROR", + reasonMessage: error.message + }; + } + + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + status: "failed", + reasonCode: "UNKNOWN_ERROR", + reasonMessage: "未知错误" + }; + } + + private normalizeOptionalString(value: string | undefined): string | null { + if (value === undefined) { + return null; + } + + const normalizedValue = value.trim(); + return normalizedValue.length > 0 ? normalizedValue : null; + } + + private maskSecret(secret: string | null): string | null { + if (!secret) { + return null; + } + + if (secret.length <= 6) { + return "*".repeat(secret.length); + } + + return `${secret.slice(0, 4)}***${secret.slice(-2)}`; + } +} diff --git a/apps/api/src/ai/ai.types.ts b/apps/api/src/ai/ai.types.ts new file mode 100644 index 0000000..e576c61 --- /dev/null +++ b/apps/api/src/ai/ai.types.ts @@ -0,0 +1,52 @@ +import { AiChannel } from "../../generated/prisma/client"; + +export type AiResolvedRouteCandidate = { + channel: AiChannel; + source: "binding" | "public_pool"; + sourceId: string | null; + providerName: string; + model: string | null; + endpoint: string | null; + apiKey: string | null; +}; + +export type AiChatInput = { + userId: string; + message: string; + sessionId: string | null; +}; + +export type AiChatResult = { + channel: AiChannel; + providerName: string; + model: string | null; + content: string; + sessionId: string | null; + raw: unknown; +}; + +export type AiRouteAttempt = { + channel: AiChannel; + providerName: string | null; + model: string | null; + status: "skipped" | "failed" | "success"; + reasonCode: string | null; + reasonMessage: string | null; +}; + +export class AiRouteFailureError extends Error { + constructor( + public readonly channel: AiChannel, + public readonly providerName: string, + public readonly code: string, + message: string + ) { + super(message); + this.name = "AiRouteFailureError"; + Object.setPrototypeOf(this, new.target.prototype); + } +} + +export interface AiChannelExecutor { + execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise; +} diff --git a/apps/api/src/ai/dto/ai-chat.dto.ts b/apps/api/src/ai/dto/ai-chat.dto.ts new file mode 100644 index 0000000..a89692a --- /dev/null +++ b/apps/api/src/ai/dto/ai-chat.dto.ts @@ -0,0 +1,17 @@ +import { IsOptional, IsString, MinLength } from "class-validator"; + +export class AiChatDto { + @IsString() + @MinLength(1) + message!: string; + + @IsOptional() + @IsString() + @MinLength(1) + sessionId?: string; + + @IsOptional() + @IsString() + @MinLength(1) + bindingId?: string; +} diff --git a/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts new file mode 100644 index 0000000..4bffff0 --- /dev/null +++ b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts @@ -0,0 +1,45 @@ +import { AiChannel } from "../../../generated/prisma/client"; +import { IsBoolean, IsEnum, IsOptional, IsString, IsUrl, MinLength } from "class-validator"; + +export class UpsertAiProviderBindingDto { + @IsOptional() + @IsString() + @MinLength(1) + id?: string; + + @IsEnum(AiChannel) + channel!: AiChannel; + + @IsString() + @MinLength(1) + providerName!: string; + + @IsOptional() + @IsString() + @MinLength(1) + model?: string; + + @IsOptional() + @IsUrl( + { + require_tld: false + }, + { + message: "endpoint 必须是合法的 URL" + } + ) + endpoint?: string; + + @IsOptional() + @IsString() + @MinLength(1) + apiKey?: string; + + @IsOptional() + @IsBoolean() + isDefault?: boolean; + + @IsOptional() + @IsBoolean() + isEnabled?: boolean; +} diff --git a/apps/api/src/ai/providers/astrbot.provider.ts b/apps/api/src/ai/providers/astrbot.provider.ts new file mode 100644 index 0000000..419139d --- /dev/null +++ b/apps/api/src/ai/providers/astrbot.provider.ts @@ -0,0 +1,197 @@ +import { Injectable } from "@nestjs/common"; +import { + AiChannelExecutor, + AiChatInput, + AiChatResult, + AiResolvedRouteCandidate, + AiRouteFailureError +} from "../ai.types"; + +@Injectable() +export class AstrbotProvider implements AiChannelExecutor { + async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise { + if (!candidate.endpoint) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_ENDPOINT", + "缺少 AstrBot 服务地址配置" + ); + } + + if (!candidate.apiKey) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_API_KEY", + "缺少 AstrBot API Key 配置" + ); + } + + const requestUrl = this.buildRequestUrl(candidate.endpoint); + + let response: Response; + try { + response = await fetch(requestUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${candidate.apiKey}` + }, + body: JSON.stringify({ + username: input.userId, + session_id: input.sessionId ?? undefined, + message: input.message, + enable_streaming: false, + selected_provider: candidate.providerName || undefined, + selected_model: candidate.model ?? undefined + }), + signal: AbortSignal.timeout(30000) + }); + } catch (error) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "UPSTREAM_UNREACHABLE", + this.toErrorMessage(error, "AstrBot 服务请求失败") + ); + } + + const rawText = await response.text(); + if (!response.ok) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + `UPSTREAM_HTTP_${response.status}`, + this.extractHttpErrorMessage(rawText, response.status) + ); + } + + const events = this.parseSseEvents(rawText); + let content = ""; + let sessionId = input.sessionId; + + for (const event of events) { + const type = this.readString(event["type"]); + if (type === "session_id") { + sessionId = this.readString(event["session_id"]) ?? sessionId; + continue; + } + + if (type === "error") { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + this.readString(event["code"]) ?? "ASTRBOT_ERROR", + this.readString(event["data"]) ?? "AstrBot 返回错误" + ); + } + + if (type !== "plain") { + continue; + } + + const chainType = this.readString(event["chain_type"]); + if ( + chainType === "reasoning" || + chainType === "tool_call" || + chainType === "tool_call_result" + ) { + continue; + } + + const data = this.readString(event["data"]); + if (!data) { + continue; + } + + if (event["streaming"] === true) { + content += data; + continue; + } + + content = data; + } + + if (!content.trim()) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "EMPTY_RESPONSE", + "AstrBot 没有返回有效内容" + ); + } + + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + content, + sessionId, + raw: events + }; + } + + private buildRequestUrl(endpoint: string): string { + const normalizedEndpoint = endpoint.replace(/\/+$/, ""); + if (normalizedEndpoint.endsWith("/api/v1/chat")) { + return normalizedEndpoint; + } + if (normalizedEndpoint.endsWith("/api/v1")) { + return `${normalizedEndpoint}/chat`; + } + if (normalizedEndpoint.endsWith("/api")) { + return `${normalizedEndpoint}/v1/chat`; + } + return `${normalizedEndpoint}/api/v1/chat`; + } + + private parseSseEvents(rawText: string): Array> { + return rawText + .split(/\r?\n\r?\n/) + .map((block) => + block + .split(/\r?\n/) + .filter((line) => line.startsWith("data:")) + .map((line) => line.slice(5).trim()) + .join("\n") + ) + .filter((payload) => payload.length > 0) + .map((payload) => { + try { + return JSON.parse(payload) as Record; + } catch { + return null; + } + }) + .filter((item): item is Record => item !== null); + } + + private extractHttpErrorMessage(rawText: string, statusCode: number): string { + try { + const payload = JSON.parse(rawText) as Record; + if (typeof payload["message"] === "string") { + return payload["message"]; + } + if (typeof payload["data"] === "string") { + return payload["data"]; + } + } catch { + return `AstrBot 服务调用失败,状态码 ${statusCode}`; + } + + return `AstrBot 服务调用失败,状态码 ${statusCode}`; + } + + private readString(value: unknown): string | null { + return typeof value === "string" ? value : null; + } + + private toErrorMessage(error: unknown, fallback: string): string { + if (error instanceof Error && error.message) { + return error.message; + } + + return fallback; + } +} diff --git a/apps/api/src/ai/providers/openai-compatible.provider.ts b/apps/api/src/ai/providers/openai-compatible.provider.ts new file mode 100644 index 0000000..2ca1723 --- /dev/null +++ b/apps/api/src/ai/providers/openai-compatible.provider.ts @@ -0,0 +1,203 @@ +import { Injectable } from "@nestjs/common"; +import { + AiChannelExecutor, + AiChatInput, + AiChatResult, + AiResolvedRouteCandidate, + AiRouteFailureError +} from "../ai.types"; + +@Injectable() +export class OpenAiCompatibleProvider implements AiChannelExecutor { + async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise { + if (!candidate.endpoint) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_ENDPOINT", + "缺少 AI 服务地址配置" + ); + } + + if (!candidate.apiKey) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_API_KEY", + "缺少 AI 服务密钥配置" + ); + } + + if (!candidate.model) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "MISSING_MODEL", + "缺少 AI 模型配置" + ); + } + + const requestUrl = this.buildRequestUrl(candidate.endpoint); + + let response: Response; + try { + response = await fetch(requestUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${candidate.apiKey}` + }, + body: JSON.stringify({ + model: candidate.model, + messages: [ + { + role: "user", + content: input.message + } + ], + stream: false + }), + signal: AbortSignal.timeout(30000) + }); + } catch (error) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "UPSTREAM_UNREACHABLE", + this.toErrorMessage(error, "AI 服务请求失败") + ); + } + + let payload: unknown; + try { + payload = await response.json(); + } catch (error) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "INVALID_RESPONSE", + this.toErrorMessage(error, "AI 服务返回了无法解析的数据") + ); + } + + if (!response.ok) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + `UPSTREAM_HTTP_${response.status}`, + this.extractErrorMessage(payload, `AI 服务调用失败,状态码 ${response.status}`) + ); + } + + const content = this.extractAssistantText(payload); + if (!content.trim()) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + "EMPTY_RESPONSE", + "AI 服务没有返回有效内容" + ); + } + + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: this.extractModel(payload) ?? candidate.model, + content, + sessionId: input.sessionId, + raw: payload + }; + } + + private buildRequestUrl(endpoint: string): string { + const normalizedEndpoint = endpoint.replace(/\/+$/, ""); + if (normalizedEndpoint.endsWith("/chat/completions")) { + return normalizedEndpoint; + } + if (normalizedEndpoint.endsWith("/v1")) { + return `${normalizedEndpoint}/chat/completions`; + } + return `${normalizedEndpoint}/v1/chat/completions`; + } + + private extractAssistantText(payload: unknown): string { + if (!this.isRecord(payload)) { + return ""; + } + + const choices = payload["choices"]; + if (!Array.isArray(choices) || choices.length === 0) { + return ""; + } + + const firstChoice = choices[0]; + if (!this.isRecord(firstChoice)) { + return ""; + } + + const message = firstChoice["message"]; + if (!this.isRecord(message)) { + return ""; + } + + return this.extractMessageContent(message["content"]); + } + + private extractMessageContent(content: unknown): string { + if (typeof content === "string") { + return content; + } + + if (!Array.isArray(content)) { + return ""; + } + + return content + .map((item) => { + if (!this.isRecord(item)) { + return ""; + } + + if (typeof item["text"] === "string") { + return item["text"]; + } + + return ""; + }) + .filter((item) => item.length > 0) + .join("\n"); + } + + private extractModel(payload: unknown): string | null { + if (!this.isRecord(payload) || typeof payload["model"] !== "string") { + return null; + } + + return payload["model"]; + } + + private extractErrorMessage(payload: unknown, fallback: string): string { + if (!this.isRecord(payload)) { + return fallback; + } + + const error = payload["error"]; + if (!this.isRecord(error) || typeof error["message"] !== "string") { + return fallback; + } + + return error["message"]; + } + + private isRecord(value: unknown): value is Record { + return typeof value === "object" && value !== null; + } + + private toErrorMessage(error: unknown, fallback: string): string { + if (error instanceof Error && error.message) { + return error.message; + } + + return fallback; + } +} diff --git a/apps/api/src/app.module.ts b/apps/api/src/app.module.ts index 8c2a78f..db0e7c4 100644 --- a/apps/api/src/app.module.ts +++ b/apps/api/src/app.module.ts @@ -1,5 +1,6 @@ import { Module } from "@nestjs/common"; import { ConfigModule } from "@nestjs/config"; +import { AiModule } from "./ai/ai.module"; import { AttachmentModule } from "./attachment/attachment.module"; import { AuthModule } from "./auth/auth.module"; import { PrismaModule } from "./prisma/prisma.module"; @@ -16,7 +17,8 @@ import { TaskModule } from "./task/task.module"; AuthModule, TaskModule, AttachmentModule, - SyncModule + SyncModule, + AiModule ] }) export class AppModule {} diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts new file mode 100644 index 0000000..bbef854 --- /dev/null +++ b/apps/api/test/ai.spec.ts @@ -0,0 +1,395 @@ +import request from "supertest"; +import { INestApplication, ValidationPipe } from "@nestjs/common"; +import { Test, TestingModule } from "@nestjs/testing"; +import { AiChannel, AiProviderBinding, AiPublicPoolConfig } from "../generated/prisma/client"; +import { AiController } from "../src/ai/ai.controller"; +import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service"; +import { AiService } from "../src/ai/ai.service"; +import { AiChannelExecutor, AiRouteFailureError } from "../src/ai/ai.types"; +import { PrismaService } from "../src/prisma/prisma.service"; + +class InMemoryAiPrismaService { + private bindingIdSequence = 1; + private publicPoolIdSequence = 1; + private bindings: AiProviderBinding[] = []; + private publicPools: AiPublicPoolConfig[] = []; + + readonly aiProviderBinding = { + findMany: async (args: { + where: { + userId: string; + }; + }) => { + return this.bindings + .filter((binding) => binding.userId === args.where.userId) + .sort((left, right) => right.updatedAt.getTime() - left.updatedAt.getTime()); + }, + + findFirst: async (args: { + where: { + id?: string; + userId?: string; + channel?: AiChannel; + isEnabled?: boolean; + }; + }) => { + return ( + this.bindings + .filter((binding) => { + if (args.where.id !== undefined && binding.id !== args.where.id) { + return false; + } + if (args.where.userId !== undefined && binding.userId !== args.where.userId) { + return false; + } + if (args.where.channel !== undefined && binding.channel !== args.where.channel) { + return false; + } + if (args.where.isEnabled !== undefined && binding.isEnabled !== args.where.isEnabled) { + return false; + } + return true; + }) + .sort((left, right) => { + if (left.isDefault !== right.isDefault) { + return Number(right.isDefault) - Number(left.isDefault); + } + return right.updatedAt.getTime() - left.updatedAt.getTime(); + })[0] ?? null + ); + }, + + create: async (args: { + data: { + userId: string; + channel: AiChannel; + providerName: string; + model: string | null; + endpoint: string | null; + encryptedApiKey: string | null; + isDefault: boolean; + isEnabled: boolean; + }; + }) => { + const now = new Date(); + const binding: AiProviderBinding = { + id: `binding_${this.bindingIdSequence++}`, + userId: args.data.userId, + channel: args.data.channel, + providerName: args.data.providerName, + model: args.data.model, + encryptedApiKey: args.data.encryptedApiKey, + endpoint: args.data.endpoint, + isDefault: args.data.isDefault, + isEnabled: args.data.isEnabled, + createdAt: now, + updatedAt: now + }; + + this.bindings.push(binding); + return binding; + }, + + update: async (args: { + where: { + id: string; + }; + data: Partial; + }) => { + const binding = this.bindings.find((item) => item.id === args.where.id); + if (!binding) { + throw new Error("binding not found"); + } + + Object.assign(binding, args.data, { updatedAt: new Date() }); + return binding; + }, + + updateMany: async (args: { + where: { + userId?: string; + channel?: AiChannel; + id?: { + not: string; + }; + }; + data: { + isDefault?: boolean; + }; + }) => { + let count = 0; + for (const binding of this.bindings) { + if (args.where.userId !== undefined && binding.userId !== args.where.userId) { + continue; + } + if (args.where.channel !== undefined && binding.channel !== args.where.channel) { + continue; + } + if (args.where.id?.not !== undefined && binding.id === args.where.id.not) { + continue; + } + + if (args.data.isDefault !== undefined) { + binding.isDefault = args.data.isDefault; + binding.updatedAt = new Date(); + } + count += 1; + } + + return { count }; + } + }; + + readonly aiPublicPoolConfig = { + findFirst: async (args?: { + where?: { + enabled?: boolean; + }; + }) => { + const items = this.publicPools + .filter((item) => + args?.where?.enabled === undefined ? true : item.enabled === args.where.enabled + ) + .sort((left, right) => right.updatedAt.getTime() - left.updatedAt.getTime()); + + return items[0] ?? null; + } + }; + + async $transaction(callback: (tx: InMemoryAiPrismaService) => Promise): Promise { + return callback(this); + } + + seedBinding(binding: Omit): void { + const now = new Date(); + this.bindings.push({ + ...binding, + createdAt: now, + updatedAt: now + }); + } + + seedPublicPool(publicPool: Omit): void { + const now = new Date(); + this.publicPools.push({ + id: `pool_${this.publicPoolIdSequence++}`, + createdAt: now, + updatedAt: now, + ...publicPool + }); + } +} + +class StaticExecutor implements AiChannelExecutor { + constructor( + private readonly resolver: (channel: AiChannel) => { + content?: string; + code?: string; + message?: string; + } + ) {} + + async execute(candidate: { channel: AiChannel; providerName: string; model: string | null }) { + const result = this.resolver(candidate.channel); + if (result.code) { + throw new AiRouteFailureError( + candidate.channel, + candidate.providerName, + result.code, + result.message ?? "执行失败" + ); + } + + return { + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + content: result.content ?? "", + sessionId: "session_ai", + raw: null + }; + } +} + +describe("AiController (integration)", () => { + let app: INestApplication; + let prismaService: InMemoryAiPrismaService; + + beforeEach(async () => { + prismaService = new InMemoryAiPrismaService(); + + const openAiExecutor = new StaticExecutor((channel) => + channel === AiChannel.USER_KEY + ? { + code: "UPSTREAM_UNREACHABLE", + message: "用户自备 Key 渠道暂时不可用" + } + : { + content: "公共 AI 已接管" + } + ); + const astrbotExecutor = new StaticExecutor(() => ({ + content: "AstrBot 已接管" + })); + + const moduleRef: TestingModule = await Test.createTestingModule({ + controllers: [AiController], + providers: [ + AiService, + { + provide: PrismaService, + useValue: prismaService + }, + { + provide: AiProviderRegistryService, + useValue: { + getExecutor: (channel: AiChannel) => + channel === AiChannel.ASTRBOT ? astrbotExecutor : openAiExecutor + } + } + ] + }).compile(); + + app = moduleRef.createNestApplication(); + app.useGlobalPipes( + new ValidationPipe({ + transform: true, + whitelist: true, + forbidNonWhitelisted: true + }) + ); + await app.init(); + }); + + afterEach(async () => { + await app.close(); + }); + + it("should create and list ai bindings", async () => { + await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.ASTRBOT, + providerName: "astrbot-main", + model: "deepseek-chat", + endpoint: "http://127.0.0.1:6185", + apiKey: "abk_secret_1234", + isDefault: true, + isEnabled: true + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .get("/ai/bindings") + .set("x-user-id", "user_1") + .expect(200); + + expect(response.body.routeOrder).toEqual([ + AiChannel.USER_KEY, + AiChannel.ASTRBOT, + AiChannel.PUBLIC_POOL + ]); + expect(response.body.bindings).toHaveLength(1); + expect(response.body.bindings[0]).toMatchObject({ + channel: AiChannel.ASTRBOT, + providerName: "astrbot-main", + model: "deepseek-chat", + hasApiKey: true, + maskedApiKey: "abk_***34", + isDefault: true + }); + }); + + it("should fallback from user key to astrbot", async () => { + prismaService.seedBinding({ + id: "binding_user_key", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + encryptedApiKey: "sk-user", + endpoint: "https://api.example.com", + isDefault: true, + isEnabled: true + }); + prismaService.seedBinding({ + id: "binding_astrbot", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "astrbot-main", + model: "deepseek-chat", + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "帮我安排今天的任务" + }) + .expect(201); + + expect(response.body.channel).toBe(AiChannel.ASTRBOT); + expect(response.body.content).toBe("AstrBot 已接管"); + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + status: "failed", + reasonCode: "UPSTREAM_UNREACHABLE", + reasonMessage: "用户自备 Key 渠道暂时不可用" + }, + { + channel: AiChannel.ASTRBOT, + providerName: "astrbot-main", + model: "deepseek-chat", + status: "success", + reasonCode: null, + reasonMessage: null + } + ]); + }); + + it("should return skipped attempts when no channel is available", async () => { + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "帮我总结今天的安排" + }) + .expect(502); + + expect(response.body.message).toBe("当前没有可用的 AI 通道,请稍后重试"); + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: null, + model: null, + status: "skipped", + reasonCode: "CHANNEL_NOT_CONFIGURED", + reasonMessage: "当前用户未配置可用的自备 Key 通道" + }, + { + channel: AiChannel.ASTRBOT, + providerName: null, + model: null, + status: "skipped", + reasonCode: "CHANNEL_NOT_CONFIGURED", + reasonMessage: "当前用户未配置可用的 AstrBot 通道" + }, + { + channel: AiChannel.PUBLIC_POOL, + providerName: null, + model: null, + status: "skipped", + reasonCode: "PUBLIC_POOL_DISABLED", + reasonMessage: "公共 AI 通道未开启" + } + ]); + }); +});