diff --git a/apps/api/src/ai/ai.controller.ts b/apps/api/src/ai/ai.controller.ts index 99afb38..8756e08 100644 --- a/apps/api/src/ai/ai.controller.ts +++ b/apps/api/src/ai/ai.controller.ts @@ -15,7 +15,8 @@ import { AiChatResponse, AiService, ListAiBindingsResponse, - ListAiUsageLogsResponse + ListAiUsageLogsResponse, + TestAiBindingResponse } from "./ai.service"; @Controller("ai") @@ -45,6 +46,14 @@ export class AiController { return this.aiService.upsertBinding(this.resolveUserId(userIdHeader), body); } + @Post("bindings/test") + async testBinding( + @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Body() body: UpsertAiProviderBindingDto + ): Promise { + return this.aiService.testBinding(this.resolveUserId(userIdHeader), body); + } + @Post("chat") async chat( @Headers("x-user-id") userIdHeader: string | string[] | undefined, diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts index 5e7e354..25cbf61 100644 --- a/apps/api/src/ai/ai.service.ts +++ b/apps/api/src/ai/ai.service.ts @@ -104,6 +104,23 @@ export type AiChatResponse = { attempts: AiRouteAttempt[]; }; +export type TestAiBindingResponse = + | { + success: true; + channel: AiChannel; + providerName: string; + model: string | null; + contentPreview: string; + } + | { + success: false; + channel: AiChannel; + providerName: string; + model: string | null; + code: string; + message: string; + }; + @Injectable() export class AiService { private readonly logger = new Logger(AiService.name); @@ -251,6 +268,65 @@ export class AiService { return this.serializeBinding(result); } + async testBinding( + userId: string, + dto: UpsertAiProviderBindingDto + ): Promise { + if (dto.channel === AiChannel.PUBLIC_POOL) { + throw new BadRequestException("公共 AI 通道不能由用户自行测试"); + } + + const candidate = await this.buildTestCandidate(userId, dto); + const executor = this.aiProviderRegistryService.getExecutor(candidate.channel); + + try { + const result = await executor.execute(candidate, { + userId, + message: "请只回复“连接成功”,不要添加其他内容。", + sessionId: null + }); + + return { + success: true, + channel: result.channel, + providerName: result.providerName, + model: result.model, + contentPreview: this.limitPreviewText(result.content) + }; + } catch (error) { + if (error instanceof AiRouteFailureError) { + return { + success: false, + channel: error.channel, + providerName: error.providerName, + model: candidate.model, + code: error.code, + message: error.message + }; + } + + if (error instanceof Error) { + return { + success: false, + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + code: "UNKNOWN_ERROR", + message: error.message + }; + } + + return { + success: false, + channel: candidate.channel, + providerName: candidate.providerName, + model: candidate.model, + code: "UNKNOWN_ERROR", + message: "未知错误" + }; + } + } + async chat( userId: string, dto: AiChatDto, @@ -433,6 +509,55 @@ export class AiService { }); } + private async buildTestCandidate( + userId: string, + dto: UpsertAiProviderBindingDto + ): Promise { + const existingBinding = await this.prismaService.aiProviderBinding.findFirst({ + where: { + userId, + channel: dto.channel + }, + orderBy: { + updatedAt: "desc" + } + }); + + const mergedDto: UpsertAiProviderBindingDto = { + channel: dto.channel, + providerName: + dto.providerName ?? this.readDecryptedString(existingBinding?.providerName ?? null) ?? "", + model: dto.model ?? this.readDecryptedString(existingBinding?.model ?? null) ?? undefined, + configId: + dto.configId ?? this.readDecryptedString(existingBinding?.configId ?? null) ?? undefined, + configName: + dto.configName ?? + this.readDecryptedString(existingBinding?.configName ?? null) ?? + undefined, + endpoint: + dto.endpoint ?? this.readDecryptedString(existingBinding?.endpoint ?? null) ?? undefined, + apiKey: + dto.apiKey ?? + this.readDecryptedString(existingBinding?.encryptedApiKey ?? null) ?? + undefined, + isEnabled: dto.isEnabled ?? existingBinding?.isEnabled ?? true + }; + + this.validateBindingInput(mergedDto); + + return { + channel: mergedDto.channel, + source: existingBinding ? "binding" : "binding", + sourceId: existingBinding?.id ?? null, + providerName: this.normalizeProviderName(mergedDto.providerName), + model: this.normalizeOptionalString(mergedDto.model), + configId: this.normalizeOptionalString(mergedDto.configId), + configName: this.normalizeOptionalString(mergedDto.configName), + endpoint: this.normalizeOptionalString(mergedDto.endpoint), + apiKey: this.normalizeOptionalString(mergedDto.apiKey) + }; + } + private toBindingCandidate(binding: AiProviderBinding): AiResolvedRouteCandidate { return { channel: binding.channel, @@ -755,6 +880,15 @@ export class AiService { return `${secret.slice(0, 4)}***${secret.slice(-2)}`; } + private limitPreviewText(content: string): string { + const normalizedContent = content.replace(/\s+/g, " ").trim(); + if (normalizedContent.length <= 60) { + return normalizedContent; + } + + return `${normalizedContent.slice(0, 60)}...`; + } + private getPriorityWeight(priority: TaskPriority): number { switch (priority) { case TaskPriority.URGENT: diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts index 486d016..7dc70e1 100644 --- a/apps/api/test/ai.spec.ts +++ b/apps/api/test/ai.spec.ts @@ -682,6 +682,103 @@ describe("AiController (integration)", () => { }); }); + it("should test binding with stored secret when api key is omitted", async () => { + prismaService.seedBinding({ + id: "binding_user_key_test_existing_secret", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + configId: null, + configName: null, + encryptedApiKey: "sk-existing", + endpoint: "https://api.example.com", + isDefault: false, + isEnabled: true + }); + + const executeSpy = jest.spyOn(openAiExecutor, "execute").mockResolvedValue({ + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + content: "连接成功", + sessionId: "session_binding_test", + usage: { + promptTokens: 1, + completionTokens: 1, + totalTokens: 2 + }, + raw: null + }); + + const response = await request(app.getHttpServer()) + .post("/ai/bindings/test") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + endpoint: "https://api.example.com" + }) + .expect(201); + + expect(response.body).toEqual({ + success: true, + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + contentPreview: "连接成功" + }); + expect(executeSpy).toHaveBeenCalledWith( + expect.objectContaining({ + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-4.1", + endpoint: "https://api.example.com", + apiKey: "sk-existing" + }), + expect.objectContaining({ + userId: "user_1" + }) + ); + }); + + it("should return structured failure result when binding test fails", async () => { + prismaService.seedBinding({ + id: "binding_user_key_test_failure", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-5.4", + configId: null, + configName: null, + encryptedApiKey: "sk-existing", + endpoint: "https://api.example.com", + isDefault: false, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/bindings/test") + .set("x-user-id", "user_1") + .send({ + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-5.4", + endpoint: "https://api.example.com" + }) + .expect(201); + + expect(response.body).toEqual({ + success: false, + channel: AiChannel.USER_KEY, + providerName: "airouter", + model: "gpt-5.4", + code: "UPSTREAM_UNREACHABLE", + message: "用户自备 Key 渠道暂时不可用" + }); + }); + it("should use selected channel without automatic fallback", async () => { prismaService.seedBinding({ id: "binding_user_key_selected", diff --git a/apps/web/src/pages/ai-chat-page.tsx b/apps/web/src/pages/ai-chat-page.tsx index f48e9a8..d514e63 100644 --- a/apps/web/src/pages/ai-chat-page.tsx +++ b/apps/web/src/pages/ai-chat-page.tsx @@ -1,4 +1,5 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import type { KeyboardEvent } from "react"; import { Bot, CircleAlert, @@ -287,6 +288,15 @@ export function AiChatPage({ session }: AiChatPageProps) { } } + function handleDraftKeyDown(event: KeyboardEvent): void { + if (event.key !== "Enter" || event.shiftKey || event.nativeEvent.isComposing) { + return; + } + + event.preventDefault(); + void handleSendMessage(); + } + return (
@@ -484,6 +494,7 @@ export function AiChatPage({ session }: AiChatPageProps) {