feat(ai-config): validate bindings before save

This commit is contained in:
2026-04-08 00:09:36 +08:00
parent 929b838e0f
commit e5948cd346
6 changed files with 381 additions and 15 deletions
+10 -1
View File
@@ -15,7 +15,8 @@ import {
AiChatResponse, AiChatResponse,
AiService, AiService,
ListAiBindingsResponse, ListAiBindingsResponse,
ListAiUsageLogsResponse ListAiUsageLogsResponse,
TestAiBindingResponse
} from "./ai.service"; } from "./ai.service";
@Controller("ai") @Controller("ai")
@@ -45,6 +46,14 @@ export class AiController {
return this.aiService.upsertBinding(this.resolveUserId(userIdHeader), body); return this.aiService.upsertBinding(this.resolveUserId(userIdHeader), body);
} }
@Post("bindings/test")
async testBinding(
@Headers("x-user-id") userIdHeader: string | string[] | undefined,
@Body() body: UpsertAiProviderBindingDto
): Promise<TestAiBindingResponse> {
return this.aiService.testBinding(this.resolveUserId(userIdHeader), body);
}
@Post("chat") @Post("chat")
async chat( async chat(
@Headers("x-user-id") userIdHeader: string | string[] | undefined, @Headers("x-user-id") userIdHeader: string | string[] | undefined,
+134
View File
@@ -104,6 +104,23 @@ export type AiChatResponse = {
attempts: AiRouteAttempt[]; attempts: AiRouteAttempt[];
}; };
export type TestAiBindingResponse =
| {
success: true;
channel: AiChannel;
providerName: string;
model: string | null;
contentPreview: string;
}
| {
success: false;
channel: AiChannel;
providerName: string;
model: string | null;
code: string;
message: string;
};
@Injectable() @Injectable()
export class AiService { export class AiService {
private readonly logger = new Logger(AiService.name); private readonly logger = new Logger(AiService.name);
@@ -251,6 +268,65 @@ export class AiService {
return this.serializeBinding(result); return this.serializeBinding(result);
} }
async testBinding(
userId: string,
dto: UpsertAiProviderBindingDto
): Promise<TestAiBindingResponse> {
if (dto.channel === AiChannel.PUBLIC_POOL) {
throw new BadRequestException("公共 AI 通道不能由用户自行测试");
}
const candidate = await this.buildTestCandidate(userId, dto);
const executor = this.aiProviderRegistryService.getExecutor(candidate.channel);
try {
const result = await executor.execute(candidate, {
userId,
message: "请只回复“连接成功”,不要添加其他内容。",
sessionId: null
});
return {
success: true,
channel: result.channel,
providerName: result.providerName,
model: result.model,
contentPreview: this.limitPreviewText(result.content)
};
} catch (error) {
if (error instanceof AiRouteFailureError) {
return {
success: false,
channel: error.channel,
providerName: error.providerName,
model: candidate.model,
code: error.code,
message: error.message
};
}
if (error instanceof Error) {
return {
success: false,
channel: candidate.channel,
providerName: candidate.providerName,
model: candidate.model,
code: "UNKNOWN_ERROR",
message: error.message
};
}
return {
success: false,
channel: candidate.channel,
providerName: candidate.providerName,
model: candidate.model,
code: "UNKNOWN_ERROR",
message: "未知错误"
};
}
}
async chat( async chat(
userId: string, userId: string,
dto: AiChatDto, dto: AiChatDto,
@@ -433,6 +509,55 @@ export class AiService {
}); });
} }
private async buildTestCandidate(
userId: string,
dto: UpsertAiProviderBindingDto
): Promise<AiResolvedRouteCandidate> {
const existingBinding = await this.prismaService.aiProviderBinding.findFirst({
where: {
userId,
channel: dto.channel
},
orderBy: {
updatedAt: "desc"
}
});
const mergedDto: UpsertAiProviderBindingDto = {
channel: dto.channel,
providerName:
dto.providerName ?? this.readDecryptedString(existingBinding?.providerName ?? null) ?? "",
model: dto.model ?? this.readDecryptedString(existingBinding?.model ?? null) ?? undefined,
configId:
dto.configId ?? this.readDecryptedString(existingBinding?.configId ?? null) ?? undefined,
configName:
dto.configName ??
this.readDecryptedString(existingBinding?.configName ?? null) ??
undefined,
endpoint:
dto.endpoint ?? this.readDecryptedString(existingBinding?.endpoint ?? null) ?? undefined,
apiKey:
dto.apiKey ??
this.readDecryptedString(existingBinding?.encryptedApiKey ?? null) ??
undefined,
isEnabled: dto.isEnabled ?? existingBinding?.isEnabled ?? true
};
this.validateBindingInput(mergedDto);
return {
channel: mergedDto.channel,
source: existingBinding ? "binding" : "binding",
sourceId: existingBinding?.id ?? null,
providerName: this.normalizeProviderName(mergedDto.providerName),
model: this.normalizeOptionalString(mergedDto.model),
configId: this.normalizeOptionalString(mergedDto.configId),
configName: this.normalizeOptionalString(mergedDto.configName),
endpoint: this.normalizeOptionalString(mergedDto.endpoint),
apiKey: this.normalizeOptionalString(mergedDto.apiKey)
};
}
private toBindingCandidate(binding: AiProviderBinding): AiResolvedRouteCandidate { private toBindingCandidate(binding: AiProviderBinding): AiResolvedRouteCandidate {
return { return {
channel: binding.channel, channel: binding.channel,
@@ -755,6 +880,15 @@ export class AiService {
return `${secret.slice(0, 4)}***${secret.slice(-2)}`; return `${secret.slice(0, 4)}***${secret.slice(-2)}`;
} }
private limitPreviewText(content: string): string {
const normalizedContent = content.replace(/\s+/g, " ").trim();
if (normalizedContent.length <= 60) {
return normalizedContent;
}
return `${normalizedContent.slice(0, 60)}...`;
}
private getPriorityWeight(priority: TaskPriority): number { private getPriorityWeight(priority: TaskPriority): number {
switch (priority) { switch (priority) {
case TaskPriority.URGENT: case TaskPriority.URGENT:
+97
View File
@@ -682,6 +682,103 @@ describe("AiController (integration)", () => {
}); });
}); });
it("should test binding with stored secret when api key is omitted", async () => {
prismaService.seedBinding({
id: "binding_user_key_test_existing_secret",
userId: "user_1",
channel: AiChannel.USER_KEY,
providerName: "airouter",
model: "gpt-4.1",
configId: null,
configName: null,
encryptedApiKey: "sk-existing",
endpoint: "https://api.example.com",
isDefault: false,
isEnabled: true
});
const executeSpy = jest.spyOn(openAiExecutor, "execute").mockResolvedValue({
channel: AiChannel.USER_KEY,
providerName: "airouter",
model: "gpt-4.1",
content: "连接成功",
sessionId: "session_binding_test",
usage: {
promptTokens: 1,
completionTokens: 1,
totalTokens: 2
},
raw: null
});
const response = await request(app.getHttpServer())
.post("/ai/bindings/test")
.set("x-user-id", "user_1")
.send({
channel: AiChannel.USER_KEY,
providerName: "airouter",
model: "gpt-4.1",
endpoint: "https://api.example.com"
})
.expect(201);
expect(response.body).toEqual({
success: true,
channel: AiChannel.USER_KEY,
providerName: "airouter",
model: "gpt-4.1",
contentPreview: "连接成功"
});
expect(executeSpy).toHaveBeenCalledWith(
expect.objectContaining({
channel: AiChannel.USER_KEY,
providerName: "airouter",
model: "gpt-4.1",
endpoint: "https://api.example.com",
apiKey: "sk-existing"
}),
expect.objectContaining({
userId: "user_1"
})
);
});
it("should return structured failure result when binding test fails", async () => {
prismaService.seedBinding({
id: "binding_user_key_test_failure",
userId: "user_1",
channel: AiChannel.USER_KEY,
providerName: "airouter",
model: "gpt-5.4",
configId: null,
configName: null,
encryptedApiKey: "sk-existing",
endpoint: "https://api.example.com",
isDefault: false,
isEnabled: true
});
const response = await request(app.getHttpServer())
.post("/ai/bindings/test")
.set("x-user-id", "user_1")
.send({
channel: AiChannel.USER_KEY,
providerName: "airouter",
model: "gpt-5.4",
endpoint: "https://api.example.com"
})
.expect(201);
expect(response.body).toEqual({
success: false,
channel: AiChannel.USER_KEY,
providerName: "airouter",
model: "gpt-5.4",
code: "UPSTREAM_UNREACHABLE",
message: "用户自备 Key 渠道暂时不可用"
});
});
it("should use selected channel without automatic fallback", async () => { it("should use selected channel without automatic fallback", async () => {
prismaService.seedBinding({ prismaService.seedBinding({
id: "binding_user_key_selected", id: "binding_user_key_selected",
+11
View File
@@ -1,4 +1,5 @@
import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import type { KeyboardEvent } from "react";
import { import {
Bot, Bot,
CircleAlert, CircleAlert,
@@ -287,6 +288,15 @@ export function AiChatPage({ session }: AiChatPageProps) {
} }
} }
function handleDraftKeyDown(event: KeyboardEvent<HTMLTextAreaElement>): void {
if (event.key !== "Enter" || event.shiftKey || event.nativeEvent.isComposing) {
return;
}
event.preventDefault();
void handleSendMessage();
}
return ( return (
<section className="space-y-4"> <section className="space-y-4">
<div className="rounded-[2rem] border border-border/70 bg-card/92 p-6 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]"> <div className="rounded-[2rem] border border-border/70 bg-card/92 p-6 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]">
@@ -484,6 +494,7 @@ export function AiChatPage({ session }: AiChatPageProps) {
<textarea <textarea
value={draftMessage} value={draftMessage}
onChange={(event) => setDraftMessage(event.target.value)} onChange={(event) => setDraftMessage(event.target.value)}
onKeyDown={handleDraftKeyDown}
placeholder="输入你的问题,例如:结合我当前待办,帮我排一下今天的优先级。" placeholder="输入你的问题,例如:结合我当前待办,帮我排一下今天的优先级。"
className="min-h-[140px] w-full rounded-2xl border border-border bg-background px-4 py-3 text-sm leading-7 outline-none transition-colors placeholder:text-muted-foreground focus:border-primary/40" className="min-h-[140px] w-full rounded-2xl border border-border bg-background px-4 py-3 text-sm leading-7 outline-none transition-colors placeholder:text-muted-foreground focus:border-primary/40"
/> />
+95 -14
View File
@@ -4,6 +4,7 @@ import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { import {
listAiBindings, listAiBindings,
testAiBinding,
upsertAiBinding, upsertAiBinding,
type WebAiBindingSummary, type WebAiBindingSummary,
type WebAiBindingsResponse, type WebAiBindingsResponse,
@@ -27,6 +28,12 @@ type NoticeState = {
message: string; message: string;
}; };
type ChannelNoticeState = NoticeState & {
detail?: string;
};
const TODOLIST_VERSION = "0.1.0";
function AiConfigCard({ function AiConfigCard({
channel, channel,
title, title,
@@ -36,7 +43,8 @@ function AiConfigCard({
onChange, onChange,
onSave, onSave,
saving, saving,
binding binding,
notice
}: { }: {
channel: Exclude<WebAiChannel, "PUBLIC_POOL">; channel: Exclude<WebAiChannel, "PUBLIC_POOL">;
title: string; title: string;
@@ -47,6 +55,7 @@ function AiConfigCard({
onSave: () => Promise<void>; onSave: () => Promise<void>;
saving: boolean; saving: boolean;
binding: WebAiBindingSummary | null; binding: WebAiBindingSummary | null;
notice: ChannelNoticeState | null;
}) { }) {
return ( return (
<section className="rounded-[2rem] border border-border/70 bg-card/92 p-5 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]"> <section className="rounded-[2rem] border border-border/70 bg-card/92 p-5 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]">
@@ -72,6 +81,27 @@ function AiConfigCard({
</span> </span>
</div> </div>
{notice ? (
<div
className={cn(
"mt-4 rounded-2xl border px-3 py-3 text-sm",
notice.tone === "success"
? "border-emerald-500/20 bg-emerald-500/10 text-emerald-700 dark:text-emerald-300"
: "border-destructive/20 bg-destructive/10 text-destructive"
)}
>
<div className="flex items-start gap-2">
<CheckCircle2 className="mt-0.5 size-4 shrink-0" />
<div className="min-w-0">
<div>{notice.message}</div>
{notice.detail ? (
<div className="mt-1 text-xs leading-6 opacity-80">{notice.detail}</div>
) : null}
</div>
</div>
</div>
) : null}
<div className="mt-5 grid gap-3 sm:grid-cols-2"> <div className="mt-5 grid gap-3 sm:grid-cols-2">
<label className="space-y-1.5"> <label className="space-y-1.5">
<span className="text-xs font-medium text-muted-foreground"></span> <span className="text-xs font-medium text-muted-foreground"></span>
@@ -200,16 +230,20 @@ function AiConfigCard({
{channel === "USER_KEY" {channel === "USER_KEY"
? "该配置按用户单独保存,适合接入你自己的服务商密钥。" ? "该配置按用户单独保存,适合接入你自己的服务商密钥。"
: "该配置按用户单独保存,适合直接复用 AstrBot 中已有的模型能力。"} : "该配置按用户单独保存,适合直接复用 AstrBot 中已有的模型能力。"}
<br />
</p> </p>
<Button type="button" onClick={() => void onSave()} disabled={saving}> <Button type="button" onClick={() => void onSave()} disabled={saving}>
{saving ? ( {saving ? (
<> <>
<LoaderCircle className="size-4 animate-spin" /> <LoaderCircle className="size-4 animate-spin" />
</> </>
) : formState.isEnabled ? (
"测试并保存"
) : ( ) : (
"保存配置" "保存草稿"
)} )}
</Button> </Button>
</div> </div>
@@ -224,6 +258,9 @@ export function SettingsPage({ session }: SettingsPageProps) {
const [refreshing, setRefreshing] = useState(false); const [refreshing, setRefreshing] = useState(false);
const [notice, setNotice] = useState<NoticeState | null>(null); const [notice, setNotice] = useState<NoticeState | null>(null);
const [savingChannel, setSavingChannel] = useState<WebAiChannel | null>(null); const [savingChannel, setSavingChannel] = useState<WebAiChannel | null>(null);
const [channelNotices, setChannelNotices] = useState<
Partial<Record<Exclude<WebAiChannel, "PUBLIC_POOL">, ChannelNoticeState>>
>({});
const [userKeyForm, setUserKeyForm] = useState<AiBindingFormState>(() => const [userKeyForm, setUserKeyForm] = useState<AiBindingFormState>(() =>
createAiBindingFormState() createAiBindingFormState()
); );
@@ -283,14 +320,49 @@ export function SettingsPage({ session }: SettingsPageProps) {
async function handleSaveChannel(channel: Exclude<WebAiChannel, "PUBLIC_POOL">): Promise<void> { async function handleSaveChannel(channel: Exclude<WebAiChannel, "PUBLIC_POOL">): Promise<void> {
const formState = channel === "USER_KEY" ? userKeyForm : astrbotForm; const formState = channel === "USER_KEY" ? userKeyForm : astrbotForm;
const binding = bindingMap.get(channel) ?? null; const binding = bindingMap.get(channel) ?? null;
const payload = buildAiBindingPayload(channel, formState, binding);
try { try {
setSavingChannel(channel); setSavingChannel(channel);
await upsertAiBinding(session, buildAiBindingPayload(channel, formState, binding)); setChannelNotices((current) => ({
setNotice({ ...current,
tone: "success", [channel]: undefined
message: channel === "USER_KEY" ? "自备厂商配置已保存。" : "AstrBot 配置已保存。" }));
}); if (payload.isEnabled) {
const testResult = await testAiBinding(session, payload);
if (!testResult.success) {
setChannelNotices((current) => ({
...current,
[channel]: {
tone: "error",
message: `连通性测试未通过:${testResult.message}`,
detail: binding
? "测试的是你当前编辑中的草稿配置。由于未保存,系统仍会继续使用上一份已保存配置,所以聊天可能依然正常。"
: "当前还没有已保存配置。请先修正表单中的地址、模型或密钥后再测试。"
}
}));
return;
}
}
await upsertAiBinding(session, payload);
setChannelNotices((current) => ({
...current,
[channel]: {
tone: "success",
message:
channel === "USER_KEY"
? payload.isEnabled
? "自备厂商连通性测试通过,配置已保存。"
: "自备厂商配置草稿已保存。"
: payload.isEnabled
? "AstrBot 连通性测试通过,配置已保存。"
: "AstrBot 配置草稿已保存。",
detail: payload.isEnabled
? "之后 AI 助手会使用这份刚保存的配置。"
: "当前只是保存草稿,未启用时不会参与实际聊天。"
}
}));
if (channel === "USER_KEY") { if (channel === "USER_KEY") {
setUserKeyForm((current) => ({ setUserKeyForm((current) => ({
...current, ...current,
@@ -304,10 +376,13 @@ export function SettingsPage({ session }: SettingsPageProps) {
} }
await loadBindings(); await loadBindings();
} catch (error) { } catch (error) {
setNotice({ setChannelNotices((current) => ({
tone: "error", ...current,
message: error instanceof Error ? error.message : "AI 配置保存失败" [channel]: {
}); tone: "error",
message: error instanceof Error ? error.message : "AI 配置保存失败"
}
}));
} finally { } finally {
setSavingChannel(null); setSavingChannel(null);
} }
@@ -323,11 +398,15 @@ export function SettingsPage({ session }: SettingsPageProps) {
</div> </div>
<h1 className="mt-2 text-2xl font-semibold tracking-tight text-foreground"> <h1 className="mt-2 text-2xl font-semibold tracking-tight text-foreground">
AI
</h1> </h1>
<p className="mt-2 text-sm leading-7 text-muted-foreground"> <p className="mt-2 text-sm leading-7 text-muted-foreground">
AI AI 使 AstrBot AI 使
</p> </p>
<div className="mt-3 inline-flex items-center rounded-full border border-border/70 bg-background/80 px-3 py-1 text-xs font-medium text-muted-foreground">
TodoList v{TODOLIST_VERSION}
</div>
</div> </div>
<Button <Button
@@ -396,6 +475,7 @@ export function SettingsPage({ session }: SettingsPageProps) {
onSave={() => handleSaveChannel("USER_KEY")} onSave={() => handleSaveChannel("USER_KEY")}
saving={savingChannel === "USER_KEY"} saving={savingChannel === "USER_KEY"}
binding={bindingMap.get("USER_KEY") ?? null} binding={bindingMap.get("USER_KEY") ?? null}
notice={channelNotices.USER_KEY ?? null}
/> />
<AiConfigCard <AiConfigCard
@@ -408,6 +488,7 @@ export function SettingsPage({ session }: SettingsPageProps) {
onSave={() => handleSaveChannel("ASTRBOT")} onSave={() => handleSaveChannel("ASTRBOT")}
saving={savingChannel === "ASTRBOT"} saving={savingChannel === "ASTRBOT"}
binding={bindingMap.get("ASTRBOT") ?? null} binding={bindingMap.get("ASTRBOT") ?? null}
notice={channelNotices.ASTRBOT ?? null}
/> />
<section className="rounded-[2rem] border border-border/70 bg-card/92 p-5 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]"> <section className="rounded-[2rem] border border-border/70 bg-card/92 p-5 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]">
+34
View File
@@ -47,6 +47,23 @@ export type UpsertWebAiBindingInput = {
isEnabled?: boolean; isEnabled?: boolean;
}; };
export type TestWebAiBindingResponse =
| {
success: true;
channel: Exclude<WebAiChannel, "PUBLIC_POOL">;
providerName: string;
model: string | null;
contentPreview: string;
}
| {
success: false;
channel: Exclude<WebAiChannel, "PUBLIC_POOL">;
providerName: string;
model: string | null;
code: string;
message: string;
};
export type WebAiChatResponse = { export type WebAiChatResponse = {
channel: WebAiChannel; channel: WebAiChannel;
providerName: string; providerName: string;
@@ -141,6 +158,23 @@ export async function upsertAiBinding(
return (await response.json()) as WebAiBindingSummary; return (await response.json()) as WebAiBindingSummary;
} }
export async function testAiBinding(
session: WebSession,
payload: UpsertWebAiBindingInput
): Promise<TestWebAiBindingResponse> {
const response = await fetch(`${resolveApiBaseUrl()}/ai/bindings/test`, {
method: "POST",
headers: createHeaders(session),
body: JSON.stringify(payload)
});
if (!response.ok) {
throw await createApiError(response);
}
return (await response.json()) as TestWebAiBindingResponse;
}
export async function chatWithAi( export async function chatWithAi(
session: WebSession, session: WebSession,
payload: { payload: {