feat(api-ai): support astrbot config selection
This commit is contained in:
@@ -273,6 +273,8 @@ model AiProviderBinding {
|
|||||||
channel AiChannel
|
channel AiChannel
|
||||||
providerName String
|
providerName String
|
||||||
model String?
|
model String?
|
||||||
|
configId String?
|
||||||
|
configName String?
|
||||||
encryptedApiKey String?
|
encryptedApiKey String?
|
||||||
endpoint String?
|
endpoint String?
|
||||||
isDefault Boolean @default(false)
|
isDefault Boolean @default(false)
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ type AiBindingSummary = {
|
|||||||
channel: AiChannel;
|
channel: AiChannel;
|
||||||
providerName: string;
|
providerName: string;
|
||||||
model: string | null;
|
model: string | null;
|
||||||
|
configId: string | null;
|
||||||
|
configName: string | null;
|
||||||
endpoint: string | null;
|
endpoint: string | null;
|
||||||
isDefault: boolean;
|
isDefault: boolean;
|
||||||
isEnabled: boolean;
|
isEnabled: boolean;
|
||||||
@@ -105,6 +107,8 @@ export class AiService {
|
|||||||
throw new BadRequestException("公共 AI 通道只能由管理员配置");
|
throw new BadRequestException("公共 AI 通道只能由管理员配置");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.validateBindingInput(dto);
|
||||||
|
|
||||||
const result = await this.prismaService.$transaction(async (tx) => {
|
const result = await this.prismaService.$transaction(async (tx) => {
|
||||||
if (dto.isDefault) {
|
if (dto.isDefault) {
|
||||||
const where: Prisma.AiProviderBindingWhereInput = {
|
const where: Prisma.AiProviderBindingWhereInput = {
|
||||||
@@ -131,8 +135,10 @@ export class AiService {
|
|||||||
data: {
|
data: {
|
||||||
userId,
|
userId,
|
||||||
channel: dto.channel,
|
channel: dto.channel,
|
||||||
providerName: dto.providerName.trim(),
|
providerName: this.normalizeProviderName(dto.providerName),
|
||||||
model: this.normalizeOptionalString(dto.model),
|
model: this.normalizeOptionalString(dto.model),
|
||||||
|
configId: this.normalizeOptionalString(dto.configId),
|
||||||
|
configName: this.normalizeOptionalString(dto.configName),
|
||||||
endpoint: this.normalizeOptionalString(dto.endpoint),
|
endpoint: this.normalizeOptionalString(dto.endpoint),
|
||||||
encryptedApiKey: this.normalizeOptionalString(dto.apiKey),
|
encryptedApiKey: this.normalizeOptionalString(dto.apiKey),
|
||||||
isDefault: dto.isDefault ?? false,
|
isDefault: dto.isDefault ?? false,
|
||||||
@@ -154,8 +160,10 @@ export class AiService {
|
|||||||
|
|
||||||
const updateData: Prisma.AiProviderBindingUpdateInput = {
|
const updateData: Prisma.AiProviderBindingUpdateInput = {
|
||||||
channel: dto.channel,
|
channel: dto.channel,
|
||||||
providerName: dto.providerName.trim(),
|
providerName: this.normalizeProviderName(dto.providerName),
|
||||||
model: this.normalizeOptionalString(dto.model),
|
model: this.normalizeOptionalString(dto.model),
|
||||||
|
configId: this.normalizeOptionalString(dto.configId),
|
||||||
|
configName: this.normalizeOptionalString(dto.configName),
|
||||||
isDefault: dto.isDefault ?? existingBinding.isDefault,
|
isDefault: dto.isDefault ?? existingBinding.isDefault,
|
||||||
isEnabled: dto.isEnabled ?? existingBinding.isEnabled
|
isEnabled: dto.isEnabled ?? existingBinding.isEnabled
|
||||||
};
|
};
|
||||||
@@ -342,6 +350,8 @@ export class AiService {
|
|||||||
sourceId: binding.id,
|
sourceId: binding.id,
|
||||||
providerName: binding.providerName,
|
providerName: binding.providerName,
|
||||||
model: binding.model,
|
model: binding.model,
|
||||||
|
configId: binding.configId,
|
||||||
|
configName: binding.configName,
|
||||||
endpoint: binding.endpoint,
|
endpoint: binding.endpoint,
|
||||||
apiKey: binding.encryptedApiKey
|
apiKey: binding.encryptedApiKey
|
||||||
};
|
};
|
||||||
@@ -354,6 +364,8 @@ export class AiService {
|
|||||||
sourceId: publicPool.id,
|
sourceId: publicPool.id,
|
||||||
providerName: publicPool.providerName ?? "public-pool",
|
providerName: publicPool.providerName ?? "public-pool",
|
||||||
model: publicPool.model,
|
model: publicPool.model,
|
||||||
|
configId: null,
|
||||||
|
configName: null,
|
||||||
endpoint: publicPool.endpoint,
|
endpoint: publicPool.endpoint,
|
||||||
apiKey: publicPool.encryptedApiKey
|
apiKey: publicPool.encryptedApiKey
|
||||||
};
|
};
|
||||||
@@ -365,6 +377,8 @@ export class AiService {
|
|||||||
channel: binding.channel,
|
channel: binding.channel,
|
||||||
providerName: binding.providerName,
|
providerName: binding.providerName,
|
||||||
model: binding.model,
|
model: binding.model,
|
||||||
|
configId: binding.configId,
|
||||||
|
configName: binding.configName,
|
||||||
endpoint: binding.endpoint,
|
endpoint: binding.endpoint,
|
||||||
isDefault: binding.isDefault,
|
isDefault: binding.isDefault,
|
||||||
isEnabled: binding.isEnabled,
|
isEnabled: binding.isEnabled,
|
||||||
@@ -416,6 +430,29 @@ export class AiService {
|
|||||||
return normalizedValue.length > 0 ? normalizedValue : null;
|
return normalizedValue.length > 0 ? normalizedValue : null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private normalizeProviderName(value: string | undefined): string {
|
||||||
|
return this.normalizeOptionalString(value) ?? "";
|
||||||
|
}
|
||||||
|
|
||||||
|
private validateBindingInput(dto: UpsertAiProviderBindingDto): void {
|
||||||
|
const providerName = this.normalizeOptionalString(dto.providerName);
|
||||||
|
const configId = this.normalizeOptionalString(dto.configId);
|
||||||
|
const configName = this.normalizeOptionalString(dto.configName);
|
||||||
|
|
||||||
|
if (dto.channel === AiChannel.ASTRBOT) {
|
||||||
|
if (!providerName && !configId && !configName) {
|
||||||
|
throw new BadRequestException(
|
||||||
|
"AstrBot 通道至少需要 providerName、configId、configName 三者之一"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!providerName) {
|
||||||
|
throw new BadRequestException("当前通道必须提供 providerName");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private maskSecret(secret: string | null): string | null {
|
private maskSecret(secret: string | null): string | null {
|
||||||
if (!secret) {
|
if (!secret) {
|
||||||
return null;
|
return null;
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ export type AiResolvedRouteCandidate = {
|
|||||||
sourceId: string | null;
|
sourceId: string | null;
|
||||||
providerName: string;
|
providerName: string;
|
||||||
model: string | null;
|
model: string | null;
|
||||||
|
configId: string | null;
|
||||||
|
configName: string | null;
|
||||||
endpoint: string | null;
|
endpoint: string | null;
|
||||||
apiKey: string | null;
|
apiKey: string | null;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -10,15 +10,26 @@ export class UpsertAiProviderBindingDto {
|
|||||||
@IsEnum(AiChannel)
|
@IsEnum(AiChannel)
|
||||||
channel!: AiChannel;
|
channel!: AiChannel;
|
||||||
|
|
||||||
|
@IsOptional()
|
||||||
@IsString()
|
@IsString()
|
||||||
@MinLength(1)
|
@MinLength(1)
|
||||||
providerName!: string;
|
providerName?: string;
|
||||||
|
|
||||||
@IsOptional()
|
@IsOptional()
|
||||||
@IsString()
|
@IsString()
|
||||||
@MinLength(1)
|
@MinLength(1)
|
||||||
model?: string;
|
model?: string;
|
||||||
|
|
||||||
|
@IsOptional()
|
||||||
|
@IsString()
|
||||||
|
@MinLength(1)
|
||||||
|
configId?: string;
|
||||||
|
|
||||||
|
@IsOptional()
|
||||||
|
@IsString()
|
||||||
|
@MinLength(1)
|
||||||
|
configName?: string;
|
||||||
|
|
||||||
@IsOptional()
|
@IsOptional()
|
||||||
@IsUrl(
|
@IsUrl(
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -10,10 +10,13 @@ import {
|
|||||||
@Injectable()
|
@Injectable()
|
||||||
export class AstrbotProvider implements AiChannelExecutor {
|
export class AstrbotProvider implements AiChannelExecutor {
|
||||||
async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise<AiChatResult> {
|
async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise<AiChatResult> {
|
||||||
|
const routeLabel =
|
||||||
|
candidate.providerName || candidate.configName || candidate.configId || "astrbot";
|
||||||
|
|
||||||
if (!candidate.endpoint) {
|
if (!candidate.endpoint) {
|
||||||
throw new AiRouteFailureError(
|
throw new AiRouteFailureError(
|
||||||
candidate.channel,
|
candidate.channel,
|
||||||
candidate.providerName,
|
routeLabel,
|
||||||
"MISSING_ENDPOINT",
|
"MISSING_ENDPOINT",
|
||||||
"缺少 AstrBot 服务地址配置"
|
"缺少 AstrBot 服务地址配置"
|
||||||
);
|
);
|
||||||
@@ -22,7 +25,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
|||||||
if (!candidate.apiKey) {
|
if (!candidate.apiKey) {
|
||||||
throw new AiRouteFailureError(
|
throw new AiRouteFailureError(
|
||||||
candidate.channel,
|
candidate.channel,
|
||||||
candidate.providerName,
|
routeLabel,
|
||||||
"MISSING_API_KEY",
|
"MISSING_API_KEY",
|
||||||
"缺少 AstrBot API Key 配置"
|
"缺少 AstrBot API Key 配置"
|
||||||
);
|
);
|
||||||
@@ -43,6 +46,8 @@ export class AstrbotProvider implements AiChannelExecutor {
|
|||||||
session_id: input.sessionId ?? undefined,
|
session_id: input.sessionId ?? undefined,
|
||||||
message: input.message,
|
message: input.message,
|
||||||
enable_streaming: false,
|
enable_streaming: false,
|
||||||
|
config_id: candidate.configId ?? undefined,
|
||||||
|
config_name: candidate.configName ?? undefined,
|
||||||
selected_provider: candidate.providerName || undefined,
|
selected_provider: candidate.providerName || undefined,
|
||||||
selected_model: candidate.model ?? undefined
|
selected_model: candidate.model ?? undefined
|
||||||
}),
|
}),
|
||||||
@@ -51,7 +56,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
throw new AiRouteFailureError(
|
throw new AiRouteFailureError(
|
||||||
candidate.channel,
|
candidate.channel,
|
||||||
candidate.providerName,
|
routeLabel,
|
||||||
"UPSTREAM_UNREACHABLE",
|
"UPSTREAM_UNREACHABLE",
|
||||||
this.toErrorMessage(error, "AstrBot 服务请求失败")
|
this.toErrorMessage(error, "AstrBot 服务请求失败")
|
||||||
);
|
);
|
||||||
@@ -61,7 +66,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
|||||||
const rawText = await response.text();
|
const rawText = await response.text();
|
||||||
throw new AiRouteFailureError(
|
throw new AiRouteFailureError(
|
||||||
candidate.channel,
|
candidate.channel,
|
||||||
candidate.providerName,
|
routeLabel,
|
||||||
`UPSTREAM_HTTP_${response.status}`,
|
`UPSTREAM_HTTP_${response.status}`,
|
||||||
this.extractHttpErrorMessage(rawText, response.status)
|
this.extractHttpErrorMessage(rawText, response.status)
|
||||||
);
|
);
|
||||||
@@ -81,7 +86,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
|||||||
if (type === "error") {
|
if (type === "error") {
|
||||||
throw new AiRouteFailureError(
|
throw new AiRouteFailureError(
|
||||||
candidate.channel,
|
candidate.channel,
|
||||||
candidate.providerName,
|
routeLabel,
|
||||||
this.readString(event["code"]) ?? "ASTRBOT_ERROR",
|
this.readString(event["code"]) ?? "ASTRBOT_ERROR",
|
||||||
this.readString(event["data"]) ?? "AstrBot 返回错误"
|
this.readString(event["data"]) ?? "AstrBot 返回错误"
|
||||||
);
|
);
|
||||||
@@ -116,7 +121,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
|||||||
if (!content.trim()) {
|
if (!content.trim()) {
|
||||||
throw new AiRouteFailureError(
|
throw new AiRouteFailureError(
|
||||||
candidate.channel,
|
candidate.channel,
|
||||||
candidate.providerName,
|
routeLabel,
|
||||||
"EMPTY_RESPONSE",
|
"EMPTY_RESPONSE",
|
||||||
"AstrBot 没有返回有效内容"
|
"AstrBot 没有返回有效内容"
|
||||||
);
|
);
|
||||||
@@ -124,7 +129,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
channel: candidate.channel,
|
channel: candidate.channel,
|
||||||
providerName: candidate.providerName,
|
providerName: routeLabel,
|
||||||
model: candidate.model,
|
model: candidate.model,
|
||||||
content,
|
content,
|
||||||
sessionId,
|
sessionId,
|
||||||
|
|||||||
@@ -5,7 +5,11 @@ import { AiChannel, AiProviderBinding, AiPublicPoolConfig } from "../generated/p
|
|||||||
import { AiController } from "../src/ai/ai.controller";
|
import { AiController } from "../src/ai/ai.controller";
|
||||||
import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service";
|
import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service";
|
||||||
import { AiService } from "../src/ai/ai.service";
|
import { AiService } from "../src/ai/ai.service";
|
||||||
import { AiChannelExecutor, AiRouteFailureError } from "../src/ai/ai.types";
|
import {
|
||||||
|
AiChannelExecutor,
|
||||||
|
AiResolvedRouteCandidate,
|
||||||
|
AiRouteFailureError
|
||||||
|
} from "../src/ai/ai.types";
|
||||||
import { PrismaService } from "../src/prisma/prisma.service";
|
import { PrismaService } from "../src/prisma/prisma.service";
|
||||||
|
|
||||||
class InMemoryAiPrismaService {
|
class InMemoryAiPrismaService {
|
||||||
@@ -65,6 +69,8 @@ class InMemoryAiPrismaService {
|
|||||||
channel: AiChannel;
|
channel: AiChannel;
|
||||||
providerName: string;
|
providerName: string;
|
||||||
model: string | null;
|
model: string | null;
|
||||||
|
configId: string | null;
|
||||||
|
configName: string | null;
|
||||||
endpoint: string | null;
|
endpoint: string | null;
|
||||||
encryptedApiKey: string | null;
|
encryptedApiKey: string | null;
|
||||||
isDefault: boolean;
|
isDefault: boolean;
|
||||||
@@ -78,6 +84,8 @@ class InMemoryAiPrismaService {
|
|||||||
channel: args.data.channel,
|
channel: args.data.channel,
|
||||||
providerName: args.data.providerName,
|
providerName: args.data.providerName,
|
||||||
model: args.data.model,
|
model: args.data.model,
|
||||||
|
configId: args.data.configId,
|
||||||
|
configName: args.data.configName,
|
||||||
encryptedApiKey: args.data.encryptedApiKey,
|
encryptedApiKey: args.data.encryptedApiKey,
|
||||||
endpoint: args.data.endpoint,
|
endpoint: args.data.endpoint,
|
||||||
isDefault: args.data.isDefault,
|
isDefault: args.data.isDefault,
|
||||||
@@ -189,12 +197,12 @@ class StaticExecutor implements AiChannelExecutor {
|
|||||||
}
|
}
|
||||||
) {}
|
) {}
|
||||||
|
|
||||||
async execute(candidate: { channel: AiChannel; providerName: string; model: string | null }) {
|
async execute(candidate: AiResolvedRouteCandidate) {
|
||||||
const result = this.resolver(candidate.channel);
|
const result = this.resolver(candidate.channel);
|
||||||
if (result.code) {
|
if (result.code) {
|
||||||
throw new AiRouteFailureError(
|
throw new AiRouteFailureError(
|
||||||
candidate.channel,
|
candidate.channel,
|
||||||
candidate.providerName,
|
candidate.providerName || candidate.configName || candidate.configId || "unknown",
|
||||||
result.code,
|
result.code,
|
||||||
result.message ?? "执行失败"
|
result.message ?? "执行失败"
|
||||||
);
|
);
|
||||||
@@ -202,7 +210,7 @@ class StaticExecutor implements AiChannelExecutor {
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
channel: candidate.channel,
|
channel: candidate.channel,
|
||||||
providerName: candidate.providerName,
|
providerName: candidate.providerName || candidate.configName || candidate.configId || "",
|
||||||
model: candidate.model,
|
model: candidate.model,
|
||||||
content: result.content ?? "",
|
content: result.content ?? "",
|
||||||
sessionId: "session_ai",
|
sessionId: "session_ai",
|
||||||
@@ -273,6 +281,7 @@ describe("AiController (integration)", () => {
|
|||||||
channel: AiChannel.ASTRBOT,
|
channel: AiChannel.ASTRBOT,
|
||||||
providerName: "astrbot-main",
|
providerName: "astrbot-main",
|
||||||
model: "deepseek-chat",
|
model: "deepseek-chat",
|
||||||
|
configId: "default",
|
||||||
endpoint: "http://127.0.0.1:6185",
|
endpoint: "http://127.0.0.1:6185",
|
||||||
apiKey: "abk_secret_1234",
|
apiKey: "abk_secret_1234",
|
||||||
isDefault: true,
|
isDefault: true,
|
||||||
@@ -295,6 +304,8 @@ describe("AiController (integration)", () => {
|
|||||||
channel: AiChannel.ASTRBOT,
|
channel: AiChannel.ASTRBOT,
|
||||||
providerName: "astrbot-main",
|
providerName: "astrbot-main",
|
||||||
model: "deepseek-chat",
|
model: "deepseek-chat",
|
||||||
|
configId: "default",
|
||||||
|
configName: null,
|
||||||
hasApiKey: true,
|
hasApiKey: true,
|
||||||
maskedApiKey: "abk_***34",
|
maskedApiKey: "abk_***34",
|
||||||
isDefault: true
|
isDefault: true
|
||||||
@@ -308,6 +319,8 @@ describe("AiController (integration)", () => {
|
|||||||
channel: AiChannel.USER_KEY,
|
channel: AiChannel.USER_KEY,
|
||||||
providerName: "openai",
|
providerName: "openai",
|
||||||
model: "gpt-4o-mini",
|
model: "gpt-4o-mini",
|
||||||
|
configId: null,
|
||||||
|
configName: null,
|
||||||
encryptedApiKey: "sk-user",
|
encryptedApiKey: "sk-user",
|
||||||
endpoint: "https://api.example.com",
|
endpoint: "https://api.example.com",
|
||||||
isDefault: true,
|
isDefault: true,
|
||||||
@@ -317,8 +330,10 @@ describe("AiController (integration)", () => {
|
|||||||
id: "binding_astrbot",
|
id: "binding_astrbot",
|
||||||
userId: "user_1",
|
userId: "user_1",
|
||||||
channel: AiChannel.ASTRBOT,
|
channel: AiChannel.ASTRBOT,
|
||||||
providerName: "astrbot-main",
|
providerName: "",
|
||||||
model: "deepseek-chat",
|
model: null,
|
||||||
|
configId: "default",
|
||||||
|
configName: null,
|
||||||
encryptedApiKey: "abk_astrbot",
|
encryptedApiKey: "abk_astrbot",
|
||||||
endpoint: "http://127.0.0.1:6185",
|
endpoint: "http://127.0.0.1:6185",
|
||||||
isDefault: true,
|
isDefault: true,
|
||||||
@@ -346,8 +361,8 @@ describe("AiController (integration)", () => {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
channel: AiChannel.ASTRBOT,
|
channel: AiChannel.ASTRBOT,
|
||||||
providerName: "astrbot-main",
|
providerName: "default",
|
||||||
model: "deepseek-chat",
|
model: null,
|
||||||
status: "success",
|
status: "success",
|
||||||
reasonCode: null,
|
reasonCode: null,
|
||||||
reasonMessage: null
|
reasonMessage: null
|
||||||
@@ -355,6 +370,28 @@ describe("AiController (integration)", () => {
|
|||||||
]);
|
]);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should allow astrbot binding with config id only", async () => {
|
||||||
|
const response = await request(app.getHttpServer())
|
||||||
|
.post("/ai/bindings")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.send({
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
configId: "default",
|
||||||
|
endpoint: "http://127.0.0.1:6185",
|
||||||
|
apiKey: "abk_secret_1234",
|
||||||
|
isDefault: true,
|
||||||
|
isEnabled: true
|
||||||
|
})
|
||||||
|
.expect(201);
|
||||||
|
|
||||||
|
expect(response.body).toMatchObject({
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
providerName: "",
|
||||||
|
configId: "default",
|
||||||
|
configName: null
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
it("should return skipped attempts when no channel is available", async () => {
|
it("should return skipped attempts when no channel is available", async () => {
|
||||||
const response = await request(app.getHttpServer())
|
const response = await request(app.getHttpServer())
|
||||||
.post("/ai/chat")
|
.post("/ai/chat")
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ describe("AstrbotProvider", () => {
|
|||||||
sourceId: "binding_1",
|
sourceId: "binding_1",
|
||||||
providerName: "",
|
providerName: "",
|
||||||
model: null,
|
model: null,
|
||||||
|
configId: "default",
|
||||||
|
configName: null,
|
||||||
endpoint: "http://127.0.0.1:6185",
|
endpoint: "http://127.0.0.1:6185",
|
||||||
apiKey: "abk_test"
|
apiKey: "abk_test"
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user