fix(api-ai): stop forwarding invalid astrbot selection fields
This commit is contained in:
@@ -46,9 +46,6 @@ export class AstrbotProvider implements AiChannelExecutor {
|
|||||||
session_id: input.sessionId ?? undefined,
|
session_id: input.sessionId ?? undefined,
|
||||||
message: input.message,
|
message: input.message,
|
||||||
enable_streaming: false,
|
enable_streaming: false,
|
||||||
config_id: candidate.configId ?? undefined,
|
|
||||||
config_name: candidate.configName ?? undefined,
|
|
||||||
selected_provider: candidate.providerName || undefined,
|
|
||||||
selected_model: candidate.model ?? undefined
|
selected_model: candidate.model ?? undefined
|
||||||
}),
|
}),
|
||||||
signal: AbortSignal.timeout(30000)
|
signal: AbortSignal.timeout(30000)
|
||||||
|
|||||||
@@ -2,95 +2,72 @@ import { AiChannel } from "../generated/prisma/client";
|
|||||||
import { AstrbotProvider } from "../src/ai/providers/astrbot.provider";
|
import { AstrbotProvider } from "../src/ai/providers/astrbot.provider";
|
||||||
|
|
||||||
describe("AstrbotProvider", () => {
|
describe("AstrbotProvider", () => {
|
||||||
|
const originalFetch = global.fetch;
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
|
global.fetch = originalFetch;
|
||||||
jest.restoreAllMocks();
|
jest.restoreAllMocks();
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should stop reading once the end event arrives", async () => {
|
it("should not forward binding label fields as astrbot selection parameters", async () => {
|
||||||
const encoder = new TextEncoder();
|
const provider = new AstrbotProvider();
|
||||||
let pullCount = 0;
|
const fetchMock = jest.fn(async (_input: unknown, init?: RequestInit) => {
|
||||||
|
expect(init?.method).toBe("POST");
|
||||||
|
const payload = JSON.parse(String(init?.body ?? "{}")) as Record<string, unknown>;
|
||||||
|
|
||||||
const stream = new ReadableStream<Uint8Array>({
|
expect(payload).toMatchObject({
|
||||||
pull(controller) {
|
username: "user_1",
|
||||||
pullCount += 1;
|
session_id: "session_1",
|
||||||
if (pullCount === 1) {
|
message: "你好",
|
||||||
controller.enqueue(
|
enable_streaming: false,
|
||||||
encoder.encode('data: {"type":"session_id","data":null,"session_id":"session_1"}\n\n')
|
selected_model: "deepseek-chat"
|
||||||
);
|
});
|
||||||
return;
|
expect(payload).not.toHaveProperty("selected_provider");
|
||||||
|
expect(payload).not.toHaveProperty("config_id");
|
||||||
|
expect(payload).not.toHaveProperty("config_name");
|
||||||
|
|
||||||
|
return new Response(
|
||||||
|
[
|
||||||
|
'data: {"type":"session_id","session_id":"session_1"}',
|
||||||
|
"",
|
||||||
|
'data: {"type":"plain","data":"收到","streaming":false,"chain_type":null}',
|
||||||
|
"",
|
||||||
|
'data: {"type":"end","data":"","streaming":false}',
|
||||||
|
""
|
||||||
|
].join("\n"),
|
||||||
|
{
|
||||||
|
status: 200,
|
||||||
|
headers: {
|
||||||
|
"content-type": "text/event-stream"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
);
|
||||||
if (pullCount === 2) {
|
|
||||||
controller.enqueue(
|
|
||||||
encoder.encode(
|
|
||||||
'data: {"type":"plain","data":"TodoList AstrBot 已连接","streaming":false,"chain_type":null}\n\n'
|
|
||||||
)
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pullCount === 3) {
|
|
||||||
controller.enqueue(
|
|
||||||
encoder.encode(
|
|
||||||
'data: {"type":"agent_stats","data":{"token_usage":{"input_other":12,"input_cached":30,"output":8}}}\n\n'
|
|
||||||
)
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pullCount === 4) {
|
|
||||||
controller.enqueue(
|
|
||||||
encoder.encode('data: {"type":"end","data":"","streaming":false}\n\n')
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Promise(() => undefined);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
jest.spyOn(globalThis, "fetch").mockResolvedValue(
|
global.fetch = fetchMock as typeof global.fetch;
|
||||||
new Response(stream, {
|
|
||||||
status: 200,
|
const result = await provider.execute(
|
||||||
headers: {
|
{
|
||||||
"Content-Type": "text/event-stream"
|
channel: AiChannel.ASTRBOT,
|
||||||
}
|
source: "binding",
|
||||||
})
|
sourceId: "binding_1",
|
||||||
|
providerName: "astrbot-main",
|
||||||
|
model: "deepseek-chat",
|
||||||
|
configId: "default",
|
||||||
|
configName: "默认配置",
|
||||||
|
endpoint: "http://127.0.0.1:6185",
|
||||||
|
apiKey: "abk_secret"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
userId: "user_1",
|
||||||
|
message: "你好",
|
||||||
|
sessionId: "session_1"
|
||||||
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const provider = new AstrbotProvider();
|
expect(fetchMock).toHaveBeenCalledTimes(1);
|
||||||
|
expect(result.content).toBe("收到");
|
||||||
const result = await Promise.race([
|
|
||||||
provider.execute(
|
|
||||||
{
|
|
||||||
channel: AiChannel.ASTRBOT,
|
|
||||||
source: "binding",
|
|
||||||
sourceId: "binding_1",
|
|
||||||
providerName: "",
|
|
||||||
model: null,
|
|
||||||
configId: "default",
|
|
||||||
configName: null,
|
|
||||||
endpoint: "http://127.0.0.1:6185",
|
|
||||||
apiKey: "abk_test"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
userId: "user_1",
|
|
||||||
message: "ping",
|
|
||||||
sessionId: null
|
|
||||||
}
|
|
||||||
),
|
|
||||||
new Promise<never>((_, reject) => {
|
|
||||||
setTimeout(() => reject(new Error("provider timeout")), 1000);
|
|
||||||
})
|
|
||||||
]);
|
|
||||||
|
|
||||||
expect(result.content).toBe("TodoList AstrBot 已连接");
|
|
||||||
expect(result.sessionId).toBe("session_1");
|
expect(result.sessionId).toBe("session_1");
|
||||||
expect(result.usage).toEqual({
|
expect(result.providerName).toBe("astrbot-main");
|
||||||
promptTokens: 42,
|
|
||||||
completionTokens: 8,
|
|
||||||
totalTokens: 50
|
|
||||||
});
|
|
||||||
expect(pullCount).toBeGreaterThanOrEqual(4);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user