From 45b149ad58979d9cacf8dc2f97b8f12342a24a8e Mon Sep 17 00:00:00 2001 From: Yaosanqi137 Date: Tue, 7 Apr 2026 00:56:50 +0800 Subject: [PATCH] fix(ai-context): include local tasks in prompt injection --- apps/api/src/ai/ai.service.ts | 142 +++++++++++++++++++++------- apps/api/src/ai/dto/ai-chat.dto.ts | 45 ++++++++- apps/api/test/ai.spec.ts | 114 +++++++++++++++++++++- apps/web/src/pages/ai-chat-page.tsx | 23 ++++- apps/web/src/services/ai-api.ts | 14 ++- 5 files changed, 299 insertions(+), 39 deletions(-) diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts index 484837c..92cc275 100644 --- a/apps/api/src/ai/ai.service.ts +++ b/apps/api/src/ai/ai.service.ts @@ -70,6 +70,16 @@ type AiUsageLogSummary = { createdAt: string; }; +type AiContextTaskItem = { + id: string; + title: string; + priority: TaskPriority; + status: TaskStatus; + ddl: Date | null; + contentText: string | null; + updatedAt: Date; +}; + export type ListAiUsageLogsResponse = { items: AiUsageLogSummary[]; page: number; @@ -235,7 +245,7 @@ export class AiService { async chat(userId: string, dto: AiChatDto): Promise { const attempts: AiRouteAttempt[] = []; const plan = await this.buildRoutePlan(userId, dto.channel ?? null); - const promptMessage = await this.buildPromptMessage(userId, dto.message); + const promptMessage = await this.buildPromptMessage(userId, dto.message, dto.localTasks ?? []); for (const entry of plan) { if (entry.kind === "skip") { @@ -476,22 +486,29 @@ export class AiService { }; } - private async buildPromptMessage(userId: string, userMessage: string): Promise { - const taskSummary = await this.buildTaskContextSummary(userId); + private async buildPromptMessage( + userId: string, + userMessage: string, + localTasks: NonNullable + ): Promise { + const taskSummary = await this.buildTaskContextSummary(userId, localTasks); if (!taskSummary) { return userMessage; } return [ - "你是 TodoList 的 AI 助手,请优先结合用户当前未完成任务给出安排建议。", + "你是 TodoList 的 AI 助手,需要结合用户当前待办提供任务统筹建议。", "以下是系统整理的未完成任务摘要:", taskSummary, - "如果用户的问题与任务无关,也可以正常回答;如果相关,请优先考虑优先级、截止时间与执行顺序。", + "请优先根据这些任务的紧急度、截止时间和执行顺序回答,并给出明确可执行的建议。", `用户当前问题:${userMessage}` ].join("\n\n"); } - private async buildTaskContextSummary(userId: string): Promise { + private async buildTaskContextSummary( + userId: string, + localTasks: NonNullable + ): Promise { const tasks = await this.prismaService.task.findMany({ where: { userId, @@ -500,6 +517,7 @@ export class AiService { } }, select: { + id: true, title: true, priority: true, status: true, @@ -510,11 +528,93 @@ export class AiService { take: 20 }); - if (tasks.length === 0) { + const sortedTasks = this.sortContextTasks(this.mergeContextTasks(tasks, localTasks)); + if (sortedTasks.length === 0) { return null; } - const sortedTasks = [...tasks].sort((left, right) => { + const visibleTasks = sortedTasks.slice(0, this.maxContextTasks); + const lines = visibleTasks.map((task, index) => { + const parts = [ + `${index + 1}. ${task.title}`, + `优先级:${this.getPriorityLabel(task.priority)}`, + `状态:${this.getStatusLabel(task.status)}`, + `DDL:${task.ddl ? task.ddl.toISOString() : "未设置"}` + ]; + + const contentSnippet = this.getContentSnippet(task.contentText); + if (contentSnippet) { + parts.push(`内容摘要:${contentSnippet}`); + } + + return parts.join(" | "); + }); + + const omittedCount = sortedTasks.length - visibleTasks.length; + if (omittedCount > 0) { + lines.push(`另有 ${omittedCount} 条任务已省略。`); + } + + return [`共 ${sortedTasks.length} 条未完成任务。`, ...lines].join("\n"); + } + + private mergeContextTasks( + databaseTasks: Array<{ + id: string; + title: string; + priority: TaskPriority; + status: TaskStatus; + ddl: Date | null; + contentText: string | null; + updatedAt: Date; + }>, + localTasks: NonNullable + ): AiContextTaskItem[] { + const taskMap = new Map(); + + for (const task of databaseTasks) { + taskMap.set(task.id, { + id: task.id, + title: this.readDecryptedString(task.title) ?? "未命名任务", + priority: task.priority, + status: task.status, + ddl: task.ddl, + contentText: this.readDecryptedString(task.contentText), + updatedAt: task.updatedAt + }); + } + + for (const task of localTasks) { + if (task.status !== TaskStatus.TODO && task.status !== TaskStatus.IN_PROGRESS) { + continue; + } + + const currentTask = taskMap.get(task.id); + const nextTask: AiContextTaskItem = { + id: task.id, + title: task.title.trim().length > 0 ? task.title.trim() : "未命名任务", + priority: task.priority, + status: task.status, + ddl: typeof task.ddlAt === "number" ? new Date(task.ddlAt) : null, + contentText: + typeof task.contentText === "string" && task.contentText.trim().length > 0 + ? task.contentText + : null, + updatedAt: new Date(task.updatedAt) + }; + + if (!currentTask || nextTask.updatedAt.getTime() >= currentTask.updatedAt.getTime()) { + taskMap.set(task.id, nextTask); + } + } + + return [...taskMap.values()].filter( + (task) => task.status === TaskStatus.TODO || task.status === TaskStatus.IN_PROGRESS + ); + } + + private sortContextTasks(tasks: AiContextTaskItem[]): AiContextTaskItem[] { + return [...tasks].sort((left, right) => { const priorityDiff = this.getPriorityWeight(right.priority) - this.getPriorityWeight(left.priority); if (priorityDiff !== 0) { @@ -529,32 +629,6 @@ export class AiService { return right.updatedAt.getTime() - left.updatedAt.getTime(); }); - - const visibleTasks = sortedTasks.slice(0, this.maxContextTasks); - const lines = visibleTasks.map((task, index) => { - const taskTitle = this.readDecryptedString(task.title) ?? "未命名任务"; - const contentText = this.readDecryptedString(task.contentText); - const parts = [ - `${index + 1}. ${taskTitle}`, - `优先级:${this.getPriorityLabel(task.priority)}`, - `状态:${this.getStatusLabel(task.status)}`, - `DDL:${task.ddl ? task.ddl.toISOString() : "未设置"}` - ]; - - const contentSnippet = this.getContentSnippet(contentText); - if (contentSnippet) { - parts.push(`内容摘要:${contentSnippet}`); - } - - return parts.join(" | "); - }); - - const omittedCount = sortedTasks.length - visibleTasks.length; - if (omittedCount > 0) { - lines.push(`其余 ${omittedCount} 项未完成任务已省略。`); - } - - return [`共 ${sortedTasks.length} 项未完成任务。`, ...lines].join("\n"); } private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt { diff --git a/apps/api/src/ai/dto/ai-chat.dto.ts b/apps/api/src/ai/dto/ai-chat.dto.ts index 013b697..e2613ae 100644 --- a/apps/api/src/ai/dto/ai-chat.dto.ts +++ b/apps/api/src/ai/dto/ai-chat.dto.ts @@ -1,5 +1,42 @@ -import { IsEnum, IsOptional, IsString, MinLength } from "class-validator"; +import { Type } from "class-transformer"; +import { + IsArray, + IsEnum, + IsInt, + IsOptional, + IsString, + MinLength, + ValidateNested +} from "class-validator"; import { AiChannel } from "../../../generated/prisma/client"; +import { TaskPriority, TaskStatus } from "../../../generated/prisma/client"; + +export class LocalTaskContextItemDto { + @IsString() + @MinLength(1) + id!: string; + + @IsString() + @MinLength(1) + title!: string; + + @IsEnum(TaskPriority) + priority!: TaskPriority; + + @IsEnum(TaskStatus) + status!: TaskStatus; + + @IsOptional() + @IsInt() + ddlAt?: number | null; + + @IsOptional() + @IsString() + contentText?: string | null; + + @IsInt() + updatedAt!: number; +} export class AiChatDto { @IsString() @@ -14,4 +51,10 @@ export class AiChatDto { @IsOptional() @IsEnum(AiChannel) channel?: AiChannel; + + @IsOptional() + @IsArray() + @ValidateNested({ each: true }) + @Type(() => LocalTaskContextItemDto) + localTasks?: LocalTaskContextItemDto[]; } diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts index f3a9478..be8dcdc 100644 --- a/apps/api/test/ai.spec.ts +++ b/apps/api/test/ai.spec.ts @@ -38,6 +38,7 @@ type AiUsageLogRecord = { }; type AiTaskRecord = { + id: string; userId: string; title: string; priority: TaskPriority; @@ -262,6 +263,7 @@ class InMemoryAiPrismaService { ); return filteredTasks.slice(0, args.take ?? filteredTasks.length).map((task) => ({ + id: task.id, title: task.title, priority: task.priority, status: task.status, @@ -385,11 +387,12 @@ describe("AiController (integration)", () => { let app: INestApplication; let prismaService: InMemoryAiPrismaService; let astrbotExecutor: StaticExecutor; + let openAiExecutor: StaticExecutor; beforeEach(async () => { prismaService = new InMemoryAiPrismaService(); - const openAiExecutor = new StaticExecutor((channel) => + openAiExecutor = new StaticExecutor((channel) => channel === AiChannel.USER_KEY ? { code: "UPSTREAM_UNREACHABLE", @@ -727,6 +730,7 @@ describe("AiController (integration)", () => { isEnabled: true }); prismaService.seedTask({ + id: "task_weekly_report", userId: "user_1", title: "今晚提交周报", priority: TaskPriority.URGENT, @@ -736,6 +740,7 @@ describe("AiController (integration)", () => { updatedAt: new Date("2026-04-06T08:00:00.000Z") }); prismaService.seedTask({ + id: "task_done_item", userId: "user_1", title: "整理已完成事项", priority: TaskPriority.LOW, @@ -761,6 +766,113 @@ describe("AiController (integration)", () => { expect(astrbotExecutor.inputs[0]?.message).toContain("用户当前问题:帮我安排今天剩余任务"); }); + it("should inject local unfinished tasks into ai prompt when database is empty", async () => { + prismaService.seedBinding({ + id: "binding_user_key_local_context", + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + configId: null, + configName: null, + encryptedApiKey: "sk-user", + endpoint: "https://api.example.com", + isDefault: true, + isEnabled: true + }); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "结合我的 TodoList 帮我排优先级", + channel: AiChannel.USER_KEY, + localTasks: [ + { + id: "local_task_1", + title: "准备明天答辩材料", + priority: TaskPriority.URGENT, + status: TaskStatus.IN_PROGRESS, + ddlAt: new Date("2026-04-07T13:00:00.000Z").getTime(), + contentText: "需要补齐演示文稿和总结页", + updatedAt: new Date("2026-04-07T09:00:00.000Z").getTime() + } + ] + }) + .expect(502); + + expect(response.body.attempts).toEqual([ + { + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + status: "failed", + reasonCode: "UPSTREAM_UNREACHABLE", + reasonMessage: "用户自备 Key 渠道暂时不可用" + } + ]); + expect(openAiExecutor.inputs).toHaveLength(1); + expect(openAiExecutor.inputs[0]?.message).toContain("准备明天答辩材料"); + expect(openAiExecutor.inputs[0]?.message).toContain("优先级:紧急"); + expect(openAiExecutor.inputs[0]?.message).toContain("内容摘要:需要补齐演示文稿和总结页"); + expect(openAiExecutor.inputs[0]?.message).toContain( + "用户当前问题:结合我的 TodoList 帮我排优先级" + ); + expect(astrbotExecutor.inputs).toHaveLength(0); + }); + + it("should prefer newer local task snapshot over older database task", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_local_override", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + prismaService.seedTask({ + id: "task_same_id", + userId: "user_1", + title: "旧标题", + priority: TaskPriority.LOW, + status: TaskStatus.TODO, + ddl: new Date("2026-04-08T10:00:00.000Z"), + contentText: "旧内容", + updatedAt: new Date("2026-04-07T08:00:00.000Z") + }); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "看看我最新要做什么", + channel: AiChannel.ASTRBOT, + localTasks: [ + { + id: "task_same_id", + title: "新标题", + priority: TaskPriority.HIGH, + status: TaskStatus.IN_PROGRESS, + ddlAt: new Date("2026-04-07T15:00:00.000Z").getTime(), + contentText: "新内容", + updatedAt: new Date("2026-04-07T12:00:00.000Z").getTime() + } + ] + }) + .expect(201); + + expect(astrbotExecutor.inputs.at(-1)?.message).toContain("新标题"); + expect(astrbotExecutor.inputs.at(-1)?.message).toContain("优先级:高"); + expect(astrbotExecutor.inputs.at(-1)?.message).toContain("内容摘要:新内容"); + expect(astrbotExecutor.inputs.at(-1)?.message).not.toContain("旧标题"); + expect(astrbotExecutor.inputs.at(-1)?.message).not.toContain("旧内容"); + }); + it("should return skipped attempts when no channel is available", async () => { const response = await request(app.getHttpServer()) .post("/ai/chat") diff --git a/apps/web/src/pages/ai-chat-page.tsx b/apps/web/src/pages/ai-chat-page.tsx index 746fca7..f48e9a8 100644 --- a/apps/web/src/pages/ai-chat-page.tsx +++ b/apps/web/src/pages/ai-chat-page.tsx @@ -17,6 +17,7 @@ import { type WebAiBindingSummary, type WebAiBindingsResponse, type WebAiChannel, + type WebAiLocalTaskContextItem, WebAiApiError } from "@/services/ai-api"; import { @@ -25,6 +26,7 @@ import { saveLocalAiChatSession, type LocalAiChatMessageRecord } from "@/services/local-ai-chat-repo"; +import { listLocalTasksByUser } from "@/services/local-task-repo"; import type { WebSession } from "@/services/session-storage"; import { CHANNEL_META, CHANNEL_ORDER } from "@/components/ai/ai-shared"; @@ -64,6 +66,23 @@ function appendMessage( }; } +function buildLocalTaskContext( + items: Awaited> +): WebAiLocalTaskContextItem[] { + return items + .filter((item) => item.status === "TODO" || item.status === "IN_PROGRESS") + .slice(0, 20) + .map((item) => ({ + id: item.id, + title: item.title, + priority: item.priority, + status: item.status, + ddlAt: item.ddlAt, + contentText: item.contentText, + updatedAt: item.updatedAt + })); +} + export function AiChatPage({ session }: AiChatPageProps) { const navigate = useNavigate(); const [bindingsResponse, setBindingsResponse] = useState(null); @@ -224,10 +243,12 @@ export function AiChatPage({ session }: AiChatPageProps) { ); try { + const localTasks = buildLocalTaskContext(await listLocalTasksByUser(session.user.id)); const response = await chatWithAi(session, { channel, message, - sessionId: sessionIds[channel] + sessionId: sessionIds[channel], + localTasks }); setSessionIds((current) => ({ diff --git a/apps/web/src/services/ai-api.ts b/apps/web/src/services/ai-api.ts index 5fd6a9f..b3836ad 100644 --- a/apps/web/src/services/ai-api.ts +++ b/apps/web/src/services/ai-api.ts @@ -56,6 +56,16 @@ export type WebAiChatResponse = { attempts: WebAiRouteAttempt[]; }; +export type WebAiLocalTaskContextItem = { + id: string; + title: string; + priority: "LOW" | "MEDIUM" | "HIGH" | "URGENT"; + status: "TODO" | "IN_PROGRESS" | "DONE" | "ARCHIVED"; + ddlAt: number | null; + contentText: string | null; + updatedAt: number; +}; + export class WebAiApiError extends Error { attempts: WebAiRouteAttempt[] | null; @@ -92,7 +102,7 @@ async function createApiError(response: Response): Promise { attempts?: WebAiRouteAttempt[]; }; const message = Array.isArray(body.message) - ? body.message.join(",") + ? body.message.join(";") : typeof body.message === "string" && body.message.trim().length > 0 ? body.message : `请求失败(${response.status})`; @@ -101,7 +111,6 @@ async function createApiError(response: Response): Promise { return new WebAiApiError(`请求失败(${response.status})`); } } - export async function listAiBindings(session: WebSession): Promise { const response = await fetch(`${resolveApiBaseUrl()}/ai/bindings`, { method: "GET", @@ -138,6 +147,7 @@ export async function chatWithAi( channel: WebAiChannel; message: string; sessionId?: string; + localTasks?: WebAiLocalTaskContextItem[]; } ): Promise { const response = await fetch(`${resolveApiBaseUrl()}/ai/chat`, {