fix(api-ai): stop forwarding invalid astrbot selection fields

This commit is contained in:
2026-04-07 22:02:33 +08:00
parent 45b149ad58
commit ce72892dc8
2 changed files with 56 additions and 82 deletions
@@ -46,9 +46,6 @@ export class AstrbotProvider implements AiChannelExecutor {
session_id: input.sessionId ?? undefined,
message: input.message,
enable_streaming: false,
config_id: candidate.configId ?? undefined,
config_name: candidate.configName ?? undefined,
selected_provider: candidate.providerName || undefined,
selected_model: candidate.model ?? undefined
}),
signal: AbortSignal.timeout(30000)
+56 -79
View File
@@ -2,95 +2,72 @@ import { AiChannel } from "../generated/prisma/client";
import { AstrbotProvider } from "../src/ai/providers/astrbot.provider";
describe("AstrbotProvider", () => {
const originalFetch = global.fetch;
afterEach(() => {
global.fetch = originalFetch;
jest.restoreAllMocks();
});
it("should stop reading once the end event arrives", async () => {
const encoder = new TextEncoder();
let pullCount = 0;
it("should not forward binding label fields as astrbot selection parameters", async () => {
const provider = new AstrbotProvider();
const fetchMock = jest.fn(async (_input: unknown, init?: RequestInit) => {
expect(init?.method).toBe("POST");
const payload = JSON.parse(String(init?.body ?? "{}")) as Record<string, unknown>;
const stream = new ReadableStream<Uint8Array>({
pull(controller) {
pullCount += 1;
if (pullCount === 1) {
controller.enqueue(
encoder.encode('data: {"type":"session_id","data":null,"session_id":"session_1"}\n\n')
);
return;
expect(payload).toMatchObject({
username: "user_1",
session_id: "session_1",
message: "你好",
enable_streaming: false,
selected_model: "deepseek-chat"
});
expect(payload).not.toHaveProperty("selected_provider");
expect(payload).not.toHaveProperty("config_id");
expect(payload).not.toHaveProperty("config_name");
return new Response(
[
'data: {"type":"session_id","session_id":"session_1"}',
"",
'data: {"type":"plain","data":"收到","streaming":false,"chain_type":null}',
"",
'data: {"type":"end","data":"","streaming":false}',
""
].join("\n"),
{
status: 200,
headers: {
"content-type": "text/event-stream"
}
}
if (pullCount === 2) {
controller.enqueue(
encoder.encode(
'data: {"type":"plain","data":"TodoList AstrBot 已连接","streaming":false,"chain_type":null}\n\n'
)
);
return;
}
if (pullCount === 3) {
controller.enqueue(
encoder.encode(
'data: {"type":"agent_stats","data":{"token_usage":{"input_other":12,"input_cached":30,"output":8}}}\n\n'
)
);
return;
}
if (pullCount === 4) {
controller.enqueue(
encoder.encode('data: {"type":"end","data":"","streaming":false}\n\n')
);
return;
}
return new Promise(() => undefined);
}
);
});
jest.spyOn(globalThis, "fetch").mockResolvedValue(
new Response(stream, {
status: 200,
headers: {
"Content-Type": "text/event-stream"
}
})
global.fetch = fetchMock as typeof global.fetch;
const result = await provider.execute(
{
channel: AiChannel.ASTRBOT,
source: "binding",
sourceId: "binding_1",
providerName: "astrbot-main",
model: "deepseek-chat",
configId: "default",
configName: "默认配置",
endpoint: "http://127.0.0.1:6185",
apiKey: "abk_secret"
},
{
userId: "user_1",
message: "你好",
sessionId: "session_1"
}
);
const provider = new AstrbotProvider();
const result = await Promise.race([
provider.execute(
{
channel: AiChannel.ASTRBOT,
source: "binding",
sourceId: "binding_1",
providerName: "",
model: null,
configId: "default",
configName: null,
endpoint: "http://127.0.0.1:6185",
apiKey: "abk_test"
},
{
userId: "user_1",
message: "ping",
sessionId: null
}
),
new Promise<never>((_, reject) => {
setTimeout(() => reject(new Error("provider timeout")), 1000);
})
]);
expect(result.content).toBe("TodoList AstrBot 已连接");
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(result.content).toBe("收到");
expect(result.sessionId).toBe("session_1");
expect(result.usage).toEqual({
promptTokens: 42,
completionTokens: 8,
totalTokens: 50
});
expect(pullCount).toBeGreaterThanOrEqual(4);
expect(result.providerName).toBe("astrbot-main");
});
});