feat(api-ai): scope private bindings by user channel

This commit is contained in:
2026-04-06 13:36:28 +08:00
parent 5c956c195b
commit d0ba581184
4 changed files with 157 additions and 106 deletions
+55 -87
View File
@@ -1,10 +1,4 @@
import { import { BadGatewayException, BadRequestException, Injectable, Logger } from "@nestjs/common";
BadGatewayException,
BadRequestException,
Injectable,
Logger,
NotFoundException
} from "@nestjs/common";
import { import {
AiChannel, AiChannel,
AiUsageLog, AiUsageLog,
@@ -34,7 +28,6 @@ type AiBindingSummary = {
configId: string | null; configId: string | null;
configName: string | null; configName: string | null;
endpoint: string | null; endpoint: string | null;
isDefault: boolean;
isEnabled: boolean; isEnabled: boolean;
hasApiKey: boolean; hasApiKey: boolean;
maskedApiKey: string | null; maskedApiKey: string | null;
@@ -110,7 +103,7 @@ export class AiService {
where: { where: {
userId userId
}, },
orderBy: [{ channel: "asc" }, { isDefault: "desc" }, { updatedAt: "desc" }] orderBy: [{ updatedAt: "desc" }]
}), }),
this.prismaService.aiPublicPoolConfig.findFirst({ this.prismaService.aiPublicPoolConfig.findFirst({
orderBy: { orderBy: {
@@ -119,9 +112,11 @@ export class AiService {
}) })
]); ]);
const latestBindings = this.pickLatestBindingsByChannel(bindings);
return { return {
routeOrder: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL], routeOrder: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL],
bindings: bindings.map((binding) => this.serializeBinding(binding)), bindings: latestBindings.map((binding) => this.serializeBinding(binding)),
publicPool: publicPool publicPool: publicPool
? { ? {
enabled: publicPool.enabled, enabled: publicPool.enabled,
@@ -183,27 +178,17 @@ export class AiService {
this.validateBindingInput(dto); this.validateBindingInput(dto);
const result = await this.prismaService.$transaction(async (tx) => { const result = await this.prismaService.$transaction(async (tx) => {
if (dto.isDefault) { const existingBinding = await tx.aiProviderBinding.findFirst({
const where: Prisma.AiProviderBindingWhereInput = { where: {
userId, userId,
channel: dto.channel channel: dto.channel
}; },
orderBy: {
if (dto.id) { updatedAt: "desc"
where.id = {
not: dto.id
};
} }
});
await tx.aiProviderBinding.updateMany({ if (!existingBinding) {
where,
data: {
isDefault: false
}
});
}
if (!dto.id) {
return tx.aiProviderBinding.create({ return tx.aiProviderBinding.create({
data: { data: {
userId, userId,
@@ -214,30 +199,17 @@ export class AiService {
configName: this.normalizeOptionalString(dto.configName), configName: this.normalizeOptionalString(dto.configName),
endpoint: this.normalizeOptionalString(dto.endpoint), endpoint: this.normalizeOptionalString(dto.endpoint),
encryptedApiKey: this.normalizeOptionalString(dto.apiKey), encryptedApiKey: this.normalizeOptionalString(dto.apiKey),
isDefault: dto.isDefault ?? false,
isEnabled: dto.isEnabled ?? true isEnabled: dto.isEnabled ?? true
} }
}); });
} }
const existingBinding = await tx.aiProviderBinding.findFirst({
where: {
id: dto.id,
userId
}
});
if (!existingBinding) {
throw new NotFoundException("AI 通道配置不存在");
}
const updateData: Prisma.AiProviderBindingUpdateInput = { const updateData: Prisma.AiProviderBindingUpdateInput = {
channel: dto.channel, channel: dto.channel,
providerName: this.normalizeProviderName(dto.providerName), providerName: this.normalizeProviderName(dto.providerName),
model: this.normalizeOptionalString(dto.model), model: this.normalizeOptionalString(dto.model),
configId: this.normalizeOptionalString(dto.configId), configId: this.normalizeOptionalString(dto.configId),
configName: this.normalizeOptionalString(dto.configName), configName: this.normalizeOptionalString(dto.configName),
isDefault: dto.isDefault ?? existingBinding.isDefault,
isEnabled: dto.isEnabled ?? existingBinding.isEnabled isEnabled: dto.isEnabled ?? existingBinding.isEnabled
}; };
@@ -251,7 +223,7 @@ export class AiService {
return tx.aiProviderBinding.update({ return tx.aiProviderBinding.update({
where: { where: {
id: dto.id id: existingBinding.id
}, },
data: updateData data: updateData
}); });
@@ -262,7 +234,7 @@ export class AiService {
async chat(userId: string, dto: AiChatDto): Promise<AiChatResponse> { async chat(userId: string, dto: AiChatDto): Promise<AiChatResponse> {
const attempts: AiRouteAttempt[] = []; const attempts: AiRouteAttempt[] = [];
const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null); const plan = await this.buildRoutePlan(userId, dto.channel ?? null);
const promptMessage = await this.buildPromptMessage(userId, dto.message); const promptMessage = await this.buildPromptMessage(userId, dto.message);
for (const entry of plan) { for (const entry of plan) {
@@ -337,33 +309,34 @@ export class AiService {
private async buildRoutePlan( private async buildRoutePlan(
userId: string, userId: string,
bindingId: string | null selectedChannel: AiChannel | null
): Promise<AiRoutePlanEntry[]> { ): Promise<AiRoutePlanEntry[]> {
const plan: AiRoutePlanEntry[] = []; const plan: AiRoutePlanEntry[] = [];
const consumedChannels = new Set<AiChannel>(); const targetChannels = selectedChannel
? [selectedChannel]
: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL];
if (bindingId) { for (const channel of targetChannels) {
const pinnedBinding = await this.prismaService.aiProviderBinding.findFirst({ if (channel === AiChannel.PUBLIC_POOL) {
where: { const publicPool = await this.findEnabledPublicPool();
id: bindingId, if (publicPool) {
userId, plan.push({
isEnabled: true kind: "candidate",
candidate: this.toPublicPoolCandidate(publicPool)
});
} else {
plan.push({
kind: "skip",
attempt: {
channel: AiChannel.PUBLIC_POOL,
providerName: null,
model: null,
status: "skipped",
reasonCode: "PUBLIC_POOL_DISABLED",
reasonMessage: "公共 AI 通道未开启"
}
});
} }
});
if (!pinnedBinding) {
throw new NotFoundException("指定的 AI 通道配置不存在或已禁用");
}
plan.push({
kind: "candidate",
candidate: this.toBindingCandidate(pinnedBinding)
});
consumedChannels.add(pinnedBinding.channel);
}
for (const channel of [AiChannel.USER_KEY, AiChannel.ASTRBOT]) {
if (consumedChannels.has(channel)) {
continue; continue;
} }
@@ -392,26 +365,6 @@ export class AiService {
}); });
} }
const publicPool = await this.findEnabledPublicPool();
if (publicPool) {
plan.push({
kind: "candidate",
candidate: this.toPublicPoolCandidate(publicPool)
});
} else {
plan.push({
kind: "skip",
attempt: {
channel: AiChannel.PUBLIC_POOL,
providerName: null,
model: null,
status: "skipped",
reasonCode: "PUBLIC_POOL_DISABLED",
reasonMessage: "公共 AI 通道未开启"
}
});
}
return plan; return plan;
} }
@@ -425,7 +378,9 @@ export class AiService {
channel, channel,
isEnabled: true isEnabled: true
}, },
orderBy: [{ isDefault: "desc" }, { updatedAt: "desc" }] orderBy: {
updatedAt: "desc"
}
}); });
} }
@@ -477,7 +432,6 @@ export class AiService {
configId: binding.configId, configId: binding.configId,
configName: binding.configName, configName: binding.configName,
endpoint: binding.endpoint, endpoint: binding.endpoint,
isDefault: binding.isDefault,
isEnabled: binding.isEnabled, isEnabled: binding.isEnabled,
hasApiKey: Boolean(binding.encryptedApiKey), hasApiKey: Boolean(binding.encryptedApiKey),
maskedApiKey: this.maskSecret(binding.encryptedApiKey), maskedApiKey: this.maskSecret(binding.encryptedApiKey),
@@ -485,6 +439,20 @@ export class AiService {
}; };
} }
private pickLatestBindingsByChannel(bindings: AiProviderBinding[]): AiProviderBinding[] {
const bindingMap = new Map<AiChannel, AiProviderBinding>();
for (const binding of bindings) {
if (!bindingMap.has(binding.channel)) {
bindingMap.set(binding.channel, binding);
}
}
return [AiChannel.USER_KEY, AiChannel.ASTRBOT]
.map((channel) => bindingMap.get(channel) ?? null)
.filter((binding): binding is AiProviderBinding => binding !== null);
}
private serializeUsageLog(log: AiUsageLog): AiUsageLogSummary { private serializeUsageLog(log: AiUsageLog): AiUsageLogSummary {
return { return {
id: log.id, id: log.id,
+4 -4
View File
@@ -1,4 +1,5 @@
import { IsOptional, IsString, MinLength } from "class-validator"; import { IsEnum, IsOptional, IsString, MinLength } from "class-validator";
import { AiChannel } from "../../../generated/prisma/client";
export class AiChatDto { export class AiChatDto {
@IsString() @IsString()
@@ -11,7 +12,6 @@ export class AiChatDto {
sessionId?: string; sessionId?: string;
@IsOptional() @IsOptional()
@IsString() @IsEnum(AiChannel)
@MinLength(1) channel?: AiChannel;
bindingId?: string;
} }
@@ -1,12 +1,7 @@
import { AiChannel } from "../../../generated/prisma/client"; import { AiChannel } from "../../../generated/prisma/client";
import { IsBoolean, IsEnum, IsOptional, IsString, IsUrl, MinLength } from "class-validator"; import { IsBoolean, IsEnum, IsOptional, IsString, IsUrl, MinLength } from "class-validator";
export class UpsertAiProviderBindingDto { export class UpsertAiProviderBindingDto {
@IsOptional()
@IsString()
@MinLength(1)
id?: string;
@IsEnum(AiChannel) @IsEnum(AiChannel)
channel!: AiChannel; channel!: AiChannel;
@@ -36,7 +31,7 @@ export class UpsertAiProviderBindingDto {
require_tld: false require_tld: false
}, },
{ {
message: "endpoint 必须是合法的 URL" message: "endpoint \u5fc5\u987b\u662f\u5408\u6cd5\u7684 URL"
} }
) )
endpoint?: string; endpoint?: string;
@@ -46,10 +41,6 @@ export class UpsertAiProviderBindingDto {
@MinLength(1) @MinLength(1)
apiKey?: string; apiKey?: string;
@IsOptional()
@IsBoolean()
isDefault?: boolean;
@IsOptional() @IsOptional()
@IsBoolean() @IsBoolean()
isEnabled?: boolean; isEnabled?: boolean;
+96 -4
View File
@@ -441,7 +441,6 @@ describe("AiController (integration)", () => {
configId: "default", configId: "default",
endpoint: "http://127.0.0.1:6185", endpoint: "http://127.0.0.1:6185",
apiKey: "abk_secret_1234", apiKey: "abk_secret_1234",
isDefault: true,
isEnabled: true isEnabled: true
}) })
.expect(201); .expect(201);
@@ -465,10 +464,54 @@ describe("AiController (integration)", () => {
configName: null, configName: null,
hasApiKey: true, hasApiKey: true,
maskedApiKey: "abk_***34", maskedApiKey: "abk_***34",
isDefault: true isEnabled: true
}); });
}); });
it("should upsert one binding per user channel", async () => {
await request(app.getHttpServer())
.post("/ai/bindings")
.set("x-user-id", "user_1")
.send({
channel: AiChannel.USER_KEY,
providerName: "openai",
model: "gpt-4o-mini",
endpoint: "https://api.example.com",
apiKey: "sk-first",
isEnabled: true
})
.expect(201);
await request(app.getHttpServer())
.post("/ai/bindings")
.set("x-user-id", "user_1")
.send({
channel: AiChannel.USER_KEY,
providerName: "google",
model: "gemini-2.5-flash",
endpoint: "https://generativelanguage.googleapis.com",
apiKey: "sk-second",
isEnabled: false
})
.expect(201);
const response = await request(app.getHttpServer())
.get("/ai/bindings")
.set("x-user-id", "user_1")
.expect(200);
expect(response.body.bindings).toEqual([
expect.objectContaining({
channel: AiChannel.USER_KEY,
providerName: "google",
model: "gemini-2.5-flash",
endpoint: "https://generativelanguage.googleapis.com",
isEnabled: false,
maskedApiKey: "sk-s***nd"
})
]);
});
it("should fallback from user key to astrbot", async () => { it("should fallback from user key to astrbot", async () => {
prismaService.seedBinding({ prismaService.seedBinding({
id: "binding_user_key", id: "binding_user_key",
@@ -566,7 +609,6 @@ describe("AiController (integration)", () => {
configId: "default", configId: "default",
endpoint: "http://127.0.0.1:6185", endpoint: "http://127.0.0.1:6185",
apiKey: "abk_secret_1234", apiKey: "abk_secret_1234",
isDefault: true,
isEnabled: true isEnabled: true
}) })
.expect(201); .expect(201);
@@ -575,10 +617,60 @@ describe("AiController (integration)", () => {
channel: AiChannel.ASTRBOT, channel: AiChannel.ASTRBOT,
providerName: "", providerName: "",
configId: "default", configId: "default",
configName: null configName: null,
isEnabled: true
}); });
}); });
it("should use selected channel without automatic fallback", async () => {
prismaService.seedBinding({
id: "binding_user_key_selected",
userId: "user_1",
channel: AiChannel.USER_KEY,
providerName: "openai",
model: "gpt-4o-mini",
configId: null,
configName: null,
encryptedApiKey: "sk-user",
endpoint: "https://api.example.com",
isDefault: false,
isEnabled: true
});
prismaService.seedBinding({
id: "binding_astrbot_selected",
userId: "user_1",
channel: AiChannel.ASTRBOT,
providerName: "",
model: null,
configId: "default",
configName: null,
encryptedApiKey: "abk_astrbot",
endpoint: "http://127.0.0.1:6185",
isDefault: false,
isEnabled: true
});
const response = await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.send({
message: "只使用自备渠道",
channel: AiChannel.USER_KEY
})
.expect(502);
expect(response.body.attempts).toEqual([
{
channel: AiChannel.USER_KEY,
providerName: "openai",
model: "gpt-4o-mini",
status: "failed",
reasonCode: "UPSTREAM_UNREACHABLE",
reasonMessage: "用户自备 Key 渠道暂时不可用"
}
]);
});
it("should inject unfinished task summary into ai prompt", async () => { it("should inject unfinished task summary into ai prompt", async () => {
prismaService.seedBinding({ prismaService.seedBinding({
id: "binding_astrbot_context", id: "binding_astrbot_context",