fix(ai-context): include local tasks in prompt injection

This commit is contained in:
2026-04-07 00:56:50 +08:00
parent 1564d2dd30
commit 45b149ad58
5 changed files with 299 additions and 39 deletions
+108 -34
View File
@@ -70,6 +70,16 @@ type AiUsageLogSummary = {
createdAt: string;
};
type AiContextTaskItem = {
id: string;
title: string;
priority: TaskPriority;
status: TaskStatus;
ddl: Date | null;
contentText: string | null;
updatedAt: Date;
};
export type ListAiUsageLogsResponse = {
items: AiUsageLogSummary[];
page: number;
@@ -235,7 +245,7 @@ export class AiService {
async chat(userId: string, dto: AiChatDto): Promise<AiChatResponse> {
const attempts: AiRouteAttempt[] = [];
const plan = await this.buildRoutePlan(userId, dto.channel ?? null);
const promptMessage = await this.buildPromptMessage(userId, dto.message);
const promptMessage = await this.buildPromptMessage(userId, dto.message, dto.localTasks ?? []);
for (const entry of plan) {
if (entry.kind === "skip") {
@@ -476,22 +486,29 @@ export class AiService {
};
}
private async buildPromptMessage(userId: string, userMessage: string): Promise<string> {
const taskSummary = await this.buildTaskContextSummary(userId);
private async buildPromptMessage(
userId: string,
userMessage: string,
localTasks: NonNullable<AiChatDto["localTasks"]>
): Promise<string> {
const taskSummary = await this.buildTaskContextSummary(userId, localTasks);
if (!taskSummary) {
return userMessage;
}
return [
"你是 TodoList 的 AI 助手,请优先结合用户当前未完成任务给出安排建议。",
"你是 TodoList 的 AI 助手,需要结合用户当前待办提供任务统筹建议。",
"以下是系统整理的未完成任务摘要:",
taskSummary,
"如果用户的问题与任务无关,也可以正常回答;如果相关,请优先考虑优先级、截止时间执行顺序。",
"请优先根据这些任务的紧急度、截止时间执行顺序回答,并给出明确可执行的建议。",
`用户当前问题:${userMessage}`
].join("\n\n");
}
private async buildTaskContextSummary(userId: string): Promise<string | null> {
private async buildTaskContextSummary(
userId: string,
localTasks: NonNullable<AiChatDto["localTasks"]>
): Promise<string | null> {
const tasks = await this.prismaService.task.findMany({
where: {
userId,
@@ -500,6 +517,7 @@ export class AiService {
}
},
select: {
id: true,
title: true,
priority: true,
status: true,
@@ -510,11 +528,93 @@ export class AiService {
take: 20
});
if (tasks.length === 0) {
const sortedTasks = this.sortContextTasks(this.mergeContextTasks(tasks, localTasks));
if (sortedTasks.length === 0) {
return null;
}
const sortedTasks = [...tasks].sort((left, right) => {
const visibleTasks = sortedTasks.slice(0, this.maxContextTasks);
const lines = visibleTasks.map((task, index) => {
const parts = [
`${index + 1}. ${task.title}`,
`优先级:${this.getPriorityLabel(task.priority)}`,
`状态:${this.getStatusLabel(task.status)}`,
`DDL${task.ddl ? task.ddl.toISOString() : "未设置"}`
];
const contentSnippet = this.getContentSnippet(task.contentText);
if (contentSnippet) {
parts.push(`内容摘要:${contentSnippet}`);
}
return parts.join(" | ");
});
const omittedCount = sortedTasks.length - visibleTasks.length;
if (omittedCount > 0) {
lines.push(`另有 ${omittedCount} 条任务已省略。`);
}
return [`${sortedTasks.length} 条未完成任务。`, ...lines].join("\n");
}
private mergeContextTasks(
databaseTasks: Array<{
id: string;
title: string;
priority: TaskPriority;
status: TaskStatus;
ddl: Date | null;
contentText: string | null;
updatedAt: Date;
}>,
localTasks: NonNullable<AiChatDto["localTasks"]>
): AiContextTaskItem[] {
const taskMap = new Map<string, AiContextTaskItem>();
for (const task of databaseTasks) {
taskMap.set(task.id, {
id: task.id,
title: this.readDecryptedString(task.title) ?? "未命名任务",
priority: task.priority,
status: task.status,
ddl: task.ddl,
contentText: this.readDecryptedString(task.contentText),
updatedAt: task.updatedAt
});
}
for (const task of localTasks) {
if (task.status !== TaskStatus.TODO && task.status !== TaskStatus.IN_PROGRESS) {
continue;
}
const currentTask = taskMap.get(task.id);
const nextTask: AiContextTaskItem = {
id: task.id,
title: task.title.trim().length > 0 ? task.title.trim() : "未命名任务",
priority: task.priority,
status: task.status,
ddl: typeof task.ddlAt === "number" ? new Date(task.ddlAt) : null,
contentText:
typeof task.contentText === "string" && task.contentText.trim().length > 0
? task.contentText
: null,
updatedAt: new Date(task.updatedAt)
};
if (!currentTask || nextTask.updatedAt.getTime() >= currentTask.updatedAt.getTime()) {
taskMap.set(task.id, nextTask);
}
}
return [...taskMap.values()].filter(
(task) => task.status === TaskStatus.TODO || task.status === TaskStatus.IN_PROGRESS
);
}
private sortContextTasks(tasks: AiContextTaskItem[]): AiContextTaskItem[] {
return [...tasks].sort((left, right) => {
const priorityDiff =
this.getPriorityWeight(right.priority) - this.getPriorityWeight(left.priority);
if (priorityDiff !== 0) {
@@ -529,32 +629,6 @@ export class AiService {
return right.updatedAt.getTime() - left.updatedAt.getTime();
});
const visibleTasks = sortedTasks.slice(0, this.maxContextTasks);
const lines = visibleTasks.map((task, index) => {
const taskTitle = this.readDecryptedString(task.title) ?? "未命名任务";
const contentText = this.readDecryptedString(task.contentText);
const parts = [
`${index + 1}. ${taskTitle}`,
`优先级:${this.getPriorityLabel(task.priority)}`,
`状态:${this.getStatusLabel(task.status)}`,
`DDL${task.ddl ? task.ddl.toISOString() : "未设置"}`
];
const contentSnippet = this.getContentSnippet(contentText);
if (contentSnippet) {
parts.push(`内容摘要:${contentSnippet}`);
}
return parts.join(" | ");
});
const omittedCount = sortedTasks.length - visibleTasks.length;
if (omittedCount > 0) {
lines.push(`其余 ${omittedCount} 项未完成任务已省略。`);
}
return [`${sortedTasks.length} 项未完成任务。`, ...lines].join("\n");
}
private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt {
+44 -1
View File
@@ -1,5 +1,42 @@
import { IsEnum, IsOptional, IsString, MinLength } from "class-validator";
import { Type } from "class-transformer";
import {
IsArray,
IsEnum,
IsInt,
IsOptional,
IsString,
MinLength,
ValidateNested
} from "class-validator";
import { AiChannel } from "../../../generated/prisma/client";
import { TaskPriority, TaskStatus } from "../../../generated/prisma/client";
export class LocalTaskContextItemDto {
@IsString()
@MinLength(1)
id!: string;
@IsString()
@MinLength(1)
title!: string;
@IsEnum(TaskPriority)
priority!: TaskPriority;
@IsEnum(TaskStatus)
status!: TaskStatus;
@IsOptional()
@IsInt()
ddlAt?: number | null;
@IsOptional()
@IsString()
contentText?: string | null;
@IsInt()
updatedAt!: number;
}
export class AiChatDto {
@IsString()
@@ -14,4 +51,10 @@ export class AiChatDto {
@IsOptional()
@IsEnum(AiChannel)
channel?: AiChannel;
@IsOptional()
@IsArray()
@ValidateNested({ each: true })
@Type(() => LocalTaskContextItemDto)
localTasks?: LocalTaskContextItemDto[];
}
+113 -1
View File
@@ -38,6 +38,7 @@ type AiUsageLogRecord = {
};
type AiTaskRecord = {
id: string;
userId: string;
title: string;
priority: TaskPriority;
@@ -262,6 +263,7 @@ class InMemoryAiPrismaService {
);
return filteredTasks.slice(0, args.take ?? filteredTasks.length).map((task) => ({
id: task.id,
title: task.title,
priority: task.priority,
status: task.status,
@@ -385,11 +387,12 @@ describe("AiController (integration)", () => {
let app: INestApplication;
let prismaService: InMemoryAiPrismaService;
let astrbotExecutor: StaticExecutor;
let openAiExecutor: StaticExecutor;
beforeEach(async () => {
prismaService = new InMemoryAiPrismaService();
const openAiExecutor = new StaticExecutor((channel) =>
openAiExecutor = new StaticExecutor((channel) =>
channel === AiChannel.USER_KEY
? {
code: "UPSTREAM_UNREACHABLE",
@@ -727,6 +730,7 @@ describe("AiController (integration)", () => {
isEnabled: true
});
prismaService.seedTask({
id: "task_weekly_report",
userId: "user_1",
title: "今晚提交周报",
priority: TaskPriority.URGENT,
@@ -736,6 +740,7 @@ describe("AiController (integration)", () => {
updatedAt: new Date("2026-04-06T08:00:00.000Z")
});
prismaService.seedTask({
id: "task_done_item",
userId: "user_1",
title: "整理已完成事项",
priority: TaskPriority.LOW,
@@ -761,6 +766,113 @@ describe("AiController (integration)", () => {
expect(astrbotExecutor.inputs[0]?.message).toContain("用户当前问题:帮我安排今天剩余任务");
});
it("should inject local unfinished tasks into ai prompt when database is empty", async () => {
prismaService.seedBinding({
id: "binding_user_key_local_context",
userId: "user_1",
channel: AiChannel.USER_KEY,
providerName: "openai",
model: "gpt-4o-mini",
configId: null,
configName: null,
encryptedApiKey: "sk-user",
endpoint: "https://api.example.com",
isDefault: true,
isEnabled: true
});
const response = await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.send({
message: "结合我的 TodoList 帮我排优先级",
channel: AiChannel.USER_KEY,
localTasks: [
{
id: "local_task_1",
title: "准备明天答辩材料",
priority: TaskPriority.URGENT,
status: TaskStatus.IN_PROGRESS,
ddlAt: new Date("2026-04-07T13:00:00.000Z").getTime(),
contentText: "需要补齐演示文稿和总结页",
updatedAt: new Date("2026-04-07T09:00:00.000Z").getTime()
}
]
})
.expect(502);
expect(response.body.attempts).toEqual([
{
channel: AiChannel.USER_KEY,
providerName: "openai",
model: "gpt-4o-mini",
status: "failed",
reasonCode: "UPSTREAM_UNREACHABLE",
reasonMessage: "用户自备 Key 渠道暂时不可用"
}
]);
expect(openAiExecutor.inputs).toHaveLength(1);
expect(openAiExecutor.inputs[0]?.message).toContain("准备明天答辩材料");
expect(openAiExecutor.inputs[0]?.message).toContain("优先级:紧急");
expect(openAiExecutor.inputs[0]?.message).toContain("内容摘要:需要补齐演示文稿和总结页");
expect(openAiExecutor.inputs[0]?.message).toContain(
"用户当前问题:结合我的 TodoList 帮我排优先级"
);
expect(astrbotExecutor.inputs).toHaveLength(0);
});
it("should prefer newer local task snapshot over older database task", async () => {
prismaService.seedBinding({
id: "binding_astrbot_local_override",
userId: "user_1",
channel: AiChannel.ASTRBOT,
providerName: "",
model: null,
configId: "default",
configName: null,
encryptedApiKey: "abk_astrbot",
endpoint: "http://127.0.0.1:6185",
isDefault: true,
isEnabled: true
});
prismaService.seedTask({
id: "task_same_id",
userId: "user_1",
title: "旧标题",
priority: TaskPriority.LOW,
status: TaskStatus.TODO,
ddl: new Date("2026-04-08T10:00:00.000Z"),
contentText: "旧内容",
updatedAt: new Date("2026-04-07T08:00:00.000Z")
});
await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.send({
message: "看看我最新要做什么",
channel: AiChannel.ASTRBOT,
localTasks: [
{
id: "task_same_id",
title: "新标题",
priority: TaskPriority.HIGH,
status: TaskStatus.IN_PROGRESS,
ddlAt: new Date("2026-04-07T15:00:00.000Z").getTime(),
contentText: "新内容",
updatedAt: new Date("2026-04-07T12:00:00.000Z").getTime()
}
]
})
.expect(201);
expect(astrbotExecutor.inputs.at(-1)?.message).toContain("新标题");
expect(astrbotExecutor.inputs.at(-1)?.message).toContain("优先级:高");
expect(astrbotExecutor.inputs.at(-1)?.message).toContain("内容摘要:新内容");
expect(astrbotExecutor.inputs.at(-1)?.message).not.toContain("旧标题");
expect(astrbotExecutor.inputs.at(-1)?.message).not.toContain("旧内容");
});
it("should return skipped attempts when no channel is available", async () => {
const response = await request(app.getHttpServer())
.post("/ai/chat")