feat(api-ai): scope private bindings by user channel
This commit is contained in:
@@ -1,10 +1,4 @@
|
|||||||
import {
|
import { BadGatewayException, BadRequestException, Injectable, Logger } from "@nestjs/common";
|
||||||
BadGatewayException,
|
|
||||||
BadRequestException,
|
|
||||||
Injectable,
|
|
||||||
Logger,
|
|
||||||
NotFoundException
|
|
||||||
} from "@nestjs/common";
|
|
||||||
import {
|
import {
|
||||||
AiChannel,
|
AiChannel,
|
||||||
AiUsageLog,
|
AiUsageLog,
|
||||||
@@ -34,7 +28,6 @@ type AiBindingSummary = {
|
|||||||
configId: string | null;
|
configId: string | null;
|
||||||
configName: string | null;
|
configName: string | null;
|
||||||
endpoint: string | null;
|
endpoint: string | null;
|
||||||
isDefault: boolean;
|
|
||||||
isEnabled: boolean;
|
isEnabled: boolean;
|
||||||
hasApiKey: boolean;
|
hasApiKey: boolean;
|
||||||
maskedApiKey: string | null;
|
maskedApiKey: string | null;
|
||||||
@@ -110,7 +103,7 @@ export class AiService {
|
|||||||
where: {
|
where: {
|
||||||
userId
|
userId
|
||||||
},
|
},
|
||||||
orderBy: [{ channel: "asc" }, { isDefault: "desc" }, { updatedAt: "desc" }]
|
orderBy: [{ updatedAt: "desc" }]
|
||||||
}),
|
}),
|
||||||
this.prismaService.aiPublicPoolConfig.findFirst({
|
this.prismaService.aiPublicPoolConfig.findFirst({
|
||||||
orderBy: {
|
orderBy: {
|
||||||
@@ -119,9 +112,11 @@ export class AiService {
|
|||||||
})
|
})
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
const latestBindings = this.pickLatestBindingsByChannel(bindings);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
routeOrder: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL],
|
routeOrder: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL],
|
||||||
bindings: bindings.map((binding) => this.serializeBinding(binding)),
|
bindings: latestBindings.map((binding) => this.serializeBinding(binding)),
|
||||||
publicPool: publicPool
|
publicPool: publicPool
|
||||||
? {
|
? {
|
||||||
enabled: publicPool.enabled,
|
enabled: publicPool.enabled,
|
||||||
@@ -183,27 +178,17 @@ export class AiService {
|
|||||||
this.validateBindingInput(dto);
|
this.validateBindingInput(dto);
|
||||||
|
|
||||||
const result = await this.prismaService.$transaction(async (tx) => {
|
const result = await this.prismaService.$transaction(async (tx) => {
|
||||||
if (dto.isDefault) {
|
const existingBinding = await tx.aiProviderBinding.findFirst({
|
||||||
const where: Prisma.AiProviderBindingWhereInput = {
|
where: {
|
||||||
userId,
|
userId,
|
||||||
channel: dto.channel
|
channel: dto.channel
|
||||||
};
|
},
|
||||||
|
orderBy: {
|
||||||
if (dto.id) {
|
updatedAt: "desc"
|
||||||
where.id = {
|
|
||||||
not: dto.id
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
await tx.aiProviderBinding.updateMany({
|
if (!existingBinding) {
|
||||||
where,
|
|
||||||
data: {
|
|
||||||
isDefault: false
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!dto.id) {
|
|
||||||
return tx.aiProviderBinding.create({
|
return tx.aiProviderBinding.create({
|
||||||
data: {
|
data: {
|
||||||
userId,
|
userId,
|
||||||
@@ -214,30 +199,17 @@ export class AiService {
|
|||||||
configName: this.normalizeOptionalString(dto.configName),
|
configName: this.normalizeOptionalString(dto.configName),
|
||||||
endpoint: this.normalizeOptionalString(dto.endpoint),
|
endpoint: this.normalizeOptionalString(dto.endpoint),
|
||||||
encryptedApiKey: this.normalizeOptionalString(dto.apiKey),
|
encryptedApiKey: this.normalizeOptionalString(dto.apiKey),
|
||||||
isDefault: dto.isDefault ?? false,
|
|
||||||
isEnabled: dto.isEnabled ?? true
|
isEnabled: dto.isEnabled ?? true
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const existingBinding = await tx.aiProviderBinding.findFirst({
|
|
||||||
where: {
|
|
||||||
id: dto.id,
|
|
||||||
userId
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!existingBinding) {
|
|
||||||
throw new NotFoundException("AI 通道配置不存在");
|
|
||||||
}
|
|
||||||
|
|
||||||
const updateData: Prisma.AiProviderBindingUpdateInput = {
|
const updateData: Prisma.AiProviderBindingUpdateInput = {
|
||||||
channel: dto.channel,
|
channel: dto.channel,
|
||||||
providerName: this.normalizeProviderName(dto.providerName),
|
providerName: this.normalizeProviderName(dto.providerName),
|
||||||
model: this.normalizeOptionalString(dto.model),
|
model: this.normalizeOptionalString(dto.model),
|
||||||
configId: this.normalizeOptionalString(dto.configId),
|
configId: this.normalizeOptionalString(dto.configId),
|
||||||
configName: this.normalizeOptionalString(dto.configName),
|
configName: this.normalizeOptionalString(dto.configName),
|
||||||
isDefault: dto.isDefault ?? existingBinding.isDefault,
|
|
||||||
isEnabled: dto.isEnabled ?? existingBinding.isEnabled
|
isEnabled: dto.isEnabled ?? existingBinding.isEnabled
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -251,7 +223,7 @@ export class AiService {
|
|||||||
|
|
||||||
return tx.aiProviderBinding.update({
|
return tx.aiProviderBinding.update({
|
||||||
where: {
|
where: {
|
||||||
id: dto.id
|
id: existingBinding.id
|
||||||
},
|
},
|
||||||
data: updateData
|
data: updateData
|
||||||
});
|
});
|
||||||
@@ -262,7 +234,7 @@ export class AiService {
|
|||||||
|
|
||||||
async chat(userId: string, dto: AiChatDto): Promise<AiChatResponse> {
|
async chat(userId: string, dto: AiChatDto): Promise<AiChatResponse> {
|
||||||
const attempts: AiRouteAttempt[] = [];
|
const attempts: AiRouteAttempt[] = [];
|
||||||
const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null);
|
const plan = await this.buildRoutePlan(userId, dto.channel ?? null);
|
||||||
const promptMessage = await this.buildPromptMessage(userId, dto.message);
|
const promptMessage = await this.buildPromptMessage(userId, dto.message);
|
||||||
|
|
||||||
for (const entry of plan) {
|
for (const entry of plan) {
|
||||||
@@ -337,33 +309,34 @@ export class AiService {
|
|||||||
|
|
||||||
private async buildRoutePlan(
|
private async buildRoutePlan(
|
||||||
userId: string,
|
userId: string,
|
||||||
bindingId: string | null
|
selectedChannel: AiChannel | null
|
||||||
): Promise<AiRoutePlanEntry[]> {
|
): Promise<AiRoutePlanEntry[]> {
|
||||||
const plan: AiRoutePlanEntry[] = [];
|
const plan: AiRoutePlanEntry[] = [];
|
||||||
const consumedChannels = new Set<AiChannel>();
|
const targetChannels = selectedChannel
|
||||||
|
? [selectedChannel]
|
||||||
|
: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL];
|
||||||
|
|
||||||
if (bindingId) {
|
for (const channel of targetChannels) {
|
||||||
const pinnedBinding = await this.prismaService.aiProviderBinding.findFirst({
|
if (channel === AiChannel.PUBLIC_POOL) {
|
||||||
where: {
|
const publicPool = await this.findEnabledPublicPool();
|
||||||
id: bindingId,
|
if (publicPool) {
|
||||||
userId,
|
plan.push({
|
||||||
isEnabled: true
|
kind: "candidate",
|
||||||
|
candidate: this.toPublicPoolCandidate(publicPool)
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
plan.push({
|
||||||
|
kind: "skip",
|
||||||
|
attempt: {
|
||||||
|
channel: AiChannel.PUBLIC_POOL,
|
||||||
|
providerName: null,
|
||||||
|
model: null,
|
||||||
|
status: "skipped",
|
||||||
|
reasonCode: "PUBLIC_POOL_DISABLED",
|
||||||
|
reasonMessage: "公共 AI 通道未开启"
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
});
|
|
||||||
|
|
||||||
if (!pinnedBinding) {
|
|
||||||
throw new NotFoundException("指定的 AI 通道配置不存在或已禁用");
|
|
||||||
}
|
|
||||||
|
|
||||||
plan.push({
|
|
||||||
kind: "candidate",
|
|
||||||
candidate: this.toBindingCandidate(pinnedBinding)
|
|
||||||
});
|
|
||||||
consumedChannels.add(pinnedBinding.channel);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const channel of [AiChannel.USER_KEY, AiChannel.ASTRBOT]) {
|
|
||||||
if (consumedChannels.has(channel)) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -392,26 +365,6 @@ export class AiService {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const publicPool = await this.findEnabledPublicPool();
|
|
||||||
if (publicPool) {
|
|
||||||
plan.push({
|
|
||||||
kind: "candidate",
|
|
||||||
candidate: this.toPublicPoolCandidate(publicPool)
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
plan.push({
|
|
||||||
kind: "skip",
|
|
||||||
attempt: {
|
|
||||||
channel: AiChannel.PUBLIC_POOL,
|
|
||||||
providerName: null,
|
|
||||||
model: null,
|
|
||||||
status: "skipped",
|
|
||||||
reasonCode: "PUBLIC_POOL_DISABLED",
|
|
||||||
reasonMessage: "公共 AI 通道未开启"
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
return plan;
|
return plan;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -425,7 +378,9 @@ export class AiService {
|
|||||||
channel,
|
channel,
|
||||||
isEnabled: true
|
isEnabled: true
|
||||||
},
|
},
|
||||||
orderBy: [{ isDefault: "desc" }, { updatedAt: "desc" }]
|
orderBy: {
|
||||||
|
updatedAt: "desc"
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -477,7 +432,6 @@ export class AiService {
|
|||||||
configId: binding.configId,
|
configId: binding.configId,
|
||||||
configName: binding.configName,
|
configName: binding.configName,
|
||||||
endpoint: binding.endpoint,
|
endpoint: binding.endpoint,
|
||||||
isDefault: binding.isDefault,
|
|
||||||
isEnabled: binding.isEnabled,
|
isEnabled: binding.isEnabled,
|
||||||
hasApiKey: Boolean(binding.encryptedApiKey),
|
hasApiKey: Boolean(binding.encryptedApiKey),
|
||||||
maskedApiKey: this.maskSecret(binding.encryptedApiKey),
|
maskedApiKey: this.maskSecret(binding.encryptedApiKey),
|
||||||
@@ -485,6 +439,20 @@ export class AiService {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private pickLatestBindingsByChannel(bindings: AiProviderBinding[]): AiProviderBinding[] {
|
||||||
|
const bindingMap = new Map<AiChannel, AiProviderBinding>();
|
||||||
|
|
||||||
|
for (const binding of bindings) {
|
||||||
|
if (!bindingMap.has(binding.channel)) {
|
||||||
|
bindingMap.set(binding.channel, binding);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return [AiChannel.USER_KEY, AiChannel.ASTRBOT]
|
||||||
|
.map((channel) => bindingMap.get(channel) ?? null)
|
||||||
|
.filter((binding): binding is AiProviderBinding => binding !== null);
|
||||||
|
}
|
||||||
|
|
||||||
private serializeUsageLog(log: AiUsageLog): AiUsageLogSummary {
|
private serializeUsageLog(log: AiUsageLog): AiUsageLogSummary {
|
||||||
return {
|
return {
|
||||||
id: log.id,
|
id: log.id,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import { IsOptional, IsString, MinLength } from "class-validator";
|
import { IsEnum, IsOptional, IsString, MinLength } from "class-validator";
|
||||||
|
import { AiChannel } from "../../../generated/prisma/client";
|
||||||
|
|
||||||
export class AiChatDto {
|
export class AiChatDto {
|
||||||
@IsString()
|
@IsString()
|
||||||
@@ -11,7 +12,6 @@ export class AiChatDto {
|
|||||||
sessionId?: string;
|
sessionId?: string;
|
||||||
|
|
||||||
@IsOptional()
|
@IsOptional()
|
||||||
@IsString()
|
@IsEnum(AiChannel)
|
||||||
@MinLength(1)
|
channel?: AiChannel;
|
||||||
bindingId?: string;
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,7 @@
|
|||||||
import { AiChannel } from "../../../generated/prisma/client";
|
import { AiChannel } from "../../../generated/prisma/client";
|
||||||
import { IsBoolean, IsEnum, IsOptional, IsString, IsUrl, MinLength } from "class-validator";
|
import { IsBoolean, IsEnum, IsOptional, IsString, IsUrl, MinLength } from "class-validator";
|
||||||
|
|
||||||
export class UpsertAiProviderBindingDto {
|
export class UpsertAiProviderBindingDto {
|
||||||
@IsOptional()
|
|
||||||
@IsString()
|
|
||||||
@MinLength(1)
|
|
||||||
id?: string;
|
|
||||||
|
|
||||||
@IsEnum(AiChannel)
|
@IsEnum(AiChannel)
|
||||||
channel!: AiChannel;
|
channel!: AiChannel;
|
||||||
|
|
||||||
@@ -36,7 +31,7 @@ export class UpsertAiProviderBindingDto {
|
|||||||
require_tld: false
|
require_tld: false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
message: "endpoint 必须是合法的 URL"
|
message: "endpoint \u5fc5\u987b\u662f\u5408\u6cd5\u7684 URL"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
endpoint?: string;
|
endpoint?: string;
|
||||||
@@ -46,10 +41,6 @@ export class UpsertAiProviderBindingDto {
|
|||||||
@MinLength(1)
|
@MinLength(1)
|
||||||
apiKey?: string;
|
apiKey?: string;
|
||||||
|
|
||||||
@IsOptional()
|
|
||||||
@IsBoolean()
|
|
||||||
isDefault?: boolean;
|
|
||||||
|
|
||||||
@IsOptional()
|
@IsOptional()
|
||||||
@IsBoolean()
|
@IsBoolean()
|
||||||
isEnabled?: boolean;
|
isEnabled?: boolean;
|
||||||
|
|||||||
@@ -441,7 +441,6 @@ describe("AiController (integration)", () => {
|
|||||||
configId: "default",
|
configId: "default",
|
||||||
endpoint: "http://127.0.0.1:6185",
|
endpoint: "http://127.0.0.1:6185",
|
||||||
apiKey: "abk_secret_1234",
|
apiKey: "abk_secret_1234",
|
||||||
isDefault: true,
|
|
||||||
isEnabled: true
|
isEnabled: true
|
||||||
})
|
})
|
||||||
.expect(201);
|
.expect(201);
|
||||||
@@ -465,10 +464,54 @@ describe("AiController (integration)", () => {
|
|||||||
configName: null,
|
configName: null,
|
||||||
hasApiKey: true,
|
hasApiKey: true,
|
||||||
maskedApiKey: "abk_***34",
|
maskedApiKey: "abk_***34",
|
||||||
isDefault: true
|
isEnabled: true
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should upsert one binding per user channel", async () => {
|
||||||
|
await request(app.getHttpServer())
|
||||||
|
.post("/ai/bindings")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.send({
|
||||||
|
channel: AiChannel.USER_KEY,
|
||||||
|
providerName: "openai",
|
||||||
|
model: "gpt-4o-mini",
|
||||||
|
endpoint: "https://api.example.com",
|
||||||
|
apiKey: "sk-first",
|
||||||
|
isEnabled: true
|
||||||
|
})
|
||||||
|
.expect(201);
|
||||||
|
|
||||||
|
await request(app.getHttpServer())
|
||||||
|
.post("/ai/bindings")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.send({
|
||||||
|
channel: AiChannel.USER_KEY,
|
||||||
|
providerName: "google",
|
||||||
|
model: "gemini-2.5-flash",
|
||||||
|
endpoint: "https://generativelanguage.googleapis.com",
|
||||||
|
apiKey: "sk-second",
|
||||||
|
isEnabled: false
|
||||||
|
})
|
||||||
|
.expect(201);
|
||||||
|
|
||||||
|
const response = await request(app.getHttpServer())
|
||||||
|
.get("/ai/bindings")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.expect(200);
|
||||||
|
|
||||||
|
expect(response.body.bindings).toEqual([
|
||||||
|
expect.objectContaining({
|
||||||
|
channel: AiChannel.USER_KEY,
|
||||||
|
providerName: "google",
|
||||||
|
model: "gemini-2.5-flash",
|
||||||
|
endpoint: "https://generativelanguage.googleapis.com",
|
||||||
|
isEnabled: false,
|
||||||
|
maskedApiKey: "sk-s***nd"
|
||||||
|
})
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
it("should fallback from user key to astrbot", async () => {
|
it("should fallback from user key to astrbot", async () => {
|
||||||
prismaService.seedBinding({
|
prismaService.seedBinding({
|
||||||
id: "binding_user_key",
|
id: "binding_user_key",
|
||||||
@@ -566,7 +609,6 @@ describe("AiController (integration)", () => {
|
|||||||
configId: "default",
|
configId: "default",
|
||||||
endpoint: "http://127.0.0.1:6185",
|
endpoint: "http://127.0.0.1:6185",
|
||||||
apiKey: "abk_secret_1234",
|
apiKey: "abk_secret_1234",
|
||||||
isDefault: true,
|
|
||||||
isEnabled: true
|
isEnabled: true
|
||||||
})
|
})
|
||||||
.expect(201);
|
.expect(201);
|
||||||
@@ -575,10 +617,60 @@ describe("AiController (integration)", () => {
|
|||||||
channel: AiChannel.ASTRBOT,
|
channel: AiChannel.ASTRBOT,
|
||||||
providerName: "",
|
providerName: "",
|
||||||
configId: "default",
|
configId: "default",
|
||||||
configName: null
|
configName: null,
|
||||||
|
isEnabled: true
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should use selected channel without automatic fallback", async () => {
|
||||||
|
prismaService.seedBinding({
|
||||||
|
id: "binding_user_key_selected",
|
||||||
|
userId: "user_1",
|
||||||
|
channel: AiChannel.USER_KEY,
|
||||||
|
providerName: "openai",
|
||||||
|
model: "gpt-4o-mini",
|
||||||
|
configId: null,
|
||||||
|
configName: null,
|
||||||
|
encryptedApiKey: "sk-user",
|
||||||
|
endpoint: "https://api.example.com",
|
||||||
|
isDefault: false,
|
||||||
|
isEnabled: true
|
||||||
|
});
|
||||||
|
prismaService.seedBinding({
|
||||||
|
id: "binding_astrbot_selected",
|
||||||
|
userId: "user_1",
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
providerName: "",
|
||||||
|
model: null,
|
||||||
|
configId: "default",
|
||||||
|
configName: null,
|
||||||
|
encryptedApiKey: "abk_astrbot",
|
||||||
|
endpoint: "http://127.0.0.1:6185",
|
||||||
|
isDefault: false,
|
||||||
|
isEnabled: true
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await request(app.getHttpServer())
|
||||||
|
.post("/ai/chat")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.send({
|
||||||
|
message: "只使用自备渠道",
|
||||||
|
channel: AiChannel.USER_KEY
|
||||||
|
})
|
||||||
|
.expect(502);
|
||||||
|
|
||||||
|
expect(response.body.attempts).toEqual([
|
||||||
|
{
|
||||||
|
channel: AiChannel.USER_KEY,
|
||||||
|
providerName: "openai",
|
||||||
|
model: "gpt-4o-mini",
|
||||||
|
status: "failed",
|
||||||
|
reasonCode: "UPSTREAM_UNREACHABLE",
|
||||||
|
reasonMessage: "用户自备 Key 渠道暂时不可用"
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
it("should inject unfinished task summary into ai prompt", async () => {
|
it("should inject unfinished task summary into ai prompt", async () => {
|
||||||
prismaService.seedBinding({
|
prismaService.seedBinding({
|
||||||
id: "binding_astrbot_context",
|
id: "binding_astrbot_context",
|
||||||
|
|||||||
Reference in New Issue
Block a user