feat(api-ai): support astrbot config selection
This commit is contained in:
@@ -22,6 +22,8 @@ type AiBindingSummary = {
|
||||
channel: AiChannel;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
configId: string | null;
|
||||
configName: string | null;
|
||||
endpoint: string | null;
|
||||
isDefault: boolean;
|
||||
isEnabled: boolean;
|
||||
@@ -105,6 +107,8 @@ export class AiService {
|
||||
throw new BadRequestException("公共 AI 通道只能由管理员配置");
|
||||
}
|
||||
|
||||
this.validateBindingInput(dto);
|
||||
|
||||
const result = await this.prismaService.$transaction(async (tx) => {
|
||||
if (dto.isDefault) {
|
||||
const where: Prisma.AiProviderBindingWhereInput = {
|
||||
@@ -131,8 +135,10 @@ export class AiService {
|
||||
data: {
|
||||
userId,
|
||||
channel: dto.channel,
|
||||
providerName: dto.providerName.trim(),
|
||||
providerName: this.normalizeProviderName(dto.providerName),
|
||||
model: this.normalizeOptionalString(dto.model),
|
||||
configId: this.normalizeOptionalString(dto.configId),
|
||||
configName: this.normalizeOptionalString(dto.configName),
|
||||
endpoint: this.normalizeOptionalString(dto.endpoint),
|
||||
encryptedApiKey: this.normalizeOptionalString(dto.apiKey),
|
||||
isDefault: dto.isDefault ?? false,
|
||||
@@ -154,8 +160,10 @@ export class AiService {
|
||||
|
||||
const updateData: Prisma.AiProviderBindingUpdateInput = {
|
||||
channel: dto.channel,
|
||||
providerName: dto.providerName.trim(),
|
||||
providerName: this.normalizeProviderName(dto.providerName),
|
||||
model: this.normalizeOptionalString(dto.model),
|
||||
configId: this.normalizeOptionalString(dto.configId),
|
||||
configName: this.normalizeOptionalString(dto.configName),
|
||||
isDefault: dto.isDefault ?? existingBinding.isDefault,
|
||||
isEnabled: dto.isEnabled ?? existingBinding.isEnabled
|
||||
};
|
||||
@@ -342,6 +350,8 @@ export class AiService {
|
||||
sourceId: binding.id,
|
||||
providerName: binding.providerName,
|
||||
model: binding.model,
|
||||
configId: binding.configId,
|
||||
configName: binding.configName,
|
||||
endpoint: binding.endpoint,
|
||||
apiKey: binding.encryptedApiKey
|
||||
};
|
||||
@@ -354,6 +364,8 @@ export class AiService {
|
||||
sourceId: publicPool.id,
|
||||
providerName: publicPool.providerName ?? "public-pool",
|
||||
model: publicPool.model,
|
||||
configId: null,
|
||||
configName: null,
|
||||
endpoint: publicPool.endpoint,
|
||||
apiKey: publicPool.encryptedApiKey
|
||||
};
|
||||
@@ -365,6 +377,8 @@ export class AiService {
|
||||
channel: binding.channel,
|
||||
providerName: binding.providerName,
|
||||
model: binding.model,
|
||||
configId: binding.configId,
|
||||
configName: binding.configName,
|
||||
endpoint: binding.endpoint,
|
||||
isDefault: binding.isDefault,
|
||||
isEnabled: binding.isEnabled,
|
||||
@@ -416,6 +430,29 @@ export class AiService {
|
||||
return normalizedValue.length > 0 ? normalizedValue : null;
|
||||
}
|
||||
|
||||
private normalizeProviderName(value: string | undefined): string {
|
||||
return this.normalizeOptionalString(value) ?? "";
|
||||
}
|
||||
|
||||
private validateBindingInput(dto: UpsertAiProviderBindingDto): void {
|
||||
const providerName = this.normalizeOptionalString(dto.providerName);
|
||||
const configId = this.normalizeOptionalString(dto.configId);
|
||||
const configName = this.normalizeOptionalString(dto.configName);
|
||||
|
||||
if (dto.channel === AiChannel.ASTRBOT) {
|
||||
if (!providerName && !configId && !configName) {
|
||||
throw new BadRequestException(
|
||||
"AstrBot 通道至少需要 providerName、configId、configName 三者之一"
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!providerName) {
|
||||
throw new BadRequestException("当前通道必须提供 providerName");
|
||||
}
|
||||
}
|
||||
|
||||
private maskSecret(secret: string | null): string | null {
|
||||
if (!secret) {
|
||||
return null;
|
||||
|
||||
@@ -6,6 +6,8 @@ export type AiResolvedRouteCandidate = {
|
||||
sourceId: string | null;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
configId: string | null;
|
||||
configName: string | null;
|
||||
endpoint: string | null;
|
||||
apiKey: string | null;
|
||||
};
|
||||
|
||||
@@ -10,15 +10,26 @@ export class UpsertAiProviderBindingDto {
|
||||
@IsEnum(AiChannel)
|
||||
channel!: AiChannel;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
providerName!: string;
|
||||
providerName?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
model?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
configId?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
configName?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsUrl(
|
||||
{
|
||||
|
||||
@@ -10,10 +10,13 @@ import {
|
||||
@Injectable()
|
||||
export class AstrbotProvider implements AiChannelExecutor {
|
||||
async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise<AiChatResult> {
|
||||
const routeLabel =
|
||||
candidate.providerName || candidate.configName || candidate.configId || "astrbot";
|
||||
|
||||
if (!candidate.endpoint) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
routeLabel,
|
||||
"MISSING_ENDPOINT",
|
||||
"缺少 AstrBot 服务地址配置"
|
||||
);
|
||||
@@ -22,7 +25,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
||||
if (!candidate.apiKey) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
routeLabel,
|
||||
"MISSING_API_KEY",
|
||||
"缺少 AstrBot API Key 配置"
|
||||
);
|
||||
@@ -43,6 +46,8 @@ export class AstrbotProvider implements AiChannelExecutor {
|
||||
session_id: input.sessionId ?? undefined,
|
||||
message: input.message,
|
||||
enable_streaming: false,
|
||||
config_id: candidate.configId ?? undefined,
|
||||
config_name: candidate.configName ?? undefined,
|
||||
selected_provider: candidate.providerName || undefined,
|
||||
selected_model: candidate.model ?? undefined
|
||||
}),
|
||||
@@ -51,7 +56,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
||||
} catch (error) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
routeLabel,
|
||||
"UPSTREAM_UNREACHABLE",
|
||||
this.toErrorMessage(error, "AstrBot 服务请求失败")
|
||||
);
|
||||
@@ -61,7 +66,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
||||
const rawText = await response.text();
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
routeLabel,
|
||||
`UPSTREAM_HTTP_${response.status}`,
|
||||
this.extractHttpErrorMessage(rawText, response.status)
|
||||
);
|
||||
@@ -81,7 +86,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
||||
if (type === "error") {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
routeLabel,
|
||||
this.readString(event["code"]) ?? "ASTRBOT_ERROR",
|
||||
this.readString(event["data"]) ?? "AstrBot 返回错误"
|
||||
);
|
||||
@@ -116,7 +121,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
||||
if (!content.trim()) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
routeLabel,
|
||||
"EMPTY_RESPONSE",
|
||||
"AstrBot 没有返回有效内容"
|
||||
);
|
||||
@@ -124,7 +129,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
||||
|
||||
return {
|
||||
channel: candidate.channel,
|
||||
providerName: candidate.providerName,
|
||||
providerName: routeLabel,
|
||||
model: candidate.model,
|
||||
content,
|
||||
sessionId,
|
||||
|
||||
Reference in New Issue
Block a user