feat(api-ai): add provider registry and routing fallback
This commit is contained in:
@@ -0,0 +1,395 @@
|
||||
import request from "supertest";
|
||||
import { INestApplication, ValidationPipe } from "@nestjs/common";
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { AiChannel, AiProviderBinding, AiPublicPoolConfig } from "../generated/prisma/client";
|
||||
import { AiController } from "../src/ai/ai.controller";
|
||||
import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service";
|
||||
import { AiService } from "../src/ai/ai.service";
|
||||
import { AiChannelExecutor, AiRouteFailureError } from "../src/ai/ai.types";
|
||||
import { PrismaService } from "../src/prisma/prisma.service";
|
||||
|
||||
class InMemoryAiPrismaService {
|
||||
private bindingIdSequence = 1;
|
||||
private publicPoolIdSequence = 1;
|
||||
private bindings: AiProviderBinding[] = [];
|
||||
private publicPools: AiPublicPoolConfig[] = [];
|
||||
|
||||
readonly aiProviderBinding = {
|
||||
findMany: async (args: {
|
||||
where: {
|
||||
userId: string;
|
||||
};
|
||||
}) => {
|
||||
return this.bindings
|
||||
.filter((binding) => binding.userId === args.where.userId)
|
||||
.sort((left, right) => right.updatedAt.getTime() - left.updatedAt.getTime());
|
||||
},
|
||||
|
||||
findFirst: async (args: {
|
||||
where: {
|
||||
id?: string;
|
||||
userId?: string;
|
||||
channel?: AiChannel;
|
||||
isEnabled?: boolean;
|
||||
};
|
||||
}) => {
|
||||
return (
|
||||
this.bindings
|
||||
.filter((binding) => {
|
||||
if (args.where.id !== undefined && binding.id !== args.where.id) {
|
||||
return false;
|
||||
}
|
||||
if (args.where.userId !== undefined && binding.userId !== args.where.userId) {
|
||||
return false;
|
||||
}
|
||||
if (args.where.channel !== undefined && binding.channel !== args.where.channel) {
|
||||
return false;
|
||||
}
|
||||
if (args.where.isEnabled !== undefined && binding.isEnabled !== args.where.isEnabled) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.sort((left, right) => {
|
||||
if (left.isDefault !== right.isDefault) {
|
||||
return Number(right.isDefault) - Number(left.isDefault);
|
||||
}
|
||||
return right.updatedAt.getTime() - left.updatedAt.getTime();
|
||||
})[0] ?? null
|
||||
);
|
||||
},
|
||||
|
||||
create: async (args: {
|
||||
data: {
|
||||
userId: string;
|
||||
channel: AiChannel;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
endpoint: string | null;
|
||||
encryptedApiKey: string | null;
|
||||
isDefault: boolean;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
}) => {
|
||||
const now = new Date();
|
||||
const binding: AiProviderBinding = {
|
||||
id: `binding_${this.bindingIdSequence++}`,
|
||||
userId: args.data.userId,
|
||||
channel: args.data.channel,
|
||||
providerName: args.data.providerName,
|
||||
model: args.data.model,
|
||||
encryptedApiKey: args.data.encryptedApiKey,
|
||||
endpoint: args.data.endpoint,
|
||||
isDefault: args.data.isDefault,
|
||||
isEnabled: args.data.isEnabled,
|
||||
createdAt: now,
|
||||
updatedAt: now
|
||||
};
|
||||
|
||||
this.bindings.push(binding);
|
||||
return binding;
|
||||
},
|
||||
|
||||
update: async (args: {
|
||||
where: {
|
||||
id: string;
|
||||
};
|
||||
data: Partial<AiProviderBinding>;
|
||||
}) => {
|
||||
const binding = this.bindings.find((item) => item.id === args.where.id);
|
||||
if (!binding) {
|
||||
throw new Error("binding not found");
|
||||
}
|
||||
|
||||
Object.assign(binding, args.data, { updatedAt: new Date() });
|
||||
return binding;
|
||||
},
|
||||
|
||||
updateMany: async (args: {
|
||||
where: {
|
||||
userId?: string;
|
||||
channel?: AiChannel;
|
||||
id?: {
|
||||
not: string;
|
||||
};
|
||||
};
|
||||
data: {
|
||||
isDefault?: boolean;
|
||||
};
|
||||
}) => {
|
||||
let count = 0;
|
||||
for (const binding of this.bindings) {
|
||||
if (args.where.userId !== undefined && binding.userId !== args.where.userId) {
|
||||
continue;
|
||||
}
|
||||
if (args.where.channel !== undefined && binding.channel !== args.where.channel) {
|
||||
continue;
|
||||
}
|
||||
if (args.where.id?.not !== undefined && binding.id === args.where.id.not) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (args.data.isDefault !== undefined) {
|
||||
binding.isDefault = args.data.isDefault;
|
||||
binding.updatedAt = new Date();
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
|
||||
return { count };
|
||||
}
|
||||
};
|
||||
|
||||
readonly aiPublicPoolConfig = {
|
||||
findFirst: async (args?: {
|
||||
where?: {
|
||||
enabled?: boolean;
|
||||
};
|
||||
}) => {
|
||||
const items = this.publicPools
|
||||
.filter((item) =>
|
||||
args?.where?.enabled === undefined ? true : item.enabled === args.where.enabled
|
||||
)
|
||||
.sort((left, right) => right.updatedAt.getTime() - left.updatedAt.getTime());
|
||||
|
||||
return items[0] ?? null;
|
||||
}
|
||||
};
|
||||
|
||||
async $transaction<T>(callback: (tx: InMemoryAiPrismaService) => Promise<T>): Promise<T> {
|
||||
return callback(this);
|
||||
}
|
||||
|
||||
seedBinding(binding: Omit<AiProviderBinding, "createdAt" | "updatedAt">): void {
|
||||
const now = new Date();
|
||||
this.bindings.push({
|
||||
...binding,
|
||||
createdAt: now,
|
||||
updatedAt: now
|
||||
});
|
||||
}
|
||||
|
||||
seedPublicPool(publicPool: Omit<AiPublicPoolConfig, "id" | "createdAt" | "updatedAt">): void {
|
||||
const now = new Date();
|
||||
this.publicPools.push({
|
||||
id: `pool_${this.publicPoolIdSequence++}`,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
...publicPool
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
class StaticExecutor implements AiChannelExecutor {
|
||||
constructor(
|
||||
private readonly resolver: (channel: AiChannel) => {
|
||||
content?: string;
|
||||
code?: string;
|
||||
message?: string;
|
||||
}
|
||||
) {}
|
||||
|
||||
async execute(candidate: { channel: AiChannel; providerName: string; model: string | null }) {
|
||||
const result = this.resolver(candidate.channel);
|
||||
if (result.code) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
result.code,
|
||||
result.message ?? "执行失败"
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
channel: candidate.channel,
|
||||
providerName: candidate.providerName,
|
||||
model: candidate.model,
|
||||
content: result.content ?? "",
|
||||
sessionId: "session_ai",
|
||||
raw: null
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
describe("AiController (integration)", () => {
|
||||
let app: INestApplication;
|
||||
let prismaService: InMemoryAiPrismaService;
|
||||
|
||||
beforeEach(async () => {
|
||||
prismaService = new InMemoryAiPrismaService();
|
||||
|
||||
const openAiExecutor = new StaticExecutor((channel) =>
|
||||
channel === AiChannel.USER_KEY
|
||||
? {
|
||||
code: "UPSTREAM_UNREACHABLE",
|
||||
message: "用户自备 Key 渠道暂时不可用"
|
||||
}
|
||||
: {
|
||||
content: "公共 AI 已接管"
|
||||
}
|
||||
);
|
||||
const astrbotExecutor = new StaticExecutor(() => ({
|
||||
content: "AstrBot 已接管"
|
||||
}));
|
||||
|
||||
const moduleRef: TestingModule = await Test.createTestingModule({
|
||||
controllers: [AiController],
|
||||
providers: [
|
||||
AiService,
|
||||
{
|
||||
provide: PrismaService,
|
||||
useValue: prismaService
|
||||
},
|
||||
{
|
||||
provide: AiProviderRegistryService,
|
||||
useValue: {
|
||||
getExecutor: (channel: AiChannel) =>
|
||||
channel === AiChannel.ASTRBOT ? astrbotExecutor : openAiExecutor
|
||||
}
|
||||
}
|
||||
]
|
||||
}).compile();
|
||||
|
||||
app = moduleRef.createNestApplication();
|
||||
app.useGlobalPipes(
|
||||
new ValidationPipe({
|
||||
transform: true,
|
||||
whitelist: true,
|
||||
forbidNonWhitelisted: true
|
||||
})
|
||||
);
|
||||
await app.init();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await app.close();
|
||||
});
|
||||
|
||||
it("should create and list ai bindings", async () => {
|
||||
await request(app.getHttpServer())
|
||||
.post("/ai/bindings")
|
||||
.set("x-user-id", "user_1")
|
||||
.send({
|
||||
channel: AiChannel.ASTRBOT,
|
||||
providerName: "astrbot-main",
|
||||
model: "deepseek-chat",
|
||||
endpoint: "http://127.0.0.1:6185",
|
||||
apiKey: "abk_secret_1234",
|
||||
isDefault: true,
|
||||
isEnabled: true
|
||||
})
|
||||
.expect(201);
|
||||
|
||||
const response = await request(app.getHttpServer())
|
||||
.get("/ai/bindings")
|
||||
.set("x-user-id", "user_1")
|
||||
.expect(200);
|
||||
|
||||
expect(response.body.routeOrder).toEqual([
|
||||
AiChannel.USER_KEY,
|
||||
AiChannel.ASTRBOT,
|
||||
AiChannel.PUBLIC_POOL
|
||||
]);
|
||||
expect(response.body.bindings).toHaveLength(1);
|
||||
expect(response.body.bindings[0]).toMatchObject({
|
||||
channel: AiChannel.ASTRBOT,
|
||||
providerName: "astrbot-main",
|
||||
model: "deepseek-chat",
|
||||
hasApiKey: true,
|
||||
maskedApiKey: "abk_***34",
|
||||
isDefault: true
|
||||
});
|
||||
});
|
||||
|
||||
it("should fallback from user key to astrbot", async () => {
|
||||
prismaService.seedBinding({
|
||||
id: "binding_user_key",
|
||||
userId: "user_1",
|
||||
channel: AiChannel.USER_KEY,
|
||||
providerName: "openai",
|
||||
model: "gpt-4o-mini",
|
||||
encryptedApiKey: "sk-user",
|
||||
endpoint: "https://api.example.com",
|
||||
isDefault: true,
|
||||
isEnabled: true
|
||||
});
|
||||
prismaService.seedBinding({
|
||||
id: "binding_astrbot",
|
||||
userId: "user_1",
|
||||
channel: AiChannel.ASTRBOT,
|
||||
providerName: "astrbot-main",
|
||||
model: "deepseek-chat",
|
||||
encryptedApiKey: "abk_astrbot",
|
||||
endpoint: "http://127.0.0.1:6185",
|
||||
isDefault: true,
|
||||
isEnabled: true
|
||||
});
|
||||
|
||||
const response = await request(app.getHttpServer())
|
||||
.post("/ai/chat")
|
||||
.set("x-user-id", "user_1")
|
||||
.send({
|
||||
message: "帮我安排今天的任务"
|
||||
})
|
||||
.expect(201);
|
||||
|
||||
expect(response.body.channel).toBe(AiChannel.ASTRBOT);
|
||||
expect(response.body.content).toBe("AstrBot 已接管");
|
||||
expect(response.body.attempts).toEqual([
|
||||
{
|
||||
channel: AiChannel.USER_KEY,
|
||||
providerName: "openai",
|
||||
model: "gpt-4o-mini",
|
||||
status: "failed",
|
||||
reasonCode: "UPSTREAM_UNREACHABLE",
|
||||
reasonMessage: "用户自备 Key 渠道暂时不可用"
|
||||
},
|
||||
{
|
||||
channel: AiChannel.ASTRBOT,
|
||||
providerName: "astrbot-main",
|
||||
model: "deepseek-chat",
|
||||
status: "success",
|
||||
reasonCode: null,
|
||||
reasonMessage: null
|
||||
}
|
||||
]);
|
||||
});
|
||||
|
||||
it("should return skipped attempts when no channel is available", async () => {
|
||||
const response = await request(app.getHttpServer())
|
||||
.post("/ai/chat")
|
||||
.set("x-user-id", "user_1")
|
||||
.send({
|
||||
message: "帮我总结今天的安排"
|
||||
})
|
||||
.expect(502);
|
||||
|
||||
expect(response.body.message).toBe("当前没有可用的 AI 通道,请稍后重试");
|
||||
expect(response.body.attempts).toEqual([
|
||||
{
|
||||
channel: AiChannel.USER_KEY,
|
||||
providerName: null,
|
||||
model: null,
|
||||
status: "skipped",
|
||||
reasonCode: "CHANNEL_NOT_CONFIGURED",
|
||||
reasonMessage: "当前用户未配置可用的自备 Key 通道"
|
||||
},
|
||||
{
|
||||
channel: AiChannel.ASTRBOT,
|
||||
providerName: null,
|
||||
model: null,
|
||||
status: "skipped",
|
||||
reasonCode: "CHANNEL_NOT_CONFIGURED",
|
||||
reasonMessage: "当前用户未配置可用的 AstrBot 通道"
|
||||
},
|
||||
{
|
||||
channel: AiChannel.PUBLIC_POOL,
|
||||
providerName: null,
|
||||
model: null,
|
||||
status: "skipped",
|
||||
reasonCode: "PUBLIC_POOL_DISABLED",
|
||||
reasonMessage: "公共 AI 通道未开启"
|
||||
}
|
||||
]);
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user