From d0ba58118450467678fe2423f69db4bd5b2b7354 Mon Sep 17 00:00:00 2001 From: Yaosanqi137 Date: Mon, 6 Apr 2026 13:36:28 +0800 Subject: [PATCH] feat(api-ai): scope private bindings by user channel --- apps/api/src/ai/ai.service.ts | 142 +++++++----------- apps/api/src/ai/dto/ai-chat.dto.ts | 8 +- .../ai/dto/upsert-ai-provider-binding.dto.ts | 13 +- apps/api/test/ai.spec.ts | 100 +++++++++++- 4 files changed, 157 insertions(+), 106 deletions(-) diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts index dfbbd9a..8dae914 100644 --- a/apps/api/src/ai/ai.service.ts +++ b/apps/api/src/ai/ai.service.ts @@ -1,10 +1,4 @@ -import { - BadGatewayException, - BadRequestException, - Injectable, - Logger, - NotFoundException -} from "@nestjs/common"; +import { BadGatewayException, BadRequestException, Injectable, Logger } from "@nestjs/common"; import { AiChannel, AiUsageLog, @@ -34,7 +28,6 @@ type AiBindingSummary = { configId: string | null; configName: string | null; endpoint: string | null; - isDefault: boolean; isEnabled: boolean; hasApiKey: boolean; maskedApiKey: string | null; @@ -110,7 +103,7 @@ export class AiService { where: { userId }, - orderBy: [{ channel: "asc" }, { isDefault: "desc" }, { updatedAt: "desc" }] + orderBy: [{ updatedAt: "desc" }] }), this.prismaService.aiPublicPoolConfig.findFirst({ orderBy: { @@ -119,9 +112,11 @@ export class AiService { }) ]); + const latestBindings = this.pickLatestBindingsByChannel(bindings); + return { routeOrder: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL], - bindings: bindings.map((binding) => this.serializeBinding(binding)), + bindings: latestBindings.map((binding) => this.serializeBinding(binding)), publicPool: publicPool ? { enabled: publicPool.enabled, @@ -183,27 +178,17 @@ export class AiService { this.validateBindingInput(dto); const result = await this.prismaService.$transaction(async (tx) => { - if (dto.isDefault) { - const where: Prisma.AiProviderBindingWhereInput = { + const existingBinding = await tx.aiProviderBinding.findFirst({ + where: { userId, channel: dto.channel - }; - - if (dto.id) { - where.id = { - not: dto.id - }; + }, + orderBy: { + updatedAt: "desc" } + }); - await tx.aiProviderBinding.updateMany({ - where, - data: { - isDefault: false - } - }); - } - - if (!dto.id) { + if (!existingBinding) { return tx.aiProviderBinding.create({ data: { userId, @@ -214,30 +199,17 @@ export class AiService { configName: this.normalizeOptionalString(dto.configName), endpoint: this.normalizeOptionalString(dto.endpoint), encryptedApiKey: this.normalizeOptionalString(dto.apiKey), - isDefault: dto.isDefault ?? false, isEnabled: dto.isEnabled ?? true } }); } - const existingBinding = await tx.aiProviderBinding.findFirst({ - where: { - id: dto.id, - userId - } - }); - - if (!existingBinding) { - throw new NotFoundException("AI 通道配置不存在"); - } - const updateData: Prisma.AiProviderBindingUpdateInput = { channel: dto.channel, providerName: this.normalizeProviderName(dto.providerName), model: this.normalizeOptionalString(dto.model), configId: this.normalizeOptionalString(dto.configId), configName: this.normalizeOptionalString(dto.configName), - isDefault: dto.isDefault ?? existingBinding.isDefault, isEnabled: dto.isEnabled ?? existingBinding.isEnabled }; @@ -251,7 +223,7 @@ export class AiService { return tx.aiProviderBinding.update({ where: { - id: dto.id + id: existingBinding.id }, data: updateData }); @@ -262,7 +234,7 @@ export class AiService { async chat(userId: string, dto: AiChatDto): Promise { const attempts: AiRouteAttempt[] = []; - const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null); + const plan = await this.buildRoutePlan(userId, dto.channel ?? null); const promptMessage = await this.buildPromptMessage(userId, dto.message); for (const entry of plan) { @@ -337,33 +309,34 @@ export class AiService { private async buildRoutePlan( userId: string, - bindingId: string | null + selectedChannel: AiChannel | null ): Promise { const plan: AiRoutePlanEntry[] = []; - const consumedChannels = new Set(); + const targetChannels = selectedChannel + ? [selectedChannel] + : [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL]; - if (bindingId) { - const pinnedBinding = await this.prismaService.aiProviderBinding.findFirst({ - where: { - id: bindingId, - userId, - isEnabled: true + for (const channel of targetChannels) { + if (channel === AiChannel.PUBLIC_POOL) { + const publicPool = await this.findEnabledPublicPool(); + if (publicPool) { + plan.push({ + kind: "candidate", + candidate: this.toPublicPoolCandidate(publicPool) + }); + } else { + plan.push({ + kind: "skip", + attempt: { + channel: AiChannel.PUBLIC_POOL, + providerName: null, + model: null, + status: "skipped", + reasonCode: "PUBLIC_POOL_DISABLED", + reasonMessage: "公共 AI 通道未开启" + } + }); } - }); - - if (!pinnedBinding) { - throw new NotFoundException("指定的 AI 通道配置不存在或已禁用"); - } - - plan.push({ - kind: "candidate", - candidate: this.toBindingCandidate(pinnedBinding) - }); - consumedChannels.add(pinnedBinding.channel); - } - - for (const channel of [AiChannel.USER_KEY, AiChannel.ASTRBOT]) { - if (consumedChannels.has(channel)) { continue; } @@ -392,26 +365,6 @@ export class AiService { }); } - const publicPool = await this.findEnabledPublicPool(); - if (publicPool) { - plan.push({ - kind: "candidate", - candidate: this.toPublicPoolCandidate(publicPool) - }); - } else { - plan.push({ - kind: "skip", - attempt: { - channel: AiChannel.PUBLIC_POOL, - providerName: null, - model: null, - status: "skipped", - reasonCode: "PUBLIC_POOL_DISABLED", - reasonMessage: "公共 AI 通道未开启" - } - }); - } - return plan; } @@ -425,7 +378,9 @@ export class AiService { channel, isEnabled: true }, - orderBy: [{ isDefault: "desc" }, { updatedAt: "desc" }] + orderBy: { + updatedAt: "desc" + } }); } @@ -477,7 +432,6 @@ export class AiService { configId: binding.configId, configName: binding.configName, endpoint: binding.endpoint, - isDefault: binding.isDefault, isEnabled: binding.isEnabled, hasApiKey: Boolean(binding.encryptedApiKey), maskedApiKey: this.maskSecret(binding.encryptedApiKey), @@ -485,6 +439,20 @@ export class AiService { }; } + private pickLatestBindingsByChannel(bindings: AiProviderBinding[]): AiProviderBinding[] { + const bindingMap = new Map(); + + for (const binding of bindings) { + if (!bindingMap.has(binding.channel)) { + bindingMap.set(binding.channel, binding); + } + } + + return [AiChannel.USER_KEY, AiChannel.ASTRBOT] + .map((channel) => bindingMap.get(channel) ?? null) + .filter((binding): binding is AiProviderBinding => binding !== null); + } + private serializeUsageLog(log: AiUsageLog): AiUsageLogSummary { return { id: log.id, diff --git a/apps/api/src/ai/dto/ai-chat.dto.ts b/apps/api/src/ai/dto/ai-chat.dto.ts index a89692a..013b697 100644 --- a/apps/api/src/ai/dto/ai-chat.dto.ts +++ b/apps/api/src/ai/dto/ai-chat.dto.ts @@ -1,4 +1,5 @@ -import { IsOptional, IsString, MinLength } from "class-validator"; +import { IsEnum, IsOptional, IsString, MinLength } from "class-validator"; +import { AiChannel } from "../../../generated/prisma/client"; export class AiChatDto { @IsString() @@ -11,7 +12,6 @@ export class AiChatDto { sessionId?: string; @IsOptional() - @IsString() - @MinLength(1) - bindingId?: string; + @IsEnum(AiChannel) + channel?: AiChannel; } diff --git a/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts index 144f86d..b821bcc 100644 --- a/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts +++ b/apps/api/src/ai/dto/upsert-ai-provider-binding.dto.ts @@ -1,12 +1,7 @@ -import { AiChannel } from "../../../generated/prisma/client"; +import { AiChannel } from "../../../generated/prisma/client"; import { IsBoolean, IsEnum, IsOptional, IsString, IsUrl, MinLength } from "class-validator"; export class UpsertAiProviderBindingDto { - @IsOptional() - @IsString() - @MinLength(1) - id?: string; - @IsEnum(AiChannel) channel!: AiChannel; @@ -36,7 +31,7 @@ export class UpsertAiProviderBindingDto { require_tld: false }, { - message: "endpoint 必须是合法的 URL" + message: "endpoint \u5fc5\u987b\u662f\u5408\u6cd5\u7684 URL" } ) endpoint?: string; @@ -46,10 +41,6 @@ export class UpsertAiProviderBindingDto { @MinLength(1) apiKey?: string; - @IsOptional() - @IsBoolean() - isDefault?: boolean; - @IsOptional() @IsBoolean() isEnabled?: boolean; diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts index 4fd9b2f..d344a9e 100644 --- a/apps/api/test/ai.spec.ts +++ b/apps/api/test/ai.spec.ts @@ -441,7 +441,6 @@ describe("AiController (integration)", () => { configId: "default", endpoint: "http://127.0.0.1:6185", apiKey: "abk_secret_1234", - isDefault: true, isEnabled: true }) .expect(201); @@ -465,10 +464,54 @@ describe("AiController (integration)", () => { configName: null, hasApiKey: true, maskedApiKey: "abk_***34", - isDefault: true + isEnabled: true }); }); + it("should upsert one binding per user channel", async () => { + await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + endpoint: "https://api.example.com", + apiKey: "sk-first", + isEnabled: true + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/bindings") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "google", + model: "gemini-2.5-flash", + endpoint: "https://generativelanguage.googleapis.com", + apiKey: "sk-second", + isEnabled: false + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .get("/ai/bindings") + .set("x-user-id", "user_1") + .expect(200); + + expect(response.body.bindings).toEqual([ + expect.objectContaining({ + channel: AiChannel.USER_KEY, + providerName: "google", + model: "gemini-2.5-flash", + endpoint: "https://generativelanguage.googleapis.com", + isEnabled: false, + maskedApiKey: "sk-s***nd" + }) + ]); + }); + it("should fallback from user key to astrbot", async () => { prismaService.seedBinding({ id: "binding_user_key", @@ -566,7 +609,6 @@ describe("AiController (integration)", () => { configId: "default", endpoint: "http://127.0.0.1:6185", apiKey: "abk_secret_1234", - isDefault: true, isEnabled: true }) .expect(201); @@ -575,10 +617,60 @@ describe("AiController (integration)", () => { channel: AiChannel.ASTRBOT, providerName: "", configId: "default", - configName: null + configName: null, + isEnabled: true }); }); + it("should use selected channel without automatic fallback", async () => { + prismaService.seedBinding({ + id: "binding_user_key_selected", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + configId: null, + configName: null, + encryptedApiKey: "sk-user", + endpoint: "https://api.example.com", + isDefault: false, + isEnabled: true + }); + prismaService.seedBinding({ + id: "binding_astrbot_selected", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: false, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "只使用自备渠道", + channel: AiChannel.USER_KEY + }) + .expect(502); + + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + status: "failed", + reasonCode: "UPSTREAM_UNREACHABLE", + reasonMessage: "用户自备 Key 渠道暂时不可用" + } + ]); + }); + it("should inject unfinished task summary into ai prompt", async () => { prismaService.seedBinding({ id: "binding_astrbot_context",