fix(api-ai): stop astrbot stream on end event
This commit is contained in:
@@ -57,8 +57,8 @@ export class AstrbotProvider implements AiChannelExecutor {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const rawText = await response.text();
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
|
const rawText = await response.text();
|
||||||
throw new AiRouteFailureError(
|
throw new AiRouteFailureError(
|
||||||
candidate.channel,
|
candidate.channel,
|
||||||
candidate.providerName,
|
candidate.providerName,
|
||||||
@@ -67,7 +67,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const events = this.parseSseEvents(rawText);
|
const events = await this.readSseEvents(response);
|
||||||
let content = "";
|
let content = "";
|
||||||
let sessionId = input.sessionId;
|
let sessionId = input.sessionId;
|
||||||
|
|
||||||
@@ -167,6 +167,55 @@ export class AstrbotProvider implements AiChannelExecutor {
|
|||||||
.filter((item): item is Record<string, unknown> => item !== null);
|
.filter((item): item is Record<string, unknown> => item !== null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async readSseEvents(response: Response): Promise<Array<Record<string, unknown>>> {
|
||||||
|
if (!response.body) {
|
||||||
|
return this.parseSseEvents(await response.text());
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
const events: Array<Record<string, unknown>> = [];
|
||||||
|
let buffer = "";
|
||||||
|
let reachedEndEvent = false;
|
||||||
|
|
||||||
|
try {
|
||||||
|
while (!reachedEndEvent) {
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer += decoder.decode(value, { stream: true });
|
||||||
|
const segments = buffer.split(/\r?\n\r?\n/);
|
||||||
|
buffer = segments.pop() ?? "";
|
||||||
|
|
||||||
|
for (const segment of segments) {
|
||||||
|
const parsedEvents = this.parseSseEvents(segment);
|
||||||
|
for (const event of parsedEvents) {
|
||||||
|
events.push(event);
|
||||||
|
if (this.readString(event["type"]) === "end") {
|
||||||
|
reachedEndEvent = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (reachedEndEvent) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const tail = `${buffer}${decoder.decode()}`;
|
||||||
|
if (tail.trim().length > 0) {
|
||||||
|
events.push(...this.parseSseEvents(tail));
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
await reader.cancel();
|
||||||
|
}
|
||||||
|
|
||||||
|
return events;
|
||||||
|
}
|
||||||
|
|
||||||
private extractHttpErrorMessage(rawText: string, statusCode: number): string {
|
private extractHttpErrorMessage(rawText: string, statusCode: number): string {
|
||||||
try {
|
try {
|
||||||
const payload = JSON.parse(rawText) as Record<string, unknown>;
|
const payload = JSON.parse(rawText) as Record<string, unknown>;
|
||||||
|
|||||||
@@ -0,0 +1,80 @@
|
|||||||
|
import { AiChannel } from "../generated/prisma/client";
|
||||||
|
import { AstrbotProvider } from "../src/ai/providers/astrbot.provider";
|
||||||
|
|
||||||
|
describe("AstrbotProvider", () => {
|
||||||
|
afterEach(() => {
|
||||||
|
jest.restoreAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should stop reading once the end event arrives", async () => {
|
||||||
|
const encoder = new TextEncoder();
|
||||||
|
let pullCount = 0;
|
||||||
|
|
||||||
|
const stream = new ReadableStream<Uint8Array>({
|
||||||
|
pull(controller) {
|
||||||
|
pullCount += 1;
|
||||||
|
if (pullCount === 1) {
|
||||||
|
controller.enqueue(
|
||||||
|
encoder.encode('data: {"type":"session_id","data":null,"session_id":"session_1"}\n\n')
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pullCount === 2) {
|
||||||
|
controller.enqueue(
|
||||||
|
encoder.encode(
|
||||||
|
'data: {"type":"plain","data":"TodoList AstrBot 已连接","streaming":false,"chain_type":null}\n\n'
|
||||||
|
)
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pullCount === 3) {
|
||||||
|
controller.enqueue(
|
||||||
|
encoder.encode('data: {"type":"end","data":"","streaming":false}\n\n')
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new Promise(() => undefined);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
jest.spyOn(globalThis, "fetch").mockResolvedValue(
|
||||||
|
new Response(stream, {
|
||||||
|
status: 200,
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "text/event-stream"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
const provider = new AstrbotProvider();
|
||||||
|
|
||||||
|
const result = await Promise.race([
|
||||||
|
provider.execute(
|
||||||
|
{
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
source: "binding",
|
||||||
|
sourceId: "binding_1",
|
||||||
|
providerName: "",
|
||||||
|
model: null,
|
||||||
|
endpoint: "http://127.0.0.1:6185",
|
||||||
|
apiKey: "abk_test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
userId: "user_1",
|
||||||
|
message: "ping",
|
||||||
|
sessionId: null
|
||||||
|
}
|
||||||
|
),
|
||||||
|
new Promise<never>((_, reject) => {
|
||||||
|
setTimeout(() => reject(new Error("provider timeout")), 1000);
|
||||||
|
})
|
||||||
|
]);
|
||||||
|
|
||||||
|
expect(result.content).toBe("TodoList AstrBot 已连接");
|
||||||
|
expect(result.sessionId).toBe("session_1");
|
||||||
|
expect(pullCount).toBeGreaterThanOrEqual(3);
|
||||||
|
});
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user