diff --git a/apps/api/prisma/schema.prisma b/apps/api/prisma/schema.prisma index 838df48..facef68 100644 --- a/apps/api/prisma/schema.prisma +++ b/apps/api/prisma/schema.prisma @@ -273,6 +273,8 @@ model AiProviderBinding { channel AiChannel providerName String model String? + configId String? + configName String? encryptedApiKey String? endpoint String? isDefault Boolean @default(false) diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts index c4b4825..b5bf24e 100644 --- a/apps/api/src/ai/ai.service.ts +++ b/apps/api/src/ai/ai.service.ts @@ -22,6 +22,8 @@ type AiBindingSummary = { channel: AiChannel; providerName: string; model: string | null; + configId: string | null; + configName: string | null; endpoint: string | null; isDefault: boolean; isEnabled: boolean; @@ -105,6 +107,8 @@ export class AiService { throw new BadRequestException("公共 AI 通道只能由管理员配置"); } + this.validateBindingInput(dto); + const result = await this.prismaService.$transaction(async (tx) => { if (dto.isDefault) { const where: Prisma.AiProviderBindingWhereInput = { @@ -131,8 +135,10 @@ export class AiService { data: { userId, channel: dto.channel, - providerName: dto.providerName.trim(), + providerName: this.normalizeProviderName(dto.providerName), model: this.normalizeOptionalString(dto.model), + configId: this.normalizeOptionalString(dto.configId), + configName: this.normalizeOptionalString(dto.configName), endpoint: this.normalizeOptionalString(dto.endpoint), encryptedApiKey: this.normalizeOptionalString(dto.apiKey), isDefault: dto.isDefault ?? false, @@ -154,8 +160,10 @@ export class AiService { const updateData: Prisma.AiProviderBindingUpdateInput = { channel: dto.channel, - providerName: dto.providerName.trim(), + providerName: this.normalizeProviderName(dto.providerName), model: this.normalizeOptionalString(dto.model), + configId: this.normalizeOptionalString(dto.configId), + configName: this.normalizeOptionalString(dto.configName), isDefault: dto.isDefault ?? existingBinding.isDefault, isEnabled: dto.isEnabled ?? existingBinding.isEnabled }; @@ -342,6 +350,8 @@ export class AiService { sourceId: binding.id, providerName: binding.providerName, model: binding.model, + configId: binding.configId, + configName: binding.configName, endpoint: binding.endpoint, apiKey: binding.encryptedApiKey }; @@ -354,6 +364,8 @@ export class AiService { sourceId: publicPool.id, providerName: publicPool.providerName ?? "public-pool", model: publicPool.model, + configId: null, + configName: null, endpoint: publicPool.endpoint, apiKey: publicPool.encryptedApiKey }; @@ -365,6 +377,8 @@ export class AiService { channel: binding.channel, providerName: binding.providerName, model: binding.model, + configId: binding.configId, + configName: binding.configName, endpoint: binding.endpoint, isDefault: binding.isDefault, isEnabled: binding.isEnabled, @@ -416,6 +430,29 @@ export class AiService { return normalizedValue.length > 0 ? normalizedValue : null; } + private normalizeProviderName(value: string | undefined): string { + return this.normalizeOptionalString(value) ?? ""; + } + + private validateBindingInput(dto: UpsertAiProviderBindingDto): void { + const providerName = this.normalizeOptionalString(dto.providerName); + const configId = this.normalizeOptionalString(dto.configId); + const configName = this.normalizeOptionalString(dto.configName); + + if (dto.channel === AiChannel.ASTRBOT) { + if (!providerName && !configId && !configName) { + throw new BadRequestException( + "AstrBot 通道至少需要 providerName、configId、configName 三者之一" + ); + } + return; + } + + if (!providerName) { + throw new BadRequestException("当前通道必须提供 providerName"); + } + } + private maskSecret(secret: string | null): string | null { if (!secret) { return null; diff --git a/apps/api/src/ai/ai.types.ts b/apps/api/src/ai/ai.types.ts index e576c61..5c52915 100644 --- a/apps/api/src/ai/ai.types.ts +++ b/apps/api/src/ai/ai.types.ts @@ -6,6 +6,8 @@ export type AiResolvedRouteCandidate = { sourceId: string | null; providerName: string; model: string | null; + configId: string | null; + configName: string | null; endpoint: string | null; apiKey: string | null; }; diff --git a/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts index 4bffff0..144f86d 100644 --- a/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts +++ b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts @@ -10,15 +10,26 @@ export class UpsertAiProviderBindingDto { @IsEnum(AiChannel) channel!: AiChannel; + @IsOptional() @IsString() @MinLength(1) - providerName!: string; + providerName?: string; @IsOptional() @IsString() @MinLength(1) model?: string; + @IsOptional() + @IsString() + @MinLength(1) + configId?: string; + + @IsOptional() + @IsString() + @MinLength(1) + configName?: string; + @IsOptional() @IsUrl( { diff --git a/apps/api/src/ai/providers/astrbot.provider.ts b/apps/api/src/ai/providers/astrbot.provider.ts index a413a3b..a82cb99 100644 --- a/apps/api/src/ai/providers/astrbot.provider.ts +++ b/apps/api/src/ai/providers/astrbot.provider.ts @@ -10,10 +10,13 @@ import { @Injectable() export class AstrbotProvider implements AiChannelExecutor { async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise { + const routeLabel = + candidate.providerName || candidate.configName || candidate.configId || "astrbot"; + if (!candidate.endpoint) { throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + routeLabel, "MISSING_ENDPOINT", "缺少 AstrBot 服务地址配置" ); @@ -22,7 +25,7 @@ export class AstrbotProvider implements AiChannelExecutor { if (!candidate.apiKey) { throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + routeLabel, "MISSING_API_KEY", "缺少 AstrBot API Key 配置" ); @@ -43,6 +46,8 @@ export class AstrbotProvider implements AiChannelExecutor { session_id: input.sessionId ?? undefined, message: input.message, enable_streaming: false, + config_id: candidate.configId ?? undefined, + config_name: candidate.configName ?? undefined, selected_provider: candidate.providerName || undefined, selected_model: candidate.model ?? undefined }), @@ -51,7 +56,7 @@ export class AstrbotProvider implements AiChannelExecutor { } catch (error) { throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + routeLabel, "UPSTREAM_UNREACHABLE", this.toErrorMessage(error, "AstrBot 服务请求失败") ); @@ -61,7 +66,7 @@ export class AstrbotProvider implements AiChannelExecutor { const rawText = await response.text(); throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + routeLabel, `UPSTREAM_HTTP_${response.status}`, this.extractHttpErrorMessage(rawText, response.status) ); @@ -81,7 +86,7 @@ export class AstrbotProvider implements AiChannelExecutor { if (type === "error") { throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + routeLabel, this.readString(event["code"]) ?? "ASTRBOT_ERROR", this.readString(event["data"]) ?? "AstrBot 返回错误" ); @@ -116,7 +121,7 @@ export class AstrbotProvider implements AiChannelExecutor { if (!content.trim()) { throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + routeLabel, "EMPTY_RESPONSE", "AstrBot 没有返回有效内容" ); @@ -124,7 +129,7 @@ export class AstrbotProvider implements AiChannelExecutor { return { channel: candidate.channel, - providerName: candidate.providerName, + providerName: routeLabel, model: candidate.model, content, sessionId, diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts index bbef854..113d037 100644 --- a/apps/api/test/ai.spec.ts +++ b/apps/api/test/ai.spec.ts @@ -5,7 +5,11 @@ import { AiChannel, AiProviderBinding, AiPublicPoolConfig } from "../generated/p import { AiController } from "../src/ai/ai.controller"; import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service"; import { AiService } from "../src/ai/ai.service"; -import { AiChannelExecutor, AiRouteFailureError } from "../src/ai/ai.types"; +import { + AiChannelExecutor, + AiResolvedRouteCandidate, + AiRouteFailureError +} from "../src/ai/ai.types"; import { PrismaService } from "../src/prisma/prisma.service"; class InMemoryAiPrismaService { @@ -65,6 +69,8 @@ class InMemoryAiPrismaService { channel: AiChannel; providerName: string; model: string | null; + configId: string | null; + configName: string | null; endpoint: string | null; encryptedApiKey: string | null; isDefault: boolean; @@ -78,6 +84,8 @@ class InMemoryAiPrismaService { channel: args.data.channel, providerName: args.data.providerName, model: args.data.model, + configId: args.data.configId, + configName: args.data.configName, encryptedApiKey: args.data.encryptedApiKey, endpoint: args.data.endpoint, isDefault: args.data.isDefault, @@ -189,12 +197,12 @@ class StaticExecutor implements AiChannelExecutor { } ) {} - async execute(candidate: { channel: AiChannel; providerName: string; model: string | null }) { + async execute(candidate: AiResolvedRouteCandidate) { const result = this.resolver(candidate.channel); if (result.code) { throw new AiRouteFailureError( candidate.channel, - candidate.providerName, + candidate.providerName || candidate.configName || candidate.configId || "unknown", result.code, result.message ?? "执行失败" ); @@ -202,7 +210,7 @@ class StaticExecutor implements AiChannelExecutor { return { channel: candidate.channel, - providerName: candidate.providerName, + providerName: candidate.providerName || candidate.configName || candidate.configId || "", model: candidate.model, content: result.content ?? "", sessionId: "session_ai", @@ -273,6 +281,7 @@ describe("AiController (integration)", () => { channel: AiChannel.ASTRBOT, providerName: "astrbot-main", model: "deepseek-chat", + configId: "default", endpoint: "http://127.0.0.1:6185", apiKey: "abk_secret_1234", isDefault: true, @@ -295,6 +304,8 @@ describe("AiController (integration)", () => { channel: AiChannel.ASTRBOT, providerName: "astrbot-main", model: "deepseek-chat", + configId: "default", + configName: null, hasApiKey: true, maskedApiKey: "abk_***34", isDefault: true @@ -308,6 +319,8 @@ describe("AiController (integration)", () => { channel: AiChannel.USER_KEY, providerName: "openai", model: "gpt-4o-mini", + configId: null, + configName: null, encryptedApiKey: "sk-user", endpoint: "https://api.example.com", isDefault: true, @@ -317,8 +330,10 @@ describe("AiController (integration)", () => { id: "binding_astrbot", userId: "user_1", channel: AiChannel.ASTRBOT, - providerName: "astrbot-main", - model: "deepseek-chat", + providerName: "", + model: null, + configId: "default", + configName: null, encryptedApiKey: "abk_astrbot", endpoint: "http://127.0.0.1:6185", isDefault: true, @@ -346,8 +361,8 @@ describe("AiController (integration)", () => { }, { channel: AiChannel.ASTRBOT, - providerName: "astrbot-main", - model: "deepseek-chat", + providerName: "default", + model: null, status: "success", reasonCode: null, reasonMessage: null @@ -355,6 +370,28 @@ describe("AiController (integration)", () => { ]); }); + it("should allow astrbot binding with config id only", async () => { + const response = await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.ASTRBOT, + configId: "default", + endpoint: "http://127.0.0.1:6185", + apiKey: "abk_secret_1234", + isDefault: true, + isEnabled: true + }) + .expect(201); + + expect(response.body).toMatchObject({ + channel: AiChannel.ASTRBOT, + providerName: "", + configId: "default", + configName: null + }); + }); + it("should return skipped attempts when no channel is available", async () => { const response = await request(app.getHttpServer()) .post("/ai/chat") diff --git a/apps/api/test/astrbot-provider.spec.ts b/apps/api/test/astrbot-provider.spec.ts index 6190c4a..a8fcee8 100644 --- a/apps/api/test/astrbot-provider.spec.ts +++ b/apps/api/test/astrbot-provider.spec.ts @@ -59,6 +59,8 @@ describe("AstrbotProvider", () => { sourceId: "binding_1", providerName: "", model: null, + configId: "default", + configName: null, endpoint: "http://127.0.0.1:6185", apiKey: "abk_test" },