feat(api-ai): inject unfinished task summary

This commit is contained in:
2026-04-06 12:57:14 +08:00
parent 45177e9fad
commit 4578116a30
2 changed files with 258 additions and 5 deletions
+148 -2
View File
@@ -9,7 +9,9 @@ import {
AiChannel, AiChannel,
AiProviderBinding, AiProviderBinding,
AiPublicPoolConfig, AiPublicPoolConfig,
Prisma Prisma,
TaskPriority,
TaskStatus
} from "../../generated/prisma/client"; } from "../../generated/prisma/client";
import { PrismaService } from "../prisma/prisma.service"; import { PrismaService } from "../prisma/prisma.service";
import { AiProviderRegistryService } from "./ai-provider-registry.service"; import { AiProviderRegistryService } from "./ai-provider-registry.service";
@@ -71,6 +73,8 @@ export type AiChatResponse = {
@Injectable() @Injectable()
export class AiService { export class AiService {
private readonly logger = new Logger(AiService.name); private readonly logger = new Logger(AiService.name);
private readonly maxContextTasks = 6;
private readonly maxContextContentLength = 80;
constructor( constructor(
private readonly prismaService: PrismaService, private readonly prismaService: PrismaService,
@@ -195,6 +199,7 @@ export class AiService {
async chat(userId: string, dto: AiChatDto): Promise<AiChatResponse> { async chat(userId: string, dto: AiChatDto): Promise<AiChatResponse> {
const attempts: AiRouteAttempt[] = []; const attempts: AiRouteAttempt[] = [];
const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null); const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null);
const promptMessage = await this.buildPromptMessage(userId, dto.message);
for (const entry of plan) { for (const entry of plan) {
if (entry.kind === "skip") { if (entry.kind === "skip") {
@@ -208,7 +213,7 @@ export class AiService {
try { try {
const result = await executor.execute(entry.candidate, { const result = await executor.execute(entry.candidate, {
userId, userId,
message: dto.message, message: promptMessage,
sessionId: dto.sessionId ?? null sessionId: dto.sessionId ?? null
}); });
const latencyMs = Date.now() - startedAt; const latencyMs = Date.now() - startedAt;
@@ -416,6 +421,85 @@ export class AiService {
}; };
} }
private async buildPromptMessage(userId: string, userMessage: string): Promise<string> {
const taskSummary = await this.buildTaskContextSummary(userId);
if (!taskSummary) {
return userMessage;
}
return [
"你是 TodoList 的 AI 助手,请优先结合用户当前未完成任务给出安排建议。",
"以下是系统整理的未完成任务摘要:",
taskSummary,
"如果用户的问题与任务无关,也可以正常回答;如果相关,请优先考虑优先级、截止时间与执行顺序。",
`用户当前问题:${userMessage}`
].join("\n\n");
}
private async buildTaskContextSummary(userId: string): Promise<string | null> {
const tasks = await this.prismaService.task.findMany({
where: {
userId,
status: {
in: [TaskStatus.TODO, TaskStatus.IN_PROGRESS]
}
},
select: {
title: true,
priority: true,
status: true,
ddl: true,
contentText: true,
updatedAt: true
},
take: 20
});
if (tasks.length === 0) {
return null;
}
const sortedTasks = [...tasks].sort((left, right) => {
const priorityDiff =
this.getPriorityWeight(right.priority) - this.getPriorityWeight(left.priority);
if (priorityDiff !== 0) {
return priorityDiff;
}
const leftDdl = left.ddl?.getTime() ?? Number.POSITIVE_INFINITY;
const rightDdl = right.ddl?.getTime() ?? Number.POSITIVE_INFINITY;
if (leftDdl !== rightDdl) {
return leftDdl - rightDdl;
}
return right.updatedAt.getTime() - left.updatedAt.getTime();
});
const visibleTasks = sortedTasks.slice(0, this.maxContextTasks);
const lines = visibleTasks.map((task, index) => {
const parts = [
`${index + 1}. ${task.title}`,
`优先级:${this.getPriorityLabel(task.priority)}`,
`状态:${this.getStatusLabel(task.status)}`,
`DDL${task.ddl ? task.ddl.toISOString() : "未设置"}`
];
const contentSnippet = this.getContentSnippet(task.contentText);
if (contentSnippet) {
parts.push(`内容摘要:${contentSnippet}`);
}
return parts.join(" | ");
});
const omittedCount = sortedTasks.length - visibleTasks.length;
if (omittedCount > 0) {
lines.push(`其余 ${omittedCount} 项未完成任务已省略。`);
}
return [`${sortedTasks.length} 项未完成任务。`, ...lines].join("\n");
}
private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt { private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt {
if (error instanceof AiRouteFailureError) { if (error instanceof AiRouteFailureError) {
return { return {
@@ -493,6 +577,68 @@ export class AiService {
return `${secret.slice(0, 4)}***${secret.slice(-2)}`; return `${secret.slice(0, 4)}***${secret.slice(-2)}`;
} }
private getPriorityWeight(priority: TaskPriority): number {
switch (priority) {
case TaskPriority.URGENT:
return 4;
case TaskPriority.HIGH:
return 3;
case TaskPriority.MEDIUM:
return 2;
case TaskPriority.LOW:
return 1;
default:
return 0;
}
}
private getPriorityLabel(priority: TaskPriority): string {
switch (priority) {
case TaskPriority.URGENT:
return "紧急";
case TaskPriority.HIGH:
return "高";
case TaskPriority.MEDIUM:
return "中";
case TaskPriority.LOW:
return "低";
default:
return String(priority);
}
}
private getStatusLabel(status: TaskStatus): string {
switch (status) {
case TaskStatus.TODO:
return "待开始";
case TaskStatus.IN_PROGRESS:
return "进行中";
case TaskStatus.DONE:
return "已完成";
case TaskStatus.ARCHIVED:
return "已归档";
default:
return String(status);
}
}
private getContentSnippet(contentText: string | null): string | null {
if (!contentText) {
return null;
}
const normalizedContent = contentText.replace(/\s+/g, " ").trim();
if (normalizedContent.length === 0) {
return null;
}
if (normalizedContent.length <= this.maxContextContentLength) {
return normalizedContent;
}
return `${normalizedContent.slice(0, this.maxContextContentLength)}...`;
}
private async recordUsageLog(input: { private async recordUsageLog(input: {
userId: string; userId: string;
channel: AiChannel; channel: AiChannel;
+110 -3
View File
@@ -1,11 +1,18 @@
import request from "supertest"; import request from "supertest";
import { INestApplication, ValidationPipe } from "@nestjs/common"; import { INestApplication, ValidationPipe } from "@nestjs/common";
import { Test, TestingModule } from "@nestjs/testing"; import { Test, TestingModule } from "@nestjs/testing";
import { AiChannel, AiProviderBinding, AiPublicPoolConfig } from "../generated/prisma/client"; import {
AiChannel,
AiProviderBinding,
AiPublicPoolConfig,
TaskPriority,
TaskStatus
} from "../generated/prisma/client";
import { AiController } from "../src/ai/ai.controller"; import { AiController } from "../src/ai/ai.controller";
import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service"; import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service";
import { AiService } from "../src/ai/ai.service"; import { AiService } from "../src/ai/ai.service";
import { import {
AiChatInput,
AiChannelExecutor, AiChannelExecutor,
AiResolvedRouteCandidate, AiResolvedRouteCandidate,
AiRouteFailureError AiRouteFailureError
@@ -25,12 +32,23 @@ type AiUsageLogRecord = {
errorCode: string | null; errorCode: string | null;
}; };
type AiTaskRecord = {
userId: string;
title: string;
priority: TaskPriority;
status: TaskStatus;
ddl: Date | null;
contentText: string | null;
updatedAt: Date;
};
class InMemoryAiPrismaService { class InMemoryAiPrismaService {
private bindingIdSequence = 1; private bindingIdSequence = 1;
private publicPoolIdSequence = 1; private publicPoolIdSequence = 1;
private bindings: AiProviderBinding[] = []; private bindings: AiProviderBinding[] = [];
private publicPools: AiPublicPoolConfig[] = []; private publicPools: AiPublicPoolConfig[] = [];
private usageLogs: AiUsageLogRecord[] = []; private usageLogs: AiUsageLogRecord[] = [];
private tasks: AiTaskRecord[] = [];
readonly aiProviderBinding = { readonly aiProviderBinding = {
findMany: async (args: { findMany: async (args: {
@@ -185,6 +203,31 @@ class InMemoryAiPrismaService {
} }
}; };
readonly task = {
findMany: async (args: {
where: {
userId: string;
status: {
in: TaskStatus[];
};
};
take?: number;
}) => {
const filteredTasks = this.tasks.filter(
(task) => task.userId === args.where.userId && args.where.status.in.includes(task.status)
);
return filteredTasks.slice(0, args.take ?? filteredTasks.length).map((task) => ({
title: task.title,
priority: task.priority,
status: task.status,
ddl: task.ddl,
contentText: task.contentText,
updatedAt: task.updatedAt
}));
}
};
async $transaction<T>(callback: (tx: InMemoryAiPrismaService) => Promise<T>): Promise<T> { async $transaction<T>(callback: (tx: InMemoryAiPrismaService) => Promise<T>): Promise<T> {
return callback(this); return callback(this);
} }
@@ -211,9 +254,18 @@ class InMemoryAiPrismaService {
getUsageLogs(): AiUsageLogRecord[] { getUsageLogs(): AiUsageLogRecord[] {
return [...this.usageLogs]; return [...this.usageLogs];
} }
seedTask(task: AiTaskRecord): void {
this.tasks.push(task);
}
} }
class StaticExecutor implements AiChannelExecutor { class StaticExecutor implements AiChannelExecutor {
readonly inputs: Array<{
candidate: AiResolvedRouteCandidate;
message: string;
}> = [];
constructor( constructor(
private readonly resolver: (channel: AiChannel) => { private readonly resolver: (channel: AiChannel) => {
content?: string; content?: string;
@@ -222,7 +274,12 @@ class StaticExecutor implements AiChannelExecutor {
} }
) {} ) {}
async execute(candidate: AiResolvedRouteCandidate) { async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput) {
this.inputs.push({
candidate,
message: input.message
});
const result = this.resolver(candidate.channel); const result = this.resolver(candidate.channel);
if (result.code) { if (result.code) {
throw new AiRouteFailureError( throw new AiRouteFailureError(
@@ -252,6 +309,7 @@ class StaticExecutor implements AiChannelExecutor {
describe("AiController (integration)", () => { describe("AiController (integration)", () => {
let app: INestApplication; let app: INestApplication;
let prismaService: InMemoryAiPrismaService; let prismaService: InMemoryAiPrismaService;
let astrbotExecutor: StaticExecutor;
beforeEach(async () => { beforeEach(async () => {
prismaService = new InMemoryAiPrismaService(); prismaService = new InMemoryAiPrismaService();
@@ -266,7 +324,7 @@ describe("AiController (integration)", () => {
content: "公共 AI 已接管" content: "公共 AI 已接管"
} }
); );
const astrbotExecutor = new StaticExecutor(() => ({ astrbotExecutor = new StaticExecutor(() => ({
content: "AstrBot 已接管" content: "AstrBot 已接管"
})); }));
@@ -448,6 +506,55 @@ describe("AiController (integration)", () => {
}); });
}); });
it("should inject unfinished task summary into ai prompt", async () => {
prismaService.seedBinding({
id: "binding_astrbot_context",
userId: "user_1",
channel: AiChannel.ASTRBOT,
providerName: "",
model: null,
configId: "default",
configName: null,
encryptedApiKey: "abk_astrbot",
endpoint: "http://127.0.0.1:6185",
isDefault: true,
isEnabled: true
});
prismaService.seedTask({
userId: "user_1",
title: "今晚提交周报",
priority: TaskPriority.URGENT,
status: TaskStatus.IN_PROGRESS,
ddl: new Date("2026-04-06T12:00:00.000Z"),
contentText: "需要汇总 AI 路由、AstrBot 接入和同步模块进度",
updatedAt: new Date("2026-04-06T08:00:00.000Z")
});
prismaService.seedTask({
userId: "user_1",
title: "整理已完成事项",
priority: TaskPriority.LOW,
status: TaskStatus.DONE,
ddl: null,
contentText: "这条任务不应该出现在上下文里",
updatedAt: new Date("2026-04-06T07:00:00.000Z")
});
await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.send({
message: "帮我安排今天剩余任务"
})
.expect(201);
expect(astrbotExecutor.inputs).toHaveLength(1);
expect(astrbotExecutor.inputs[0]?.message).toContain("以下是系统整理的未完成任务摘要");
expect(astrbotExecutor.inputs[0]?.message).toContain("今晚提交周报");
expect(astrbotExecutor.inputs[0]?.message).toContain("优先级:紧急");
expect(astrbotExecutor.inputs[0]?.message).not.toContain("整理已完成事项");
expect(astrbotExecutor.inputs[0]?.message).toContain("用户当前问题:帮我安排今天剩余任务");
});
it("should return skipped attempts when no channel is available", async () => { it("should return skipped attempts when no channel is available", async () => {
const response = await request(app.getHttpServer()) const response = await request(app.getHttpServer())
.post("/ai/chat") .post("/ai/chat")