diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts index 2269d2d..efee87e 100644 --- a/apps/api/src/ai/ai.service.ts +++ b/apps/api/src/ai/ai.service.ts @@ -9,7 +9,9 @@ import { AiChannel, AiProviderBinding, AiPublicPoolConfig, - Prisma + Prisma, + TaskPriority, + TaskStatus } from "../../generated/prisma/client"; import { PrismaService } from "../prisma/prisma.service"; import { AiProviderRegistryService } from "./ai-provider-registry.service"; @@ -71,6 +73,8 @@ export type AiChatResponse = { @Injectable() export class AiService { private readonly logger = new Logger(AiService.name); + private readonly maxContextTasks = 6; + private readonly maxContextContentLength = 80; constructor( private readonly prismaService: PrismaService, @@ -195,6 +199,7 @@ export class AiService { async chat(userId: string, dto: AiChatDto): Promise { const attempts: AiRouteAttempt[] = []; const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null); + const promptMessage = await this.buildPromptMessage(userId, dto.message); for (const entry of plan) { if (entry.kind === "skip") { @@ -208,7 +213,7 @@ export class AiService { try { const result = await executor.execute(entry.candidate, { userId, - message: dto.message, + message: promptMessage, sessionId: dto.sessionId ?? null }); const latencyMs = Date.now() - startedAt; @@ -416,6 +421,85 @@ export class AiService { }; } + private async buildPromptMessage(userId: string, userMessage: string): Promise { + const taskSummary = await this.buildTaskContextSummary(userId); + if (!taskSummary) { + return userMessage; + } + + return [ + "你是 TodoList 的 AI 助手,请优先结合用户当前未完成任务给出安排建议。", + "以下是系统整理的未完成任务摘要:", + taskSummary, + "如果用户的问题与任务无关,也可以正常回答;如果相关,请优先考虑优先级、截止时间与执行顺序。", + `用户当前问题:${userMessage}` + ].join("\n\n"); + } + + private async buildTaskContextSummary(userId: string): Promise { + const tasks = await this.prismaService.task.findMany({ + where: { + userId, + status: { + in: [TaskStatus.TODO, TaskStatus.IN_PROGRESS] + } + }, + select: { + title: true, + priority: true, + status: true, + ddl: true, + contentText: true, + updatedAt: true + }, + take: 20 + }); + + if (tasks.length === 0) { + return null; + } + + const sortedTasks = [...tasks].sort((left, right) => { + const priorityDiff = + this.getPriorityWeight(right.priority) - this.getPriorityWeight(left.priority); + if (priorityDiff !== 0) { + return priorityDiff; + } + + const leftDdl = left.ddl?.getTime() ?? Number.POSITIVE_INFINITY; + const rightDdl = right.ddl?.getTime() ?? Number.POSITIVE_INFINITY; + if (leftDdl !== rightDdl) { + return leftDdl - rightDdl; + } + + return right.updatedAt.getTime() - left.updatedAt.getTime(); + }); + + const visibleTasks = sortedTasks.slice(0, this.maxContextTasks); + const lines = visibleTasks.map((task, index) => { + const parts = [ + `${index + 1}. ${task.title}`, + `优先级:${this.getPriorityLabel(task.priority)}`, + `状态:${this.getStatusLabel(task.status)}`, + `DDL:${task.ddl ? task.ddl.toISOString() : "未设置"}` + ]; + + const contentSnippet = this.getContentSnippet(task.contentText); + if (contentSnippet) { + parts.push(`内容摘要:${contentSnippet}`); + } + + return parts.join(" | "); + }); + + const omittedCount = sortedTasks.length - visibleTasks.length; + if (omittedCount > 0) { + lines.push(`其余 ${omittedCount} 项未完成任务已省略。`); + } + + return [`共 ${sortedTasks.length} 项未完成任务。`, ...lines].join("\n"); + } + private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt { if (error instanceof AiRouteFailureError) { return { @@ -493,6 +577,68 @@ export class AiService { return `${secret.slice(0, 4)}***${secret.slice(-2)}`; } + private getPriorityWeight(priority: TaskPriority): number { + switch (priority) { + case TaskPriority.URGENT: + return 4; + case TaskPriority.HIGH: + return 3; + case TaskPriority.MEDIUM: + return 2; + case TaskPriority.LOW: + return 1; + default: + return 0; + } + } + + private getPriorityLabel(priority: TaskPriority): string { + switch (priority) { + case TaskPriority.URGENT: + return "紧急"; + case TaskPriority.HIGH: + return "高"; + case TaskPriority.MEDIUM: + return "中"; + case TaskPriority.LOW: + return "低"; + default: + return String(priority); + } + } + + private getStatusLabel(status: TaskStatus): string { + switch (status) { + case TaskStatus.TODO: + return "待开始"; + case TaskStatus.IN_PROGRESS: + return "进行中"; + case TaskStatus.DONE: + return "已完成"; + case TaskStatus.ARCHIVED: + return "已归档"; + default: + return String(status); + } + } + + private getContentSnippet(contentText: string | null): string | null { + if (!contentText) { + return null; + } + + const normalizedContent = contentText.replace(/\s+/g, " ").trim(); + if (normalizedContent.length === 0) { + return null; + } + + if (normalizedContent.length <= this.maxContextContentLength) { + return normalizedContent; + } + + return `${normalizedContent.slice(0, this.maxContextContentLength)}...`; + } + private async recordUsageLog(input: { userId: string; channel: AiChannel; diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts index 85705ef..38f1b7f 100644 --- a/apps/api/test/ai.spec.ts +++ b/apps/api/test/ai.spec.ts @@ -1,11 +1,18 @@ import request from "supertest"; import { INestApplication, ValidationPipe } from "@nestjs/common"; import { Test, TestingModule } from "@nestjs/testing"; -import { AiChannel, AiProviderBinding, AiPublicPoolConfig } from "../generated/prisma/client"; +import { + AiChannel, + AiProviderBinding, + AiPublicPoolConfig, + TaskPriority, + TaskStatus +} from "../generated/prisma/client"; import { AiController } from "../src/ai/ai.controller"; import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service"; import { AiService } from "../src/ai/ai.service"; import { + AiChatInput, AiChannelExecutor, AiResolvedRouteCandidate, AiRouteFailureError @@ -25,12 +32,23 @@ type AiUsageLogRecord = { errorCode: string | null; }; +type AiTaskRecord = { + userId: string; + title: string; + priority: TaskPriority; + status: TaskStatus; + ddl: Date | null; + contentText: string | null; + updatedAt: Date; +}; + class InMemoryAiPrismaService { private bindingIdSequence = 1; private publicPoolIdSequence = 1; private bindings: AiProviderBinding[] = []; private publicPools: AiPublicPoolConfig[] = []; private usageLogs: AiUsageLogRecord[] = []; + private tasks: AiTaskRecord[] = []; readonly aiProviderBinding = { findMany: async (args: { @@ -185,6 +203,31 @@ class InMemoryAiPrismaService { } }; + readonly task = { + findMany: async (args: { + where: { + userId: string; + status: { + in: TaskStatus[]; + }; + }; + take?: number; + }) => { + const filteredTasks = this.tasks.filter( + (task) => task.userId === args.where.userId && args.where.status.in.includes(task.status) + ); + + return filteredTasks.slice(0, args.take ?? filteredTasks.length).map((task) => ({ + title: task.title, + priority: task.priority, + status: task.status, + ddl: task.ddl, + contentText: task.contentText, + updatedAt: task.updatedAt + })); + } + }; + async $transaction(callback: (tx: InMemoryAiPrismaService) => Promise): Promise { return callback(this); } @@ -211,9 +254,18 @@ class InMemoryAiPrismaService { getUsageLogs(): AiUsageLogRecord[] { return [...this.usageLogs]; } + + seedTask(task: AiTaskRecord): void { + this.tasks.push(task); + } } class StaticExecutor implements AiChannelExecutor { + readonly inputs: Array<{ + candidate: AiResolvedRouteCandidate; + message: string; + }> = []; + constructor( private readonly resolver: (channel: AiChannel) => { content?: string; @@ -222,7 +274,12 @@ class StaticExecutor implements AiChannelExecutor { } ) {} - async execute(candidate: AiResolvedRouteCandidate) { + async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput) { + this.inputs.push({ + candidate, + message: input.message + }); + const result = this.resolver(candidate.channel); if (result.code) { throw new AiRouteFailureError( @@ -252,6 +309,7 @@ class StaticExecutor implements AiChannelExecutor { describe("AiController (integration)", () => { let app: INestApplication; let prismaService: InMemoryAiPrismaService; + let astrbotExecutor: StaticExecutor; beforeEach(async () => { prismaService = new InMemoryAiPrismaService(); @@ -266,7 +324,7 @@ describe("AiController (integration)", () => { content: "公共 AI 已接管" } ); - const astrbotExecutor = new StaticExecutor(() => ({ + astrbotExecutor = new StaticExecutor(() => ({ content: "AstrBot 已接管" })); @@ -448,6 +506,55 @@ describe("AiController (integration)", () => { }); }); + it("should inject unfinished task summary into ai prompt", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_context", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + prismaService.seedTask({ + userId: "user_1", + title: "今晚提交周报", + priority: TaskPriority.URGENT, + status: TaskStatus.IN_PROGRESS, + ddl: new Date("2026-04-06T12:00:00.000Z"), + contentText: "需要汇总 AI 路由、AstrBot 接入和同步模块进度", + updatedAt: new Date("2026-04-06T08:00:00.000Z") + }); + prismaService.seedTask({ + userId: "user_1", + title: "整理已完成事项", + priority: TaskPriority.LOW, + status: TaskStatus.DONE, + ddl: null, + contentText: "这条任务不应该出现在上下文里", + updatedAt: new Date("2026-04-06T07:00:00.000Z") + }); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .send({ + message: "帮我安排今天剩余任务" + }) + .expect(201); + + expect(astrbotExecutor.inputs).toHaveLength(1); + expect(astrbotExecutor.inputs[0]?.message).toContain("以下是系统整理的未完成任务摘要"); + expect(astrbotExecutor.inputs[0]?.message).toContain("今晚提交周报"); + expect(astrbotExecutor.inputs[0]?.message).toContain("优先级:紧急"); + expect(astrbotExecutor.inputs[0]?.message).not.toContain("整理已完成事项"); + expect(astrbotExecutor.inputs[0]?.message).toContain("用户当前问题:帮我安排今天剩余任务"); + }); + it("should return skipped attempts when no channel is available", async () => { const response = await request(app.getHttpServer()) .post("/ai/chat")