fix(api-ai): stop astrbot stream on end event

This commit is contained in:
2026-04-06 12:19:13 +08:00
parent 180f7a9baa
commit 2bce9a59c6
2 changed files with 131 additions and 2 deletions
+51 -2
View File
@@ -57,8 +57,8 @@ export class AstrbotProvider implements AiChannelExecutor {
); );
} }
const rawText = await response.text();
if (!response.ok) { if (!response.ok) {
const rawText = await response.text();
throw new AiRouteFailureError( throw new AiRouteFailureError(
candidate.channel, candidate.channel,
candidate.providerName, candidate.providerName,
@@ -67,7 +67,7 @@ export class AstrbotProvider implements AiChannelExecutor {
); );
} }
const events = this.parseSseEvents(rawText); const events = await this.readSseEvents(response);
let content = ""; let content = "";
let sessionId = input.sessionId; let sessionId = input.sessionId;
@@ -167,6 +167,55 @@ export class AstrbotProvider implements AiChannelExecutor {
.filter((item): item is Record<string, unknown> => item !== null); .filter((item): item is Record<string, unknown> => item !== null);
} }
private async readSseEvents(response: Response): Promise<Array<Record<string, unknown>>> {
if (!response.body) {
return this.parseSseEvents(await response.text());
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
const events: Array<Record<string, unknown>> = [];
let buffer = "";
let reachedEndEvent = false;
try {
while (!reachedEndEvent) {
const { done, value } = await reader.read();
if (done) {
break;
}
buffer += decoder.decode(value, { stream: true });
const segments = buffer.split(/\r?\n\r?\n/);
buffer = segments.pop() ?? "";
for (const segment of segments) {
const parsedEvents = this.parseSseEvents(segment);
for (const event of parsedEvents) {
events.push(event);
if (this.readString(event["type"]) === "end") {
reachedEndEvent = true;
break;
}
}
if (reachedEndEvent) {
break;
}
}
}
const tail = `${buffer}${decoder.decode()}`;
if (tail.trim().length > 0) {
events.push(...this.parseSseEvents(tail));
}
} finally {
await reader.cancel();
}
return events;
}
private extractHttpErrorMessage(rawText: string, statusCode: number): string { private extractHttpErrorMessage(rawText: string, statusCode: number): string {
try { try {
const payload = JSON.parse(rawText) as Record<string, unknown>; const payload = JSON.parse(rawText) as Record<string, unknown>;
+80
View File
@@ -0,0 +1,80 @@
import { AiChannel } from "../generated/prisma/client";
import { AstrbotProvider } from "../src/ai/providers/astrbot.provider";
describe("AstrbotProvider", () => {
afterEach(() => {
jest.restoreAllMocks();
});
it("should stop reading once the end event arrives", async () => {
const encoder = new TextEncoder();
let pullCount = 0;
const stream = new ReadableStream<Uint8Array>({
pull(controller) {
pullCount += 1;
if (pullCount === 1) {
controller.enqueue(
encoder.encode('data: {"type":"session_id","data":null,"session_id":"session_1"}\n\n')
);
return;
}
if (pullCount === 2) {
controller.enqueue(
encoder.encode(
'data: {"type":"plain","data":"TodoList AstrBot 已连接","streaming":false,"chain_type":null}\n\n'
)
);
return;
}
if (pullCount === 3) {
controller.enqueue(
encoder.encode('data: {"type":"end","data":"","streaming":false}\n\n')
);
return;
}
return new Promise(() => undefined);
}
});
jest.spyOn(globalThis, "fetch").mockResolvedValue(
new Response(stream, {
status: 200,
headers: {
"Content-Type": "text/event-stream"
}
})
);
const provider = new AstrbotProvider();
const result = await Promise.race([
provider.execute(
{
channel: AiChannel.ASTRBOT,
source: "binding",
sourceId: "binding_1",
providerName: "",
model: null,
endpoint: "http://127.0.0.1:6185",
apiKey: "abk_test"
},
{
userId: "user_1",
message: "ping",
sessionId: null
}
),
new Promise<never>((_, reject) => {
setTimeout(() => reject(new Error("provider timeout")), 1000);
})
]);
expect(result.content).toBe("TodoList AstrBot 已连接");
expect(result.sessionId).toBe("session_1");
expect(pullCount).toBeGreaterThanOrEqual(3);
});
});