feat(api-ai): inject unfinished task summary
This commit is contained in:
@@ -9,7 +9,9 @@ import {
|
|||||||
AiChannel,
|
AiChannel,
|
||||||
AiProviderBinding,
|
AiProviderBinding,
|
||||||
AiPublicPoolConfig,
|
AiPublicPoolConfig,
|
||||||
Prisma
|
Prisma,
|
||||||
|
TaskPriority,
|
||||||
|
TaskStatus
|
||||||
} from "../../generated/prisma/client";
|
} from "../../generated/prisma/client";
|
||||||
import { PrismaService } from "../prisma/prisma.service";
|
import { PrismaService } from "../prisma/prisma.service";
|
||||||
import { AiProviderRegistryService } from "./ai-provider-registry.service";
|
import { AiProviderRegistryService } from "./ai-provider-registry.service";
|
||||||
@@ -71,6 +73,8 @@ export type AiChatResponse = {
|
|||||||
@Injectable()
|
@Injectable()
|
||||||
export class AiService {
|
export class AiService {
|
||||||
private readonly logger = new Logger(AiService.name);
|
private readonly logger = new Logger(AiService.name);
|
||||||
|
private readonly maxContextTasks = 6;
|
||||||
|
private readonly maxContextContentLength = 80;
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
private readonly prismaService: PrismaService,
|
private readonly prismaService: PrismaService,
|
||||||
@@ -195,6 +199,7 @@ export class AiService {
|
|||||||
async chat(userId: string, dto: AiChatDto): Promise<AiChatResponse> {
|
async chat(userId: string, dto: AiChatDto): Promise<AiChatResponse> {
|
||||||
const attempts: AiRouteAttempt[] = [];
|
const attempts: AiRouteAttempt[] = [];
|
||||||
const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null);
|
const plan = await this.buildRoutePlan(userId, dto.bindingId ?? null);
|
||||||
|
const promptMessage = await this.buildPromptMessage(userId, dto.message);
|
||||||
|
|
||||||
for (const entry of plan) {
|
for (const entry of plan) {
|
||||||
if (entry.kind === "skip") {
|
if (entry.kind === "skip") {
|
||||||
@@ -208,7 +213,7 @@ export class AiService {
|
|||||||
try {
|
try {
|
||||||
const result = await executor.execute(entry.candidate, {
|
const result = await executor.execute(entry.candidate, {
|
||||||
userId,
|
userId,
|
||||||
message: dto.message,
|
message: promptMessage,
|
||||||
sessionId: dto.sessionId ?? null
|
sessionId: dto.sessionId ?? null
|
||||||
});
|
});
|
||||||
const latencyMs = Date.now() - startedAt;
|
const latencyMs = Date.now() - startedAt;
|
||||||
@@ -416,6 +421,85 @@ export class AiService {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async buildPromptMessage(userId: string, userMessage: string): Promise<string> {
|
||||||
|
const taskSummary = await this.buildTaskContextSummary(userId);
|
||||||
|
if (!taskSummary) {
|
||||||
|
return userMessage;
|
||||||
|
}
|
||||||
|
|
||||||
|
return [
|
||||||
|
"你是 TodoList 的 AI 助手,请优先结合用户当前未完成任务给出安排建议。",
|
||||||
|
"以下是系统整理的未完成任务摘要:",
|
||||||
|
taskSummary,
|
||||||
|
"如果用户的问题与任务无关,也可以正常回答;如果相关,请优先考虑优先级、截止时间与执行顺序。",
|
||||||
|
`用户当前问题:${userMessage}`
|
||||||
|
].join("\n\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
private async buildTaskContextSummary(userId: string): Promise<string | null> {
|
||||||
|
const tasks = await this.prismaService.task.findMany({
|
||||||
|
where: {
|
||||||
|
userId,
|
||||||
|
status: {
|
||||||
|
in: [TaskStatus.TODO, TaskStatus.IN_PROGRESS]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
select: {
|
||||||
|
title: true,
|
||||||
|
priority: true,
|
||||||
|
status: true,
|
||||||
|
ddl: true,
|
||||||
|
contentText: true,
|
||||||
|
updatedAt: true
|
||||||
|
},
|
||||||
|
take: 20
|
||||||
|
});
|
||||||
|
|
||||||
|
if (tasks.length === 0) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const sortedTasks = [...tasks].sort((left, right) => {
|
||||||
|
const priorityDiff =
|
||||||
|
this.getPriorityWeight(right.priority) - this.getPriorityWeight(left.priority);
|
||||||
|
if (priorityDiff !== 0) {
|
||||||
|
return priorityDiff;
|
||||||
|
}
|
||||||
|
|
||||||
|
const leftDdl = left.ddl?.getTime() ?? Number.POSITIVE_INFINITY;
|
||||||
|
const rightDdl = right.ddl?.getTime() ?? Number.POSITIVE_INFINITY;
|
||||||
|
if (leftDdl !== rightDdl) {
|
||||||
|
return leftDdl - rightDdl;
|
||||||
|
}
|
||||||
|
|
||||||
|
return right.updatedAt.getTime() - left.updatedAt.getTime();
|
||||||
|
});
|
||||||
|
|
||||||
|
const visibleTasks = sortedTasks.slice(0, this.maxContextTasks);
|
||||||
|
const lines = visibleTasks.map((task, index) => {
|
||||||
|
const parts = [
|
||||||
|
`${index + 1}. ${task.title}`,
|
||||||
|
`优先级:${this.getPriorityLabel(task.priority)}`,
|
||||||
|
`状态:${this.getStatusLabel(task.status)}`,
|
||||||
|
`DDL:${task.ddl ? task.ddl.toISOString() : "未设置"}`
|
||||||
|
];
|
||||||
|
|
||||||
|
const contentSnippet = this.getContentSnippet(task.contentText);
|
||||||
|
if (contentSnippet) {
|
||||||
|
parts.push(`内容摘要:${contentSnippet}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return parts.join(" | ");
|
||||||
|
});
|
||||||
|
|
||||||
|
const omittedCount = sortedTasks.length - visibleTasks.length;
|
||||||
|
if (omittedCount > 0) {
|
||||||
|
lines.push(`其余 ${omittedCount} 项未完成任务已省略。`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return [`共 ${sortedTasks.length} 项未完成任务。`, ...lines].join("\n");
|
||||||
|
}
|
||||||
|
|
||||||
private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt {
|
private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt {
|
||||||
if (error instanceof AiRouteFailureError) {
|
if (error instanceof AiRouteFailureError) {
|
||||||
return {
|
return {
|
||||||
@@ -493,6 +577,68 @@ export class AiService {
|
|||||||
return `${secret.slice(0, 4)}***${secret.slice(-2)}`;
|
return `${secret.slice(0, 4)}***${secret.slice(-2)}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private getPriorityWeight(priority: TaskPriority): number {
|
||||||
|
switch (priority) {
|
||||||
|
case TaskPriority.URGENT:
|
||||||
|
return 4;
|
||||||
|
case TaskPriority.HIGH:
|
||||||
|
return 3;
|
||||||
|
case TaskPriority.MEDIUM:
|
||||||
|
return 2;
|
||||||
|
case TaskPriority.LOW:
|
||||||
|
return 1;
|
||||||
|
default:
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private getPriorityLabel(priority: TaskPriority): string {
|
||||||
|
switch (priority) {
|
||||||
|
case TaskPriority.URGENT:
|
||||||
|
return "紧急";
|
||||||
|
case TaskPriority.HIGH:
|
||||||
|
return "高";
|
||||||
|
case TaskPriority.MEDIUM:
|
||||||
|
return "中";
|
||||||
|
case TaskPriority.LOW:
|
||||||
|
return "低";
|
||||||
|
default:
|
||||||
|
return String(priority);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private getStatusLabel(status: TaskStatus): string {
|
||||||
|
switch (status) {
|
||||||
|
case TaskStatus.TODO:
|
||||||
|
return "待开始";
|
||||||
|
case TaskStatus.IN_PROGRESS:
|
||||||
|
return "进行中";
|
||||||
|
case TaskStatus.DONE:
|
||||||
|
return "已完成";
|
||||||
|
case TaskStatus.ARCHIVED:
|
||||||
|
return "已归档";
|
||||||
|
default:
|
||||||
|
return String(status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private getContentSnippet(contentText: string | null): string | null {
|
||||||
|
if (!contentText) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const normalizedContent = contentText.replace(/\s+/g, " ").trim();
|
||||||
|
if (normalizedContent.length === 0) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normalizedContent.length <= this.maxContextContentLength) {
|
||||||
|
return normalizedContent;
|
||||||
|
}
|
||||||
|
|
||||||
|
return `${normalizedContent.slice(0, this.maxContextContentLength)}...`;
|
||||||
|
}
|
||||||
|
|
||||||
private async recordUsageLog(input: {
|
private async recordUsageLog(input: {
|
||||||
userId: string;
|
userId: string;
|
||||||
channel: AiChannel;
|
channel: AiChannel;
|
||||||
|
|||||||
+110
-3
@@ -1,11 +1,18 @@
|
|||||||
import request from "supertest";
|
import request from "supertest";
|
||||||
import { INestApplication, ValidationPipe } from "@nestjs/common";
|
import { INestApplication, ValidationPipe } from "@nestjs/common";
|
||||||
import { Test, TestingModule } from "@nestjs/testing";
|
import { Test, TestingModule } from "@nestjs/testing";
|
||||||
import { AiChannel, AiProviderBinding, AiPublicPoolConfig } from "../generated/prisma/client";
|
import {
|
||||||
|
AiChannel,
|
||||||
|
AiProviderBinding,
|
||||||
|
AiPublicPoolConfig,
|
||||||
|
TaskPriority,
|
||||||
|
TaskStatus
|
||||||
|
} from "../generated/prisma/client";
|
||||||
import { AiController } from "../src/ai/ai.controller";
|
import { AiController } from "../src/ai/ai.controller";
|
||||||
import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service";
|
import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service";
|
||||||
import { AiService } from "../src/ai/ai.service";
|
import { AiService } from "../src/ai/ai.service";
|
||||||
import {
|
import {
|
||||||
|
AiChatInput,
|
||||||
AiChannelExecutor,
|
AiChannelExecutor,
|
||||||
AiResolvedRouteCandidate,
|
AiResolvedRouteCandidate,
|
||||||
AiRouteFailureError
|
AiRouteFailureError
|
||||||
@@ -25,12 +32,23 @@ type AiUsageLogRecord = {
|
|||||||
errorCode: string | null;
|
errorCode: string | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type AiTaskRecord = {
|
||||||
|
userId: string;
|
||||||
|
title: string;
|
||||||
|
priority: TaskPriority;
|
||||||
|
status: TaskStatus;
|
||||||
|
ddl: Date | null;
|
||||||
|
contentText: string | null;
|
||||||
|
updatedAt: Date;
|
||||||
|
};
|
||||||
|
|
||||||
class InMemoryAiPrismaService {
|
class InMemoryAiPrismaService {
|
||||||
private bindingIdSequence = 1;
|
private bindingIdSequence = 1;
|
||||||
private publicPoolIdSequence = 1;
|
private publicPoolIdSequence = 1;
|
||||||
private bindings: AiProviderBinding[] = [];
|
private bindings: AiProviderBinding[] = [];
|
||||||
private publicPools: AiPublicPoolConfig[] = [];
|
private publicPools: AiPublicPoolConfig[] = [];
|
||||||
private usageLogs: AiUsageLogRecord[] = [];
|
private usageLogs: AiUsageLogRecord[] = [];
|
||||||
|
private tasks: AiTaskRecord[] = [];
|
||||||
|
|
||||||
readonly aiProviderBinding = {
|
readonly aiProviderBinding = {
|
||||||
findMany: async (args: {
|
findMany: async (args: {
|
||||||
@@ -185,6 +203,31 @@ class InMemoryAiPrismaService {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
readonly task = {
|
||||||
|
findMany: async (args: {
|
||||||
|
where: {
|
||||||
|
userId: string;
|
||||||
|
status: {
|
||||||
|
in: TaskStatus[];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
take?: number;
|
||||||
|
}) => {
|
||||||
|
const filteredTasks = this.tasks.filter(
|
||||||
|
(task) => task.userId === args.where.userId && args.where.status.in.includes(task.status)
|
||||||
|
);
|
||||||
|
|
||||||
|
return filteredTasks.slice(0, args.take ?? filteredTasks.length).map((task) => ({
|
||||||
|
title: task.title,
|
||||||
|
priority: task.priority,
|
||||||
|
status: task.status,
|
||||||
|
ddl: task.ddl,
|
||||||
|
contentText: task.contentText,
|
||||||
|
updatedAt: task.updatedAt
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
async $transaction<T>(callback: (tx: InMemoryAiPrismaService) => Promise<T>): Promise<T> {
|
async $transaction<T>(callback: (tx: InMemoryAiPrismaService) => Promise<T>): Promise<T> {
|
||||||
return callback(this);
|
return callback(this);
|
||||||
}
|
}
|
||||||
@@ -211,9 +254,18 @@ class InMemoryAiPrismaService {
|
|||||||
getUsageLogs(): AiUsageLogRecord[] {
|
getUsageLogs(): AiUsageLogRecord[] {
|
||||||
return [...this.usageLogs];
|
return [...this.usageLogs];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
seedTask(task: AiTaskRecord): void {
|
||||||
|
this.tasks.push(task);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class StaticExecutor implements AiChannelExecutor {
|
class StaticExecutor implements AiChannelExecutor {
|
||||||
|
readonly inputs: Array<{
|
||||||
|
candidate: AiResolvedRouteCandidate;
|
||||||
|
message: string;
|
||||||
|
}> = [];
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
private readonly resolver: (channel: AiChannel) => {
|
private readonly resolver: (channel: AiChannel) => {
|
||||||
content?: string;
|
content?: string;
|
||||||
@@ -222,7 +274,12 @@ class StaticExecutor implements AiChannelExecutor {
|
|||||||
}
|
}
|
||||||
) {}
|
) {}
|
||||||
|
|
||||||
async execute(candidate: AiResolvedRouteCandidate) {
|
async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput) {
|
||||||
|
this.inputs.push({
|
||||||
|
candidate,
|
||||||
|
message: input.message
|
||||||
|
});
|
||||||
|
|
||||||
const result = this.resolver(candidate.channel);
|
const result = this.resolver(candidate.channel);
|
||||||
if (result.code) {
|
if (result.code) {
|
||||||
throw new AiRouteFailureError(
|
throw new AiRouteFailureError(
|
||||||
@@ -252,6 +309,7 @@ class StaticExecutor implements AiChannelExecutor {
|
|||||||
describe("AiController (integration)", () => {
|
describe("AiController (integration)", () => {
|
||||||
let app: INestApplication;
|
let app: INestApplication;
|
||||||
let prismaService: InMemoryAiPrismaService;
|
let prismaService: InMemoryAiPrismaService;
|
||||||
|
let astrbotExecutor: StaticExecutor;
|
||||||
|
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
prismaService = new InMemoryAiPrismaService();
|
prismaService = new InMemoryAiPrismaService();
|
||||||
@@ -266,7 +324,7 @@ describe("AiController (integration)", () => {
|
|||||||
content: "公共 AI 已接管"
|
content: "公共 AI 已接管"
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
const astrbotExecutor = new StaticExecutor(() => ({
|
astrbotExecutor = new StaticExecutor(() => ({
|
||||||
content: "AstrBot 已接管"
|
content: "AstrBot 已接管"
|
||||||
}));
|
}));
|
||||||
|
|
||||||
@@ -448,6 +506,55 @@ describe("AiController (integration)", () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should inject unfinished task summary into ai prompt", async () => {
|
||||||
|
prismaService.seedBinding({
|
||||||
|
id: "binding_astrbot_context",
|
||||||
|
userId: "user_1",
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
providerName: "",
|
||||||
|
model: null,
|
||||||
|
configId: "default",
|
||||||
|
configName: null,
|
||||||
|
encryptedApiKey: "abk_astrbot",
|
||||||
|
endpoint: "http://127.0.0.1:6185",
|
||||||
|
isDefault: true,
|
||||||
|
isEnabled: true
|
||||||
|
});
|
||||||
|
prismaService.seedTask({
|
||||||
|
userId: "user_1",
|
||||||
|
title: "今晚提交周报",
|
||||||
|
priority: TaskPriority.URGENT,
|
||||||
|
status: TaskStatus.IN_PROGRESS,
|
||||||
|
ddl: new Date("2026-04-06T12:00:00.000Z"),
|
||||||
|
contentText: "需要汇总 AI 路由、AstrBot 接入和同步模块进度",
|
||||||
|
updatedAt: new Date("2026-04-06T08:00:00.000Z")
|
||||||
|
});
|
||||||
|
prismaService.seedTask({
|
||||||
|
userId: "user_1",
|
||||||
|
title: "整理已完成事项",
|
||||||
|
priority: TaskPriority.LOW,
|
||||||
|
status: TaskStatus.DONE,
|
||||||
|
ddl: null,
|
||||||
|
contentText: "这条任务不应该出现在上下文里",
|
||||||
|
updatedAt: new Date("2026-04-06T07:00:00.000Z")
|
||||||
|
});
|
||||||
|
|
||||||
|
await request(app.getHttpServer())
|
||||||
|
.post("/ai/chat")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.send({
|
||||||
|
message: "帮我安排今天剩余任务"
|
||||||
|
})
|
||||||
|
.expect(201);
|
||||||
|
|
||||||
|
expect(astrbotExecutor.inputs).toHaveLength(1);
|
||||||
|
expect(astrbotExecutor.inputs[0]?.message).toContain("以下是系统整理的未完成任务摘要");
|
||||||
|
expect(astrbotExecutor.inputs[0]?.message).toContain("今晚提交周报");
|
||||||
|
expect(astrbotExecutor.inputs[0]?.message).toContain("优先级:紧急");
|
||||||
|
expect(astrbotExecutor.inputs[0]?.message).not.toContain("整理已完成事项");
|
||||||
|
expect(astrbotExecutor.inputs[0]?.message).toContain("用户当前问题:帮我安排今天剩余任务");
|
||||||
|
});
|
||||||
|
|
||||||
it("should return skipped attempts when no channel is available", async () => {
|
it("should return skipped attempts when no channel is available", async () => {
|
||||||
const response = await request(app.getHttpServer())
|
const response = await request(app.getHttpServer())
|
||||||
.post("/ai/chat")
|
.post("/ai/chat")
|
||||||
|
|||||||
Reference in New Issue
Block a user