diff --git a/apps/api/src/ai/providers/astrbot.provider.ts b/apps/api/src/ai/providers/astrbot.provider.ts index 419139d..a413a3b 100644 --- a/apps/api/src/ai/providers/astrbot.provider.ts +++ b/apps/api/src/ai/providers/astrbot.provider.ts @@ -57,8 +57,8 @@ export class AstrbotProvider implements AiChannelExecutor { ); } - const rawText = await response.text(); if (!response.ok) { + const rawText = await response.text(); throw new AiRouteFailureError( candidate.channel, candidate.providerName, @@ -67,7 +67,7 @@ export class AstrbotProvider implements AiChannelExecutor { ); } - const events = this.parseSseEvents(rawText); + const events = await this.readSseEvents(response); let content = ""; let sessionId = input.sessionId; @@ -167,6 +167,55 @@ export class AstrbotProvider implements AiChannelExecutor { .filter((item): item is Record => item !== null); } + private async readSseEvents(response: Response): Promise>> { + if (!response.body) { + return this.parseSseEvents(await response.text()); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + const events: Array> = []; + let buffer = ""; + let reachedEndEvent = false; + + try { + while (!reachedEndEvent) { + const { done, value } = await reader.read(); + if (done) { + break; + } + + buffer += decoder.decode(value, { stream: true }); + const segments = buffer.split(/\r?\n\r?\n/); + buffer = segments.pop() ?? ""; + + for (const segment of segments) { + const parsedEvents = this.parseSseEvents(segment); + for (const event of parsedEvents) { + events.push(event); + if (this.readString(event["type"]) === "end") { + reachedEndEvent = true; + break; + } + } + + if (reachedEndEvent) { + break; + } + } + } + + const tail = `${buffer}${decoder.decode()}`; + if (tail.trim().length > 0) { + events.push(...this.parseSseEvents(tail)); + } + } finally { + await reader.cancel(); + } + + return events; + } + private extractHttpErrorMessage(rawText: string, statusCode: number): string { try { const payload = JSON.parse(rawText) as Record; diff --git a/apps/api/test/astrbot-provider.spec.ts b/apps/api/test/astrbot-provider.spec.ts new file mode 100644 index 0000000..6190c4a --- /dev/null +++ b/apps/api/test/astrbot-provider.spec.ts @@ -0,0 +1,80 @@ +import { AiChannel } from "../generated/prisma/client"; +import { AstrbotProvider } from "../src/ai/providers/astrbot.provider"; + +describe("AstrbotProvider", () => { + afterEach(() => { + jest.restoreAllMocks(); + }); + + it("should stop reading once the end event arrives", async () => { + const encoder = new TextEncoder(); + let pullCount = 0; + + const stream = new ReadableStream({ + pull(controller) { + pullCount += 1; + if (pullCount === 1) { + controller.enqueue( + encoder.encode('data: {"type":"session_id","data":null,"session_id":"session_1"}\n\n') + ); + return; + } + + if (pullCount === 2) { + controller.enqueue( + encoder.encode( + 'data: {"type":"plain","data":"TodoList AstrBot 已连接","streaming":false,"chain_type":null}\n\n' + ) + ); + return; + } + + if (pullCount === 3) { + controller.enqueue( + encoder.encode('data: {"type":"end","data":"","streaming":false}\n\n') + ); + return; + } + + return new Promise(() => undefined); + } + }); + + jest.spyOn(globalThis, "fetch").mockResolvedValue( + new Response(stream, { + status: 200, + headers: { + "Content-Type": "text/event-stream" + } + }) + ); + + const provider = new AstrbotProvider(); + + const result = await Promise.race([ + provider.execute( + { + channel: AiChannel.ASTRBOT, + source: "binding", + sourceId: "binding_1", + providerName: "", + model: null, + endpoint: "http://127.0.0.1:6185", + apiKey: "abk_test" + }, + { + userId: "user_1", + message: "ping", + sessionId: null + } + ), + new Promise((_, reject) => { + setTimeout(() => reject(new Error("provider timeout")), 1000); + }) + ]); + + expect(result.content).toBe("TodoList AstrBot 已连接"); + expect(result.sessionId).toBe("session_1"); + expect(pullCount).toBeGreaterThanOrEqual(3); + }); +});