feat(api-ai): add provider registry and routing fallback
This commit is contained in:
@@ -0,0 +1,197 @@
|
||||
import { Injectable } from "@nestjs/common";
|
||||
import {
|
||||
AiChannelExecutor,
|
||||
AiChatInput,
|
||||
AiChatResult,
|
||||
AiResolvedRouteCandidate,
|
||||
AiRouteFailureError
|
||||
} from "../ai.types";
|
||||
|
||||
@Injectable()
|
||||
export class AstrbotProvider implements AiChannelExecutor {
|
||||
async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise<AiChatResult> {
|
||||
if (!candidate.endpoint) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
"MISSING_ENDPOINT",
|
||||
"缺少 AstrBot 服务地址配置"
|
||||
);
|
||||
}
|
||||
|
||||
if (!candidate.apiKey) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
"MISSING_API_KEY",
|
||||
"缺少 AstrBot API Key 配置"
|
||||
);
|
||||
}
|
||||
|
||||
const requestUrl = this.buildRequestUrl(candidate.endpoint);
|
||||
|
||||
let response: Response;
|
||||
try {
|
||||
response = await fetch(requestUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${candidate.apiKey}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
username: input.userId,
|
||||
session_id: input.sessionId ?? undefined,
|
||||
message: input.message,
|
||||
enable_streaming: false,
|
||||
selected_provider: candidate.providerName || undefined,
|
||||
selected_model: candidate.model ?? undefined
|
||||
}),
|
||||
signal: AbortSignal.timeout(30000)
|
||||
});
|
||||
} catch (error) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
"UPSTREAM_UNREACHABLE",
|
||||
this.toErrorMessage(error, "AstrBot 服务请求失败")
|
||||
);
|
||||
}
|
||||
|
||||
const rawText = await response.text();
|
||||
if (!response.ok) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
`UPSTREAM_HTTP_${response.status}`,
|
||||
this.extractHttpErrorMessage(rawText, response.status)
|
||||
);
|
||||
}
|
||||
|
||||
const events = this.parseSseEvents(rawText);
|
||||
let content = "";
|
||||
let sessionId = input.sessionId;
|
||||
|
||||
for (const event of events) {
|
||||
const type = this.readString(event["type"]);
|
||||
if (type === "session_id") {
|
||||
sessionId = this.readString(event["session_id"]) ?? sessionId;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (type === "error") {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
this.readString(event["code"]) ?? "ASTRBOT_ERROR",
|
||||
this.readString(event["data"]) ?? "AstrBot 返回错误"
|
||||
);
|
||||
}
|
||||
|
||||
if (type !== "plain") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const chainType = this.readString(event["chain_type"]);
|
||||
if (
|
||||
chainType === "reasoning" ||
|
||||
chainType === "tool_call" ||
|
||||
chainType === "tool_call_result"
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const data = this.readString(event["data"]);
|
||||
if (!data) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (event["streaming"] === true) {
|
||||
content += data;
|
||||
continue;
|
||||
}
|
||||
|
||||
content = data;
|
||||
}
|
||||
|
||||
if (!content.trim()) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
"EMPTY_RESPONSE",
|
||||
"AstrBot 没有返回有效内容"
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
channel: candidate.channel,
|
||||
providerName: candidate.providerName,
|
||||
model: candidate.model,
|
||||
content,
|
||||
sessionId,
|
||||
raw: events
|
||||
};
|
||||
}
|
||||
|
||||
private buildRequestUrl(endpoint: string): string {
|
||||
const normalizedEndpoint = endpoint.replace(/\/+$/, "");
|
||||
if (normalizedEndpoint.endsWith("/api/v1/chat")) {
|
||||
return normalizedEndpoint;
|
||||
}
|
||||
if (normalizedEndpoint.endsWith("/api/v1")) {
|
||||
return `${normalizedEndpoint}/chat`;
|
||||
}
|
||||
if (normalizedEndpoint.endsWith("/api")) {
|
||||
return `${normalizedEndpoint}/v1/chat`;
|
||||
}
|
||||
return `${normalizedEndpoint}/api/v1/chat`;
|
||||
}
|
||||
|
||||
private parseSseEvents(rawText: string): Array<Record<string, unknown>> {
|
||||
return rawText
|
||||
.split(/\r?\n\r?\n/)
|
||||
.map((block) =>
|
||||
block
|
||||
.split(/\r?\n/)
|
||||
.filter((line) => line.startsWith("data:"))
|
||||
.map((line) => line.slice(5).trim())
|
||||
.join("\n")
|
||||
)
|
||||
.filter((payload) => payload.length > 0)
|
||||
.map((payload) => {
|
||||
try {
|
||||
return JSON.parse(payload) as Record<string, unknown>;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter((item): item is Record<string, unknown> => item !== null);
|
||||
}
|
||||
|
||||
private extractHttpErrorMessage(rawText: string, statusCode: number): string {
|
||||
try {
|
||||
const payload = JSON.parse(rawText) as Record<string, unknown>;
|
||||
if (typeof payload["message"] === "string") {
|
||||
return payload["message"];
|
||||
}
|
||||
if (typeof payload["data"] === "string") {
|
||||
return payload["data"];
|
||||
}
|
||||
} catch {
|
||||
return `AstrBot 服务调用失败,状态码 ${statusCode}`;
|
||||
}
|
||||
|
||||
return `AstrBot 服务调用失败,状态码 ${statusCode}`;
|
||||
}
|
||||
|
||||
private readString(value: unknown): string | null {
|
||||
return typeof value === "string" ? value : null;
|
||||
}
|
||||
|
||||
private toErrorMessage(error: unknown, fallback: string): string {
|
||||
if (error instanceof Error && error.message) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
return fallback;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user