feat(api-ai): add provider registry and routing fallback
This commit is contained in:
@@ -0,0 +1,28 @@
|
|||||||
|
import { Injectable } from "@nestjs/common";
|
||||||
|
import { AiChannel } from "../../generated/prisma/client";
|
||||||
|
import { AstrbotProvider } from "./providers/astrbot.provider";
|
||||||
|
import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider";
|
||||||
|
import { AiChannelExecutor } from "./ai.types";
|
||||||
|
|
||||||
|
@Injectable()
|
||||||
|
export class AiProviderRegistryService {
|
||||||
|
private readonly executors = new Map<AiChannel, AiChannelExecutor>();
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
openAiCompatibleProvider: OpenAiCompatibleProvider,
|
||||||
|
astrbotProvider: AstrbotProvider
|
||||||
|
) {
|
||||||
|
this.executors.set(AiChannel.USER_KEY, openAiCompatibleProvider);
|
||||||
|
this.executors.set(AiChannel.PUBLIC_POOL, openAiCompatibleProvider);
|
||||||
|
this.executors.set(AiChannel.ASTRBOT, astrbotProvider);
|
||||||
|
}
|
||||||
|
|
||||||
|
getExecutor(channel: AiChannel): AiChannelExecutor {
|
||||||
|
const executor = this.executors.get(channel);
|
||||||
|
if (!executor) {
|
||||||
|
throw new Error(`未找到 ${channel} 对应的 AI 通道执行器`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return executor;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
import { Body, Controller, Get, Headers, Post, UnauthorizedException } from "@nestjs/common";
|
||||||
|
import { AiChatDto } from "./dto/ai-chat.dto";
|
||||||
|
import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto";
|
||||||
|
import { AiChatResponse, AiService, ListAiBindingsResponse } from "./ai.service";
|
||||||
|
|
||||||
|
@Controller("ai")
|
||||||
|
export class AiController {
|
||||||
|
constructor(private readonly aiService: AiService) {}
|
||||||
|
|
||||||
|
@Get("bindings")
|
||||||
|
async listBindings(
|
||||||
|
@Headers("x-user-id") userIdHeader: string | string[] | undefined
|
||||||
|
): Promise<ListAiBindingsResponse> {
|
||||||
|
return this.aiService.listBindings(this.resolveUserId(userIdHeader));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Post("bindings")
|
||||||
|
async upsertBinding(
|
||||||
|
@Headers("x-user-id") userIdHeader: string | string[] | undefined,
|
||||||
|
@Body() body: UpsertAiProviderBindingDto
|
||||||
|
) {
|
||||||
|
return this.aiService.upsertBinding(this.resolveUserId(userIdHeader), body);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Post("chat")
|
||||||
|
async chat(
|
||||||
|
@Headers("x-user-id") userIdHeader: string | string[] | undefined,
|
||||||
|
@Body() body: AiChatDto
|
||||||
|
): Promise<AiChatResponse> {
|
||||||
|
return this.aiService.chat(this.resolveUserId(userIdHeader), body);
|
||||||
|
}
|
||||||
|
|
||||||
|
private resolveUserId(userIdHeader: string | string[] | undefined): string {
|
||||||
|
const userId = Array.isArray(userIdHeader) ? userIdHeader[0] : userIdHeader;
|
||||||
|
if (!userId) {
|
||||||
|
throw new UnauthorizedException("缺少用户上下文");
|
||||||
|
}
|
||||||
|
|
||||||
|
return userId;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
import { Module } from "@nestjs/common";
|
||||||
|
import { PrismaModule } from "../prisma/prisma.module";
|
||||||
|
import { AiController } from "./ai.controller";
|
||||||
|
import { AiProviderRegistryService } from "./ai-provider-registry.service";
|
||||||
|
import { AiService } from "./ai.service";
|
||||||
|
import { AstrbotProvider } from "./providers/astrbot.provider";
|
||||||
|
import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider";
|
||||||
|
|
||||||
|
@Module({
|
||||||
|
imports: [PrismaModule],
|
||||||
|
controllers: [AiController],
|
||||||
|
providers: [AiService, AiProviderRegistryService, OpenAiCompatibleProvider, AstrbotProvider]
|
||||||
|
})
|
||||||
|
export class AiModule {}
|
||||||
@@ -0,0 +1,430 @@
|
|||||||
|
import {
|
||||||
|
BadGatewayException,
|
||||||
|
BadRequestException,
|
||||||
|
Injectable,
|
||||||
|
Logger,
|
||||||
|
NotFoundException
|
||||||
|
} from "@nestjs/common";
|
||||||
|
import {
|
||||||
|
AiChannel,
|
||||||
|
AiProviderBinding,
|
||||||
|
AiPublicPoolConfig,
|
||||||
|
Prisma
|
||||||
|
} from "../../generated/prisma/client";
|
||||||
|
import { PrismaService } from "../prisma/prisma.service";
|
||||||
|
import { AiProviderRegistryService } from "./ai-provider-registry.service";
|
||||||
|
import { AiChatDto } from "./dto/ai-chat.dto";
|
||||||
|
import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto";
|
||||||
|
import { AiResolvedRouteCandidate, AiRouteAttempt, AiRouteFailureError } from "./ai.types";
|
||||||
|
|
||||||
|
type AiBindingSummary = {
|
||||||
|
id: string;
|
||||||
|
channel: AiChannel;
|
||||||
|
providerName: string;
|
||||||
|
model: string | null;
|
||||||
|
endpoint: string | null;
|
||||||
|
isDefault: boolean;
|
||||||
|
isEnabled: boolean;
|
||||||
|
hasApiKey: boolean;
|
||||||
|
maskedApiKey: string | null;
|
||||||
|
updatedAt: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
type AiRoutePlanEntry =
|
||||||
|
| {
|
||||||
|
kind: "candidate";
|
||||||
|
candidate: AiResolvedRouteCandidate;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
kind: "skip";
|
||||||
|
attempt: AiRouteAttempt;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ListAiBindingsResponse = {
|
||||||
|
routeOrder: AiChannel[];
|
||||||
|
bindings: AiBindingSummary[];
|
||||||
|
publicPool: {
|
||||||
|
enabled: boolean;
|
||||||
|
providerName: string | null;
|
||||||
|
model: string | null;
|
||||||
|
endpoint: string | null;
|
||||||
|
hasApiKey: boolean;
|
||||||
|
} | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type AiChatResponse = {
|
||||||
|
channel: AiChannel;
|
||||||
|
providerName: string;
|
||||||
|
model: string | null;
|
||||||
|
content: string;
|
||||||
|
sessionId: string | null;
|
||||||
|
attempts: AiRouteAttempt[];
|
||||||
|
};
|
||||||
|
|
||||||
|
@Injectable()
|
||||||
|
export class AiService {
|
||||||
|
private readonly logger = new Logger(AiService.name);
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
private readonly prismaService: PrismaService,
|
||||||
|
private readonly aiProviderRegistryService: AiProviderRegistryService
|
||||||
|
) {}
|
||||||
|
|
||||||
|
async listBindings(userId: string): Promise<ListAiBindingsResponse> {
|
||||||
|
const [bindings, publicPool] = await Promise.all([
|
||||||
|
this.prismaService.aiProviderBinding.findMany({
|
||||||
|
where: {
|
||||||
|
userId
|
||||||
|
},
|
||||||
|
orderBy: [{ channel: "asc" }, { isDefault: "desc" }, { updatedAt: "desc" }]
|
||||||
|
}),
|
||||||
|
this.prismaService.aiPublicPoolConfig.findFirst({
|
||||||
|
orderBy: {
|
||||||
|
updatedAt: "desc"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
routeOrder: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL],
|
||||||
|
bindings: bindings.map((binding) => this.serializeBinding(binding)),
|
||||||
|
publicPool: publicPool
|
||||||
|
? {
|
||||||
|
enabled: publicPool.enabled,
|
||||||
|
providerName: publicPool.providerName,
|
||||||
|
model: publicPool.model,
|
||||||
|
endpoint: publicPool.endpoint,
|
||||||
|
hasApiKey: Boolean(publicPool.encryptedApiKey)
|
||||||
|
}
|
||||||
|
: null
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async upsertBinding(userId: string, dto: UpsertAiProviderBindingDto): Promise<AiBindingSummary> {
|
||||||
|
if (dto.channel === AiChannel.PUBLIC_POOL) {
|
||||||
|
throw new BadRequestException("公共 AI 通道只能由管理员配置");
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await this.prismaService.$transaction(async (tx) => {
|
||||||
|
if (dto.isDefault) {
|
||||||
|
const where: Prisma.AiProviderBindingWhereInput = {
|
||||||
|
userId,
|
||||||
|
channel: dto.channel
|
||||||
|
};
|
||||||
|
|
||||||
|
if (dto.id) {
|
||||||
|
where.id = {
|
||||||
|
not: dto.id
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
await tx.aiProviderBinding.updateMany({
|
||||||
|
where,
|
||||||
|
data: {
|
||||||
|
isDefault: false
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!dto.id) {
|
||||||
|
return tx.aiProviderBinding.create({
|
||||||
|
data: {
|
||||||
|
userId,
|
||||||
|
channel: dto.channel,
|
||||||
|
providerName: dto.providerName.trim(),
|
||||||
|
model: this.normalizeOptionalString(dto.model),
|
||||||
|
endpoint: this.normalizeOptionalString(dto.endpoint),
|
||||||
|
encryptedApiKey: this.normalizeOptionalString(dto.apiKey),
|
||||||
|
isDefault: dto.isDefault ?? false,
|
||||||
|
isEnabled: dto.isEnabled ?? true
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const existingBinding = await tx.aiProviderBinding.findFirst({
|
||||||
|
where: {
|
||||||
|
id: dto.id,
|
||||||
|
userId
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!existingBinding) {
|
||||||
|
throw new NotFoundException("AI 通道配置不存在");
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateData: Prisma.AiProviderBindingUpdateInput = {
|
||||||
|
channel: dto.channel,
|
||||||
|
providerName: dto.providerName.trim(),
|
||||||
|
model: this.normalizeOptionalString(dto.model),
|
||||||
|
isDefault: dto.isDefault ?? existingBinding.isDefault,
|
||||||
|
isEnabled: dto.isEnabled ?? existingBinding.isEnabled
|
||||||
|
};
|
||||||
|
|
||||||
|
if (dto.endpoint !== undefined) {
|
||||||
|
updateData.endpoint = this.normalizeOptionalString(dto.endpoint);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dto.apiKey !== undefined) {
|
||||||
|
updateData.encryptedApiKey = this.normalizeOptionalString(dto.apiKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.aiProviderBinding.update({
|
||||||
|
where: {
|
||||||
|
id: dto.id
|
||||||
|
},
|
||||||
|
data: updateData
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return this.serializeBinding(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
async chat(userId: string, dto: AiChatDto): Promise<AiChatResponse> {
|
||||||
|
const attempts: AiRouteAttempt[] = [];
|
||||||
|
const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null);
|
||||||
|
|
||||||
|
for (const entry of plan) {
|
||||||
|
if (entry.kind === "skip") {
|
||||||
|
attempts.push(entry.attempt);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const executor = this.aiProviderRegistryService.getExecutor(entry.candidate.channel);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = await executor.execute(entry.candidate, {
|
||||||
|
userId,
|
||||||
|
message: dto.message,
|
||||||
|
sessionId: dto.sessionId ?? null
|
||||||
|
});
|
||||||
|
|
||||||
|
attempts.push({
|
||||||
|
channel: result.channel,
|
||||||
|
providerName: result.providerName,
|
||||||
|
model: result.model,
|
||||||
|
status: "success",
|
||||||
|
reasonCode: null,
|
||||||
|
reasonMessage: null
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
channel: result.channel,
|
||||||
|
providerName: result.providerName,
|
||||||
|
model: result.model,
|
||||||
|
content: result.content,
|
||||||
|
sessionId: result.sessionId,
|
||||||
|
attempts
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
const failureAttempt = this.toFailureAttempt(entry.candidate, error);
|
||||||
|
attempts.push(failureAttempt);
|
||||||
|
this.logger.warn(
|
||||||
|
`AI 通道降级:channel=${failureAttempt.channel} provider=${failureAttempt.providerName ?? "unknown"} code=${failureAttempt.reasonCode ?? "UNKNOWN"} message=${failureAttempt.reasonMessage ?? "unknown"}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new BadGatewayException({
|
||||||
|
message: "当前没有可用的 AI 通道,请稍后重试",
|
||||||
|
attempts
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private async buildRoutePlan(
|
||||||
|
userId: string,
|
||||||
|
bindingId: string | null
|
||||||
|
): Promise<AiRoutePlanEntry[]> {
|
||||||
|
const plan: AiRoutePlanEntry[] = [];
|
||||||
|
const consumedChannels = new Set<AiChannel>();
|
||||||
|
|
||||||
|
if (bindingId) {
|
||||||
|
const pinnedBinding = await this.prismaService.aiProviderBinding.findFirst({
|
||||||
|
where: {
|
||||||
|
id: bindingId,
|
||||||
|
userId,
|
||||||
|
isEnabled: true
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!pinnedBinding) {
|
||||||
|
throw new NotFoundException("指定的 AI 通道配置不存在或已禁用");
|
||||||
|
}
|
||||||
|
|
||||||
|
plan.push({
|
||||||
|
kind: "candidate",
|
||||||
|
candidate: this.toBindingCandidate(pinnedBinding)
|
||||||
|
});
|
||||||
|
consumedChannels.add(pinnedBinding.channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const channel of [AiChannel.USER_KEY, AiChannel.ASTRBOT]) {
|
||||||
|
if (consumedChannels.has(channel)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const binding = await this.findPreferredBinding(userId, channel);
|
||||||
|
if (binding) {
|
||||||
|
plan.push({
|
||||||
|
kind: "candidate",
|
||||||
|
candidate: this.toBindingCandidate(binding)
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
plan.push({
|
||||||
|
kind: "skip",
|
||||||
|
attempt: {
|
||||||
|
channel,
|
||||||
|
providerName: null,
|
||||||
|
model: null,
|
||||||
|
status: "skipped",
|
||||||
|
reasonCode: "CHANNEL_NOT_CONFIGURED",
|
||||||
|
reasonMessage:
|
||||||
|
channel === AiChannel.USER_KEY
|
||||||
|
? "当前用户未配置可用的自备 Key 通道"
|
||||||
|
: "当前用户未配置可用的 AstrBot 通道"
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const publicPool = await this.findEnabledPublicPool();
|
||||||
|
if (publicPool) {
|
||||||
|
plan.push({
|
||||||
|
kind: "candidate",
|
||||||
|
candidate: this.toPublicPoolCandidate(publicPool)
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
plan.push({
|
||||||
|
kind: "skip",
|
||||||
|
attempt: {
|
||||||
|
channel: AiChannel.PUBLIC_POOL,
|
||||||
|
providerName: null,
|
||||||
|
model: null,
|
||||||
|
status: "skipped",
|
||||||
|
reasonCode: "PUBLIC_POOL_DISABLED",
|
||||||
|
reasonMessage: "公共 AI 通道未开启"
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return plan;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async findPreferredBinding(
|
||||||
|
userId: string,
|
||||||
|
channel: AiChannel
|
||||||
|
): Promise<AiProviderBinding | null> {
|
||||||
|
return this.prismaService.aiProviderBinding.findFirst({
|
||||||
|
where: {
|
||||||
|
userId,
|
||||||
|
channel,
|
||||||
|
isEnabled: true
|
||||||
|
},
|
||||||
|
orderBy: [{ isDefault: "desc" }, { updatedAt: "desc" }]
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private async findEnabledPublicPool(): Promise<AiPublicPoolConfig | null> {
|
||||||
|
return this.prismaService.aiPublicPoolConfig.findFirst({
|
||||||
|
where: {
|
||||||
|
enabled: true
|
||||||
|
},
|
||||||
|
orderBy: {
|
||||||
|
updatedAt: "desc"
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private toBindingCandidate(binding: AiProviderBinding): AiResolvedRouteCandidate {
|
||||||
|
return {
|
||||||
|
channel: binding.channel,
|
||||||
|
source: "binding",
|
||||||
|
sourceId: binding.id,
|
||||||
|
providerName: binding.providerName,
|
||||||
|
model: binding.model,
|
||||||
|
endpoint: binding.endpoint,
|
||||||
|
apiKey: binding.encryptedApiKey
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private toPublicPoolCandidate(publicPool: AiPublicPoolConfig): AiResolvedRouteCandidate {
|
||||||
|
return {
|
||||||
|
channel: AiChannel.PUBLIC_POOL,
|
||||||
|
source: "public_pool",
|
||||||
|
sourceId: publicPool.id,
|
||||||
|
providerName: publicPool.providerName ?? "public-pool",
|
||||||
|
model: publicPool.model,
|
||||||
|
endpoint: publicPool.endpoint,
|
||||||
|
apiKey: publicPool.encryptedApiKey
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private serializeBinding(binding: AiProviderBinding): AiBindingSummary {
|
||||||
|
return {
|
||||||
|
id: binding.id,
|
||||||
|
channel: binding.channel,
|
||||||
|
providerName: binding.providerName,
|
||||||
|
model: binding.model,
|
||||||
|
endpoint: binding.endpoint,
|
||||||
|
isDefault: binding.isDefault,
|
||||||
|
isEnabled: binding.isEnabled,
|
||||||
|
hasApiKey: Boolean(binding.encryptedApiKey),
|
||||||
|
maskedApiKey: this.maskSecret(binding.encryptedApiKey),
|
||||||
|
updatedAt: binding.updatedAt.toISOString()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt {
|
||||||
|
if (error instanceof AiRouteFailureError) {
|
||||||
|
return {
|
||||||
|
channel: error.channel,
|
||||||
|
providerName: error.providerName,
|
||||||
|
model: candidate.model,
|
||||||
|
status: "failed",
|
||||||
|
reasonCode: error.code,
|
||||||
|
reasonMessage: error.message
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error instanceof Error) {
|
||||||
|
return {
|
||||||
|
channel: candidate.channel,
|
||||||
|
providerName: candidate.providerName,
|
||||||
|
model: candidate.model,
|
||||||
|
status: "failed",
|
||||||
|
reasonCode: "UNKNOWN_ERROR",
|
||||||
|
reasonMessage: error.message
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
channel: candidate.channel,
|
||||||
|
providerName: candidate.providerName,
|
||||||
|
model: candidate.model,
|
||||||
|
status: "failed",
|
||||||
|
reasonCode: "UNKNOWN_ERROR",
|
||||||
|
reasonMessage: "未知错误"
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private normalizeOptionalString(value: string | undefined): string | null {
|
||||||
|
if (value === undefined) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const normalizedValue = value.trim();
|
||||||
|
return normalizedValue.length > 0 ? normalizedValue : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private maskSecret(secret: string | null): string | null {
|
||||||
|
if (!secret) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (secret.length <= 6) {
|
||||||
|
return "*".repeat(secret.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
return `${secret.slice(0, 4)}***${secret.slice(-2)}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
import { AiChannel } from "../../generated/prisma/client";
|
||||||
|
|
||||||
|
export type AiResolvedRouteCandidate = {
|
||||||
|
channel: AiChannel;
|
||||||
|
source: "binding" | "public_pool";
|
||||||
|
sourceId: string | null;
|
||||||
|
providerName: string;
|
||||||
|
model: string | null;
|
||||||
|
endpoint: string | null;
|
||||||
|
apiKey: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type AiChatInput = {
|
||||||
|
userId: string;
|
||||||
|
message: string;
|
||||||
|
sessionId: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type AiChatResult = {
|
||||||
|
channel: AiChannel;
|
||||||
|
providerName: string;
|
||||||
|
model: string | null;
|
||||||
|
content: string;
|
||||||
|
sessionId: string | null;
|
||||||
|
raw: unknown;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type AiRouteAttempt = {
|
||||||
|
channel: AiChannel;
|
||||||
|
providerName: string | null;
|
||||||
|
model: string | null;
|
||||||
|
status: "skipped" | "failed" | "success";
|
||||||
|
reasonCode: string | null;
|
||||||
|
reasonMessage: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export class AiRouteFailureError extends Error {
|
||||||
|
constructor(
|
||||||
|
public readonly channel: AiChannel,
|
||||||
|
public readonly providerName: string,
|
||||||
|
public readonly code: string,
|
||||||
|
message: string
|
||||||
|
) {
|
||||||
|
super(message);
|
||||||
|
this.name = "AiRouteFailureError";
|
||||||
|
Object.setPrototypeOf(this, new.target.prototype);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface AiChannelExecutor {
|
||||||
|
execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise<AiChatResult>;
|
||||||
|
}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
import { IsOptional, IsString, MinLength } from "class-validator";
|
||||||
|
|
||||||
|
export class AiChatDto {
|
||||||
|
@IsString()
|
||||||
|
@MinLength(1)
|
||||||
|
message!: string;
|
||||||
|
|
||||||
|
@IsOptional()
|
||||||
|
@IsString()
|
||||||
|
@MinLength(1)
|
||||||
|
sessionId?: string;
|
||||||
|
|
||||||
|
@IsOptional()
|
||||||
|
@IsString()
|
||||||
|
@MinLength(1)
|
||||||
|
bindingId?: string;
|
||||||
|
}
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
import { AiChannel } from "../../../generated/prisma/client";
|
||||||
|
import { IsBoolean, IsEnum, IsOptional, IsString, IsUrl, MinLength } from "class-validator";
|
||||||
|
|
||||||
|
export class UpsertAiProviderBindingDto {
|
||||||
|
@IsOptional()
|
||||||
|
@IsString()
|
||||||
|
@MinLength(1)
|
||||||
|
id?: string;
|
||||||
|
|
||||||
|
@IsEnum(AiChannel)
|
||||||
|
channel!: AiChannel;
|
||||||
|
|
||||||
|
@IsString()
|
||||||
|
@MinLength(1)
|
||||||
|
providerName!: string;
|
||||||
|
|
||||||
|
@IsOptional()
|
||||||
|
@IsString()
|
||||||
|
@MinLength(1)
|
||||||
|
model?: string;
|
||||||
|
|
||||||
|
@IsOptional()
|
||||||
|
@IsUrl(
|
||||||
|
{
|
||||||
|
require_tld: false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
message: "endpoint 必须是合法的 URL"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
endpoint?: string;
|
||||||
|
|
||||||
|
@IsOptional()
|
||||||
|
@IsString()
|
||||||
|
@MinLength(1)
|
||||||
|
apiKey?: string;
|
||||||
|
|
||||||
|
@IsOptional()
|
||||||
|
@IsBoolean()
|
||||||
|
isDefault?: boolean;
|
||||||
|
|
||||||
|
@IsOptional()
|
||||||
|
@IsBoolean()
|
||||||
|
isEnabled?: boolean;
|
||||||
|
}
|
||||||
@@ -0,0 +1,197 @@
|
|||||||
|
import { Injectable } from "@nestjs/common";
|
||||||
|
import {
|
||||||
|
AiChannelExecutor,
|
||||||
|
AiChatInput,
|
||||||
|
AiChatResult,
|
||||||
|
AiResolvedRouteCandidate,
|
||||||
|
AiRouteFailureError
|
||||||
|
} from "../ai.types";
|
||||||
|
|
||||||
|
@Injectable()
|
||||||
|
export class AstrbotProvider implements AiChannelExecutor {
|
||||||
|
async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise<AiChatResult> {
|
||||||
|
if (!candidate.endpoint) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
"MISSING_ENDPOINT",
|
||||||
|
"缺少 AstrBot 服务地址配置"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!candidate.apiKey) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
"MISSING_API_KEY",
|
||||||
|
"缺少 AstrBot API Key 配置"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const requestUrl = this.buildRequestUrl(candidate.endpoint);
|
||||||
|
|
||||||
|
let response: Response;
|
||||||
|
try {
|
||||||
|
response = await fetch(requestUrl, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
Authorization: `Bearer ${candidate.apiKey}`
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
username: input.userId,
|
||||||
|
session_id: input.sessionId ?? undefined,
|
||||||
|
message: input.message,
|
||||||
|
enable_streaming: false,
|
||||||
|
selected_provider: candidate.providerName || undefined,
|
||||||
|
selected_model: candidate.model ?? undefined
|
||||||
|
}),
|
||||||
|
signal: AbortSignal.timeout(30000)
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
"UPSTREAM_UNREACHABLE",
|
||||||
|
this.toErrorMessage(error, "AstrBot 服务请求失败")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const rawText = await response.text();
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
`UPSTREAM_HTTP_${response.status}`,
|
||||||
|
this.extractHttpErrorMessage(rawText, response.status)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const events = this.parseSseEvents(rawText);
|
||||||
|
let content = "";
|
||||||
|
let sessionId = input.sessionId;
|
||||||
|
|
||||||
|
for (const event of events) {
|
||||||
|
const type = this.readString(event["type"]);
|
||||||
|
if (type === "session_id") {
|
||||||
|
sessionId = this.readString(event["session_id"]) ?? sessionId;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type === "error") {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
this.readString(event["code"]) ?? "ASTRBOT_ERROR",
|
||||||
|
this.readString(event["data"]) ?? "AstrBot 返回错误"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type !== "plain") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const chainType = this.readString(event["chain_type"]);
|
||||||
|
if (
|
||||||
|
chainType === "reasoning" ||
|
||||||
|
chainType === "tool_call" ||
|
||||||
|
chainType === "tool_call_result"
|
||||||
|
) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = this.readString(event["data"]);
|
||||||
|
if (!data) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (event["streaming"] === true) {
|
||||||
|
content += data;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
content = data;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!content.trim()) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
"EMPTY_RESPONSE",
|
||||||
|
"AstrBot 没有返回有效内容"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
channel: candidate.channel,
|
||||||
|
providerName: candidate.providerName,
|
||||||
|
model: candidate.model,
|
||||||
|
content,
|
||||||
|
sessionId,
|
||||||
|
raw: events
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private buildRequestUrl(endpoint: string): string {
|
||||||
|
const normalizedEndpoint = endpoint.replace(/\/+$/, "");
|
||||||
|
if (normalizedEndpoint.endsWith("/api/v1/chat")) {
|
||||||
|
return normalizedEndpoint;
|
||||||
|
}
|
||||||
|
if (normalizedEndpoint.endsWith("/api/v1")) {
|
||||||
|
return `${normalizedEndpoint}/chat`;
|
||||||
|
}
|
||||||
|
if (normalizedEndpoint.endsWith("/api")) {
|
||||||
|
return `${normalizedEndpoint}/v1/chat`;
|
||||||
|
}
|
||||||
|
return `${normalizedEndpoint}/api/v1/chat`;
|
||||||
|
}
|
||||||
|
|
||||||
|
private parseSseEvents(rawText: string): Array<Record<string, unknown>> {
|
||||||
|
return rawText
|
||||||
|
.split(/\r?\n\r?\n/)
|
||||||
|
.map((block) =>
|
||||||
|
block
|
||||||
|
.split(/\r?\n/)
|
||||||
|
.filter((line) => line.startsWith("data:"))
|
||||||
|
.map((line) => line.slice(5).trim())
|
||||||
|
.join("\n")
|
||||||
|
)
|
||||||
|
.filter((payload) => payload.length > 0)
|
||||||
|
.map((payload) => {
|
||||||
|
try {
|
||||||
|
return JSON.parse(payload) as Record<string, unknown>;
|
||||||
|
} catch {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.filter((item): item is Record<string, unknown> => item !== null);
|
||||||
|
}
|
||||||
|
|
||||||
|
private extractHttpErrorMessage(rawText: string, statusCode: number): string {
|
||||||
|
try {
|
||||||
|
const payload = JSON.parse(rawText) as Record<string, unknown>;
|
||||||
|
if (typeof payload["message"] === "string") {
|
||||||
|
return payload["message"];
|
||||||
|
}
|
||||||
|
if (typeof payload["data"] === "string") {
|
||||||
|
return payload["data"];
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
return `AstrBot 服务调用失败,状态码 ${statusCode}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return `AstrBot 服务调用失败,状态码 ${statusCode}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
private readString(value: unknown): string | null {
|
||||||
|
return typeof value === "string" ? value : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private toErrorMessage(error: unknown, fallback: string): string {
|
||||||
|
if (error instanceof Error && error.message) {
|
||||||
|
return error.message;
|
||||||
|
}
|
||||||
|
|
||||||
|
return fallback;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,203 @@
|
|||||||
|
import { Injectable } from "@nestjs/common";
|
||||||
|
import {
|
||||||
|
AiChannelExecutor,
|
||||||
|
AiChatInput,
|
||||||
|
AiChatResult,
|
||||||
|
AiResolvedRouteCandidate,
|
||||||
|
AiRouteFailureError
|
||||||
|
} from "../ai.types";
|
||||||
|
|
||||||
|
@Injectable()
|
||||||
|
export class OpenAiCompatibleProvider implements AiChannelExecutor {
|
||||||
|
async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise<AiChatResult> {
|
||||||
|
if (!candidate.endpoint) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
"MISSING_ENDPOINT",
|
||||||
|
"缺少 AI 服务地址配置"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!candidate.apiKey) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
"MISSING_API_KEY",
|
||||||
|
"缺少 AI 服务密钥配置"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!candidate.model) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
"MISSING_MODEL",
|
||||||
|
"缺少 AI 模型配置"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const requestUrl = this.buildRequestUrl(candidate.endpoint);
|
||||||
|
|
||||||
|
let response: Response;
|
||||||
|
try {
|
||||||
|
response = await fetch(requestUrl, {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
Authorization: `Bearer ${candidate.apiKey}`
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: candidate.model,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: input.message
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stream: false
|
||||||
|
}),
|
||||||
|
signal: AbortSignal.timeout(30000)
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
"UPSTREAM_UNREACHABLE",
|
||||||
|
this.toErrorMessage(error, "AI 服务请求失败")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload: unknown;
|
||||||
|
try {
|
||||||
|
payload = await response.json();
|
||||||
|
} catch (error) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
"INVALID_RESPONSE",
|
||||||
|
this.toErrorMessage(error, "AI 服务返回了无法解析的数据")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
`UPSTREAM_HTTP_${response.status}`,
|
||||||
|
this.extractErrorMessage(payload, `AI 服务调用失败,状态码 ${response.status}`)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const content = this.extractAssistantText(payload);
|
||||||
|
if (!content.trim()) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
"EMPTY_RESPONSE",
|
||||||
|
"AI 服务没有返回有效内容"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
channel: candidate.channel,
|
||||||
|
providerName: candidate.providerName,
|
||||||
|
model: this.extractModel(payload) ?? candidate.model,
|
||||||
|
content,
|
||||||
|
sessionId: input.sessionId,
|
||||||
|
raw: payload
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private buildRequestUrl(endpoint: string): string {
|
||||||
|
const normalizedEndpoint = endpoint.replace(/\/+$/, "");
|
||||||
|
if (normalizedEndpoint.endsWith("/chat/completions")) {
|
||||||
|
return normalizedEndpoint;
|
||||||
|
}
|
||||||
|
if (normalizedEndpoint.endsWith("/v1")) {
|
||||||
|
return `${normalizedEndpoint}/chat/completions`;
|
||||||
|
}
|
||||||
|
return `${normalizedEndpoint}/v1/chat/completions`;
|
||||||
|
}
|
||||||
|
|
||||||
|
private extractAssistantText(payload: unknown): string {
|
||||||
|
if (!this.isRecord(payload)) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
const choices = payload["choices"];
|
||||||
|
if (!Array.isArray(choices) || choices.length === 0) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstChoice = choices[0];
|
||||||
|
if (!this.isRecord(firstChoice)) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
const message = firstChoice["message"];
|
||||||
|
if (!this.isRecord(message)) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.extractMessageContent(message["content"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
private extractMessageContent(content: unknown): string {
|
||||||
|
if (typeof content === "string") {
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!Array.isArray(content)) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
return content
|
||||||
|
.map((item) => {
|
||||||
|
if (!this.isRecord(item)) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof item["text"] === "string") {
|
||||||
|
return item["text"];
|
||||||
|
}
|
||||||
|
|
||||||
|
return "";
|
||||||
|
})
|
||||||
|
.filter((item) => item.length > 0)
|
||||||
|
.join("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
private extractModel(payload: unknown): string | null {
|
||||||
|
if (!this.isRecord(payload) || typeof payload["model"] !== "string") {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload["model"];
|
||||||
|
}
|
||||||
|
|
||||||
|
private extractErrorMessage(payload: unknown, fallback: string): string {
|
||||||
|
if (!this.isRecord(payload)) {
|
||||||
|
return fallback;
|
||||||
|
}
|
||||||
|
|
||||||
|
const error = payload["error"];
|
||||||
|
if (!this.isRecord(error) || typeof error["message"] !== "string") {
|
||||||
|
return fallback;
|
||||||
|
}
|
||||||
|
|
||||||
|
return error["message"];
|
||||||
|
}
|
||||||
|
|
||||||
|
private isRecord(value: unknown): value is Record<string, unknown> {
|
||||||
|
return typeof value === "object" && value !== null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private toErrorMessage(error: unknown, fallback: string): string {
|
||||||
|
if (error instanceof Error && error.message) {
|
||||||
|
return error.message;
|
||||||
|
}
|
||||||
|
|
||||||
|
return fallback;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import { Module } from "@nestjs/common";
|
import { Module } from "@nestjs/common";
|
||||||
import { ConfigModule } from "@nestjs/config";
|
import { ConfigModule } from "@nestjs/config";
|
||||||
|
import { AiModule } from "./ai/ai.module";
|
||||||
import { AttachmentModule } from "./attachment/attachment.module";
|
import { AttachmentModule } from "./attachment/attachment.module";
|
||||||
import { AuthModule } from "./auth/auth.module";
|
import { AuthModule } from "./auth/auth.module";
|
||||||
import { PrismaModule } from "./prisma/prisma.module";
|
import { PrismaModule } from "./prisma/prisma.module";
|
||||||
@@ -16,7 +17,8 @@ import { TaskModule } from "./task/task.module";
|
|||||||
AuthModule,
|
AuthModule,
|
||||||
TaskModule,
|
TaskModule,
|
||||||
AttachmentModule,
|
AttachmentModule,
|
||||||
SyncModule
|
SyncModule,
|
||||||
|
AiModule
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
export class AppModule {}
|
export class AppModule {}
|
||||||
|
|||||||
@@ -0,0 +1,395 @@
|
|||||||
|
import request from "supertest";
|
||||||
|
import { INestApplication, ValidationPipe } from "@nestjs/common";
|
||||||
|
import { Test, TestingModule } from "@nestjs/testing";
|
||||||
|
import { AiChannel, AiProviderBinding, AiPublicPoolConfig } from "../generated/prisma/client";
|
||||||
|
import { AiController } from "../src/ai/ai.controller";
|
||||||
|
import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service";
|
||||||
|
import { AiService } from "../src/ai/ai.service";
|
||||||
|
import { AiChannelExecutor, AiRouteFailureError } from "../src/ai/ai.types";
|
||||||
|
import { PrismaService } from "../src/prisma/prisma.service";
|
||||||
|
|
||||||
|
class InMemoryAiPrismaService {
|
||||||
|
private bindingIdSequence = 1;
|
||||||
|
private publicPoolIdSequence = 1;
|
||||||
|
private bindings: AiProviderBinding[] = [];
|
||||||
|
private publicPools: AiPublicPoolConfig[] = [];
|
||||||
|
|
||||||
|
readonly aiProviderBinding = {
|
||||||
|
findMany: async (args: {
|
||||||
|
where: {
|
||||||
|
userId: string;
|
||||||
|
};
|
||||||
|
}) => {
|
||||||
|
return this.bindings
|
||||||
|
.filter((binding) => binding.userId === args.where.userId)
|
||||||
|
.sort((left, right) => right.updatedAt.getTime() - left.updatedAt.getTime());
|
||||||
|
},
|
||||||
|
|
||||||
|
findFirst: async (args: {
|
||||||
|
where: {
|
||||||
|
id?: string;
|
||||||
|
userId?: string;
|
||||||
|
channel?: AiChannel;
|
||||||
|
isEnabled?: boolean;
|
||||||
|
};
|
||||||
|
}) => {
|
||||||
|
return (
|
||||||
|
this.bindings
|
||||||
|
.filter((binding) => {
|
||||||
|
if (args.where.id !== undefined && binding.id !== args.where.id) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (args.where.userId !== undefined && binding.userId !== args.where.userId) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (args.where.channel !== undefined && binding.channel !== args.where.channel) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (args.where.isEnabled !== undefined && binding.isEnabled !== args.where.isEnabled) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
})
|
||||||
|
.sort((left, right) => {
|
||||||
|
if (left.isDefault !== right.isDefault) {
|
||||||
|
return Number(right.isDefault) - Number(left.isDefault);
|
||||||
|
}
|
||||||
|
return right.updatedAt.getTime() - left.updatedAt.getTime();
|
||||||
|
})[0] ?? null
|
||||||
|
);
|
||||||
|
},
|
||||||
|
|
||||||
|
create: async (args: {
|
||||||
|
data: {
|
||||||
|
userId: string;
|
||||||
|
channel: AiChannel;
|
||||||
|
providerName: string;
|
||||||
|
model: string | null;
|
||||||
|
endpoint: string | null;
|
||||||
|
encryptedApiKey: string | null;
|
||||||
|
isDefault: boolean;
|
||||||
|
isEnabled: boolean;
|
||||||
|
};
|
||||||
|
}) => {
|
||||||
|
const now = new Date();
|
||||||
|
const binding: AiProviderBinding = {
|
||||||
|
id: `binding_${this.bindingIdSequence++}`,
|
||||||
|
userId: args.data.userId,
|
||||||
|
channel: args.data.channel,
|
||||||
|
providerName: args.data.providerName,
|
||||||
|
model: args.data.model,
|
||||||
|
encryptedApiKey: args.data.encryptedApiKey,
|
||||||
|
endpoint: args.data.endpoint,
|
||||||
|
isDefault: args.data.isDefault,
|
||||||
|
isEnabled: args.data.isEnabled,
|
||||||
|
createdAt: now,
|
||||||
|
updatedAt: now
|
||||||
|
};
|
||||||
|
|
||||||
|
this.bindings.push(binding);
|
||||||
|
return binding;
|
||||||
|
},
|
||||||
|
|
||||||
|
update: async (args: {
|
||||||
|
where: {
|
||||||
|
id: string;
|
||||||
|
};
|
||||||
|
data: Partial<AiProviderBinding>;
|
||||||
|
}) => {
|
||||||
|
const binding = this.bindings.find((item) => item.id === args.where.id);
|
||||||
|
if (!binding) {
|
||||||
|
throw new Error("binding not found");
|
||||||
|
}
|
||||||
|
|
||||||
|
Object.assign(binding, args.data, { updatedAt: new Date() });
|
||||||
|
return binding;
|
||||||
|
},
|
||||||
|
|
||||||
|
updateMany: async (args: {
|
||||||
|
where: {
|
||||||
|
userId?: string;
|
||||||
|
channel?: AiChannel;
|
||||||
|
id?: {
|
||||||
|
not: string;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
data: {
|
||||||
|
isDefault?: boolean;
|
||||||
|
};
|
||||||
|
}) => {
|
||||||
|
let count = 0;
|
||||||
|
for (const binding of this.bindings) {
|
||||||
|
if (args.where.userId !== undefined && binding.userId !== args.where.userId) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (args.where.channel !== undefined && binding.channel !== args.where.channel) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (args.where.id?.not !== undefined && binding.id === args.where.id.not) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args.data.isDefault !== undefined) {
|
||||||
|
binding.isDefault = args.data.isDefault;
|
||||||
|
binding.updatedAt = new Date();
|
||||||
|
}
|
||||||
|
count += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return { count };
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
readonly aiPublicPoolConfig = {
|
||||||
|
findFirst: async (args?: {
|
||||||
|
where?: {
|
||||||
|
enabled?: boolean;
|
||||||
|
};
|
||||||
|
}) => {
|
||||||
|
const items = this.publicPools
|
||||||
|
.filter((item) =>
|
||||||
|
args?.where?.enabled === undefined ? true : item.enabled === args.where.enabled
|
||||||
|
)
|
||||||
|
.sort((left, right) => right.updatedAt.getTime() - left.updatedAt.getTime());
|
||||||
|
|
||||||
|
return items[0] ?? null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
async $transaction<T>(callback: (tx: InMemoryAiPrismaService) => Promise<T>): Promise<T> {
|
||||||
|
return callback(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
seedBinding(binding: Omit<AiProviderBinding, "createdAt" | "updatedAt">): void {
|
||||||
|
const now = new Date();
|
||||||
|
this.bindings.push({
|
||||||
|
...binding,
|
||||||
|
createdAt: now,
|
||||||
|
updatedAt: now
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
seedPublicPool(publicPool: Omit<AiPublicPoolConfig, "id" | "createdAt" | "updatedAt">): void {
|
||||||
|
const now = new Date();
|
||||||
|
this.publicPools.push({
|
||||||
|
id: `pool_${this.publicPoolIdSequence++}`,
|
||||||
|
createdAt: now,
|
||||||
|
updatedAt: now,
|
||||||
|
...publicPool
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class StaticExecutor implements AiChannelExecutor {
|
||||||
|
constructor(
|
||||||
|
private readonly resolver: (channel: AiChannel) => {
|
||||||
|
content?: string;
|
||||||
|
code?: string;
|
||||||
|
message?: string;
|
||||||
|
}
|
||||||
|
) {}
|
||||||
|
|
||||||
|
async execute(candidate: { channel: AiChannel; providerName: string; model: string | null }) {
|
||||||
|
const result = this.resolver(candidate.channel);
|
||||||
|
if (result.code) {
|
||||||
|
throw new AiRouteFailureError(
|
||||||
|
candidate.channel,
|
||||||
|
candidate.providerName,
|
||||||
|
result.code,
|
||||||
|
result.message ?? "执行失败"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
channel: candidate.channel,
|
||||||
|
providerName: candidate.providerName,
|
||||||
|
model: candidate.model,
|
||||||
|
content: result.content ?? "",
|
||||||
|
sessionId: "session_ai",
|
||||||
|
raw: null
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("AiController (integration)", () => {
|
||||||
|
let app: INestApplication;
|
||||||
|
let prismaService: InMemoryAiPrismaService;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
prismaService = new InMemoryAiPrismaService();
|
||||||
|
|
||||||
|
const openAiExecutor = new StaticExecutor((channel) =>
|
||||||
|
channel === AiChannel.USER_KEY
|
||||||
|
? {
|
||||||
|
code: "UPSTREAM_UNREACHABLE",
|
||||||
|
message: "用户自备 Key 渠道暂时不可用"
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
content: "公共 AI 已接管"
|
||||||
|
}
|
||||||
|
);
|
||||||
|
const astrbotExecutor = new StaticExecutor(() => ({
|
||||||
|
content: "AstrBot 已接管"
|
||||||
|
}));
|
||||||
|
|
||||||
|
const moduleRef: TestingModule = await Test.createTestingModule({
|
||||||
|
controllers: [AiController],
|
||||||
|
providers: [
|
||||||
|
AiService,
|
||||||
|
{
|
||||||
|
provide: PrismaService,
|
||||||
|
useValue: prismaService
|
||||||
|
},
|
||||||
|
{
|
||||||
|
provide: AiProviderRegistryService,
|
||||||
|
useValue: {
|
||||||
|
getExecutor: (channel: AiChannel) =>
|
||||||
|
channel === AiChannel.ASTRBOT ? astrbotExecutor : openAiExecutor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}).compile();
|
||||||
|
|
||||||
|
app = moduleRef.createNestApplication();
|
||||||
|
app.useGlobalPipes(
|
||||||
|
new ValidationPipe({
|
||||||
|
transform: true,
|
||||||
|
whitelist: true,
|
||||||
|
forbidNonWhitelisted: true
|
||||||
|
})
|
||||||
|
);
|
||||||
|
await app.init();
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(async () => {
|
||||||
|
await app.close();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should create and list ai bindings", async () => {
|
||||||
|
await request(app.getHttpServer())
|
||||||
|
.post("/ai/bindings")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.send({
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
providerName: "astrbot-main",
|
||||||
|
model: "deepseek-chat",
|
||||||
|
endpoint: "http://127.0.0.1:6185",
|
||||||
|
apiKey: "abk_secret_1234",
|
||||||
|
isDefault: true,
|
||||||
|
isEnabled: true
|
||||||
|
})
|
||||||
|
.expect(201);
|
||||||
|
|
||||||
|
const response = await request(app.getHttpServer())
|
||||||
|
.get("/ai/bindings")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.expect(200);
|
||||||
|
|
||||||
|
expect(response.body.routeOrder).toEqual([
|
||||||
|
AiChannel.USER_KEY,
|
||||||
|
AiChannel.ASTRBOT,
|
||||||
|
AiChannel.PUBLIC_POOL
|
||||||
|
]);
|
||||||
|
expect(response.body.bindings).toHaveLength(1);
|
||||||
|
expect(response.body.bindings[0]).toMatchObject({
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
providerName: "astrbot-main",
|
||||||
|
model: "deepseek-chat",
|
||||||
|
hasApiKey: true,
|
||||||
|
maskedApiKey: "abk_***34",
|
||||||
|
isDefault: true
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should fallback from user key to astrbot", async () => {
|
||||||
|
prismaService.seedBinding({
|
||||||
|
id: "binding_user_key",
|
||||||
|
userId: "user_1",
|
||||||
|
channel: AiChannel.USER_KEY,
|
||||||
|
providerName: "openai",
|
||||||
|
model: "gpt-4o-mini",
|
||||||
|
encryptedApiKey: "sk-user",
|
||||||
|
endpoint: "https://api.example.com",
|
||||||
|
isDefault: true,
|
||||||
|
isEnabled: true
|
||||||
|
});
|
||||||
|
prismaService.seedBinding({
|
||||||
|
id: "binding_astrbot",
|
||||||
|
userId: "user_1",
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
providerName: "astrbot-main",
|
||||||
|
model: "deepseek-chat",
|
||||||
|
encryptedApiKey: "abk_astrbot",
|
||||||
|
endpoint: "http://127.0.0.1:6185",
|
||||||
|
isDefault: true,
|
||||||
|
isEnabled: true
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await request(app.getHttpServer())
|
||||||
|
.post("/ai/chat")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.send({
|
||||||
|
message: "帮我安排今天的任务"
|
||||||
|
})
|
||||||
|
.expect(201);
|
||||||
|
|
||||||
|
expect(response.body.channel).toBe(AiChannel.ASTRBOT);
|
||||||
|
expect(response.body.content).toBe("AstrBot 已接管");
|
||||||
|
expect(response.body.attempts).toEqual([
|
||||||
|
{
|
||||||
|
channel: AiChannel.USER_KEY,
|
||||||
|
providerName: "openai",
|
||||||
|
model: "gpt-4o-mini",
|
||||||
|
status: "failed",
|
||||||
|
reasonCode: "UPSTREAM_UNREACHABLE",
|
||||||
|
reasonMessage: "用户自备 Key 渠道暂时不可用"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
providerName: "astrbot-main",
|
||||||
|
model: "deepseek-chat",
|
||||||
|
status: "success",
|
||||||
|
reasonCode: null,
|
||||||
|
reasonMessage: null
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should return skipped attempts when no channel is available", async () => {
|
||||||
|
const response = await request(app.getHttpServer())
|
||||||
|
.post("/ai/chat")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.send({
|
||||||
|
message: "帮我总结今天的安排"
|
||||||
|
})
|
||||||
|
.expect(502);
|
||||||
|
|
||||||
|
expect(response.body.message).toBe("当前没有可用的 AI 通道,请稍后重试");
|
||||||
|
expect(response.body.attempts).toEqual([
|
||||||
|
{
|
||||||
|
channel: AiChannel.USER_KEY,
|
||||||
|
providerName: null,
|
||||||
|
model: null,
|
||||||
|
status: "skipped",
|
||||||
|
reasonCode: "CHANNEL_NOT_CONFIGURED",
|
||||||
|
reasonMessage: "当前用户未配置可用的自备 Key 通道"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
providerName: null,
|
||||||
|
model: null,
|
||||||
|
status: "skipped",
|
||||||
|
reasonCode: "CHANNEL_NOT_CONFIGURED",
|
||||||
|
reasonMessage: "当前用户未配置可用的 AstrBot 通道"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
channel: AiChannel.PUBLIC_POOL,
|
||||||
|
providerName: null,
|
||||||
|
model: null,
|
||||||
|
status: "skipped",
|
||||||
|
reasonCode: "PUBLIC_POOL_DISABLED",
|
||||||
|
reasonMessage: "公共 AI 通道未开启"
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user