Merge pull request #14 from Yaosanqi137/feature/p3-ai-routing
Feature/p3 ai routing
This commit is contained in:
@@ -63,3 +63,11 @@ MAIL_SMTP_PASS="replace-with-smtp-password"
|
||||
# 发件人显示名称与地址
|
||||
MAIL_FROM_NAME="TodoList"
|
||||
MAIL_FROM_ADDRESS="no-reply@example.com"
|
||||
|
||||
# [数据加密] 服务端敏感数据加密主密钥
|
||||
# 用于加密 AI 配置、任务内容、同步 payload、附件元数据等数据库字段
|
||||
# 请使用高强度随机字符串,生产环境务必单独保管
|
||||
DATA_ENCRYPTION_SECRET="replace-with-a-long-random-secret"
|
||||
|
||||
# [对象存储加密] 服务端对象加密策略,默认使用 AES256;如需关闭可填写 NONE
|
||||
S3_SERVER_SIDE_ENCRYPTION="AES256"
|
||||
|
||||
@@ -3,12 +3,13 @@
|
||||
"version": "0.1.0",
|
||||
"description": "TodoList API service",
|
||||
"scripts": {
|
||||
"prisma:generate": "prisma generate",
|
||||
"prisma:generate": "node -e \"require('node:fs').rmSync('generated/prisma', { recursive: true, force: true })\" && prisma generate",
|
||||
"prisma:format": "prisma format",
|
||||
"prisma:validate": "prisma validate",
|
||||
"prebuild": "pnpm run prisma:generate",
|
||||
"pretypecheck": "pnpm run prisma:generate",
|
||||
"pretest": "pnpm run prisma:generate",
|
||||
"data:reencrypt": "node -e \"require('node:fs').rmSync('.tmp-compile', { recursive: true, force: true })\" && tsc -p tsconfig.json --outDir .tmp-compile --noEmit false && node .tmp-compile/scripts/reencrypt-sensitive-data.js && node -e \"require('node:fs').rmSync('.tmp-compile', { recursive: true, force: true })\"",
|
||||
"start": "node dist/main.js",
|
||||
"start:dev": "ts-node-dev --respawn --transpile-only src/main.ts",
|
||||
"build": "tsc -p tsconfig.build.json",
|
||||
|
||||
@@ -63,7 +63,8 @@ enum NotificationStatus {
|
||||
|
||||
model User {
|
||||
id String @id @default(cuid())
|
||||
email String @unique
|
||||
email String
|
||||
emailHash String? @unique
|
||||
nickname String?
|
||||
avatarUrl String?
|
||||
status UserStatus @default(ACTIVE)
|
||||
@@ -97,11 +98,13 @@ model AuthIdentity {
|
||||
provider AuthProvider
|
||||
providerUserId String
|
||||
email String?
|
||||
emailHash String?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
|
||||
@@unique([provider, providerUserId])
|
||||
@@index([emailHash])
|
||||
@@index([userId])
|
||||
@@map("auth_identities")
|
||||
}
|
||||
@@ -273,6 +276,8 @@ model AiProviderBinding {
|
||||
channel AiChannel
|
||||
providerName String
|
||||
model String?
|
||||
configId String?
|
||||
configName String?
|
||||
encryptedApiKey String?
|
||||
endpoint String?
|
||||
isDefault Boolean @default(false)
|
||||
|
||||
@@ -0,0 +1,418 @@
|
||||
import "dotenv/config";
|
||||
import { PrismaPg } from "@prisma/adapter-pg";
|
||||
import { ConfigService } from "@nestjs/config";
|
||||
import { Prisma, PrismaClient } from "../generated/prisma/client";
|
||||
import { DataEncryptionService } from "../src/security/data-encryption.service";
|
||||
|
||||
type MigrationCounter = Record<
|
||||
| "users"
|
||||
| "authIdentities"
|
||||
| "aiBindings"
|
||||
| "publicPools"
|
||||
| "aiUsageLogs"
|
||||
| "tasks"
|
||||
| "attachments"
|
||||
| "syncOperations",
|
||||
number
|
||||
>;
|
||||
|
||||
function createEncryptionService(): DataEncryptionService {
|
||||
const configService = {
|
||||
get: (key: string) => process.env[key]
|
||||
} as ConfigService;
|
||||
|
||||
return new DataEncryptionService(configService);
|
||||
}
|
||||
|
||||
function encryptStringIfNeeded(
|
||||
value: string | null,
|
||||
dataEncryptionService: DataEncryptionService
|
||||
): string | null | undefined {
|
||||
if (value === null || dataEncryptionService.isEncryptedString(value)) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return dataEncryptionService.encryptString(value) ?? null;
|
||||
}
|
||||
|
||||
function assignRequiredEncryptedString<T extends Record<string, unknown>, K extends keyof T>(
|
||||
target: T,
|
||||
key: K,
|
||||
value: string | null | undefined
|
||||
): void {
|
||||
if (typeof value === "string") {
|
||||
target[key] = value as T[K];
|
||||
}
|
||||
}
|
||||
|
||||
function assignOptionalEncryptedString<T extends Record<string, unknown>, K extends keyof T>(
|
||||
target: T,
|
||||
key: K,
|
||||
value: string | null | undefined
|
||||
): void {
|
||||
if (value !== undefined) {
|
||||
target[key] = value as T[K];
|
||||
}
|
||||
}
|
||||
|
||||
function encryptJsonIfNeeded(
|
||||
value: Prisma.JsonValue | null,
|
||||
dataEncryptionService: DataEncryptionService
|
||||
): Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput | undefined {
|
||||
if (value === null) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (typeof value === "string" && dataEncryptionService.isEncryptedString(value)) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return (dataEncryptionService.encryptJson(value as Prisma.InputJsonValue) ?? Prisma.JsonNull) as
|
||||
| Prisma.InputJsonValue
|
||||
| Prisma.NullableJsonNullValueInput;
|
||||
}
|
||||
|
||||
function resolvePlainString(
|
||||
value: string | null,
|
||||
dataEncryptionService: DataEncryptionService
|
||||
): string | null {
|
||||
if (value === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return dataEncryptionService.isEncryptedString(value)
|
||||
? (dataEncryptionService.decryptString(value) ?? null)
|
||||
: value;
|
||||
}
|
||||
|
||||
async function main(): Promise<void> {
|
||||
if (!process.env["DATABASE_URL"]) {
|
||||
throw new Error("缺少 DATABASE_URL,无法执行敏感数据迁移");
|
||||
}
|
||||
|
||||
if (!process.env["DATA_ENCRYPTION_SECRET"]) {
|
||||
throw new Error("缺少 DATA_ENCRYPTION_SECRET,无法执行敏感数据迁移");
|
||||
}
|
||||
|
||||
const prisma = new PrismaClient({
|
||||
adapter: new PrismaPg({
|
||||
connectionString: process.env["DATABASE_URL"]
|
||||
})
|
||||
});
|
||||
const dataEncryptionService = createEncryptionService();
|
||||
const counter: MigrationCounter = {
|
||||
users: 0,
|
||||
authIdentities: 0,
|
||||
aiBindings: 0,
|
||||
publicPools: 0,
|
||||
aiUsageLogs: 0,
|
||||
tasks: 0,
|
||||
attachments: 0,
|
||||
syncOperations: 0
|
||||
};
|
||||
|
||||
try {
|
||||
const users = await prisma.user.findMany({
|
||||
select: {
|
||||
id: true,
|
||||
email: true,
|
||||
emailHash: true,
|
||||
nickname: true,
|
||||
avatarUrl: true
|
||||
}
|
||||
});
|
||||
|
||||
for (const user of users) {
|
||||
const normalizedEmail = resolvePlainString(user.email, dataEncryptionService)?.toLowerCase();
|
||||
if (!normalizedEmail) {
|
||||
continue;
|
||||
}
|
||||
const nextEmailHash = dataEncryptionService.createLookupHash("user.email", normalizedEmail);
|
||||
const data: Prisma.UserUpdateInput = {};
|
||||
const email = encryptStringIfNeeded(user.email, dataEncryptionService);
|
||||
const nickname = encryptStringIfNeeded(user.nickname, dataEncryptionService);
|
||||
const avatarUrl = encryptStringIfNeeded(user.avatarUrl, dataEncryptionService);
|
||||
|
||||
assignRequiredEncryptedString(data, "email", email);
|
||||
if (user.emailHash !== nextEmailHash) {
|
||||
data.emailHash = nextEmailHash;
|
||||
}
|
||||
assignOptionalEncryptedString(data, "nickname", nickname);
|
||||
assignOptionalEncryptedString(data, "avatarUrl", avatarUrl);
|
||||
|
||||
if (Object.keys(data).length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
await prisma.user.update({
|
||||
where: {
|
||||
id: user.id
|
||||
},
|
||||
data
|
||||
});
|
||||
counter.users += 1;
|
||||
}
|
||||
|
||||
const authIdentities = await prisma.authIdentity.findMany({
|
||||
select: {
|
||||
id: true,
|
||||
email: true,
|
||||
emailHash: true
|
||||
}
|
||||
});
|
||||
|
||||
for (const authIdentity of authIdentities) {
|
||||
const data: Prisma.AuthIdentityUpdateInput = {};
|
||||
const email = encryptStringIfNeeded(authIdentity.email, dataEncryptionService);
|
||||
const normalizedIdentityEmail = resolvePlainString(authIdentity.email, dataEncryptionService);
|
||||
const nextEmailHash =
|
||||
normalizedIdentityEmail === null
|
||||
? null
|
||||
: dataEncryptionService.createLookupHash(
|
||||
"auth_identity.email",
|
||||
normalizedIdentityEmail.toLowerCase()
|
||||
);
|
||||
|
||||
assignOptionalEncryptedString(data, "email", email);
|
||||
if (authIdentity.emailHash !== nextEmailHash) {
|
||||
data.emailHash = nextEmailHash;
|
||||
}
|
||||
|
||||
if (Object.keys(data).length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
await prisma.authIdentity.update({
|
||||
where: {
|
||||
id: authIdentity.id
|
||||
},
|
||||
data
|
||||
});
|
||||
counter.authIdentities += 1;
|
||||
}
|
||||
|
||||
const aiBindings = await prisma.aiProviderBinding.findMany({
|
||||
select: {
|
||||
id: true,
|
||||
providerName: true,
|
||||
model: true,
|
||||
configId: true,
|
||||
configName: true,
|
||||
endpoint: true,
|
||||
encryptedApiKey: true
|
||||
}
|
||||
});
|
||||
|
||||
for (const binding of aiBindings) {
|
||||
const data: Prisma.AiProviderBindingUpdateInput = {};
|
||||
const providerName = encryptStringIfNeeded(binding.providerName, dataEncryptionService);
|
||||
const model = encryptStringIfNeeded(binding.model, dataEncryptionService);
|
||||
const configId = encryptStringIfNeeded(binding.configId, dataEncryptionService);
|
||||
const configName = encryptStringIfNeeded(binding.configName, dataEncryptionService);
|
||||
const endpoint = encryptStringIfNeeded(binding.endpoint, dataEncryptionService);
|
||||
const encryptedApiKey = encryptStringIfNeeded(binding.encryptedApiKey, dataEncryptionService);
|
||||
|
||||
assignRequiredEncryptedString(data, "providerName", providerName);
|
||||
assignOptionalEncryptedString(data, "model", model);
|
||||
assignOptionalEncryptedString(data, "configId", configId);
|
||||
assignOptionalEncryptedString(data, "configName", configName);
|
||||
assignOptionalEncryptedString(data, "endpoint", endpoint);
|
||||
assignOptionalEncryptedString(data, "encryptedApiKey", encryptedApiKey);
|
||||
|
||||
if (Object.keys(data).length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
await prisma.aiProviderBinding.update({
|
||||
where: {
|
||||
id: binding.id
|
||||
},
|
||||
data
|
||||
});
|
||||
counter.aiBindings += 1;
|
||||
}
|
||||
|
||||
const publicPools = await prisma.aiPublicPoolConfig.findMany({
|
||||
select: {
|
||||
id: true,
|
||||
providerName: true,
|
||||
model: true,
|
||||
endpoint: true,
|
||||
encryptedApiKey: true
|
||||
}
|
||||
});
|
||||
|
||||
for (const publicPool of publicPools) {
|
||||
const data: Prisma.AiPublicPoolConfigUpdateInput = {};
|
||||
const providerName = encryptStringIfNeeded(publicPool.providerName, dataEncryptionService);
|
||||
const model = encryptStringIfNeeded(publicPool.model, dataEncryptionService);
|
||||
const endpoint = encryptStringIfNeeded(publicPool.endpoint, dataEncryptionService);
|
||||
const encryptedApiKey = encryptStringIfNeeded(
|
||||
publicPool.encryptedApiKey,
|
||||
dataEncryptionService
|
||||
);
|
||||
|
||||
assignOptionalEncryptedString(data, "providerName", providerName);
|
||||
assignOptionalEncryptedString(data, "model", model);
|
||||
assignOptionalEncryptedString(data, "endpoint", endpoint);
|
||||
assignOptionalEncryptedString(data, "encryptedApiKey", encryptedApiKey);
|
||||
|
||||
if (Object.keys(data).length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
await prisma.aiPublicPoolConfig.update({
|
||||
where: {
|
||||
id: publicPool.id
|
||||
},
|
||||
data
|
||||
});
|
||||
counter.publicPools += 1;
|
||||
}
|
||||
|
||||
const aiUsageLogs = await prisma.aiUsageLog.findMany({
|
||||
select: {
|
||||
id: true,
|
||||
providerName: true,
|
||||
model: true
|
||||
}
|
||||
});
|
||||
|
||||
for (const aiUsageLog of aiUsageLogs) {
|
||||
const data: Prisma.AiUsageLogUpdateInput = {};
|
||||
const providerName = encryptStringIfNeeded(aiUsageLog.providerName, dataEncryptionService);
|
||||
const model = encryptStringIfNeeded(aiUsageLog.model, dataEncryptionService);
|
||||
|
||||
assignOptionalEncryptedString(data, "providerName", providerName);
|
||||
assignOptionalEncryptedString(data, "model", model);
|
||||
|
||||
if (Object.keys(data).length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
await prisma.aiUsageLog.update({
|
||||
where: {
|
||||
id: aiUsageLog.id
|
||||
},
|
||||
data
|
||||
});
|
||||
counter.aiUsageLogs += 1;
|
||||
}
|
||||
|
||||
const tasks = await prisma.task.findMany({
|
||||
select: {
|
||||
id: true,
|
||||
title: true,
|
||||
contentJson: true,
|
||||
contentText: true
|
||||
}
|
||||
});
|
||||
|
||||
for (const task of tasks) {
|
||||
const data: Prisma.TaskUpdateInput = {};
|
||||
const title = encryptStringIfNeeded(task.title, dataEncryptionService);
|
||||
const contentJson = encryptJsonIfNeeded(task.contentJson, dataEncryptionService);
|
||||
const contentText = encryptStringIfNeeded(task.contentText, dataEncryptionService);
|
||||
|
||||
assignRequiredEncryptedString(data, "title", title);
|
||||
if (contentJson !== undefined) {
|
||||
data.contentJson = contentJson;
|
||||
}
|
||||
assignOptionalEncryptedString(data, "contentText", contentText);
|
||||
|
||||
if (Object.keys(data).length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
await prisma.task.update({
|
||||
where: {
|
||||
id: task.id
|
||||
},
|
||||
data
|
||||
});
|
||||
counter.tasks += 1;
|
||||
}
|
||||
|
||||
const attachments = await prisma.attachment.findMany({
|
||||
select: {
|
||||
id: true,
|
||||
url: true,
|
||||
fileName: true,
|
||||
checksum: true
|
||||
}
|
||||
});
|
||||
|
||||
for (const attachment of attachments) {
|
||||
const data: Prisma.AttachmentUpdateInput = {};
|
||||
const url = encryptStringIfNeeded(attachment.url, dataEncryptionService);
|
||||
const fileName = encryptStringIfNeeded(attachment.fileName, dataEncryptionService);
|
||||
const checksum = encryptStringIfNeeded(attachment.checksum, dataEncryptionService);
|
||||
|
||||
assignRequiredEncryptedString(data, "url", url);
|
||||
assignOptionalEncryptedString(data, "fileName", fileName);
|
||||
assignOptionalEncryptedString(data, "checksum", checksum);
|
||||
|
||||
if (Object.keys(data).length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
await prisma.attachment.update({
|
||||
where: {
|
||||
id: attachment.id
|
||||
},
|
||||
data
|
||||
});
|
||||
counter.attachments += 1;
|
||||
}
|
||||
|
||||
const syncOperations = await prisma.syncOperation.findMany({
|
||||
select: {
|
||||
id: true,
|
||||
payload: true
|
||||
}
|
||||
});
|
||||
|
||||
for (const operation of syncOperations) {
|
||||
if (operation.payload === null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let nextPayload: string | null = null;
|
||||
if (typeof operation.payload === "string") {
|
||||
if (dataEncryptionService.isEncryptedString(operation.payload)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
nextPayload = dataEncryptionService.encryptString(operation.payload) ?? null;
|
||||
} else {
|
||||
nextPayload =
|
||||
dataEncryptionService.encryptString(JSON.stringify(operation.payload)) ?? null;
|
||||
}
|
||||
|
||||
if (nextPayload === null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
await prisma.syncOperation.update({
|
||||
where: {
|
||||
id: operation.id
|
||||
},
|
||||
data: {
|
||||
payload: nextPayload
|
||||
}
|
||||
});
|
||||
counter.syncOperations += 1;
|
||||
}
|
||||
|
||||
console.log("敏感数据迁移完成");
|
||||
console.log(JSON.stringify(counter, null, 2));
|
||||
} finally {
|
||||
await prisma.$disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
void main().catch((error: unknown) => {
|
||||
const message = error instanceof Error ? error.message : "未知错误";
|
||||
console.error(`敏感数据迁移失败:${message}`);
|
||||
process.exitCode = 1;
|
||||
});
|
||||
@@ -0,0 +1,28 @@
|
||||
import { Injectable } from "@nestjs/common";
|
||||
import { AiChannel } from "../../generated/prisma/client";
|
||||
import { AstrbotProvider } from "./providers/astrbot.provider";
|
||||
import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider";
|
||||
import { AiChannelExecutor } from "./ai.types";
|
||||
|
||||
@Injectable()
|
||||
export class AiProviderRegistryService {
|
||||
private readonly executors = new Map<AiChannel, AiChannelExecutor>();
|
||||
|
||||
constructor(
|
||||
openAiCompatibleProvider: OpenAiCompatibleProvider,
|
||||
astrbotProvider: AstrbotProvider
|
||||
) {
|
||||
this.executors.set(AiChannel.USER_KEY, openAiCompatibleProvider);
|
||||
this.executors.set(AiChannel.PUBLIC_POOL, openAiCompatibleProvider);
|
||||
this.executors.set(AiChannel.ASTRBOT, astrbotProvider);
|
||||
}
|
||||
|
||||
getExecutor(channel: AiChannel): AiChannelExecutor {
|
||||
const executor = this.executors.get(channel);
|
||||
if (!executor) {
|
||||
throw new Error(`未找到 ${channel} 对应的 AI 通道执行器`);
|
||||
}
|
||||
|
||||
return executor;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
import { Injectable } from "@nestjs/common";
|
||||
import { ConfigService } from "@nestjs/config";
|
||||
|
||||
type AiRateLimitBucket = {
|
||||
count: number;
|
||||
resetAt: number;
|
||||
};
|
||||
|
||||
export type AiRateLimitResult =
|
||||
| {
|
||||
allowed: true;
|
||||
}
|
||||
| {
|
||||
allowed: false;
|
||||
reason: "USER" | "IP";
|
||||
retryAfterMs: number;
|
||||
limit: number;
|
||||
windowMs: number;
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
export class AiRateLimitService {
|
||||
private readonly userBuckets = new Map<string, AiRateLimitBucket>();
|
||||
private readonly ipBuckets = new Map<string, AiRateLimitBucket>();
|
||||
private readonly windowMs: number;
|
||||
private readonly userLimit: number;
|
||||
private readonly ipLimit: number;
|
||||
|
||||
constructor(private readonly configService: ConfigService) {
|
||||
this.windowMs = this.readPositiveInt("AI_RATE_LIMIT_WINDOW_MS", 60_000);
|
||||
this.userLimit = this.readPositiveInt("AI_RATE_LIMIT_USER_MAX", 20);
|
||||
this.ipLimit = this.readPositiveInt("AI_RATE_LIMIT_IP_MAX", 60);
|
||||
}
|
||||
|
||||
consume(userId: string, clientIp: string | null): AiRateLimitResult {
|
||||
const now = Date.now();
|
||||
const userBucket = this.getBucket(this.userBuckets, userId, now);
|
||||
if (userBucket.count >= this.userLimit) {
|
||||
return {
|
||||
allowed: false,
|
||||
reason: "USER",
|
||||
retryAfterMs: Math.max(0, userBucket.resetAt - now),
|
||||
limit: this.userLimit,
|
||||
windowMs: this.windowMs
|
||||
};
|
||||
}
|
||||
|
||||
const normalizedIp = this.normalizeIp(clientIp);
|
||||
const ipBucket = normalizedIp ? this.getBucket(this.ipBuckets, normalizedIp, now) : null;
|
||||
if (ipBucket && ipBucket.count >= this.ipLimit) {
|
||||
return {
|
||||
allowed: false,
|
||||
reason: "IP",
|
||||
retryAfterMs: Math.max(0, ipBucket.resetAt - now),
|
||||
limit: this.ipLimit,
|
||||
windowMs: this.windowMs
|
||||
};
|
||||
}
|
||||
|
||||
userBucket.count += 1;
|
||||
if (ipBucket) {
|
||||
ipBucket.count += 1;
|
||||
}
|
||||
|
||||
this.cleanupExpiredBuckets(this.userBuckets, now);
|
||||
this.cleanupExpiredBuckets(this.ipBuckets, now);
|
||||
|
||||
return {
|
||||
allowed: true
|
||||
};
|
||||
}
|
||||
|
||||
private getBucket(
|
||||
buckets: Map<string, AiRateLimitBucket>,
|
||||
key: string,
|
||||
now: number
|
||||
): AiRateLimitBucket {
|
||||
const currentBucket = buckets.get(key);
|
||||
if (!currentBucket || now >= currentBucket.resetAt) {
|
||||
const nextBucket: AiRateLimitBucket = {
|
||||
count: 0,
|
||||
resetAt: now + this.windowMs
|
||||
};
|
||||
buckets.set(key, nextBucket);
|
||||
return nextBucket;
|
||||
}
|
||||
|
||||
return currentBucket;
|
||||
}
|
||||
|
||||
private cleanupExpiredBuckets(buckets: Map<string, AiRateLimitBucket>, now: number): void {
|
||||
if (buckets.size <= 256) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const [key, bucket] of buckets.entries()) {
|
||||
if (now >= bucket.resetAt) {
|
||||
buckets.delete(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private normalizeIp(clientIp: string | null): string | null {
|
||||
if (!clientIp) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const normalizedIp = clientIp.trim();
|
||||
return normalizedIp.length > 0 ? normalizedIp : null;
|
||||
}
|
||||
|
||||
private readPositiveInt(key: string, fallbackValue: number): number {
|
||||
const rawValue = this.configService.get<string | number | undefined>(key);
|
||||
const parsedValue =
|
||||
typeof rawValue === "number" ? rawValue : Number.parseInt(String(rawValue ?? ""), 10);
|
||||
|
||||
if (!Number.isFinite(parsedValue) || parsedValue <= 0) {
|
||||
return fallbackValue;
|
||||
}
|
||||
|
||||
return parsedValue;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
import {
|
||||
Body,
|
||||
Controller,
|
||||
Get,
|
||||
Headers,
|
||||
Ip,
|
||||
Post,
|
||||
Query,
|
||||
UnauthorizedException
|
||||
} from "@nestjs/common";
|
||||
import { AiChatDto } from "./dto/ai-chat.dto";
|
||||
import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto";
|
||||
import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto";
|
||||
import {
|
||||
AiChatResponse,
|
||||
AiService,
|
||||
ListAiBindingsResponse,
|
||||
ListAiUsageLogsResponse,
|
||||
TestAiBindingResponse
|
||||
} from "./ai.service";
|
||||
|
||||
@Controller("ai")
|
||||
export class AiController {
|
||||
constructor(private readonly aiService: AiService) {}
|
||||
|
||||
@Get("bindings")
|
||||
async listBindings(
|
||||
@Headers("x-user-id") userIdHeader: string | string[] | undefined
|
||||
): Promise<ListAiBindingsResponse> {
|
||||
return this.aiService.listBindings(this.resolveUserId(userIdHeader));
|
||||
}
|
||||
|
||||
@Get("usage-logs")
|
||||
async listUsageLogs(
|
||||
@Headers("x-user-id") userIdHeader: string | string[] | undefined,
|
||||
@Query() query: ListAiUsageLogsQueryDto
|
||||
): Promise<ListAiUsageLogsResponse> {
|
||||
return this.aiService.listUsageLogs(this.resolveUserId(userIdHeader), query);
|
||||
}
|
||||
|
||||
@Post("bindings")
|
||||
async upsertBinding(
|
||||
@Headers("x-user-id") userIdHeader: string | string[] | undefined,
|
||||
@Body() body: UpsertAiProviderBindingDto
|
||||
) {
|
||||
return this.aiService.upsertBinding(this.resolveUserId(userIdHeader), body);
|
||||
}
|
||||
|
||||
@Post("bindings/test")
|
||||
async testBinding(
|
||||
@Headers("x-user-id") userIdHeader: string | string[] | undefined,
|
||||
@Body() body: UpsertAiProviderBindingDto
|
||||
): Promise<TestAiBindingResponse> {
|
||||
return this.aiService.testBinding(this.resolveUserId(userIdHeader), body);
|
||||
}
|
||||
|
||||
@Post("chat")
|
||||
async chat(
|
||||
@Headers("x-user-id") userIdHeader: string | string[] | undefined,
|
||||
@Ip() clientIp: string,
|
||||
@Body() body: AiChatDto
|
||||
): Promise<AiChatResponse> {
|
||||
return this.aiService.chat(this.resolveUserId(userIdHeader), body, clientIp);
|
||||
}
|
||||
|
||||
private resolveUserId(userIdHeader: string | string[] | undefined): string {
|
||||
const userId = Array.isArray(userIdHeader) ? userIdHeader[0] : userIdHeader;
|
||||
if (!userId) {
|
||||
throw new UnauthorizedException("缺少用户上下文");
|
||||
}
|
||||
|
||||
return userId;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
import { Module } from "@nestjs/common";
|
||||
import { PrismaModule } from "../prisma/prisma.module";
|
||||
import { AiRateLimitService } from "./ai-rate-limit.service";
|
||||
import { AiController } from "./ai.controller";
|
||||
import { AiProviderRegistryService } from "./ai-provider-registry.service";
|
||||
import { AiService } from "./ai.service";
|
||||
import { AstrbotProvider } from "./providers/astrbot.provider";
|
||||
import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider";
|
||||
|
||||
@Module({
|
||||
imports: [PrismaModule],
|
||||
controllers: [AiController],
|
||||
providers: [
|
||||
AiService,
|
||||
AiRateLimitService,
|
||||
AiProviderRegistryService,
|
||||
OpenAiCompatibleProvider,
|
||||
AstrbotProvider
|
||||
]
|
||||
})
|
||||
export class AiModule {}
|
||||
@@ -0,0 +1,988 @@
|
||||
import {
|
||||
BadGatewayException,
|
||||
BadRequestException,
|
||||
HttpException,
|
||||
HttpStatus,
|
||||
Injectable,
|
||||
Logger
|
||||
} from "@nestjs/common";
|
||||
import {
|
||||
AiChannel,
|
||||
AiUsageLog,
|
||||
AiProviderBinding,
|
||||
AiPublicPoolConfig,
|
||||
Prisma,
|
||||
TaskPriority,
|
||||
TaskStatus
|
||||
} from "../../generated/prisma/client";
|
||||
import { PrismaService } from "../prisma/prisma.service";
|
||||
import { DataEncryptionService } from "../security/data-encryption.service";
|
||||
import { AiRateLimitService } from "./ai-rate-limit.service";
|
||||
import { AiProviderRegistryService } from "./ai-provider-registry.service";
|
||||
import { AiChatDto } from "./dto/ai-chat.dto";
|
||||
import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto";
|
||||
import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto";
|
||||
import {
|
||||
AiResolvedRouteCandidate,
|
||||
AiRouteAttempt,
|
||||
AiRouteFailureError,
|
||||
AiUsageMetrics
|
||||
} from "./ai.types";
|
||||
|
||||
type AiBindingSummary = {
|
||||
id: string;
|
||||
channel: AiChannel;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
configId: string | null;
|
||||
configName: string | null;
|
||||
endpoint: string | null;
|
||||
isEnabled: boolean;
|
||||
hasApiKey: boolean;
|
||||
maskedApiKey: string | null;
|
||||
updatedAt: string;
|
||||
};
|
||||
|
||||
type AiRoutePlanEntry =
|
||||
| {
|
||||
kind: "candidate";
|
||||
candidate: AiResolvedRouteCandidate;
|
||||
}
|
||||
| {
|
||||
kind: "skip";
|
||||
attempt: AiRouteAttempt;
|
||||
};
|
||||
|
||||
export type ListAiBindingsResponse = {
|
||||
routeOrder: AiChannel[];
|
||||
bindings: AiBindingSummary[];
|
||||
publicPool: {
|
||||
enabled: boolean;
|
||||
providerName: string | null;
|
||||
model: string | null;
|
||||
hasApiKey: boolean;
|
||||
} | null;
|
||||
};
|
||||
|
||||
type AiUsageLogSummary = {
|
||||
id: string;
|
||||
channel: AiChannel;
|
||||
providerName: string | null;
|
||||
model: string | null;
|
||||
promptTokens: number;
|
||||
completionTokens: number;
|
||||
totalTokens: number;
|
||||
latencyMs: number | null;
|
||||
success: boolean;
|
||||
errorCode: string | null;
|
||||
createdAt: string;
|
||||
};
|
||||
|
||||
type AiContextTaskItem = {
|
||||
id: string;
|
||||
title: string;
|
||||
priority: TaskPriority;
|
||||
status: TaskStatus;
|
||||
ddl: Date | null;
|
||||
contentText: string | null;
|
||||
updatedAt: Date;
|
||||
};
|
||||
|
||||
export type ListAiUsageLogsResponse = {
|
||||
items: AiUsageLogSummary[];
|
||||
page: number;
|
||||
pageSize: number;
|
||||
total: number;
|
||||
};
|
||||
|
||||
export type AiChatResponse = {
|
||||
channel: AiChannel;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
content: string;
|
||||
sessionId: string | null;
|
||||
attempts: AiRouteAttempt[];
|
||||
};
|
||||
|
||||
export type TestAiBindingResponse =
|
||||
| {
|
||||
success: true;
|
||||
channel: AiChannel;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
contentPreview: string;
|
||||
}
|
||||
| {
|
||||
success: false;
|
||||
channel: AiChannel;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
code: string;
|
||||
message: string;
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
export class AiService {
|
||||
private readonly logger = new Logger(AiService.name);
|
||||
private readonly maxContextTasks = 6;
|
||||
private readonly maxContextContentLength = 80;
|
||||
|
||||
constructor(
|
||||
private readonly prismaService: PrismaService,
|
||||
private readonly aiProviderRegistryService: AiProviderRegistryService,
|
||||
private readonly dataEncryptionService: DataEncryptionService,
|
||||
private readonly aiRateLimitService: AiRateLimitService
|
||||
) {}
|
||||
|
||||
async listBindings(userId: string): Promise<ListAiBindingsResponse> {
|
||||
const [bindings, publicPool] = await Promise.all([
|
||||
this.prismaService.aiProviderBinding.findMany({
|
||||
where: {
|
||||
userId
|
||||
},
|
||||
orderBy: [{ updatedAt: "desc" }]
|
||||
}),
|
||||
this.prismaService.aiPublicPoolConfig.findFirst({
|
||||
orderBy: {
|
||||
updatedAt: "desc"
|
||||
}
|
||||
})
|
||||
]);
|
||||
|
||||
const latestBindings = this.pickLatestBindingsByChannel(bindings);
|
||||
|
||||
return {
|
||||
routeOrder: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL],
|
||||
bindings: latestBindings.map((binding) => this.serializeBinding(binding)),
|
||||
publicPool: publicPool
|
||||
? {
|
||||
enabled: publicPool.enabled,
|
||||
providerName: this.readDecryptedString(publicPool.providerName),
|
||||
model: this.readDecryptedString(publicPool.model),
|
||||
hasApiKey: Boolean(publicPool.encryptedApiKey)
|
||||
}
|
||||
: null
|
||||
};
|
||||
}
|
||||
|
||||
async listUsageLogs(
|
||||
userId: string,
|
||||
query: ListAiUsageLogsQueryDto
|
||||
): Promise<ListAiUsageLogsResponse> {
|
||||
const page = query.page ?? 1;
|
||||
const pageSize = query.pageSize ?? 20;
|
||||
const skip = (page - 1) * pageSize;
|
||||
const where: Prisma.AiUsageLogWhereInput = {
|
||||
userId
|
||||
};
|
||||
|
||||
if (query.channel) {
|
||||
where.channel = query.channel;
|
||||
}
|
||||
|
||||
if (query.success !== undefined) {
|
||||
where.success = query.success;
|
||||
}
|
||||
|
||||
const [items, total] = await Promise.all([
|
||||
this.prismaService.aiUsageLog.findMany({
|
||||
where,
|
||||
orderBy: {
|
||||
createdAt: "desc"
|
||||
},
|
||||
skip,
|
||||
take: pageSize
|
||||
}),
|
||||
this.prismaService.aiUsageLog.count({
|
||||
where
|
||||
})
|
||||
]);
|
||||
|
||||
return {
|
||||
items: items.map((item) => this.serializeUsageLog(item)),
|
||||
page,
|
||||
pageSize,
|
||||
total
|
||||
};
|
||||
}
|
||||
|
||||
async upsertBinding(userId: string, dto: UpsertAiProviderBindingDto): Promise<AiBindingSummary> {
|
||||
if (dto.channel === AiChannel.PUBLIC_POOL) {
|
||||
throw new BadRequestException("公共 AI 通道只能由管理员配置");
|
||||
}
|
||||
|
||||
this.validateBindingInput(dto);
|
||||
|
||||
const result = await this.prismaService.$transaction(async (tx) => {
|
||||
const existingBinding = await tx.aiProviderBinding.findFirst({
|
||||
where: {
|
||||
userId,
|
||||
channel: dto.channel
|
||||
},
|
||||
orderBy: {
|
||||
updatedAt: "desc"
|
||||
}
|
||||
});
|
||||
|
||||
if (!existingBinding) {
|
||||
return tx.aiProviderBinding.create({
|
||||
data: {
|
||||
userId,
|
||||
channel: dto.channel,
|
||||
providerName: this.encryptRequiredString(this.normalizeProviderName(dto.providerName)),
|
||||
model: this.encryptOptionalString(dto.model),
|
||||
configId: this.encryptOptionalString(dto.configId),
|
||||
configName: this.encryptOptionalString(dto.configName),
|
||||
endpoint: this.encryptOptionalString(dto.endpoint),
|
||||
encryptedApiKey: this.encryptOptionalString(dto.apiKey),
|
||||
isEnabled: dto.isEnabled ?? true
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const updateData: Prisma.AiProviderBindingUpdateInput = {
|
||||
channel: dto.channel,
|
||||
providerName: this.encryptRequiredString(this.normalizeProviderName(dto.providerName)),
|
||||
model: this.encryptOptionalString(dto.model),
|
||||
configId: this.encryptOptionalString(dto.configId),
|
||||
configName: this.encryptOptionalString(dto.configName),
|
||||
isEnabled: dto.isEnabled ?? existingBinding.isEnabled
|
||||
};
|
||||
|
||||
if (dto.endpoint !== undefined) {
|
||||
updateData.endpoint = this.encryptOptionalString(dto.endpoint);
|
||||
}
|
||||
|
||||
if (dto.apiKey !== undefined) {
|
||||
updateData.encryptedApiKey = this.encryptOptionalString(dto.apiKey);
|
||||
}
|
||||
|
||||
return tx.aiProviderBinding.update({
|
||||
where: {
|
||||
id: existingBinding.id
|
||||
},
|
||||
data: updateData
|
||||
});
|
||||
});
|
||||
|
||||
return this.serializeBinding(result);
|
||||
}
|
||||
|
||||
async testBinding(
|
||||
userId: string,
|
||||
dto: UpsertAiProviderBindingDto
|
||||
): Promise<TestAiBindingResponse> {
|
||||
if (dto.channel === AiChannel.PUBLIC_POOL) {
|
||||
throw new BadRequestException("公共 AI 通道不能由用户自行测试");
|
||||
}
|
||||
|
||||
const candidate = await this.buildTestCandidate(userId, dto);
|
||||
const executor = this.aiProviderRegistryService.getExecutor(candidate.channel);
|
||||
|
||||
try {
|
||||
const result = await executor.execute(candidate, {
|
||||
userId,
|
||||
message: "请只回复“连接成功”,不要添加其他内容。",
|
||||
sessionId: null
|
||||
});
|
||||
|
||||
return {
|
||||
success: true,
|
||||
channel: result.channel,
|
||||
providerName: result.providerName,
|
||||
model: result.model,
|
||||
contentPreview: this.limitPreviewText(result.content)
|
||||
};
|
||||
} catch (error) {
|
||||
if (error instanceof AiRouteFailureError) {
|
||||
return {
|
||||
success: false,
|
||||
channel: error.channel,
|
||||
providerName: error.providerName,
|
||||
model: candidate.model,
|
||||
code: error.code,
|
||||
message: error.message
|
||||
};
|
||||
}
|
||||
|
||||
if (error instanceof Error) {
|
||||
return {
|
||||
success: false,
|
||||
channel: candidate.channel,
|
||||
providerName: candidate.providerName,
|
||||
model: candidate.model,
|
||||
code: "UNKNOWN_ERROR",
|
||||
message: error.message
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
success: false,
|
||||
channel: candidate.channel,
|
||||
providerName: candidate.providerName,
|
||||
model: candidate.model,
|
||||
code: "UNKNOWN_ERROR",
|
||||
message: "未知错误"
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async chat(
|
||||
userId: string,
|
||||
dto: AiChatDto,
|
||||
clientIp: string | null = null
|
||||
): Promise<AiChatResponse> {
|
||||
const rateLimitResult = this.aiRateLimitService.consume(userId, clientIp);
|
||||
if (!rateLimitResult.allowed) {
|
||||
throw new HttpException(
|
||||
{
|
||||
message: "AI 请求过于频繁,请稍后再试",
|
||||
code: "AI_RATE_LIMITED",
|
||||
dimension: rateLimitResult.reason === "USER" ? "user" : "ip",
|
||||
retryAfterMs: rateLimitResult.retryAfterMs,
|
||||
limit: rateLimitResult.limit,
|
||||
windowMs: rateLimitResult.windowMs
|
||||
},
|
||||
HttpStatus.TOO_MANY_REQUESTS
|
||||
);
|
||||
}
|
||||
|
||||
const attempts: AiRouteAttempt[] = [];
|
||||
const plan = await this.buildRoutePlan(userId, dto.channel ?? null);
|
||||
const promptMessage = await this.buildPromptMessage(userId, dto.message, dto.localTasks ?? []);
|
||||
|
||||
for (const entry of plan) {
|
||||
if (entry.kind === "skip") {
|
||||
attempts.push(entry.attempt);
|
||||
continue;
|
||||
}
|
||||
|
||||
const executor = this.aiProviderRegistryService.getExecutor(entry.candidate.channel);
|
||||
const startedAt = Date.now();
|
||||
|
||||
try {
|
||||
const result = await executor.execute(entry.candidate, {
|
||||
userId,
|
||||
message: promptMessage,
|
||||
sessionId: dto.sessionId ?? null
|
||||
});
|
||||
const latencyMs = Date.now() - startedAt;
|
||||
|
||||
attempts.push({
|
||||
channel: result.channel,
|
||||
providerName: result.providerName,
|
||||
model: result.model,
|
||||
status: "success",
|
||||
reasonCode: null,
|
||||
reasonMessage: null
|
||||
});
|
||||
await this.recordUsageLog({
|
||||
userId,
|
||||
channel: result.channel,
|
||||
providerName: result.providerName,
|
||||
model: result.model,
|
||||
usage: result.usage,
|
||||
latencyMs,
|
||||
success: true,
|
||||
errorCode: null
|
||||
});
|
||||
|
||||
return {
|
||||
channel: result.channel,
|
||||
providerName: result.providerName,
|
||||
model: result.model,
|
||||
content: result.content,
|
||||
sessionId: result.sessionId,
|
||||
attempts
|
||||
};
|
||||
} catch (error) {
|
||||
const latencyMs = Date.now() - startedAt;
|
||||
const failureAttempt = this.toFailureAttempt(entry.candidate, error);
|
||||
attempts.push(failureAttempt);
|
||||
await this.recordUsageLog({
|
||||
userId,
|
||||
channel: failureAttempt.channel,
|
||||
providerName: failureAttempt.providerName,
|
||||
model: failureAttempt.model,
|
||||
usage: null,
|
||||
latencyMs,
|
||||
success: false,
|
||||
errorCode: failureAttempt.reasonCode
|
||||
});
|
||||
this.logger.warn(
|
||||
`AI 通道降级:channel=${failureAttempt.channel} provider=${failureAttempt.providerName ?? "unknown"} code=${failureAttempt.reasonCode ?? "UNKNOWN"} message=${failureAttempt.reasonMessage ?? "unknown"}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
throw new BadGatewayException({
|
||||
message: "当前没有可用的 AI 通道,请稍后重试",
|
||||
attempts
|
||||
});
|
||||
}
|
||||
|
||||
private async buildRoutePlan(
|
||||
userId: string,
|
||||
selectedChannel: AiChannel | null
|
||||
): Promise<AiRoutePlanEntry[]> {
|
||||
const plan: AiRoutePlanEntry[] = [];
|
||||
const targetChannels = selectedChannel
|
||||
? [selectedChannel]
|
||||
: [AiChannel.USER_KEY, AiChannel.ASTRBOT, AiChannel.PUBLIC_POOL];
|
||||
|
||||
for (const channel of targetChannels) {
|
||||
if (channel === AiChannel.PUBLIC_POOL) {
|
||||
const publicPool = await this.findEnabledPublicPool();
|
||||
if (publicPool) {
|
||||
plan.push({
|
||||
kind: "candidate",
|
||||
candidate: this.toPublicPoolCandidate(publicPool)
|
||||
});
|
||||
} else {
|
||||
plan.push({
|
||||
kind: "skip",
|
||||
attempt: {
|
||||
channel: AiChannel.PUBLIC_POOL,
|
||||
providerName: null,
|
||||
model: null,
|
||||
status: "skipped",
|
||||
reasonCode: "PUBLIC_POOL_DISABLED",
|
||||
reasonMessage: "公共 AI 通道未开启"
|
||||
}
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const binding = await this.findPreferredBinding(userId, channel);
|
||||
if (binding) {
|
||||
plan.push({
|
||||
kind: "candidate",
|
||||
candidate: this.toBindingCandidate(binding)
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
plan.push({
|
||||
kind: "skip",
|
||||
attempt: {
|
||||
channel,
|
||||
providerName: null,
|
||||
model: null,
|
||||
status: "skipped",
|
||||
reasonCode: "CHANNEL_NOT_CONFIGURED",
|
||||
reasonMessage:
|
||||
channel === AiChannel.USER_KEY
|
||||
? "当前用户未配置可用的自备 Key 通道"
|
||||
: "当前用户未配置可用的 AstrBot 通道"
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return plan;
|
||||
}
|
||||
|
||||
private async findPreferredBinding(
|
||||
userId: string,
|
||||
channel: AiChannel
|
||||
): Promise<AiProviderBinding | null> {
|
||||
return this.prismaService.aiProviderBinding.findFirst({
|
||||
where: {
|
||||
userId,
|
||||
channel,
|
||||
isEnabled: true
|
||||
},
|
||||
orderBy: {
|
||||
updatedAt: "desc"
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private async findEnabledPublicPool(): Promise<AiPublicPoolConfig | null> {
|
||||
return this.prismaService.aiPublicPoolConfig.findFirst({
|
||||
where: {
|
||||
enabled: true
|
||||
},
|
||||
orderBy: {
|
||||
updatedAt: "desc"
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private async buildTestCandidate(
|
||||
userId: string,
|
||||
dto: UpsertAiProviderBindingDto
|
||||
): Promise<AiResolvedRouteCandidate> {
|
||||
const existingBinding = await this.prismaService.aiProviderBinding.findFirst({
|
||||
where: {
|
||||
userId,
|
||||
channel: dto.channel
|
||||
},
|
||||
orderBy: {
|
||||
updatedAt: "desc"
|
||||
}
|
||||
});
|
||||
|
||||
const mergedDto: UpsertAiProviderBindingDto = {
|
||||
channel: dto.channel,
|
||||
providerName:
|
||||
dto.providerName ?? this.readDecryptedString(existingBinding?.providerName ?? null) ?? "",
|
||||
model: dto.model ?? this.readDecryptedString(existingBinding?.model ?? null) ?? undefined,
|
||||
configId:
|
||||
dto.configId ?? this.readDecryptedString(existingBinding?.configId ?? null) ?? undefined,
|
||||
configName:
|
||||
dto.configName ??
|
||||
this.readDecryptedString(existingBinding?.configName ?? null) ??
|
||||
undefined,
|
||||
endpoint:
|
||||
dto.endpoint ?? this.readDecryptedString(existingBinding?.endpoint ?? null) ?? undefined,
|
||||
apiKey:
|
||||
dto.apiKey ??
|
||||
this.readDecryptedString(existingBinding?.encryptedApiKey ?? null) ??
|
||||
undefined,
|
||||
isEnabled: dto.isEnabled ?? existingBinding?.isEnabled ?? true
|
||||
};
|
||||
|
||||
this.validateBindingInput(mergedDto);
|
||||
|
||||
return {
|
||||
channel: mergedDto.channel,
|
||||
source: existingBinding ? "binding" : "binding",
|
||||
sourceId: existingBinding?.id ?? null,
|
||||
providerName: this.normalizeProviderName(mergedDto.providerName),
|
||||
model: this.normalizeOptionalString(mergedDto.model),
|
||||
configId: this.normalizeOptionalString(mergedDto.configId),
|
||||
configName: this.normalizeOptionalString(mergedDto.configName),
|
||||
endpoint: this.normalizeOptionalString(mergedDto.endpoint),
|
||||
apiKey: this.normalizeOptionalString(mergedDto.apiKey)
|
||||
};
|
||||
}
|
||||
|
||||
private toBindingCandidate(binding: AiProviderBinding): AiResolvedRouteCandidate {
|
||||
return {
|
||||
channel: binding.channel,
|
||||
source: "binding",
|
||||
sourceId: binding.id,
|
||||
providerName: this.readDecryptedString(binding.providerName) ?? "",
|
||||
model: this.readDecryptedString(binding.model),
|
||||
configId: this.readDecryptedString(binding.configId),
|
||||
configName: this.readDecryptedString(binding.configName),
|
||||
endpoint: this.readDecryptedString(binding.endpoint),
|
||||
apiKey: this.readDecryptedString(binding.encryptedApiKey)
|
||||
};
|
||||
}
|
||||
|
||||
private toPublicPoolCandidate(publicPool: AiPublicPoolConfig): AiResolvedRouteCandidate {
|
||||
return {
|
||||
channel: AiChannel.PUBLIC_POOL,
|
||||
source: "public_pool",
|
||||
sourceId: publicPool.id,
|
||||
providerName: this.readDecryptedString(publicPool.providerName) ?? "public-pool",
|
||||
model: this.readDecryptedString(publicPool.model),
|
||||
configId: null,
|
||||
configName: null,
|
||||
endpoint: this.readDecryptedString(publicPool.endpoint),
|
||||
apiKey: this.readDecryptedString(publicPool.encryptedApiKey)
|
||||
};
|
||||
}
|
||||
|
||||
private serializeBinding(binding: AiProviderBinding): AiBindingSummary {
|
||||
const decryptedProviderName = this.readDecryptedString(binding.providerName) ?? "";
|
||||
const decryptedModel = this.readDecryptedString(binding.model);
|
||||
const decryptedConfigId = this.readDecryptedString(binding.configId);
|
||||
const decryptedConfigName = this.readDecryptedString(binding.configName);
|
||||
const decryptedEndpoint = this.readDecryptedString(binding.endpoint);
|
||||
const decryptedApiKey = this.readDecryptedString(binding.encryptedApiKey);
|
||||
|
||||
return {
|
||||
id: binding.id,
|
||||
channel: binding.channel,
|
||||
providerName: decryptedProviderName,
|
||||
model: decryptedModel,
|
||||
configId: decryptedConfigId,
|
||||
configName: decryptedConfigName,
|
||||
endpoint: decryptedEndpoint,
|
||||
isEnabled: binding.isEnabled,
|
||||
hasApiKey: Boolean(binding.encryptedApiKey),
|
||||
maskedApiKey: this.maskSecret(decryptedApiKey),
|
||||
updatedAt: binding.updatedAt.toISOString()
|
||||
};
|
||||
}
|
||||
|
||||
private pickLatestBindingsByChannel(bindings: AiProviderBinding[]): AiProviderBinding[] {
|
||||
const bindingMap = new Map<AiChannel, AiProviderBinding>();
|
||||
|
||||
for (const binding of bindings) {
|
||||
if (!bindingMap.has(binding.channel)) {
|
||||
bindingMap.set(binding.channel, binding);
|
||||
}
|
||||
}
|
||||
|
||||
return [AiChannel.USER_KEY, AiChannel.ASTRBOT]
|
||||
.map((channel) => bindingMap.get(channel) ?? null)
|
||||
.filter((binding): binding is AiProviderBinding => binding !== null);
|
||||
}
|
||||
|
||||
private serializeUsageLog(log: AiUsageLog): AiUsageLogSummary {
|
||||
return {
|
||||
id: log.id,
|
||||
channel: log.channel,
|
||||
providerName: this.readDecryptedString(log.providerName),
|
||||
model: this.readDecryptedString(log.model),
|
||||
promptTokens: log.promptTokens,
|
||||
completionTokens: log.completionTokens,
|
||||
totalTokens: log.totalTokens,
|
||||
latencyMs: log.latencyMs,
|
||||
success: log.success,
|
||||
errorCode: log.errorCode,
|
||||
createdAt: log.createdAt.toISOString()
|
||||
};
|
||||
}
|
||||
|
||||
private async buildPromptMessage(
|
||||
userId: string,
|
||||
userMessage: string,
|
||||
localTasks: NonNullable<AiChatDto["localTasks"]>
|
||||
): Promise<string> {
|
||||
const taskSummary = await this.buildTaskContextSummary(userId, localTasks);
|
||||
if (!taskSummary) {
|
||||
return userMessage;
|
||||
}
|
||||
|
||||
return [
|
||||
"你是 TodoList 的 AI 助手,需要结合用户当前待办提供任务统筹建议。",
|
||||
"以下是系统整理的未完成任务摘要:",
|
||||
taskSummary,
|
||||
"请优先根据这些任务的紧急度、截止时间和执行顺序回答,并给出明确可执行的建议。",
|
||||
`用户当前问题:${userMessage}`
|
||||
].join("\n\n");
|
||||
}
|
||||
|
||||
private async buildTaskContextSummary(
|
||||
userId: string,
|
||||
localTasks: NonNullable<AiChatDto["localTasks"]>
|
||||
): Promise<string | null> {
|
||||
const tasks = await this.prismaService.task.findMany({
|
||||
where: {
|
||||
userId,
|
||||
status: {
|
||||
in: [TaskStatus.TODO, TaskStatus.IN_PROGRESS]
|
||||
}
|
||||
},
|
||||
select: {
|
||||
id: true,
|
||||
title: true,
|
||||
priority: true,
|
||||
status: true,
|
||||
ddl: true,
|
||||
contentText: true,
|
||||
updatedAt: true
|
||||
},
|
||||
take: 20
|
||||
});
|
||||
|
||||
const sortedTasks = this.sortContextTasks(this.mergeContextTasks(tasks, localTasks));
|
||||
if (sortedTasks.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const visibleTasks = sortedTasks.slice(0, this.maxContextTasks);
|
||||
const lines = visibleTasks.map((task, index) => {
|
||||
const parts = [
|
||||
`${index + 1}. ${task.title}`,
|
||||
`优先级:${this.getPriorityLabel(task.priority)}`,
|
||||
`状态:${this.getStatusLabel(task.status)}`,
|
||||
`DDL:${task.ddl ? task.ddl.toISOString() : "未设置"}`
|
||||
];
|
||||
|
||||
const contentSnippet = this.getContentSnippet(task.contentText);
|
||||
if (contentSnippet) {
|
||||
parts.push(`内容摘要:${contentSnippet}`);
|
||||
}
|
||||
|
||||
return parts.join(" | ");
|
||||
});
|
||||
|
||||
const omittedCount = sortedTasks.length - visibleTasks.length;
|
||||
if (omittedCount > 0) {
|
||||
lines.push(`另有 ${omittedCount} 条任务已省略。`);
|
||||
}
|
||||
|
||||
return [`共 ${sortedTasks.length} 条未完成任务。`, ...lines].join("\n");
|
||||
}
|
||||
|
||||
private mergeContextTasks(
|
||||
databaseTasks: Array<{
|
||||
id: string;
|
||||
title: string;
|
||||
priority: TaskPriority;
|
||||
status: TaskStatus;
|
||||
ddl: Date | null;
|
||||
contentText: string | null;
|
||||
updatedAt: Date;
|
||||
}>,
|
||||
localTasks: NonNullable<AiChatDto["localTasks"]>
|
||||
): AiContextTaskItem[] {
|
||||
const taskMap = new Map<string, AiContextTaskItem>();
|
||||
|
||||
for (const task of databaseTasks) {
|
||||
taskMap.set(task.id, {
|
||||
id: task.id,
|
||||
title: this.readDecryptedString(task.title) ?? "未命名任务",
|
||||
priority: task.priority,
|
||||
status: task.status,
|
||||
ddl: task.ddl,
|
||||
contentText: this.readDecryptedString(task.contentText),
|
||||
updatedAt: task.updatedAt
|
||||
});
|
||||
}
|
||||
|
||||
for (const task of localTasks) {
|
||||
if (task.status !== TaskStatus.TODO && task.status !== TaskStatus.IN_PROGRESS) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const currentTask = taskMap.get(task.id);
|
||||
const nextTask: AiContextTaskItem = {
|
||||
id: task.id,
|
||||
title: task.title.trim().length > 0 ? task.title.trim() : "未命名任务",
|
||||
priority: task.priority,
|
||||
status: task.status,
|
||||
ddl: typeof task.ddlAt === "number" ? new Date(task.ddlAt) : null,
|
||||
contentText:
|
||||
typeof task.contentText === "string" && task.contentText.trim().length > 0
|
||||
? task.contentText
|
||||
: null,
|
||||
updatedAt: new Date(task.updatedAt)
|
||||
};
|
||||
|
||||
if (!currentTask || nextTask.updatedAt.getTime() >= currentTask.updatedAt.getTime()) {
|
||||
taskMap.set(task.id, nextTask);
|
||||
}
|
||||
}
|
||||
|
||||
return [...taskMap.values()].filter(
|
||||
(task) => task.status === TaskStatus.TODO || task.status === TaskStatus.IN_PROGRESS
|
||||
);
|
||||
}
|
||||
|
||||
private sortContextTasks(tasks: AiContextTaskItem[]): AiContextTaskItem[] {
|
||||
return [...tasks].sort((left, right) => {
|
||||
const priorityDiff =
|
||||
this.getPriorityWeight(right.priority) - this.getPriorityWeight(left.priority);
|
||||
if (priorityDiff !== 0) {
|
||||
return priorityDiff;
|
||||
}
|
||||
|
||||
const leftDdl = left.ddl?.getTime() ?? Number.POSITIVE_INFINITY;
|
||||
const rightDdl = right.ddl?.getTime() ?? Number.POSITIVE_INFINITY;
|
||||
if (leftDdl !== rightDdl) {
|
||||
return leftDdl - rightDdl;
|
||||
}
|
||||
|
||||
return right.updatedAt.getTime() - left.updatedAt.getTime();
|
||||
});
|
||||
}
|
||||
|
||||
private toFailureAttempt(candidate: AiResolvedRouteCandidate, error: unknown): AiRouteAttempt {
|
||||
if (error instanceof AiRouteFailureError) {
|
||||
return {
|
||||
channel: error.channel,
|
||||
providerName: error.providerName,
|
||||
model: candidate.model,
|
||||
status: "failed",
|
||||
reasonCode: error.code,
|
||||
reasonMessage: error.message
|
||||
};
|
||||
}
|
||||
|
||||
if (error instanceof Error) {
|
||||
return {
|
||||
channel: candidate.channel,
|
||||
providerName: candidate.providerName,
|
||||
model: candidate.model,
|
||||
status: "failed",
|
||||
reasonCode: "UNKNOWN_ERROR",
|
||||
reasonMessage: error.message
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
channel: candidate.channel,
|
||||
providerName: candidate.providerName,
|
||||
model: candidate.model,
|
||||
status: "failed",
|
||||
reasonCode: "UNKNOWN_ERROR",
|
||||
reasonMessage: "未知错误"
|
||||
};
|
||||
}
|
||||
|
||||
private normalizeOptionalString(value: string | undefined): string | null {
|
||||
if (value === undefined) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const normalizedValue = value.trim();
|
||||
return normalizedValue.length > 0 ? normalizedValue : null;
|
||||
}
|
||||
|
||||
private normalizeProviderName(value: string | undefined): string {
|
||||
return this.normalizeOptionalString(value) ?? "";
|
||||
}
|
||||
|
||||
private encryptOptionalString(value: string | undefined): string | null | undefined {
|
||||
const normalizedValue = this.normalizeOptionalString(value);
|
||||
return this.dataEncryptionService.encryptString(normalizedValue);
|
||||
}
|
||||
|
||||
private encryptRequiredString(value: string): string {
|
||||
const encryptedValue = this.dataEncryptionService.encryptString(value);
|
||||
if (!encryptedValue) {
|
||||
throw new BadRequestException("敏感配置加密失败");
|
||||
}
|
||||
|
||||
return encryptedValue;
|
||||
}
|
||||
|
||||
private readDecryptedString(value: string | null): string | null {
|
||||
const decryptedValue = this.dataEncryptionService.decryptString(value);
|
||||
return typeof decryptedValue === "string" ? decryptedValue : null;
|
||||
}
|
||||
|
||||
private validateBindingInput(dto: UpsertAiProviderBindingDto): void {
|
||||
const providerName = this.normalizeOptionalString(dto.providerName);
|
||||
const configId = this.normalizeOptionalString(dto.configId);
|
||||
const configName = this.normalizeOptionalString(dto.configName);
|
||||
|
||||
if (dto.channel === AiChannel.ASTRBOT) {
|
||||
if (!providerName && !configId && !configName) {
|
||||
throw new BadRequestException(
|
||||
"AstrBot 通道至少需要 providerName、configId、configName 三者之一"
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!providerName) {
|
||||
throw new BadRequestException("当前通道必须提供 providerName");
|
||||
}
|
||||
}
|
||||
|
||||
private maskSecret(secret: string | null): string | null {
|
||||
if (!secret) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (secret.length <= 6) {
|
||||
return "*".repeat(secret.length);
|
||||
}
|
||||
|
||||
return `${secret.slice(0, 4)}***${secret.slice(-2)}`;
|
||||
}
|
||||
|
||||
private limitPreviewText(content: string): string {
|
||||
const normalizedContent = content.replace(/\s+/g, " ").trim();
|
||||
if (normalizedContent.length <= 60) {
|
||||
return normalizedContent;
|
||||
}
|
||||
|
||||
return `${normalizedContent.slice(0, 60)}...`;
|
||||
}
|
||||
|
||||
private getPriorityWeight(priority: TaskPriority): number {
|
||||
switch (priority) {
|
||||
case TaskPriority.URGENT:
|
||||
return 4;
|
||||
case TaskPriority.HIGH:
|
||||
return 3;
|
||||
case TaskPriority.MEDIUM:
|
||||
return 2;
|
||||
case TaskPriority.LOW:
|
||||
return 1;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
private getPriorityLabel(priority: TaskPriority): string {
|
||||
switch (priority) {
|
||||
case TaskPriority.URGENT:
|
||||
return "紧急";
|
||||
case TaskPriority.HIGH:
|
||||
return "高";
|
||||
case TaskPriority.MEDIUM:
|
||||
return "中";
|
||||
case TaskPriority.LOW:
|
||||
return "低";
|
||||
default:
|
||||
return String(priority);
|
||||
}
|
||||
}
|
||||
|
||||
private getStatusLabel(status: TaskStatus): string {
|
||||
switch (status) {
|
||||
case TaskStatus.TODO:
|
||||
return "待开始";
|
||||
case TaskStatus.IN_PROGRESS:
|
||||
return "进行中";
|
||||
case TaskStatus.DONE:
|
||||
return "已完成";
|
||||
case TaskStatus.ARCHIVED:
|
||||
return "已归档";
|
||||
default:
|
||||
return String(status);
|
||||
}
|
||||
}
|
||||
|
||||
private getContentSnippet(contentText: string | null): string | null {
|
||||
if (!contentText) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const normalizedContent = contentText.replace(/\s+/g, " ").trim();
|
||||
if (normalizedContent.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (normalizedContent.length <= this.maxContextContentLength) {
|
||||
return normalizedContent;
|
||||
}
|
||||
|
||||
return `${normalizedContent.slice(0, this.maxContextContentLength)}...`;
|
||||
}
|
||||
|
||||
private async recordUsageLog(input: {
|
||||
userId: string;
|
||||
channel: AiChannel;
|
||||
providerName: string | null;
|
||||
model: string | null;
|
||||
usage: AiUsageMetrics | null;
|
||||
latencyMs: number;
|
||||
success: boolean;
|
||||
errorCode: string | null;
|
||||
}): Promise<void> {
|
||||
try {
|
||||
await this.prismaService.aiUsageLog.create({
|
||||
data: {
|
||||
userId: input.userId,
|
||||
channel: input.channel,
|
||||
providerName:
|
||||
input.providerName === null
|
||||
? null
|
||||
: this.dataEncryptionService.encryptString(input.providerName),
|
||||
model:
|
||||
input.model === null ? null : this.dataEncryptionService.encryptString(input.model),
|
||||
promptTokens: input.usage?.promptTokens ?? 0,
|
||||
completionTokens: input.usage?.completionTokens ?? 0,
|
||||
totalTokens: input.usage?.totalTokens ?? 0,
|
||||
latencyMs: input.latencyMs,
|
||||
success: input.success,
|
||||
errorCode: input.errorCode
|
||||
}
|
||||
});
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : "未知错误";
|
||||
this.logger.warn(`写入 AI 使用日志失败:${message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
import { AiChannel } from "../../generated/prisma/client";
|
||||
|
||||
export type AiResolvedRouteCandidate = {
|
||||
channel: AiChannel;
|
||||
source: "binding" | "public_pool";
|
||||
sourceId: string | null;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
configId: string | null;
|
||||
configName: string | null;
|
||||
endpoint: string | null;
|
||||
apiKey: string | null;
|
||||
};
|
||||
|
||||
export type AiChatInput = {
|
||||
userId: string;
|
||||
message: string;
|
||||
sessionId: string | null;
|
||||
};
|
||||
|
||||
export type AiChatResult = {
|
||||
channel: AiChannel;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
content: string;
|
||||
sessionId: string | null;
|
||||
usage: AiUsageMetrics | null;
|
||||
raw: unknown;
|
||||
};
|
||||
|
||||
export type AiUsageMetrics = {
|
||||
promptTokens: number;
|
||||
completionTokens: number;
|
||||
totalTokens: number;
|
||||
};
|
||||
|
||||
export type AiRouteAttempt = {
|
||||
channel: AiChannel;
|
||||
providerName: string | null;
|
||||
model: string | null;
|
||||
status: "skipped" | "failed" | "success";
|
||||
reasonCode: string | null;
|
||||
reasonMessage: string | null;
|
||||
};
|
||||
|
||||
export class AiRouteFailureError extends Error {
|
||||
constructor(
|
||||
public readonly channel: AiChannel,
|
||||
public readonly providerName: string,
|
||||
public readonly code: string,
|
||||
message: string
|
||||
) {
|
||||
super(message);
|
||||
this.name = "AiRouteFailureError";
|
||||
Object.setPrototypeOf(this, new.target.prototype);
|
||||
}
|
||||
}
|
||||
|
||||
export interface AiChannelExecutor {
|
||||
execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise<AiChatResult>;
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
import { Type } from "class-transformer";
|
||||
import {
|
||||
IsArray,
|
||||
IsEnum,
|
||||
IsInt,
|
||||
IsOptional,
|
||||
IsString,
|
||||
MinLength,
|
||||
ValidateNested
|
||||
} from "class-validator";
|
||||
import { AiChannel } from "../../../generated/prisma/client";
|
||||
import { TaskPriority, TaskStatus } from "../../../generated/prisma/client";
|
||||
|
||||
export class LocalTaskContextItemDto {
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
id!: string;
|
||||
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
title!: string;
|
||||
|
||||
@IsEnum(TaskPriority)
|
||||
priority!: TaskPriority;
|
||||
|
||||
@IsEnum(TaskStatus)
|
||||
status!: TaskStatus;
|
||||
|
||||
@IsOptional()
|
||||
@IsInt()
|
||||
ddlAt?: number | null;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
contentText?: string | null;
|
||||
|
||||
@IsInt()
|
||||
updatedAt!: number;
|
||||
}
|
||||
|
||||
export class AiChatDto {
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
message!: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
sessionId?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsEnum(AiChannel)
|
||||
channel?: AiChannel;
|
||||
|
||||
@IsOptional()
|
||||
@IsArray()
|
||||
@ValidateNested({ each: true })
|
||||
@Type(() => LocalTaskContextItemDto)
|
||||
localTasks?: LocalTaskContextItemDto[];
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
import { Transform, Type } from "class-transformer";
|
||||
import { IsBoolean, IsEnum, IsInt, IsOptional, Max, Min } from "class-validator";
|
||||
import { AiChannel } from "../../../generated/prisma/client";
|
||||
|
||||
function normalizeBoolean(value: unknown): boolean | undefined {
|
||||
if (typeof value === "boolean") {
|
||||
return value;
|
||||
}
|
||||
|
||||
if (typeof value !== "string") {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const normalized = value.trim().toLowerCase();
|
||||
if (normalized === "true" || normalized === "1") {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (normalized === "false" || normalized === "0") {
|
||||
return false;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export class ListAiUsageLogsQueryDto {
|
||||
@Type(() => Number)
|
||||
@IsOptional()
|
||||
@IsInt()
|
||||
@Min(1)
|
||||
page?: number;
|
||||
|
||||
@Type(() => Number)
|
||||
@IsOptional()
|
||||
@IsInt()
|
||||
@Min(1)
|
||||
@Max(100)
|
||||
pageSize?: number;
|
||||
|
||||
@IsOptional()
|
||||
@IsEnum(AiChannel)
|
||||
channel?: AiChannel;
|
||||
|
||||
@Transform(({ value }) => normalizeBoolean(value))
|
||||
@IsOptional()
|
||||
@IsBoolean()
|
||||
success?: boolean;
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
import { AiChannel } from "../../../generated/prisma/client";
|
||||
import { IsBoolean, IsEnum, IsOptional, IsString, IsUrl, MinLength } from "class-validator";
|
||||
|
||||
export class UpsertAiProviderBindingDto {
|
||||
@IsEnum(AiChannel)
|
||||
channel!: AiChannel;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
providerName?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
model?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
configId?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
configName?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsUrl(
|
||||
{
|
||||
require_tld: false
|
||||
},
|
||||
{
|
||||
message: "endpoint \u5fc5\u987b\u662f\u5408\u6cd5\u7684 URL"
|
||||
}
|
||||
)
|
||||
endpoint?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
@MinLength(1)
|
||||
apiKey?: string;
|
||||
|
||||
@IsOptional()
|
||||
@IsBoolean()
|
||||
isEnabled?: boolean;
|
||||
}
|
||||
@@ -0,0 +1,284 @@
|
||||
import { Injectable } from "@nestjs/common";
|
||||
import {
|
||||
AiChannelExecutor,
|
||||
AiChatInput,
|
||||
AiChatResult,
|
||||
AiResolvedRouteCandidate,
|
||||
AiRouteFailureError
|
||||
} from "../ai.types";
|
||||
|
||||
@Injectable()
|
||||
export class AstrbotProvider implements AiChannelExecutor {
|
||||
async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise<AiChatResult> {
|
||||
const routeLabel =
|
||||
candidate.providerName || candidate.configName || candidate.configId || "astrbot";
|
||||
|
||||
if (!candidate.endpoint) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
routeLabel,
|
||||
"MISSING_ENDPOINT",
|
||||
"缺少 AstrBot 服务地址配置"
|
||||
);
|
||||
}
|
||||
|
||||
if (!candidate.apiKey) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
routeLabel,
|
||||
"MISSING_API_KEY",
|
||||
"缺少 AstrBot API Key 配置"
|
||||
);
|
||||
}
|
||||
|
||||
const requestUrl = this.buildRequestUrl(candidate.endpoint);
|
||||
|
||||
let response: Response;
|
||||
try {
|
||||
response = await fetch(requestUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${candidate.apiKey}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
username: input.userId,
|
||||
session_id: input.sessionId ?? undefined,
|
||||
message: input.message,
|
||||
enable_streaming: false,
|
||||
selected_model: candidate.model ?? undefined
|
||||
}),
|
||||
signal: AbortSignal.timeout(30000)
|
||||
});
|
||||
} catch (error) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
routeLabel,
|
||||
"UPSTREAM_UNREACHABLE",
|
||||
this.toErrorMessage(error, "AstrBot 服务请求失败")
|
||||
);
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
const rawText = await response.text();
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
routeLabel,
|
||||
`UPSTREAM_HTTP_${response.status}`,
|
||||
this.extractHttpErrorMessage(rawText, response.status)
|
||||
);
|
||||
}
|
||||
|
||||
const events = await this.readSseEvents(response);
|
||||
let content = "";
|
||||
let sessionId = input.sessionId;
|
||||
|
||||
for (const event of events) {
|
||||
const type = this.readString(event["type"]);
|
||||
if (type === "session_id") {
|
||||
sessionId = this.readString(event["session_id"]) ?? sessionId;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (type === "error") {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
routeLabel,
|
||||
this.readString(event["code"]) ?? "ASTRBOT_ERROR",
|
||||
this.readString(event["data"]) ?? "AstrBot 返回错误"
|
||||
);
|
||||
}
|
||||
|
||||
if (type !== "plain") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const chainType = this.readString(event["chain_type"]);
|
||||
if (
|
||||
chainType === "reasoning" ||
|
||||
chainType === "tool_call" ||
|
||||
chainType === "tool_call_result"
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const data = this.readString(event["data"]);
|
||||
if (!data) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (event["streaming"] === true) {
|
||||
content += data;
|
||||
continue;
|
||||
}
|
||||
|
||||
content = data;
|
||||
}
|
||||
|
||||
if (!content.trim()) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
routeLabel,
|
||||
"EMPTY_RESPONSE",
|
||||
"AstrBot 没有返回有效内容"
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
channel: candidate.channel,
|
||||
providerName: routeLabel,
|
||||
model: candidate.model,
|
||||
content,
|
||||
sessionId,
|
||||
usage: this.extractUsage(events),
|
||||
raw: events
|
||||
};
|
||||
}
|
||||
|
||||
private buildRequestUrl(endpoint: string): string {
|
||||
const normalizedEndpoint = endpoint.replace(/\/+$/, "");
|
||||
if (normalizedEndpoint.endsWith("/api/v1/chat")) {
|
||||
return normalizedEndpoint;
|
||||
}
|
||||
if (normalizedEndpoint.endsWith("/api/v1")) {
|
||||
return `${normalizedEndpoint}/chat`;
|
||||
}
|
||||
if (normalizedEndpoint.endsWith("/api")) {
|
||||
return `${normalizedEndpoint}/v1/chat`;
|
||||
}
|
||||
return `${normalizedEndpoint}/api/v1/chat`;
|
||||
}
|
||||
|
||||
private parseSseEvents(rawText: string): Array<Record<string, unknown>> {
|
||||
return rawText
|
||||
.split(/\r?\n\r?\n/)
|
||||
.map((block) =>
|
||||
block
|
||||
.split(/\r?\n/)
|
||||
.filter((line) => line.startsWith("data:"))
|
||||
.map((line) => line.slice(5).trim())
|
||||
.join("\n")
|
||||
)
|
||||
.filter((payload) => payload.length > 0)
|
||||
.map((payload) => {
|
||||
try {
|
||||
return JSON.parse(payload) as Record<string, unknown>;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter((item): item is Record<string, unknown> => item !== null);
|
||||
}
|
||||
|
||||
private async readSseEvents(response: Response): Promise<Array<Record<string, unknown>>> {
|
||||
if (!response.body) {
|
||||
return this.parseSseEvents(await response.text());
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
const events: Array<Record<string, unknown>> = [];
|
||||
let buffer = "";
|
||||
let reachedEndEvent = false;
|
||||
|
||||
try {
|
||||
while (!reachedEndEvent) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const segments = buffer.split(/\r?\n\r?\n/);
|
||||
buffer = segments.pop() ?? "";
|
||||
|
||||
for (const segment of segments) {
|
||||
const parsedEvents = this.parseSseEvents(segment);
|
||||
for (const event of parsedEvents) {
|
||||
events.push(event);
|
||||
if (this.readString(event["type"]) === "end") {
|
||||
reachedEndEvent = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (reachedEndEvent) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const tail = `${buffer}${decoder.decode()}`;
|
||||
if (tail.trim().length > 0) {
|
||||
events.push(...this.parseSseEvents(tail));
|
||||
}
|
||||
} finally {
|
||||
await reader.cancel();
|
||||
}
|
||||
|
||||
return events;
|
||||
}
|
||||
|
||||
private extractHttpErrorMessage(rawText: string, statusCode: number): string {
|
||||
try {
|
||||
const payload = JSON.parse(rawText) as Record<string, unknown>;
|
||||
if (typeof payload["message"] === "string") {
|
||||
return payload["message"];
|
||||
}
|
||||
if (typeof payload["data"] === "string") {
|
||||
return payload["data"];
|
||||
}
|
||||
} catch {
|
||||
return `AstrBot 服务调用失败,状态码 ${statusCode}`;
|
||||
}
|
||||
|
||||
return `AstrBot 服务调用失败,状态码 ${statusCode}`;
|
||||
}
|
||||
|
||||
private readString(value: unknown): string | null {
|
||||
return typeof value === "string" ? value : null;
|
||||
}
|
||||
|
||||
private toErrorMessage(error: unknown, fallback: string): string {
|
||||
if (error instanceof Error && error.message) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
return fallback;
|
||||
}
|
||||
|
||||
private extractUsage(events: Array<Record<string, unknown>>): AiChatResult["usage"] {
|
||||
for (const event of events) {
|
||||
if (this.readString(event["type"]) !== "agent_stats") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const data = this.asRecord(event["data"]);
|
||||
const tokenUsage = this.asRecord(data?.["token_usage"]);
|
||||
if (!tokenUsage) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const promptTokens =
|
||||
(this.readNumber(tokenUsage["input_other"]) ?? 0) +
|
||||
(this.readNumber(tokenUsage["input_cached"]) ?? 0);
|
||||
const completionTokens = this.readNumber(tokenUsage["output"]) ?? 0;
|
||||
|
||||
return {
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
totalTokens: promptTokens + completionTokens
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private asRecord(value: unknown): Record<string, unknown> | null {
|
||||
return typeof value === "object" && value !== null ? (value as Record<string, unknown>) : null;
|
||||
}
|
||||
|
||||
private readNumber(value: unknown): number | null {
|
||||
return typeof value === "number" && Number.isFinite(value) ? value : null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
import { Injectable } from "@nestjs/common";
|
||||
import {
|
||||
AiChannelExecutor,
|
||||
AiChatInput,
|
||||
AiChatResult,
|
||||
AiResolvedRouteCandidate,
|
||||
AiRouteFailureError
|
||||
} from "../ai.types";
|
||||
|
||||
@Injectable()
|
||||
export class OpenAiCompatibleProvider implements AiChannelExecutor {
|
||||
async execute(candidate: AiResolvedRouteCandidate, input: AiChatInput): Promise<AiChatResult> {
|
||||
if (!candidate.endpoint) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
"MISSING_ENDPOINT",
|
||||
"缺少 AI 服务地址配置"
|
||||
);
|
||||
}
|
||||
|
||||
if (!candidate.apiKey) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
"MISSING_API_KEY",
|
||||
"缺少 AI 服务密钥配置"
|
||||
);
|
||||
}
|
||||
|
||||
if (!candidate.model) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
"MISSING_MODEL",
|
||||
"缺少 AI 模型配置"
|
||||
);
|
||||
}
|
||||
|
||||
const requestUrl = this.buildRequestUrl(candidate.endpoint);
|
||||
|
||||
let response: Response;
|
||||
try {
|
||||
response = await fetch(requestUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${candidate.apiKey}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: candidate.model,
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: input.message
|
||||
}
|
||||
],
|
||||
stream: false
|
||||
}),
|
||||
signal: AbortSignal.timeout(30000)
|
||||
});
|
||||
} catch (error) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
"UPSTREAM_UNREACHABLE",
|
||||
this.toErrorMessage(error, "AI 服务请求失败")
|
||||
);
|
||||
}
|
||||
|
||||
let payload: unknown;
|
||||
try {
|
||||
payload = await response.json();
|
||||
} catch (error) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
"INVALID_RESPONSE",
|
||||
this.toErrorMessage(error, "AI 服务返回了无法解析的数据")
|
||||
);
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
`UPSTREAM_HTTP_${response.status}`,
|
||||
this.extractErrorMessage(payload, `AI 服务调用失败,状态码 ${response.status}`)
|
||||
);
|
||||
}
|
||||
|
||||
const content = this.extractAssistantText(payload);
|
||||
if (!content.trim()) {
|
||||
throw new AiRouteFailureError(
|
||||
candidate.channel,
|
||||
candidate.providerName,
|
||||
"EMPTY_RESPONSE",
|
||||
"AI 服务没有返回有效内容"
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
channel: candidate.channel,
|
||||
providerName: candidate.providerName,
|
||||
model: this.extractModel(payload) ?? candidate.model,
|
||||
content,
|
||||
sessionId: input.sessionId,
|
||||
usage: this.extractUsage(payload),
|
||||
raw: payload
|
||||
};
|
||||
}
|
||||
|
||||
private buildRequestUrl(endpoint: string): string {
|
||||
const normalizedEndpoint = endpoint.replace(/\/+$/, "");
|
||||
if (normalizedEndpoint.endsWith("/chat/completions")) {
|
||||
return normalizedEndpoint;
|
||||
}
|
||||
if (normalizedEndpoint.endsWith("/v1")) {
|
||||
return `${normalizedEndpoint}/chat/completions`;
|
||||
}
|
||||
return `${normalizedEndpoint}/v1/chat/completions`;
|
||||
}
|
||||
|
||||
private extractAssistantText(payload: unknown): string {
|
||||
const chatCompletionText = this.extractChatCompletionText(payload);
|
||||
if (chatCompletionText) {
|
||||
return chatCompletionText;
|
||||
}
|
||||
|
||||
const responsesText = this.extractResponsesApiText(payload);
|
||||
if (responsesText) {
|
||||
return responsesText;
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
private extractChatCompletionText(payload: unknown): string {
|
||||
if (!this.isRecord(payload)) {
|
||||
return "";
|
||||
}
|
||||
|
||||
const choices = payload["choices"];
|
||||
if (!Array.isArray(choices) || choices.length === 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
const firstChoice = choices[0];
|
||||
if (!this.isRecord(firstChoice)) {
|
||||
return "";
|
||||
}
|
||||
|
||||
const message = firstChoice["message"];
|
||||
if (this.isRecord(message)) {
|
||||
const messageContent = this.extractMessageContent(message["content"]);
|
||||
if (messageContent) {
|
||||
return messageContent;
|
||||
}
|
||||
}
|
||||
|
||||
if (typeof firstChoice["text"] === "string") {
|
||||
return firstChoice["text"];
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
private extractResponsesApiText(payload: unknown): string {
|
||||
if (!this.isRecord(payload)) {
|
||||
return "";
|
||||
}
|
||||
|
||||
if (typeof payload["output_text"] === "string") {
|
||||
return payload["output_text"];
|
||||
}
|
||||
|
||||
const output = payload["output"];
|
||||
if (!Array.isArray(output)) {
|
||||
return "";
|
||||
}
|
||||
|
||||
return output
|
||||
.map((item) => {
|
||||
if (!this.isRecord(item)) {
|
||||
return "";
|
||||
}
|
||||
|
||||
if (typeof item["text"] === "string") {
|
||||
return item["text"];
|
||||
}
|
||||
|
||||
return this.extractMessageContent(item["content"]);
|
||||
})
|
||||
.filter((item) => item.length > 0)
|
||||
.join("\n")
|
||||
.trim();
|
||||
}
|
||||
|
||||
private extractMessageContent(content: unknown): string {
|
||||
if (typeof content === "string") {
|
||||
return content;
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) {
|
||||
return "";
|
||||
}
|
||||
|
||||
return content
|
||||
.map((item) => this.extractContentPartText(item))
|
||||
.filter((item) => item.length > 0)
|
||||
.join("\n")
|
||||
.trim();
|
||||
}
|
||||
|
||||
private extractContentPartText(item: unknown): string {
|
||||
if (!this.isRecord(item)) {
|
||||
return "";
|
||||
}
|
||||
|
||||
if (typeof item["text"] === "string") {
|
||||
return item["text"];
|
||||
}
|
||||
|
||||
if (this.isRecord(item["text"]) && typeof item["text"]["value"] === "string") {
|
||||
return item["text"]["value"];
|
||||
}
|
||||
|
||||
if (typeof item["content"] === "string") {
|
||||
return item["content"];
|
||||
}
|
||||
|
||||
if (this.isRecord(item["content"]) && typeof item["content"]["text"] === "string") {
|
||||
return item["content"]["text"];
|
||||
}
|
||||
|
||||
return "";
|
||||
}
|
||||
|
||||
private extractModel(payload: unknown): string | null {
|
||||
if (!this.isRecord(payload) || typeof payload["model"] !== "string") {
|
||||
return null;
|
||||
}
|
||||
|
||||
return payload["model"];
|
||||
}
|
||||
|
||||
private extractUsage(payload: unknown): AiChatResult["usage"] {
|
||||
if (!this.isRecord(payload)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const usage = payload["usage"];
|
||||
if (!this.isRecord(usage)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const promptTokens = this.readNumber(usage["prompt_tokens"]);
|
||||
const completionTokens = this.readNumber(usage["completion_tokens"]);
|
||||
const totalTokens = this.readNumber(usage["total_tokens"]);
|
||||
|
||||
if (promptTokens === null && completionTokens === null && totalTokens === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
promptTokens: promptTokens ?? 0,
|
||||
completionTokens: completionTokens ?? 0,
|
||||
totalTokens: totalTokens ?? (promptTokens ?? 0) + (completionTokens ?? 0)
|
||||
};
|
||||
}
|
||||
|
||||
private extractErrorMessage(payload: unknown, fallback: string): string {
|
||||
if (!this.isRecord(payload)) {
|
||||
return fallback;
|
||||
}
|
||||
|
||||
const error = payload["error"];
|
||||
if (!this.isRecord(error) || typeof error["message"] !== "string") {
|
||||
return fallback;
|
||||
}
|
||||
|
||||
return error["message"];
|
||||
}
|
||||
|
||||
private isRecord(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === "object" && value !== null;
|
||||
}
|
||||
|
||||
private toErrorMessage(error: unknown, fallback: string): string {
|
||||
if (error instanceof Error && error.message) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
return fallback;
|
||||
}
|
||||
|
||||
private readNumber(value: unknown): number | null {
|
||||
return typeof value === "number" && Number.isFinite(value) ? value : null;
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,11 @@
|
||||
import { Module } from "@nestjs/common";
|
||||
import { ConfigModule } from "@nestjs/config";
|
||||
import { resolve } from "node:path";
|
||||
import { AiModule } from "./ai/ai.module";
|
||||
import { AttachmentModule } from "./attachment/attachment.module";
|
||||
import { AuthModule } from "./auth/auth.module";
|
||||
import { PrismaModule } from "./prisma/prisma.module";
|
||||
import { SecurityModule } from "./security/security.module";
|
||||
import { SyncModule } from "./sync/sync.module";
|
||||
import { TaskModule } from "./task/task.module";
|
||||
|
||||
@@ -10,13 +13,15 @@ import { TaskModule } from "./task/task.module";
|
||||
imports: [
|
||||
ConfigModule.forRoot({
|
||||
isGlobal: true,
|
||||
envFilePath: ".env"
|
||||
envFilePath: [resolve(__dirname, "../.env"), ".env"]
|
||||
}),
|
||||
PrismaModule,
|
||||
SecurityModule,
|
||||
AuthModule,
|
||||
TaskModule,
|
||||
AttachmentModule,
|
||||
SyncModule
|
||||
SyncModule,
|
||||
AiModule
|
||||
]
|
||||
})
|
||||
export class AppModule {}
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import { randomUUID } from "node:crypto";
|
||||
import { Injectable, NotFoundException, PayloadTooLargeException } from "@nestjs/common";
|
||||
import {
|
||||
Injectable,
|
||||
InternalServerErrorException,
|
||||
NotFoundException,
|
||||
PayloadTooLargeException
|
||||
} from "@nestjs/common";
|
||||
import { ConfigService } from "@nestjs/config";
|
||||
import { PutObjectCommand, S3Client } from "@aws-sdk/client-s3";
|
||||
import { getSignedUrl } from "@aws-sdk/s3-request-presigner";
|
||||
import { AttachmentType } from "../../generated/prisma/client";
|
||||
import { PrismaService } from "../prisma/prisma.service";
|
||||
import { DataEncryptionService } from "../security/data-encryption.service";
|
||||
import { CompleteAttachmentDto } from "./dto/complete-attachment.dto";
|
||||
import { PresignAttachmentDto } from "./dto/presign-attachment.dto";
|
||||
|
||||
@@ -25,9 +31,7 @@ export type PresignAttachmentResponse = {
|
||||
usedBytes: string;
|
||||
remainingBytes: string;
|
||||
};
|
||||
headers: {
|
||||
"Content-Type": string;
|
||||
};
|
||||
headers: Record<string, string>;
|
||||
};
|
||||
|
||||
export type AttachmentResponse = {
|
||||
@@ -52,7 +56,8 @@ export class AttachmentService {
|
||||
|
||||
constructor(
|
||||
private readonly configService: ConfigService,
|
||||
private readonly prismaService: PrismaService
|
||||
private readonly prismaService: PrismaService,
|
||||
private readonly dataEncryptionService: DataEncryptionService
|
||||
) {}
|
||||
|
||||
async presignAttachment(
|
||||
@@ -67,15 +72,17 @@ export class AttachmentService {
|
||||
}
|
||||
|
||||
const bucket = this.getDefaultBucket();
|
||||
const objectKey = this.generateObjectKey(userId, body.fileName);
|
||||
const objectKey = this.generateObjectKey(body.fileName);
|
||||
const objectUrl = this.resolveObjectUrl(bucket, objectKey);
|
||||
const expiresInSeconds = this.getPresignExpiresInSeconds();
|
||||
const serverSideEncryption = this.getServerSideEncryptionMode();
|
||||
|
||||
const command = new PutObjectCommand({
|
||||
Bucket: bucket,
|
||||
Key: objectKey,
|
||||
ContentType: body.mimeType,
|
||||
ContentLength: body.fileSize
|
||||
ContentLength: body.fileSize,
|
||||
ServerSideEncryption: serverSideEncryption
|
||||
});
|
||||
|
||||
const uploadUrl = await getSignedUrl(this.getS3Client(), command, {
|
||||
@@ -94,9 +101,7 @@ export class AttachmentService {
|
||||
usedBytes: quotaInfo.usedBytes.toString(),
|
||||
remainingBytes: (quotaInfo.totalBytes - quotaInfo.usedBytes).toString()
|
||||
},
|
||||
headers: {
|
||||
"Content-Type": body.mimeType
|
||||
}
|
||||
headers: this.buildUploadHeaders(body.mimeType, serverSideEncryption)
|
||||
};
|
||||
}
|
||||
|
||||
@@ -139,14 +144,14 @@ export class AttachmentService {
|
||||
userId,
|
||||
taskId: body.taskId ?? null,
|
||||
type: body.type ?? this.resolveAttachmentType(body.mimeType),
|
||||
url: objectUrl,
|
||||
url: this.encryptRequiredString(objectUrl),
|
||||
mimeType: body.mimeType,
|
||||
fileName: body.fileName,
|
||||
fileName: this.encryptNullableString(body.fileName),
|
||||
fileSize: body.fileSize,
|
||||
width: body.width ?? null,
|
||||
height: body.height ?? null,
|
||||
durationMs: body.durationMs ?? null,
|
||||
checksum: body.checksum ?? null
|
||||
checksum: this.encryptNullableString(body.checksum)
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -155,14 +160,14 @@ export class AttachmentService {
|
||||
id: attachment.id,
|
||||
taskId: attachment.taskId,
|
||||
type: attachment.type,
|
||||
url: attachment.url,
|
||||
url: this.readDecryptedString(attachment.url) ?? objectUrl,
|
||||
mimeType: attachment.mimeType,
|
||||
fileName: attachment.fileName,
|
||||
fileName: this.readDecryptedString(attachment.fileName),
|
||||
fileSize: attachment.fileSize,
|
||||
width: attachment.width,
|
||||
height: attachment.height,
|
||||
durationMs: attachment.durationMs,
|
||||
checksum: attachment.checksum,
|
||||
checksum: this.readDecryptedString(attachment.checksum),
|
||||
createdAt: attachment.createdAt.toISOString(),
|
||||
updatedAt: attachment.updatedAt.toISOString()
|
||||
};
|
||||
@@ -204,10 +209,9 @@ export class AttachmentService {
|
||||
return Math.min(configValue, 604800);
|
||||
}
|
||||
|
||||
private generateObjectKey(userId: string, fileName: string): string {
|
||||
const safeFileName = fileName.replace(/[^\w.-]+/g, "_");
|
||||
private generateObjectKey(fileName: string): string {
|
||||
const datePrefix = new Date().toISOString().slice(0, 10);
|
||||
return `${userId}/${datePrefix}/${randomUUID()}-${safeFileName}`;
|
||||
return `attachments/${datePrefix}/${randomUUID()}${this.extractFileExtension(fileName)}`;
|
||||
}
|
||||
|
||||
private resolveObjectUrl(bucket: string, objectKey: string): string {
|
||||
@@ -232,6 +236,37 @@ export class AttachmentService {
|
||||
return AttachmentType.FILE;
|
||||
}
|
||||
|
||||
private buildUploadHeaders(
|
||||
mimeType: string,
|
||||
serverSideEncryption: "AES256" | undefined
|
||||
): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
"Content-Type": mimeType
|
||||
};
|
||||
|
||||
if (serverSideEncryption) {
|
||||
headers["x-amz-server-side-encryption"] = serverSideEncryption;
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
private getServerSideEncryptionMode(): "AES256" | undefined {
|
||||
const configValue =
|
||||
this.configService.get<string>("S3_SERVER_SIDE_ENCRYPTION")?.trim().toUpperCase() ?? "AES256";
|
||||
|
||||
if (configValue === "NONE" || configValue === "DISABLED") {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return "AES256";
|
||||
}
|
||||
|
||||
private extractFileExtension(fileName: string): string {
|
||||
const match = /\.[a-zA-Z0-9]{1,16}$/.exec(fileName);
|
||||
return match?.[0]?.toLowerCase() ?? "";
|
||||
}
|
||||
|
||||
private async ensureTaskOwnership(userId: string, taskId: string): Promise<void> {
|
||||
const task = await this.prismaService.task.findFirst({
|
||||
where: {
|
||||
@@ -279,4 +314,22 @@ export class AttachmentService {
|
||||
throw new PayloadTooLargeException("存储配额不足");
|
||||
}
|
||||
}
|
||||
|
||||
private encryptRequiredString(value: string): string {
|
||||
const encryptedValue = this.dataEncryptionService.encryptString(value);
|
||||
if (!encryptedValue) {
|
||||
throw new InternalServerErrorException("附件元数据加密失败");
|
||||
}
|
||||
|
||||
return encryptedValue;
|
||||
}
|
||||
|
||||
private encryptNullableString(value: string | null | undefined): string | null | undefined {
|
||||
return this.dataEncryptionService.encryptString(value);
|
||||
}
|
||||
|
||||
private readDecryptedString(value: string | null): string | null {
|
||||
const decryptedValue = this.dataEncryptionService.decryptString(value);
|
||||
return typeof decryptedValue === "string" ? decryptedValue : null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { randomUUID } from "node:crypto";
|
||||
import { authenticator } from "@otplib/preset-default";
|
||||
import { AuthMailService } from "./auth-mail.service";
|
||||
import { PrismaService } from "../prisma/prisma.service";
|
||||
import { DataEncryptionService } from "../security/data-encryption.service";
|
||||
|
||||
type EmailCodeEntry = {
|
||||
code: string;
|
||||
@@ -33,7 +34,8 @@ export class AuthService {
|
||||
private readonly configService: ConfigService,
|
||||
private readonly jwtService: JwtService,
|
||||
private readonly authMailService: AuthMailService,
|
||||
private readonly prismaService: PrismaService
|
||||
private readonly prismaService: PrismaService,
|
||||
private readonly dataEncryptionService: DataEncryptionService
|
||||
) {}
|
||||
|
||||
async sendEmailCode(email: string): Promise<{ success: boolean; expiresInSeconds: number }> {
|
||||
@@ -118,7 +120,10 @@ export class AuthService {
|
||||
}
|
||||
});
|
||||
|
||||
return this.issueTokens(entry.user);
|
||||
return this.issueTokens({
|
||||
id: entry.user.id,
|
||||
email: this.readRequiredEmail(entry.user.email)
|
||||
});
|
||||
}
|
||||
|
||||
async revokeRefreshToken(refreshToken: string): Promise<{ success: boolean }> {
|
||||
@@ -205,19 +210,27 @@ export class AuthService {
|
||||
}
|
||||
|
||||
private async getOrCreateUser(email: string): Promise<AuthUser> {
|
||||
return this.prismaService.user.upsert({
|
||||
const normalizedEmail = email.toLowerCase();
|
||||
const emailHash = this.dataEncryptionService.createLookupHash("user.email", normalizedEmail);
|
||||
const user = await this.prismaService.user.upsert({
|
||||
where: {
|
||||
email
|
||||
emailHash
|
||||
},
|
||||
update: {},
|
||||
create: {
|
||||
email
|
||||
email: this.encryptRequiredString(normalizedEmail),
|
||||
emailHash
|
||||
},
|
||||
select: {
|
||||
id: true,
|
||||
email: true
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
id: user.id,
|
||||
email: this.readRequiredEmail(user.email)
|
||||
};
|
||||
}
|
||||
|
||||
private generateCode(): string {
|
||||
@@ -254,4 +267,22 @@ export class AuthService {
|
||||
user
|
||||
};
|
||||
}
|
||||
|
||||
private encryptRequiredString(value: string): string {
|
||||
const encryptedValue = this.dataEncryptionService.encryptString(value);
|
||||
if (!encryptedValue) {
|
||||
throw new UnauthorizedException("用户敏感字段加密失败");
|
||||
}
|
||||
|
||||
return encryptedValue;
|
||||
}
|
||||
|
||||
private readRequiredEmail(value: string): string {
|
||||
const decryptedValue = this.dataEncryptionService.decryptString(value);
|
||||
if (typeof decryptedValue !== "string" || decryptedValue.length === 0) {
|
||||
throw new UnauthorizedException("用户邮箱解密失败");
|
||||
}
|
||||
|
||||
return decryptedValue;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
import { Injectable, InternalServerErrorException } from "@nestjs/common";
|
||||
import { ConfigService } from "@nestjs/config";
|
||||
import { Prisma } from "../../generated/prisma/client";
|
||||
import { createCipheriv, createDecipheriv, createHash, createHmac, randomBytes } from "node:crypto";
|
||||
|
||||
const ENCRYPTION_PREFIX = "encv1";
|
||||
const ENCRYPTION_ALGORITHM = "aes-256-gcm";
|
||||
const ENCRYPTION_IV_LENGTH = 12;
|
||||
|
||||
@Injectable()
|
||||
export class DataEncryptionService {
|
||||
constructor(private readonly configService: ConfigService) {}
|
||||
|
||||
isConfigured(): boolean {
|
||||
return Boolean(this.configService.get<string>("DATA_ENCRYPTION_SECRET"));
|
||||
}
|
||||
|
||||
isEncryptedString(value: string): boolean {
|
||||
return value.startsWith(`${ENCRYPTION_PREFIX}:`);
|
||||
}
|
||||
|
||||
encryptString(value: string | null | undefined): string | null | undefined {
|
||||
if (value === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (value === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const key = this.resolveKey();
|
||||
const iv = randomBytes(ENCRYPTION_IV_LENGTH);
|
||||
const cipher = createCipheriv(ENCRYPTION_ALGORITHM, key, iv);
|
||||
const encrypted = Buffer.concat([cipher.update(value, "utf8"), cipher.final()]);
|
||||
const authTag = cipher.getAuthTag();
|
||||
|
||||
return [
|
||||
ENCRYPTION_PREFIX,
|
||||
iv.toString("base64url"),
|
||||
authTag.toString("base64url"),
|
||||
encrypted.toString("base64url")
|
||||
].join(":");
|
||||
}
|
||||
|
||||
decryptString(value: string | null | undefined): string | null | undefined {
|
||||
if (value === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (value === null || !this.isEncryptedPayload(value)) {
|
||||
return value;
|
||||
}
|
||||
|
||||
const [prefix, ivText, authTagText, encryptedText] = value.split(":");
|
||||
if (prefix !== ENCRYPTION_PREFIX || !ivText || !authTagText || encryptedText === undefined) {
|
||||
throw new InternalServerErrorException("加密数据格式无效");
|
||||
}
|
||||
|
||||
try {
|
||||
const key = this.resolveKey();
|
||||
const decipher = createDecipheriv(
|
||||
ENCRYPTION_ALGORITHM,
|
||||
key,
|
||||
Buffer.from(ivText, "base64url")
|
||||
);
|
||||
decipher.setAuthTag(Buffer.from(authTagText, "base64url"));
|
||||
const decrypted = Buffer.concat([
|
||||
decipher.update(Buffer.from(encryptedText, "base64url")),
|
||||
decipher.final()
|
||||
]);
|
||||
|
||||
return decrypted.toString("utf8");
|
||||
} catch {
|
||||
throw new InternalServerErrorException("加密数据解密失败");
|
||||
}
|
||||
}
|
||||
|
||||
encryptJson(
|
||||
value: Prisma.InputJsonValue | null | undefined
|
||||
): Prisma.InputJsonValue | null | undefined {
|
||||
if (value === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (value === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return this.encryptString(JSON.stringify(value));
|
||||
}
|
||||
|
||||
decryptJson(value: Prisma.JsonValue | null): Prisma.JsonValue | null {
|
||||
if (value === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (typeof value !== "string" || !this.isEncryptedPayload(value)) {
|
||||
return value;
|
||||
}
|
||||
|
||||
const decrypted = this.decryptString(value);
|
||||
if (typeof decrypted !== "string") {
|
||||
throw new InternalServerErrorException("加密数据解密失败");
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.parse(decrypted) as Prisma.JsonValue;
|
||||
} catch {
|
||||
throw new InternalServerErrorException("加密 JSON 数据损坏");
|
||||
}
|
||||
}
|
||||
|
||||
decryptPayload(value: Prisma.JsonValue | null): string | null {
|
||||
if (value === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (typeof value === "string") {
|
||||
return this.decryptString(value) ?? null;
|
||||
}
|
||||
|
||||
return JSON.stringify(value);
|
||||
}
|
||||
|
||||
createLookupHash(scope: string, value: string): string {
|
||||
const normalizedScope = scope.trim().toLowerCase();
|
||||
if (!normalizedScope) {
|
||||
throw new InternalServerErrorException("缺少盲索引作用域");
|
||||
}
|
||||
|
||||
const secret = this.configService.get<string>("DATA_ENCRYPTION_SECRET");
|
||||
if (!secret) {
|
||||
throw new InternalServerErrorException("服务端未配置 DATA_ENCRYPTION_SECRET,无法生成盲索引");
|
||||
}
|
||||
|
||||
return createHmac("sha256", `lookup:${normalizedScope}:${secret}`)
|
||||
.update(value, "utf8")
|
||||
.digest("hex");
|
||||
}
|
||||
|
||||
private isEncryptedPayload(value: string): boolean {
|
||||
return this.isEncryptedString(value);
|
||||
}
|
||||
|
||||
private resolveKey(): Buffer {
|
||||
const secret = this.configService.get<string>("DATA_ENCRYPTION_SECRET");
|
||||
if (!secret) {
|
||||
throw new InternalServerErrorException(
|
||||
"服务端未配置 DATA_ENCRYPTION_SECRET,无法写入加密数据"
|
||||
);
|
||||
}
|
||||
|
||||
return createHash("sha256").update(secret, "utf8").digest();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
import { Global, Module } from "@nestjs/common";
|
||||
import { DataEncryptionService } from "./data-encryption.service";
|
||||
|
||||
@Global()
|
||||
@Module({
|
||||
providers: [DataEncryptionService],
|
||||
exports: [DataEncryptionService]
|
||||
})
|
||||
export class SecurityModule {}
|
||||
@@ -1,6 +1,7 @@
|
||||
import { BadRequestException, Injectable } from "@nestjs/common";
|
||||
import { Prisma } from "../../generated/prisma/client";
|
||||
import { PrismaService } from "../prisma/prisma.service";
|
||||
import { DataEncryptionService } from "../security/data-encryption.service";
|
||||
import { SyncPullQueryDto } from "./dto/sync-pull.dto";
|
||||
import { SyncPushDto, SyncPushOperationDto } from "./dto/sync-push.dto";
|
||||
|
||||
@@ -60,7 +61,10 @@ export type SyncPullResponse = {
|
||||
|
||||
@Injectable()
|
||||
export class SyncService {
|
||||
constructor(private readonly prismaService: PrismaService) {}
|
||||
constructor(
|
||||
private readonly prismaService: PrismaService,
|
||||
private readonly dataEncryptionService: DataEncryptionService
|
||||
) {}
|
||||
|
||||
async pullOperations(userId: string, query: SyncPullQueryDto): Promise<SyncPullResponse> {
|
||||
const limit = query.limit ?? 100;
|
||||
@@ -137,7 +141,7 @@ export class SyncService {
|
||||
entityType: operation.entityType,
|
||||
entityId: operation.entityId,
|
||||
action: operation.action,
|
||||
payload: operation.payload,
|
||||
payload: this.dataEncryptionService.encryptString(operation.payload) ?? undefined,
|
||||
clientTs: new Date(operation.clientTs)
|
||||
},
|
||||
select: {
|
||||
@@ -252,15 +256,7 @@ export class SyncService {
|
||||
}
|
||||
|
||||
private serializePayload(payload: Prisma.JsonValue | null): string | null {
|
||||
if (payload === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (typeof payload === "string") {
|
||||
return payload;
|
||||
}
|
||||
|
||||
return JSON.stringify(payload);
|
||||
return this.dataEncryptionService.decryptPayload(payload);
|
||||
}
|
||||
|
||||
private parseCursor(cursor: string | undefined): SyncPullCursorState | null {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Injectable, NotFoundException } from "@nestjs/common";
|
||||
import { Injectable, InternalServerErrorException, NotFoundException } from "@nestjs/common";
|
||||
import { Prisma, TaskPriority, TaskStatus } from "../../generated/prisma/client";
|
||||
import { PrismaService } from "../prisma/prisma.service";
|
||||
import { DataEncryptionService } from "../security/data-encryption.service";
|
||||
import { CreateTaskDto } from "./dto/create-task.dto";
|
||||
import { ListTasksQueryDto, TaskSortBy, TaskSortOrder } from "./dto/list-tasks-query.dto";
|
||||
import { UpdateTaskDto } from "./dto/update-task.dto";
|
||||
@@ -43,16 +44,48 @@ export type ListTasksResponse = {
|
||||
|
||||
@Injectable()
|
||||
export class TaskService {
|
||||
constructor(private readonly prismaService: PrismaService) {}
|
||||
constructor(
|
||||
private readonly prismaService: PrismaService,
|
||||
private readonly dataEncryptionService: DataEncryptionService
|
||||
) {}
|
||||
|
||||
async listTasks(userId: string, query: ListTasksQueryDto): Promise<ListTasksResponse> {
|
||||
const page = query.page ?? 1;
|
||||
const pageSize = query.pageSize ?? 20;
|
||||
const skip = (page - 1) * pageSize;
|
||||
const keyword = query.keyword?.trim() ?? "";
|
||||
|
||||
const where = this.buildWhereInput(userId, query);
|
||||
const where = this.buildWhereInput(userId, query, keyword.length === 0);
|
||||
const orderBy = this.buildOrderByInput(query);
|
||||
|
||||
if (keyword.length > 0) {
|
||||
const items = await this.prismaService.task.findMany({
|
||||
where,
|
||||
orderBy,
|
||||
include: {
|
||||
taskTags: {
|
||||
include: {
|
||||
tag: {
|
||||
select: {
|
||||
name: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
const serializedItems = items.map((item: TaskEntity) => this.serializeTask(item));
|
||||
const filteredItems = serializedItems.filter((item) => this.matchesKeyword(item, keyword));
|
||||
|
||||
return {
|
||||
items: filteredItems.slice(skip, skip + pageSize),
|
||||
page,
|
||||
pageSize,
|
||||
total: filteredItems.length
|
||||
};
|
||||
}
|
||||
|
||||
const [items, total] = await Promise.all([
|
||||
this.prismaService.task.findMany({
|
||||
where,
|
||||
@@ -112,15 +145,18 @@ export class TaskService {
|
||||
const tagNames = this.normalizeTagNames(body.tagNames);
|
||||
const nextStatus = body.status ?? TaskStatus.TODO;
|
||||
const contentJson =
|
||||
body.contentJson !== undefined ? (body.contentJson as Prisma.InputJsonValue) : undefined;
|
||||
body.contentJson !== undefined
|
||||
? ((this.dataEncryptionService.encryptJson(body.contentJson as Prisma.InputJsonValue) ??
|
||||
Prisma.JsonNull) as Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput)
|
||||
: undefined;
|
||||
|
||||
const task = await this.prismaService.$transaction(async (tx) => {
|
||||
const createdTask = await tx.task.create({
|
||||
data: {
|
||||
userId,
|
||||
title: body.title,
|
||||
title: this.encryptRequiredString(body.title),
|
||||
contentJson,
|
||||
contentText: body.contentText ?? null,
|
||||
contentText: this.encryptNullableString(body.contentText),
|
||||
priority: body.priority ?? TaskPriority.MEDIUM,
|
||||
status: nextStatus,
|
||||
ddl: body.ddl ? new Date(body.ddl) : null,
|
||||
@@ -172,13 +208,15 @@ export class TaskService {
|
||||
};
|
||||
|
||||
if (body.title !== undefined) {
|
||||
data.title = body.title;
|
||||
data.title = this.encryptRequiredString(body.title);
|
||||
}
|
||||
if (body.contentJson !== undefined) {
|
||||
data.contentJson = body.contentJson as Prisma.InputJsonValue;
|
||||
data.contentJson = (this.dataEncryptionService.encryptJson(
|
||||
body.contentJson as Prisma.InputJsonValue
|
||||
) ?? Prisma.JsonNull) as Prisma.InputJsonValue | Prisma.NullableJsonNullValueInput;
|
||||
}
|
||||
if (body.contentText !== undefined) {
|
||||
data.contentText = body.contentText;
|
||||
data.contentText = this.encryptNullableString(body.contentText);
|
||||
}
|
||||
if (body.priority !== undefined) {
|
||||
data.priority = body.priority;
|
||||
@@ -242,7 +280,11 @@ export class TaskService {
|
||||
return { success: true };
|
||||
}
|
||||
|
||||
private buildWhereInput(userId: string, query: ListTasksQueryDto): Prisma.TaskWhereInput {
|
||||
private buildWhereInput(
|
||||
userId: string,
|
||||
query: ListTasksQueryDto,
|
||||
includeKeyword: boolean
|
||||
): Prisma.TaskWhereInput {
|
||||
const where: Prisma.TaskWhereInput = {
|
||||
userId
|
||||
};
|
||||
@@ -267,7 +309,7 @@ export class TaskService {
|
||||
};
|
||||
}
|
||||
|
||||
if (query.keyword !== undefined && query.keyword.length > 0) {
|
||||
if (includeKeyword && query.keyword !== undefined && query.keyword.length > 0) {
|
||||
where.OR = [
|
||||
{
|
||||
title: {
|
||||
@@ -374,9 +416,9 @@ export class TaskService {
|
||||
private serializeTask(task: TaskEntity): TaskResponse {
|
||||
return {
|
||||
id: task.id,
|
||||
title: task.title,
|
||||
contentJson: task.contentJson,
|
||||
contentText: task.contentText,
|
||||
title: this.readDecryptedString(task.title) ?? "未命名任务",
|
||||
contentJson: this.dataEncryptionService.decryptJson(task.contentJson),
|
||||
contentText: this.readDecryptedString(task.contentText),
|
||||
priority: task.priority,
|
||||
status: task.status,
|
||||
ddl: task.ddl?.toISOString() ?? null,
|
||||
@@ -387,4 +429,30 @@ export class TaskService {
|
||||
updatedAt: task.updatedAt.toISOString()
|
||||
};
|
||||
}
|
||||
|
||||
private encryptRequiredString(value: string): string {
|
||||
const encryptedValue = this.dataEncryptionService.encryptString(value);
|
||||
if (!encryptedValue) {
|
||||
throw new InternalServerErrorException("任务字段加密失败");
|
||||
}
|
||||
|
||||
return encryptedValue;
|
||||
}
|
||||
|
||||
private encryptNullableString(value: string | null | undefined): string | null | undefined {
|
||||
return this.dataEncryptionService.encryptString(value);
|
||||
}
|
||||
|
||||
private readDecryptedString(value: string | null): string | null {
|
||||
const decryptedValue = this.dataEncryptionService.decryptString(value);
|
||||
return typeof decryptedValue === "string" ? decryptedValue : null;
|
||||
}
|
||||
|
||||
private matchesKeyword(task: TaskResponse, keyword: string): boolean {
|
||||
const lowerKeyword = keyword.toLocaleLowerCase();
|
||||
return (
|
||||
task.title.toLocaleLowerCase().includes(lowerKeyword) ||
|
||||
task.contentText?.toLocaleLowerCase().includes(lowerKeyword) === true
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,73 @@
|
||||
import { AiChannel } from "../generated/prisma/client";
|
||||
import { AstrbotProvider } from "../src/ai/providers/astrbot.provider";
|
||||
|
||||
describe("AstrbotProvider", () => {
|
||||
const originalFetch = global.fetch;
|
||||
|
||||
afterEach(() => {
|
||||
global.fetch = originalFetch;
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("should not forward binding label fields as astrbot selection parameters", async () => {
|
||||
const provider = new AstrbotProvider();
|
||||
const fetchMock = jest.fn(async (_input: unknown, init?: RequestInit) => {
|
||||
expect(init?.method).toBe("POST");
|
||||
const payload = JSON.parse(String(init?.body ?? "{}")) as Record<string, unknown>;
|
||||
|
||||
expect(payload).toMatchObject({
|
||||
username: "user_1",
|
||||
session_id: "session_1",
|
||||
message: "你好",
|
||||
enable_streaming: false,
|
||||
selected_model: "deepseek-chat"
|
||||
});
|
||||
expect(payload).not.toHaveProperty("selected_provider");
|
||||
expect(payload).not.toHaveProperty("config_id");
|
||||
expect(payload).not.toHaveProperty("config_name");
|
||||
|
||||
return new Response(
|
||||
[
|
||||
'data: {"type":"session_id","session_id":"session_1"}',
|
||||
"",
|
||||
'data: {"type":"plain","data":"收到","streaming":false,"chain_type":null}',
|
||||
"",
|
||||
'data: {"type":"end","data":"","streaming":false}',
|
||||
""
|
||||
].join("\n"),
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"content-type": "text/event-stream"
|
||||
}
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
global.fetch = fetchMock as typeof global.fetch;
|
||||
|
||||
const result = await provider.execute(
|
||||
{
|
||||
channel: AiChannel.ASTRBOT,
|
||||
source: "binding",
|
||||
sourceId: "binding_1",
|
||||
providerName: "astrbot-main",
|
||||
model: "deepseek-chat",
|
||||
configId: "default",
|
||||
configName: "默认配置",
|
||||
endpoint: "http://127.0.0.1:6185",
|
||||
apiKey: "abk_secret"
|
||||
},
|
||||
{
|
||||
userId: "user_1",
|
||||
message: "你好",
|
||||
sessionId: "session_1"
|
||||
}
|
||||
);
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledTimes(1);
|
||||
expect(result.content).toBe("收到");
|
||||
expect(result.sessionId).toBe("session_1");
|
||||
expect(result.providerName).toBe("astrbot-main");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,355 @@
|
||||
import { UnauthorizedException } from "@nestjs/common";
|
||||
import { ConfigService } from "@nestjs/config";
|
||||
import { JwtService } from "@nestjs/jwt";
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { AuthMailService } from "../src/auth/auth-mail.service";
|
||||
import { AuthService } from "../src/auth/auth.service";
|
||||
import { PrismaService } from "../src/prisma/prisma.service";
|
||||
import { DataEncryptionService } from "../src/security/data-encryption.service";
|
||||
|
||||
type UserRecord = {
|
||||
id: string;
|
||||
email: string;
|
||||
emailHash: string;
|
||||
nickname: string | null;
|
||||
avatarUrl: string | null;
|
||||
};
|
||||
|
||||
type RefreshTokenRecord = {
|
||||
id: string;
|
||||
userId: string;
|
||||
tokenHash: string;
|
||||
expiresAt: Date;
|
||||
revokedAt: Date | null;
|
||||
createdAt: Date;
|
||||
};
|
||||
|
||||
type UserSecurityRecord = {
|
||||
userId: string;
|
||||
twoFactorEnabled: boolean;
|
||||
twoFactorSecret: string | null;
|
||||
};
|
||||
|
||||
class InMemoryAuthPrismaService {
|
||||
private userIdSequence = 1;
|
||||
private refreshTokenIdSequence = 1;
|
||||
private users: UserRecord[] = [];
|
||||
private refreshTokens: RefreshTokenRecord[] = [];
|
||||
private userSecurities: UserSecurityRecord[] = [];
|
||||
|
||||
readonly user = {
|
||||
upsert: async (args: {
|
||||
where: {
|
||||
emailHash: string;
|
||||
};
|
||||
update: Record<string, never>;
|
||||
create: {
|
||||
email: string;
|
||||
emailHash: string;
|
||||
};
|
||||
select: {
|
||||
id: true;
|
||||
email: true;
|
||||
};
|
||||
}) => {
|
||||
const existingUser = this.users.find((user) => user.emailHash === args.where.emailHash);
|
||||
if (existingUser) {
|
||||
return {
|
||||
id: existingUser.id,
|
||||
email: existingUser.email
|
||||
};
|
||||
}
|
||||
|
||||
const createdUser: UserRecord = {
|
||||
id: `user_${this.userIdSequence++}`,
|
||||
email: args.create.email,
|
||||
emailHash: args.create.emailHash,
|
||||
nickname: null,
|
||||
avatarUrl: null
|
||||
};
|
||||
this.users.push(createdUser);
|
||||
|
||||
return {
|
||||
id: createdUser.id,
|
||||
email: createdUser.email
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
readonly refreshToken = {
|
||||
create: async (args: {
|
||||
data: {
|
||||
userId: string;
|
||||
tokenHash: string;
|
||||
expiresAt: Date;
|
||||
};
|
||||
}) => {
|
||||
const refreshToken: RefreshTokenRecord = {
|
||||
id: `refresh_${this.refreshTokenIdSequence++}`,
|
||||
userId: args.data.userId,
|
||||
tokenHash: args.data.tokenHash,
|
||||
expiresAt: args.data.expiresAt,
|
||||
revokedAt: null,
|
||||
createdAt: new Date()
|
||||
};
|
||||
this.refreshTokens.push(refreshToken);
|
||||
return refreshToken;
|
||||
},
|
||||
|
||||
findUnique: async (args: {
|
||||
where: {
|
||||
tokenHash: string;
|
||||
};
|
||||
include: {
|
||||
user: {
|
||||
select: {
|
||||
id: true;
|
||||
email: true;
|
||||
};
|
||||
};
|
||||
};
|
||||
}) => {
|
||||
const refreshToken = this.refreshTokens.find(
|
||||
(item) => item.tokenHash === args.where.tokenHash
|
||||
);
|
||||
if (!refreshToken) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const user = this.users.find((item) => item.id === refreshToken.userId);
|
||||
if (!user) {
|
||||
throw new Error("user not found");
|
||||
}
|
||||
|
||||
return {
|
||||
...refreshToken,
|
||||
user: {
|
||||
id: user.id,
|
||||
email: user.email
|
||||
}
|
||||
};
|
||||
},
|
||||
|
||||
update: async (args: {
|
||||
where: {
|
||||
id: string;
|
||||
};
|
||||
data: {
|
||||
revokedAt: Date;
|
||||
};
|
||||
}) => {
|
||||
const refreshToken = this.refreshTokens.find((item) => item.id === args.where.id);
|
||||
if (!refreshToken) {
|
||||
throw new Error("refresh token not found");
|
||||
}
|
||||
|
||||
refreshToken.revokedAt = args.data.revokedAt;
|
||||
return refreshToken;
|
||||
},
|
||||
|
||||
updateMany: async (args: {
|
||||
where: {
|
||||
tokenHash: string;
|
||||
revokedAt: null;
|
||||
};
|
||||
data: {
|
||||
revokedAt: Date;
|
||||
};
|
||||
}) => {
|
||||
let count = 0;
|
||||
for (const refreshToken of this.refreshTokens) {
|
||||
if (refreshToken.tokenHash !== args.where.tokenHash || refreshToken.revokedAt !== null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
refreshToken.revokedAt = args.data.revokedAt;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
return { count };
|
||||
}
|
||||
};
|
||||
|
||||
readonly userSecurity = {
|
||||
upsert: async (args: {
|
||||
where: {
|
||||
userId: string;
|
||||
};
|
||||
update: {
|
||||
twoFactorSecret: string;
|
||||
twoFactorEnabled: boolean;
|
||||
};
|
||||
create: {
|
||||
userId: string;
|
||||
twoFactorSecret: string;
|
||||
twoFactorEnabled: boolean;
|
||||
};
|
||||
}) => {
|
||||
const existingSecurity = this.userSecurities.find(
|
||||
(item) => item.userId === args.where.userId
|
||||
);
|
||||
if (existingSecurity) {
|
||||
existingSecurity.twoFactorSecret = args.update.twoFactorSecret;
|
||||
existingSecurity.twoFactorEnabled = args.update.twoFactorEnabled;
|
||||
return existingSecurity;
|
||||
}
|
||||
|
||||
const createdSecurity: UserSecurityRecord = {
|
||||
userId: args.create.userId,
|
||||
twoFactorSecret: args.create.twoFactorSecret,
|
||||
twoFactorEnabled: args.create.twoFactorEnabled
|
||||
};
|
||||
this.userSecurities.push(createdSecurity);
|
||||
return createdSecurity;
|
||||
},
|
||||
|
||||
findUnique: async (args: {
|
||||
where: {
|
||||
userId: string;
|
||||
};
|
||||
select: {
|
||||
twoFactorSecret: true;
|
||||
};
|
||||
}) => {
|
||||
const security = this.userSecurities.find((item) => item.userId === args.where.userId);
|
||||
if (!security) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
twoFactorSecret: security.twoFactorSecret
|
||||
};
|
||||
},
|
||||
|
||||
update: async (args: {
|
||||
where: {
|
||||
userId: string;
|
||||
};
|
||||
data: {
|
||||
twoFactorEnabled: boolean;
|
||||
};
|
||||
}) => {
|
||||
const security = this.userSecurities.find((item) => item.userId === args.where.userId);
|
||||
if (!security) {
|
||||
throw new Error("user security not found");
|
||||
}
|
||||
|
||||
security.twoFactorEnabled = args.data.twoFactorEnabled;
|
||||
return security;
|
||||
}
|
||||
};
|
||||
|
||||
getUsers(): UserRecord[] {
|
||||
return [...this.users];
|
||||
}
|
||||
}
|
||||
|
||||
class MockAuthMailService {
|
||||
readonly sentMessages: Array<{
|
||||
email: string;
|
||||
code: string;
|
||||
ttlSeconds: number;
|
||||
}> = [];
|
||||
|
||||
async sendLoginCode(email: string, code: string, ttlSeconds: number): Promise<void> {
|
||||
this.sentMessages.push({
|
||||
email,
|
||||
code,
|
||||
ttlSeconds
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
describe("AuthService", () => {
|
||||
let authService: AuthService;
|
||||
let prismaService: InMemoryAuthPrismaService;
|
||||
let authMailService: MockAuthMailService;
|
||||
|
||||
beforeEach(async () => {
|
||||
prismaService = new InMemoryAuthPrismaService();
|
||||
authMailService = new MockAuthMailService();
|
||||
|
||||
const moduleRef: TestingModule = await Test.createTestingModule({
|
||||
providers: [
|
||||
AuthService,
|
||||
DataEncryptionService,
|
||||
{
|
||||
provide: PrismaService,
|
||||
useValue: prismaService
|
||||
},
|
||||
{
|
||||
provide: AuthMailService,
|
||||
useValue: authMailService
|
||||
},
|
||||
{
|
||||
provide: JwtService,
|
||||
useValue: {
|
||||
signAsync: async (payload: Record<string, unknown>) =>
|
||||
`signed-${String(payload["sub"])}-${String(payload["email"])}`
|
||||
}
|
||||
},
|
||||
{
|
||||
provide: ConfigService,
|
||||
useValue: {
|
||||
get: (key: string) => {
|
||||
switch (key) {
|
||||
case "AUTH_EMAIL_CODE_TTL_SECONDS":
|
||||
return "300";
|
||||
case "AUTH_ACCESS_EXPIRES_IN_SECONDS":
|
||||
return "900";
|
||||
case "AUTH_REFRESH_EXPIRES_IN_SECONDS":
|
||||
return "2592000";
|
||||
case "AUTH_TOTP_ISSUER":
|
||||
return "TodoList";
|
||||
case "DATA_ENCRYPTION_SECRET":
|
||||
return "test-data-encryption-secret";
|
||||
default:
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}).compile();
|
||||
|
||||
authService = moduleRef.get(AuthService);
|
||||
});
|
||||
|
||||
it("should encrypt user email in database while keeping login flow available", async () => {
|
||||
await authService.sendEmailCode("User@Example.com");
|
||||
expect(authMailService.sentMessages).toHaveLength(1);
|
||||
expect(authMailService.sentMessages[0]?.email).toBe("user@example.com");
|
||||
|
||||
const loginResult = await authService.loginWithEmailCode(
|
||||
"USER@example.com",
|
||||
authMailService.sentMessages[0]?.code ?? ""
|
||||
);
|
||||
|
||||
expect(loginResult.user.email).toBe("user@example.com");
|
||||
expect(loginResult.accessToken).toContain("user@example.com");
|
||||
|
||||
const storedUser = prismaService.getUsers()[0];
|
||||
expect(storedUser?.email).not.toBe("user@example.com");
|
||||
expect(storedUser?.emailHash).toMatch(/^[a-f0-9]{64}$/);
|
||||
});
|
||||
|
||||
it("should decrypt user email when refreshing token", async () => {
|
||||
await authService.sendEmailCode("refresh@example.com");
|
||||
const loginResult = await authService.loginWithEmailCode(
|
||||
"refresh@example.com",
|
||||
authMailService.sentMessages[0]?.code ?? ""
|
||||
);
|
||||
|
||||
const refreshResult = await authService.refreshTokens(loginResult.refreshToken);
|
||||
expect(refreshResult.user.email).toBe("refresh@example.com");
|
||||
expect(refreshResult.accessToken).toContain("refresh@example.com");
|
||||
});
|
||||
|
||||
it("should reject invalid verification code", async () => {
|
||||
await authService.sendEmailCode("invalid@example.com");
|
||||
|
||||
await expect(
|
||||
authService.loginWithEmailCode("invalid@example.com", "000000")
|
||||
).rejects.toBeInstanceOf(UnauthorizedException);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,80 @@
|
||||
import { AiChannel } from "../generated/prisma/client";
|
||||
import { OpenAiCompatibleProvider } from "../src/ai/providers/openai-compatible.provider";
|
||||
|
||||
describe("OpenAiCompatibleProvider", () => {
|
||||
const originalFetch = global.fetch;
|
||||
|
||||
afterEach(() => {
|
||||
global.fetch = originalFetch;
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("should read text from responses style payload when chat content is empty", async () => {
|
||||
const provider = new OpenAiCompatibleProvider();
|
||||
const fetchMock = jest.fn(async (_input: unknown, init?: RequestInit) => {
|
||||
expect(init?.method).toBe("POST");
|
||||
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
id: "resp_123",
|
||||
object: "response",
|
||||
model: "gpt-5.4",
|
||||
output: [
|
||||
{
|
||||
id: "msg_123",
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "output_text",
|
||||
text: "今天优先先完成截止时间最近的任务。"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 15,
|
||||
completion_tokens: 9,
|
||||
total_tokens: 24
|
||||
}
|
||||
}),
|
||||
{
|
||||
status: 200,
|
||||
headers: {
|
||||
"content-type": "application/json"
|
||||
}
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
global.fetch = fetchMock as typeof global.fetch;
|
||||
|
||||
const result = await provider.execute(
|
||||
{
|
||||
channel: AiChannel.USER_KEY,
|
||||
source: "binding",
|
||||
sourceId: "binding_user_key_1",
|
||||
providerName: "airouter",
|
||||
model: "gpt-5.4",
|
||||
configId: null,
|
||||
configName: null,
|
||||
endpoint: "https://api.airouter.io/v1",
|
||||
apiKey: "sk_test"
|
||||
},
|
||||
{
|
||||
userId: "user_1",
|
||||
message: "帮我安排今天的任务",
|
||||
sessionId: null
|
||||
}
|
||||
);
|
||||
|
||||
expect(fetchMock).toHaveBeenCalledTimes(1);
|
||||
expect(result.content).toBe("今天优先先完成截止时间最近的任务。");
|
||||
expect(result.model).toBe("gpt-5.4");
|
||||
expect(result.usage).toEqual({
|
||||
promptTokens: 15,
|
||||
completionTokens: 9,
|
||||
totalTokens: 24
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,7 +1,9 @@
|
||||
import request from "supertest";
|
||||
import { INestApplication, ValidationPipe } from "@nestjs/common";
|
||||
import { ConfigService } from "@nestjs/config";
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { PrismaService } from "../src/prisma/prisma.service";
|
||||
import { DataEncryptionService } from "../src/security/data-encryption.service";
|
||||
import { SyncController } from "../src/sync/sync.controller";
|
||||
import { SyncService } from "../src/sync/sync.service";
|
||||
|
||||
@@ -159,6 +161,10 @@ class InMemoryPrismaService {
|
||||
return this.syncOperations.length;
|
||||
}
|
||||
|
||||
getRawOperationById(opId: string): SyncOperationRecord | undefined {
|
||||
return this.syncOperations.find((operation) => operation.opId === opId);
|
||||
}
|
||||
|
||||
seedOperations(records: Array<Omit<SyncOperationRecord, "id">>): void {
|
||||
for (const record of records) {
|
||||
this.syncOperations.push({
|
||||
@@ -196,7 +202,18 @@ describe("SyncController (integration)", () => {
|
||||
|
||||
const moduleRef: TestingModule = await Test.createTestingModule({
|
||||
controllers: [SyncController],
|
||||
providers: [SyncService, { provide: PrismaService, useValue: prismaService }]
|
||||
providers: [
|
||||
SyncService,
|
||||
DataEncryptionService,
|
||||
{ provide: PrismaService, useValue: prismaService },
|
||||
{
|
||||
provide: ConfigService,
|
||||
useValue: {
|
||||
get: (key: string) =>
|
||||
key === "DATA_ENCRYPTION_SECRET" ? "test-data-encryption-secret" : undefined
|
||||
}
|
||||
}
|
||||
]
|
||||
}).compile();
|
||||
|
||||
app = moduleRef.createNestApplication();
|
||||
@@ -258,6 +275,9 @@ describe("SyncController (integration)", () => {
|
||||
})
|
||||
]);
|
||||
expect(prismaService.getOperationCount()).toBe(2);
|
||||
expect(prismaService.getRawOperationById("op-create-1")?.payload).not.toBe(
|
||||
'{"title":"浠诲姟涓€"}'
|
||||
);
|
||||
|
||||
const secondResponse = await request(app.getHttpServer())
|
||||
.post("/sync/push")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import request from "supertest";
|
||||
import { INestApplication, ValidationPipe } from "@nestjs/common";
|
||||
import { ConfigService } from "@nestjs/config";
|
||||
import { Test, TestingModule } from "@nestjs/testing";
|
||||
import { PrismaService } from "../src/prisma/prisma.service";
|
||||
import { DataEncryptionService } from "../src/security/data-encryption.service";
|
||||
import { TaskController } from "../src/task/task.controller";
|
||||
import { TaskService } from "../src/task/task.service";
|
||||
import { TaskPriority, TaskStatus } from "../generated/prisma/client";
|
||||
@@ -355,6 +357,10 @@ class InMemoryPrismaService {
|
||||
return runner(this);
|
||||
}
|
||||
|
||||
getRawTaskById(taskId: string): TaskRecord | undefined {
|
||||
return this.tasks.find((task) => task.id === taskId);
|
||||
}
|
||||
|
||||
private toTaskWithTags(
|
||||
task: TaskRecord
|
||||
): TaskRecord & { taskTags: Array<{ tag: { name: string } }> } {
|
||||
@@ -390,7 +396,15 @@ describe("TaskController (integration)", () => {
|
||||
controllers: [TaskController],
|
||||
providers: [
|
||||
TaskService,
|
||||
{ provide: PrismaService, useValue: prismaService as unknown as PrismaService }
|
||||
DataEncryptionService,
|
||||
{ provide: PrismaService, useValue: prismaService as unknown as PrismaService },
|
||||
{
|
||||
provide: ConfigService,
|
||||
useValue: {
|
||||
get: (key: string) =>
|
||||
key === "DATA_ENCRYPTION_SECRET" ? "test-data-encryption-secret" : undefined
|
||||
}
|
||||
}
|
||||
]
|
||||
}).compile();
|
||||
|
||||
@@ -425,6 +439,9 @@ describe("TaskController (integration)", () => {
|
||||
expect(createResponse.body.id).toBeDefined();
|
||||
expect(createResponse.body.tags).toEqual(["工作", "会议"]);
|
||||
const taskId = createResponse.body.id as string;
|
||||
const rawCreatedTask = prismaService.getRawTaskById(taskId);
|
||||
expect(rawCreatedTask?.title).not.toBe("准备周会");
|
||||
expect(rawCreatedTask?.contentText).not.toBe("整理本周进度");
|
||||
|
||||
const listResponse = await request(app.getHttpServer())
|
||||
.get("/tasks")
|
||||
|
||||
@@ -5,6 +5,6 @@
|
||||
"rootDir": ".",
|
||||
"outDir": "dist"
|
||||
},
|
||||
"include": ["src/**/*.ts", "generated/prisma/**/*.ts"],
|
||||
"include": ["src/**/*.ts", "scripts/**/*.ts", "generated/prisma/**/*.ts"],
|
||||
"exclude": ["dist", "node_modules"]
|
||||
}
|
||||
|
||||
+77
-9
@@ -17,8 +17,11 @@ import {
|
||||
import { Navigate, Route, Routes, useLocation, useNavigate } from "react-router-dom";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { AiChatPage } from "@/pages/ai-chat-page";
|
||||
import { EmailLoginPage } from "@/pages/email-login-page";
|
||||
import { OAuthCallbackPage } from "@/pages/oauth-callback-page";
|
||||
import { PlaceholderPage } from "@/pages/placeholder-page";
|
||||
import { SettingsPage } from "@/pages/settings-page";
|
||||
import { TodoShellPage } from "@/pages/todo-shell-page";
|
||||
import { revokeRefreshToken, type EmailLoginResult } from "@/services/auth-api";
|
||||
import {
|
||||
@@ -38,16 +41,19 @@ type SidebarItem = {
|
||||
key: string;
|
||||
label: string;
|
||||
icon: LucideIcon;
|
||||
path: string;
|
||||
};
|
||||
|
||||
const SIDEBAR_ITEMS: SidebarItem[] = [
|
||||
{ key: "dashboard", label: "概览面板", icon: LayoutDashboard },
|
||||
{ key: "todo", label: "待办事项", icon: ListTodo },
|
||||
{ key: "ai", label: "AI 建议", icon: Sparkles },
|
||||
{ key: "notice", label: "提醒中心", icon: Bell },
|
||||
{ key: "settings", label: "系统设置", icon: Settings }
|
||||
{ key: "dashboard", label: "概览面板", icon: LayoutDashboard, path: "/dashboard" },
|
||||
{ key: "todo", label: "待办事项", icon: ListTodo, path: "/todo" },
|
||||
{ key: "ai", label: "AI 助手", icon: Sparkles, path: "/ai" },
|
||||
{ key: "notice", label: "提醒中心", icon: Bell, path: "/notice" },
|
||||
{ key: "settings", label: "系统设置", icon: Settings, path: "/settings" }
|
||||
];
|
||||
|
||||
const READY_SIDEBAR_KEYS = new Set(["todo", "ai", "settings"]);
|
||||
|
||||
function toWebSession(payload: EmailLoginResult): WebSession {
|
||||
return {
|
||||
accessToken: payload.accessToken,
|
||||
@@ -104,7 +110,7 @@ function App() {
|
||||
saveSession(nextSession);
|
||||
setSession(nextSession);
|
||||
setMobileSidebarOpen(false);
|
||||
navigate("/", { replace: true });
|
||||
navigate("/todo", { replace: true });
|
||||
}
|
||||
|
||||
function handleBootstrapSession(nextSession: WebSession): void {
|
||||
@@ -136,14 +142,21 @@ function App() {
|
||||
<nav className="space-y-1">
|
||||
{SIDEBAR_ITEMS.map((item) => {
|
||||
const ItemIcon = item.icon;
|
||||
const isActive =
|
||||
location.pathname === item.path || location.pathname.startsWith(`${item.path}/`);
|
||||
return (
|
||||
<button
|
||||
key={item.key}
|
||||
type="button"
|
||||
className={cn(
|
||||
"group flex w-full items-center rounded-xl border border-transparent px-3 py-2.5 text-left transition-colors",
|
||||
"gap-3 hover:border-primary/25 hover:bg-primary/10"
|
||||
"gap-3 hover:border-primary/25 hover:bg-primary/10",
|
||||
isActive ? "border-primary/25 bg-primary/10" : null
|
||||
)}
|
||||
onClick={() => {
|
||||
navigate(item.path);
|
||||
setMobileSidebarOpen(false);
|
||||
}}
|
||||
>
|
||||
<ItemIcon className="size-5 shrink-0 text-primary" />
|
||||
{collapsed ? null : (
|
||||
@@ -151,9 +164,11 @@ function App() {
|
||||
<span className="text-sm whitespace-nowrap text-foreground">
|
||||
{item.label}
|
||||
</span>
|
||||
{READY_SIDEBAR_KEYS.has(item.key) ? null : (
|
||||
<span className="ml-auto whitespace-nowrap rounded-full border border-border bg-card px-2 py-0.5 text-[10px] text-muted-foreground">
|
||||
即将上线
|
||||
</span>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
@@ -208,7 +223,10 @@ function App() {
|
||||
path="/auth/callback/:provider"
|
||||
element={<OAuthCallbackPage onBootstrapSession={handleBootstrapSession} />}
|
||||
/>
|
||||
<Route path="*" element={<Navigate to={session ? "/" : "/login/email"} replace />} />
|
||||
<Route
|
||||
path="*"
|
||||
element={<Navigate to={session ? "/todo" : "/login/email"} replace />}
|
||||
/>
|
||||
</Routes>
|
||||
</div>
|
||||
</main>
|
||||
@@ -294,6 +312,23 @@ function App() {
|
||||
<Routes>
|
||||
<Route
|
||||
path="/"
|
||||
element={<Navigate to={session ? "/todo" : "/login/email"} replace />}
|
||||
/>
|
||||
<Route
|
||||
path="/dashboard"
|
||||
element={
|
||||
session ? (
|
||||
<PlaceholderPage
|
||||
title="概览面板正在整理"
|
||||
description="这里后续会放任务统计、今日重点、AI 使用概况和提醒概览。当前先把导航和页面结构拆清楚。"
|
||||
/>
|
||||
) : (
|
||||
<Navigate to="/login/email" replace />
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/todo"
|
||||
element={
|
||||
session ? (
|
||||
<TodoShellPage session={session} />
|
||||
@@ -302,9 +337,42 @@ function App() {
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/ai"
|
||||
element={
|
||||
session ? (
|
||||
<AiChatPage session={session} />
|
||||
) : (
|
||||
<Navigate to="/login/email" replace />
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/notice"
|
||||
element={
|
||||
session ? (
|
||||
<PlaceholderPage
|
||||
title="提醒中心即将接入"
|
||||
description="邮件提醒、Web Push 推送、任务到期前通知都会独立收敛到这里,而不是继续堆在任务页里。"
|
||||
/>
|
||||
) : (
|
||||
<Navigate to="/login/email" replace />
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="/settings"
|
||||
element={
|
||||
session ? (
|
||||
<SettingsPage session={session} />
|
||||
) : (
|
||||
<Navigate to="/login/email" replace />
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path="*"
|
||||
element={<Navigate to={session ? "/" : "/login/email"} replace />}
|
||||
element={<Navigate to={session ? "/todo" : "/login/email"} replace />}
|
||||
/>
|
||||
</Routes>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import type { UpsertWebAiBindingInput, WebAiBindingSummary, WebAiChannel } from "@/services/ai-api";
|
||||
|
||||
export type AiBindingFormState = {
|
||||
providerName: string;
|
||||
model: string;
|
||||
endpoint: string;
|
||||
apiKey: string;
|
||||
configId: string;
|
||||
configName: string;
|
||||
isEnabled: boolean;
|
||||
};
|
||||
|
||||
export const CHANNEL_ORDER: WebAiChannel[] = ["USER_KEY", "ASTRBOT", "PUBLIC_POOL"];
|
||||
|
||||
export const CHANNEL_META: Record<
|
||||
WebAiChannel,
|
||||
{
|
||||
title: string;
|
||||
description: string;
|
||||
accentClassName: string;
|
||||
}
|
||||
> = {
|
||||
USER_KEY: {
|
||||
title: "自备厂商",
|
||||
description: "用户自行接入 OpenAI-Compatible 服务",
|
||||
accentClassName: "from-sky-500/15 via-transparent to-sky-500/5"
|
||||
},
|
||||
ASTRBOT: {
|
||||
title: "AstrBot",
|
||||
description: "复用你在 AstrBot 中维护的模型配置",
|
||||
accentClassName: "from-amber-500/15 via-transparent to-amber-500/5"
|
||||
},
|
||||
PUBLIC_POOL: {
|
||||
title: "公共 AI",
|
||||
description: "使用管理员开放的站点公共通道",
|
||||
accentClassName: "from-emerald-500/15 via-transparent to-emerald-500/5"
|
||||
}
|
||||
};
|
||||
|
||||
export function createAiBindingFormState(binding?: WebAiBindingSummary | null): AiBindingFormState {
|
||||
return {
|
||||
providerName: binding?.providerName ?? "",
|
||||
model: binding?.model ?? "",
|
||||
endpoint: binding?.endpoint ?? "",
|
||||
apiKey: "",
|
||||
configId: binding?.configId ?? "",
|
||||
configName: binding?.configName ?? "",
|
||||
isEnabled: binding?.isEnabled ?? true
|
||||
};
|
||||
}
|
||||
|
||||
export function trimAiOptionalValue(value: string): string | undefined {
|
||||
const normalized = value.trim();
|
||||
return normalized.length > 0 ? normalized : undefined;
|
||||
}
|
||||
|
||||
export function buildAiBindingPayload(
|
||||
channel: Exclude<WebAiChannel, "PUBLIC_POOL">,
|
||||
formState: AiBindingFormState,
|
||||
currentBinding: WebAiBindingSummary | null
|
||||
): UpsertWebAiBindingInput {
|
||||
return {
|
||||
channel,
|
||||
providerName: trimAiOptionalValue(formState.providerName),
|
||||
model: trimAiOptionalValue(formState.model),
|
||||
endpoint: trimAiOptionalValue(formState.endpoint),
|
||||
configId: trimAiOptionalValue(formState.configId),
|
||||
configName: trimAiOptionalValue(formState.configName),
|
||||
apiKey: trimAiOptionalValue(formState.apiKey) ?? undefined,
|
||||
isEnabled: formState.isEnabled ?? currentBinding?.isEnabled ?? true
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,539 @@
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import type { KeyboardEvent } from "react";
|
||||
import {
|
||||
Bot,
|
||||
CircleAlert,
|
||||
Globe2,
|
||||
KeyRound,
|
||||
LoaderCircle,
|
||||
PlugZap,
|
||||
SendHorizontal
|
||||
} from "lucide-react";
|
||||
import { useNavigate } from "react-router-dom";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
chatWithAi,
|
||||
listAiBindings,
|
||||
type WebAiBindingSummary,
|
||||
type WebAiBindingsResponse,
|
||||
type WebAiChannel,
|
||||
type WebAiLocalTaskContextItem,
|
||||
WebAiApiError
|
||||
} from "@/services/ai-api";
|
||||
import {
|
||||
deleteLocalAiChatSession,
|
||||
listLocalAiChatSessions,
|
||||
saveLocalAiChatSession,
|
||||
type LocalAiChatMessageRecord
|
||||
} from "@/services/local-ai-chat-repo";
|
||||
import { listLocalTasksByUser } from "@/services/local-task-repo";
|
||||
import type { WebSession } from "@/services/session-storage";
|
||||
import { CHANNEL_META, CHANNEL_ORDER } from "@/components/ai/ai-shared";
|
||||
|
||||
type AiChatPageProps = {
|
||||
session: WebSession;
|
||||
};
|
||||
|
||||
type AiMessageRecord = LocalAiChatMessageRecord;
|
||||
|
||||
function createEmptyMessages(): Record<WebAiChannel, AiMessageRecord[]> {
|
||||
return {
|
||||
USER_KEY: [],
|
||||
ASTRBOT: [],
|
||||
PUBLIC_POOL: []
|
||||
};
|
||||
}
|
||||
|
||||
function createEmptySessionIds(): Partial<Record<WebAiChannel, string>> {
|
||||
return {};
|
||||
}
|
||||
|
||||
function formatTimeLabel(date = new Date()): string {
|
||||
return date.toLocaleTimeString("zh-CN", {
|
||||
hour: "2-digit",
|
||||
minute: "2-digit"
|
||||
});
|
||||
}
|
||||
|
||||
function appendMessage(
|
||||
records: Record<WebAiChannel, AiMessageRecord[]>,
|
||||
channel: WebAiChannel,
|
||||
message: AiMessageRecord
|
||||
): Record<WebAiChannel, AiMessageRecord[]> {
|
||||
return {
|
||||
...records,
|
||||
[channel]: [...records[channel], message]
|
||||
};
|
||||
}
|
||||
|
||||
function buildLocalTaskContext(
|
||||
items: Awaited<ReturnType<typeof listLocalTasksByUser>>
|
||||
): WebAiLocalTaskContextItem[] {
|
||||
return items
|
||||
.filter((item) => item.status === "TODO" || item.status === "IN_PROGRESS")
|
||||
.slice(0, 20)
|
||||
.map((item) => ({
|
||||
id: item.id,
|
||||
title: item.title,
|
||||
priority: item.priority,
|
||||
status: item.status,
|
||||
ddlAt: item.ddlAt,
|
||||
contentText: item.contentText,
|
||||
updatedAt: item.updatedAt
|
||||
}));
|
||||
}
|
||||
|
||||
export function AiChatPage({ session }: AiChatPageProps) {
|
||||
const navigate = useNavigate();
|
||||
const [bindingsResponse, setBindingsResponse] = useState<WebAiBindingsResponse | null>(null);
|
||||
const [loadingBindings, setLoadingBindings] = useState(true);
|
||||
const [refreshingBindings, setRefreshingBindings] = useState(false);
|
||||
const [activeChannel, setActiveChannel] = useState<WebAiChannel>("USER_KEY");
|
||||
const [messagesByChannel, setMessagesByChannel] = useState<
|
||||
Record<WebAiChannel, AiMessageRecord[]>
|
||||
>(() => createEmptyMessages());
|
||||
const [sessionIds, setSessionIds] = useState<Partial<Record<WebAiChannel, string>>>(() =>
|
||||
createEmptySessionIds()
|
||||
);
|
||||
const [draftMessage, setDraftMessage] = useState("");
|
||||
const [sending, setSending] = useState(false);
|
||||
const [loadError, setLoadError] = useState<string | null>(null);
|
||||
const [historyLoaded, setHistoryLoaded] = useState(false);
|
||||
const messagesEndRef = useRef<HTMLDivElement | null>(null);
|
||||
|
||||
const bindingMap = useMemo(() => {
|
||||
const map = new Map<WebAiChannel, WebAiBindingSummary>();
|
||||
for (const binding of bindingsResponse?.bindings ?? []) {
|
||||
map.set(binding.channel, binding);
|
||||
}
|
||||
return map;
|
||||
}, [bindingsResponse]);
|
||||
|
||||
const currentBinding =
|
||||
activeChannel === "PUBLIC_POOL" ? null : (bindingMap.get(activeChannel) ?? null);
|
||||
const publicPool = bindingsResponse?.publicPool ?? null;
|
||||
const currentMessages = messagesByChannel[activeChannel];
|
||||
|
||||
const loadBindings = useCallback(async (): Promise<void> => {
|
||||
setRefreshingBindings(true);
|
||||
setLoadError(null);
|
||||
|
||||
try {
|
||||
const response = await listAiBindings(session);
|
||||
setBindingsResponse(response);
|
||||
} catch (error) {
|
||||
setLoadError(error instanceof Error ? error.message : "AI 配置加载失败");
|
||||
} finally {
|
||||
setLoadingBindings(false);
|
||||
setRefreshingBindings(false);
|
||||
}
|
||||
}, [session]);
|
||||
|
||||
useEffect(() => {
|
||||
void loadBindings();
|
||||
}, [loadBindings]);
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
|
||||
async function loadLocalHistory(): Promise<void> {
|
||||
try {
|
||||
const records = await listLocalAiChatSessions(session.user.id);
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
|
||||
const nextMessages = createEmptyMessages();
|
||||
const nextSessionIds = createEmptySessionIds();
|
||||
|
||||
for (const record of records) {
|
||||
nextMessages[record.channel] = record.messages;
|
||||
if (record.sessionId) {
|
||||
nextSessionIds[record.channel] = record.sessionId;
|
||||
}
|
||||
}
|
||||
|
||||
setMessagesByChannel(nextMessages);
|
||||
setSessionIds(nextSessionIds);
|
||||
} finally {
|
||||
if (!cancelled) {
|
||||
setHistoryLoaded(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setHistoryLoaded(false);
|
||||
void loadLocalHistory();
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [session.user.id]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!historyLoaded) {
|
||||
return;
|
||||
}
|
||||
|
||||
void Promise.all(
|
||||
CHANNEL_ORDER.map(async (channel) => {
|
||||
const messages = messagesByChannel[channel];
|
||||
const sessionId = sessionIds[channel] ?? null;
|
||||
|
||||
if (messages.length === 0 && sessionId === null) {
|
||||
await deleteLocalAiChatSession(session.user.id, channel);
|
||||
return;
|
||||
}
|
||||
|
||||
await saveLocalAiChatSession({
|
||||
userId: session.user.id,
|
||||
channel,
|
||||
sessionId,
|
||||
messages
|
||||
});
|
||||
})
|
||||
);
|
||||
}, [historyLoaded, messagesByChannel, session.user.id, sessionIds]);
|
||||
|
||||
useEffect(() => {
|
||||
messagesEndRef.current?.scrollIntoView({
|
||||
block: "end",
|
||||
behavior: "smooth"
|
||||
});
|
||||
}, [activeChannel, currentMessages.length]);
|
||||
|
||||
const sendBlockedReason = useMemo(() => {
|
||||
if (activeChannel === "PUBLIC_POOL") {
|
||||
if (!publicPool?.enabled) {
|
||||
return "管理员尚未开放公共 AI。";
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!currentBinding) {
|
||||
return activeChannel === "USER_KEY"
|
||||
? "你还没有配置自备厂商,请先前往系统设置 > AI 配置。"
|
||||
: "你还没有配置 AstrBot,请先前往系统设置 > AI 配置。";
|
||||
}
|
||||
|
||||
if (!currentBinding.isEnabled) {
|
||||
return "当前渠道已关闭,请先在系统设置 > AI 配置中启用。";
|
||||
}
|
||||
|
||||
return null;
|
||||
}, [activeChannel, currentBinding, publicPool]);
|
||||
|
||||
async function handleSendMessage(): Promise<void> {
|
||||
const message = draftMessage.trim();
|
||||
if (!message || sendBlockedReason || sending) {
|
||||
return;
|
||||
}
|
||||
|
||||
const channel = activeChannel;
|
||||
setSending(true);
|
||||
setDraftMessage("");
|
||||
setMessagesByChannel((current) =>
|
||||
appendMessage(current, channel, {
|
||||
id: crypto.randomUUID(),
|
||||
role: "user",
|
||||
content: message,
|
||||
meta: formatTimeLabel()
|
||||
})
|
||||
);
|
||||
|
||||
try {
|
||||
const localTasks = buildLocalTaskContext(await listLocalTasksByUser(session.user.id));
|
||||
const response = await chatWithAi(session, {
|
||||
channel,
|
||||
message,
|
||||
sessionId: sessionIds[channel],
|
||||
localTasks
|
||||
});
|
||||
|
||||
setSessionIds((current) => ({
|
||||
...current,
|
||||
[channel]: response.sessionId ?? current[channel]
|
||||
}));
|
||||
setMessagesByChannel((current) =>
|
||||
appendMessage(current, channel, {
|
||||
id: crypto.randomUUID(),
|
||||
role: "assistant",
|
||||
content: response.content,
|
||||
meta: `${CHANNEL_META[response.channel].title} · ${response.providerName}${response.model ? ` · ${response.model}` : ""}`
|
||||
})
|
||||
);
|
||||
} catch (error) {
|
||||
const apiError =
|
||||
error instanceof WebAiApiError
|
||||
? error
|
||||
: new WebAiApiError(error instanceof Error ? error.message : "AI 请求失败");
|
||||
const firstAttempt = apiError.attempts?.find((item) => item.reasonMessage);
|
||||
const content =
|
||||
firstAttempt?.reasonMessage && firstAttempt.reasonMessage !== apiError.message
|
||||
? `${apiError.message}\n${firstAttempt.reasonMessage}`
|
||||
: apiError.message;
|
||||
|
||||
setMessagesByChannel((current) =>
|
||||
appendMessage(current, channel, {
|
||||
id: crypto.randomUUID(),
|
||||
role: "system",
|
||||
content,
|
||||
meta: "调用失败"
|
||||
})
|
||||
);
|
||||
} finally {
|
||||
setSending(false);
|
||||
}
|
||||
}
|
||||
|
||||
function handleDraftKeyDown(event: KeyboardEvent<HTMLTextAreaElement>): void {
|
||||
if (event.key !== "Enter" || event.shiftKey || event.nativeEvent.isComposing) {
|
||||
return;
|
||||
}
|
||||
|
||||
event.preventDefault();
|
||||
void handleSendMessage();
|
||||
}
|
||||
|
||||
return (
|
||||
<section className="space-y-4">
|
||||
<div className="rounded-[2rem] border border-border/70 bg-card/92 p-6 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]">
|
||||
<div className="flex flex-col gap-4 lg:flex-row lg:items-end lg:justify-between">
|
||||
<div>
|
||||
<div className="flex items-center gap-2 text-sm font-medium text-primary">
|
||||
<Bot className="size-4" />
|
||||
AI 助手
|
||||
</div>
|
||||
<h1 className="mt-2 text-2xl font-semibold tracking-tight text-foreground">
|
||||
在独立页面中发起 AI 对话
|
||||
</h1>
|
||||
<p className="mt-2 text-sm leading-7 text-muted-foreground">
|
||||
聊天页面只负责问答和任务统筹。所有渠道配置统一放在系统设置中的 AI 配置页面。
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-3">
|
||||
<Button type="button" variant="outline" onClick={() => navigate("/settings")}>
|
||||
前往 AI 配置
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={() => void loadBindings()}
|
||||
disabled={refreshingBindings}
|
||||
>
|
||||
{refreshingBindings ? (
|
||||
<>
|
||||
<LoaderCircle className="size-4 animate-spin" />
|
||||
刷新中
|
||||
</>
|
||||
) : (
|
||||
"刷新状态"
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 xl:grid-cols-[320px_minmax(0,1fr)]">
|
||||
<aside className="space-y-4 rounded-[2rem] border border-border/70 bg-card/92 p-4 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]">
|
||||
<div>
|
||||
<div className="text-sm font-semibold text-foreground">选择渠道</div>
|
||||
<div className="mt-1 text-xs leading-6 text-muted-foreground">
|
||||
前端只会使用你当前明确选中的那一个渠道。
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
{CHANNEL_ORDER.map((channel) => {
|
||||
const selected = activeChannel === channel;
|
||||
const binding = channel === "PUBLIC_POOL" ? null : (bindingMap.get(channel) ?? null);
|
||||
const enabled =
|
||||
channel === "PUBLIC_POOL"
|
||||
? Boolean(publicPool?.enabled)
|
||||
: Boolean(binding?.isEnabled);
|
||||
const statusLabel =
|
||||
channel === "PUBLIC_POOL"
|
||||
? publicPool?.enabled
|
||||
? "可使用"
|
||||
: "未开放"
|
||||
: binding
|
||||
? enabled
|
||||
? "已启用"
|
||||
: "已停用"
|
||||
: "未配置";
|
||||
const Icon =
|
||||
channel === "PUBLIC_POOL" ? Globe2 : channel === "ASTRBOT" ? PlugZap : KeyRound;
|
||||
|
||||
return (
|
||||
<button
|
||||
key={channel}
|
||||
type="button"
|
||||
className={cn(
|
||||
"w-full rounded-2xl border bg-gradient-to-br px-3 py-3 text-left transition-all",
|
||||
CHANNEL_META[channel].accentClassName,
|
||||
selected
|
||||
? "border-primary/45 ring-2 ring-primary/15"
|
||||
: "border-border/70 hover:border-primary/25 hover:bg-muted/35"
|
||||
)}
|
||||
onClick={() => setActiveChannel(channel)}
|
||||
>
|
||||
<div className="flex items-start justify-between gap-3">
|
||||
<div className="flex items-start gap-3">
|
||||
<span className="rounded-xl bg-background/85 p-2 text-primary shadow-sm">
|
||||
<Icon className="size-4" />
|
||||
</span>
|
||||
<div>
|
||||
<div className="text-sm font-semibold text-foreground">
|
||||
{CHANNEL_META[channel].title}
|
||||
</div>
|
||||
<div className="mt-1 text-xs leading-5 text-muted-foreground">
|
||||
{CHANNEL_META[channel].description}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<span
|
||||
className={cn(
|
||||
"rounded-full border px-2 py-0.5 text-[11px] font-medium",
|
||||
enabled
|
||||
? "border-emerald-500/25 bg-emerald-500/10 text-emerald-700 dark:text-emerald-300"
|
||||
: "border-border bg-background text-muted-foreground"
|
||||
)}
|
||||
>
|
||||
{statusLabel}
|
||||
</span>
|
||||
</div>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{loadError ? (
|
||||
<div className="rounded-2xl border border-destructive/15 bg-destructive/8 px-3 py-2 text-sm text-destructive">
|
||||
{loadError}
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
<div className="rounded-2xl border border-border/70 bg-background/80 px-3 py-3 text-xs leading-6 text-muted-foreground">
|
||||
<div className="font-medium text-foreground">当前渠道状态</div>
|
||||
<div className="mt-1">
|
||||
{loadingBindings
|
||||
? "正在加载配置..."
|
||||
: activeChannel === "PUBLIC_POOL"
|
||||
? publicPool?.enabled
|
||||
? "公共 AI 已开放,可直接发送。"
|
||||
: "公共 AI 未开放。"
|
||||
: currentBinding
|
||||
? currentBinding.isEnabled
|
||||
? "已配置并启用。"
|
||||
: "已配置,但当前关闭。"
|
||||
: "尚未配置。"}
|
||||
</div>
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
<div className="flex min-h-[720px] flex-col overflow-hidden rounded-[2rem] border border-border/70 bg-card/92 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]">
|
||||
<div className="border-b border-border/70 px-5 py-4">
|
||||
<div className="text-sm font-semibold text-foreground">
|
||||
{CHANNEL_META[activeChannel].title}
|
||||
</div>
|
||||
<div className="mt-1 text-xs text-muted-foreground">
|
||||
发送消息时会自动附带你当前未完成任务的摘要。
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="min-h-0 flex-1 space-y-3 overflow-y-auto px-5 py-4">
|
||||
{currentMessages.length === 0 ? (
|
||||
<div className="rounded-2xl border border-dashed border-border bg-muted/35 p-4 text-sm leading-7 text-muted-foreground">
|
||||
<div className="font-medium text-foreground">暂无对话记录。</div>
|
||||
<div className="mt-1">
|
||||
你可以输入“帮我根据当前未完成任务安排今天下午的执行顺序”直接开始。
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
currentMessages.map((message) => (
|
||||
<div
|
||||
key={message.id}
|
||||
className={cn(
|
||||
"max-w-[92%] rounded-2xl px-4 py-3 text-sm leading-7 shadow-sm",
|
||||
message.role === "user"
|
||||
? "ml-auto bg-primary text-primary-foreground"
|
||||
: message.role === "assistant"
|
||||
? "border border-border/70 bg-background text-foreground"
|
||||
: "border border-destructive/15 bg-destructive/8 text-foreground"
|
||||
)}
|
||||
>
|
||||
<div className="whitespace-pre-wrap break-words">{message.content}</div>
|
||||
{message.meta ? (
|
||||
<div
|
||||
className={cn(
|
||||
"mt-2 text-[11px]",
|
||||
message.role === "user"
|
||||
? "text-primary-foreground/80"
|
||||
: "text-muted-foreground"
|
||||
)}
|
||||
>
|
||||
{message.meta}
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
|
||||
<div className="border-t border-border/70 p-5">
|
||||
{sendBlockedReason ? (
|
||||
<div className="mb-3 rounded-2xl border border-amber-500/15 bg-amber-500/10 px-3 py-2 text-sm leading-6 text-amber-700 dark:text-amber-300">
|
||||
{sendBlockedReason}
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
<textarea
|
||||
value={draftMessage}
|
||||
onChange={(event) => setDraftMessage(event.target.value)}
|
||||
onKeyDown={handleDraftKeyDown}
|
||||
placeholder="输入你的问题,例如:结合我当前待办,帮我排一下今天的优先级。"
|
||||
className="min-h-[140px] w-full rounded-2xl border border-border bg-background px-4 py-3 text-sm leading-7 outline-none transition-colors placeholder:text-muted-foreground focus:border-primary/40"
|
||||
/>
|
||||
|
||||
<div className="mt-3 flex items-center justify-between gap-3">
|
||||
<div className="flex items-center gap-2 text-xs text-muted-foreground">
|
||||
<CircleAlert className="size-4" />
|
||||
<span>当前只会使用你选中的渠道,不会在前端静默切换。</span>
|
||||
</div>
|
||||
<div className="flex gap-3">
|
||||
{sendBlockedReason ? (
|
||||
<Button type="button" variant="outline" onClick={() => navigate("/settings")}>
|
||||
去系统设置
|
||||
</Button>
|
||||
) : null}
|
||||
<Button
|
||||
type="button"
|
||||
onClick={() => void handleSendMessage()}
|
||||
disabled={
|
||||
sending || draftMessage.trim().length === 0 || sendBlockedReason !== null
|
||||
}
|
||||
>
|
||||
{sending ? (
|
||||
<>
|
||||
<LoaderCircle className="size-4 animate-spin" />
|
||||
发送中
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<SendHorizontal className="size-4" />
|
||||
发送
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
import type { LucideIcon } from "lucide-react";
|
||||
import { Construction } from "lucide-react";
|
||||
|
||||
type PlaceholderPageProps = {
|
||||
title: string;
|
||||
description: string;
|
||||
icon?: LucideIcon;
|
||||
};
|
||||
|
||||
export function PlaceholderPage({
|
||||
title,
|
||||
description,
|
||||
icon: Icon = Construction
|
||||
}: PlaceholderPageProps) {
|
||||
return (
|
||||
<section className="rounded-[2rem] border border-border/70 bg-card/92 p-8 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]">
|
||||
<div className="mx-auto max-w-2xl text-center">
|
||||
<div className="mx-auto flex h-16 w-16 items-center justify-center rounded-2xl bg-primary/10 text-primary">
|
||||
<Icon className="size-7" />
|
||||
</div>
|
||||
<h1 className="mt-5 text-2xl font-semibold tracking-tight text-foreground">{title}</h1>
|
||||
<p className="mt-3 text-sm leading-7 text-muted-foreground">{description}</p>
|
||||
</div>
|
||||
</section>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,549 @@
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { CheckCircle2, Globe2, KeyRound, LoaderCircle, PlugZap, Settings2 } from "lucide-react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
listAiBindings,
|
||||
testAiBinding,
|
||||
upsertAiBinding,
|
||||
type WebAiBindingSummary,
|
||||
type WebAiBindingsResponse,
|
||||
type WebAiChannel
|
||||
} from "@/services/ai-api";
|
||||
import type { WebSession } from "@/services/session-storage";
|
||||
import {
|
||||
buildAiBindingPayload,
|
||||
createAiBindingFormState,
|
||||
type AiBindingFormState
|
||||
} from "@/components/ai/ai-shared";
|
||||
|
||||
type SettingsPageProps = {
|
||||
session: WebSession;
|
||||
};
|
||||
|
||||
type SettingsTab = "ai" | "general";
|
||||
|
||||
type NoticeState = {
|
||||
tone: "success" | "error";
|
||||
message: string;
|
||||
};
|
||||
|
||||
type ChannelNoticeState = NoticeState & {
|
||||
detail?: string;
|
||||
};
|
||||
|
||||
const TODOLIST_VERSION = "0.1.0";
|
||||
|
||||
function AiConfigCard({
|
||||
channel,
|
||||
title,
|
||||
description,
|
||||
icon: Icon,
|
||||
formState,
|
||||
onChange,
|
||||
onSave,
|
||||
saving,
|
||||
binding,
|
||||
notice
|
||||
}: {
|
||||
channel: Exclude<WebAiChannel, "PUBLIC_POOL">;
|
||||
title: string;
|
||||
description: string;
|
||||
icon: typeof KeyRound;
|
||||
formState: AiBindingFormState;
|
||||
onChange: React.Dispatch<React.SetStateAction<AiBindingFormState>>;
|
||||
onSave: () => Promise<void>;
|
||||
saving: boolean;
|
||||
binding: WebAiBindingSummary | null;
|
||||
notice: ChannelNoticeState | null;
|
||||
}) {
|
||||
return (
|
||||
<section className="rounded-[2rem] border border-border/70 bg-card/92 p-5 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]">
|
||||
<div className="flex items-start justify-between gap-3">
|
||||
<div className="flex items-start gap-3">
|
||||
<span className="rounded-2xl bg-primary/10 p-3 text-primary">
|
||||
<Icon className="size-5" />
|
||||
</span>
|
||||
<div>
|
||||
<h2 className="text-lg font-semibold tracking-tight text-foreground">{title}</h2>
|
||||
<p className="mt-1 text-sm leading-6 text-muted-foreground">{description}</p>
|
||||
</div>
|
||||
</div>
|
||||
<span
|
||||
className={cn(
|
||||
"rounded-full border px-2 py-0.5 text-[11px] font-medium",
|
||||
formState.isEnabled
|
||||
? "border-emerald-500/25 bg-emerald-500/10 text-emerald-700 dark:text-emerald-300"
|
||||
: "border-border bg-background text-muted-foreground"
|
||||
)}
|
||||
>
|
||||
{formState.isEnabled ? "已启用" : "已停用"}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{notice ? (
|
||||
<div
|
||||
className={cn(
|
||||
"mt-4 rounded-2xl border px-3 py-3 text-sm",
|
||||
notice.tone === "success"
|
||||
? "border-emerald-500/20 bg-emerald-500/10 text-emerald-700 dark:text-emerald-300"
|
||||
: "border-destructive/20 bg-destructive/10 text-destructive"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-start gap-2">
|
||||
<CheckCircle2 className="mt-0.5 size-4 shrink-0" />
|
||||
<div className="min-w-0">
|
||||
<div>{notice.message}</div>
|
||||
{notice.detail ? (
|
||||
<div className="mt-1 text-xs leading-6 opacity-80">{notice.detail}</div>
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
<div className="mt-5 grid gap-3 sm:grid-cols-2">
|
||||
<label className="space-y-1.5">
|
||||
<span className="text-xs font-medium text-muted-foreground">服务商标识</span>
|
||||
<input
|
||||
className="h-10 w-full rounded-xl border border-border bg-background px-3 text-sm outline-none transition-colors focus:border-primary/40"
|
||||
value={formState.providerName}
|
||||
onChange={(event) =>
|
||||
onChange((current) => ({
|
||||
...current,
|
||||
providerName: event.target.value
|
||||
}))
|
||||
}
|
||||
placeholder={channel === "USER_KEY" ? "如 openai / deepseek / dashscope" : "可选"}
|
||||
/>
|
||||
</label>
|
||||
|
||||
<label className="space-y-1.5">
|
||||
<span className="text-xs font-medium text-muted-foreground">模型</span>
|
||||
<input
|
||||
className="h-10 w-full rounded-xl border border-border bg-background px-3 text-sm outline-none transition-colors focus:border-primary/40"
|
||||
value={formState.model}
|
||||
onChange={(event) =>
|
||||
onChange((current) => ({
|
||||
...current,
|
||||
model: event.target.value
|
||||
}))
|
||||
}
|
||||
placeholder={channel === "USER_KEY" ? "如 gpt-4o-mini" : "可选"}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div className="mt-3 grid gap-3">
|
||||
<label className="space-y-1.5">
|
||||
<span className="text-xs font-medium text-muted-foreground">
|
||||
{channel === "USER_KEY" ? "接口地址" : "AstrBot 地址"}
|
||||
</span>
|
||||
<input
|
||||
className="h-10 w-full rounded-xl border border-border bg-background px-3 text-sm outline-none transition-colors focus:border-primary/40"
|
||||
value={formState.endpoint}
|
||||
onChange={(event) =>
|
||||
onChange((current) => ({
|
||||
...current,
|
||||
endpoint: event.target.value
|
||||
}))
|
||||
}
|
||||
placeholder={
|
||||
channel === "USER_KEY" ? "如 https://api.openai.com/v1" : "如 http://100.64.0.21:6185"
|
||||
}
|
||||
/>
|
||||
</label>
|
||||
|
||||
{channel === "ASTRBOT" ? (
|
||||
<div className="grid gap-3 sm:grid-cols-2">
|
||||
<label className="space-y-1.5">
|
||||
<span className="text-xs font-medium text-muted-foreground">configId</span>
|
||||
<input
|
||||
className="h-10 w-full rounded-xl border border-border bg-background px-3 text-sm outline-none transition-colors focus:border-primary/40"
|
||||
value={formState.configId}
|
||||
onChange={(event) =>
|
||||
onChange((current) => ({
|
||||
...current,
|
||||
configId: event.target.value
|
||||
}))
|
||||
}
|
||||
placeholder="如 default"
|
||||
/>
|
||||
</label>
|
||||
|
||||
<label className="space-y-1.5">
|
||||
<span className="text-xs font-medium text-muted-foreground">configName</span>
|
||||
<input
|
||||
className="h-10 w-full rounded-xl border border-border bg-background px-3 text-sm outline-none transition-colors focus:border-primary/40"
|
||||
value={formState.configName}
|
||||
onChange={(event) =>
|
||||
onChange((current) => ({
|
||||
...current,
|
||||
configName: event.target.value
|
||||
}))
|
||||
}
|
||||
placeholder="可选"
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
<label className="space-y-1.5">
|
||||
<span className="text-xs font-medium text-muted-foreground">
|
||||
{channel === "USER_KEY" ? "API Key" : "AstrBot API Key"}
|
||||
</span>
|
||||
<input
|
||||
className="h-10 w-full rounded-xl border border-border bg-background px-3 text-sm outline-none transition-colors focus:border-primary/40"
|
||||
value={formState.apiKey}
|
||||
onChange={(event) =>
|
||||
onChange((current) => ({
|
||||
...current,
|
||||
apiKey: event.target.value
|
||||
}))
|
||||
}
|
||||
placeholder={binding?.hasApiKey ? "留空则保持当前密钥不变" : "请输入密钥"}
|
||||
/>
|
||||
{binding?.maskedApiKey ? (
|
||||
<div className="text-xs text-muted-foreground">
|
||||
当前已保存密钥:{binding.maskedApiKey}
|
||||
</div>
|
||||
) : null}
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<label className="mt-3 flex items-center gap-2 rounded-2xl border border-border/70 bg-background/70 px-3 py-2 text-sm text-foreground">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={formState.isEnabled}
|
||||
onChange={(event) =>
|
||||
onChange((current) => ({
|
||||
...current,
|
||||
isEnabled: event.target.checked
|
||||
}))
|
||||
}
|
||||
/>
|
||||
<span>保存后立即启用该渠道</span>
|
||||
</label>
|
||||
|
||||
<div className="mt-4 flex items-center justify-between gap-3">
|
||||
<p className="text-xs leading-6 text-muted-foreground">
|
||||
{channel === "USER_KEY"
|
||||
? "该配置按用户单独保存,适合接入你自己的服务商密钥。"
|
||||
: "该配置按用户单独保存,适合直接复用 AstrBot 中已有的模型能力。"}
|
||||
<br />
|
||||
测试基于你当前表单中的输入;如果测试失败,当前已保存并生效的旧配置不会被覆盖。
|
||||
</p>
|
||||
|
||||
<Button type="button" onClick={() => void onSave()} disabled={saving}>
|
||||
{saving ? (
|
||||
<>
|
||||
<LoaderCircle className="size-4 animate-spin" />
|
||||
处理中
|
||||
</>
|
||||
) : formState.isEnabled ? (
|
||||
"测试并保存"
|
||||
) : (
|
||||
"保存草稿"
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</section>
|
||||
);
|
||||
}
|
||||
|
||||
export function SettingsPage({ session }: SettingsPageProps) {
|
||||
const [activeTab, setActiveTab] = useState<SettingsTab>("ai");
|
||||
const [bindingsResponse, setBindingsResponse] = useState<WebAiBindingsResponse | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [refreshing, setRefreshing] = useState(false);
|
||||
const [notice, setNotice] = useState<NoticeState | null>(null);
|
||||
const [savingChannel, setSavingChannel] = useState<WebAiChannel | null>(null);
|
||||
const [channelNotices, setChannelNotices] = useState<
|
||||
Partial<Record<Exclude<WebAiChannel, "PUBLIC_POOL">, ChannelNoticeState>>
|
||||
>({});
|
||||
const [userKeyForm, setUserKeyForm] = useState<AiBindingFormState>(() =>
|
||||
createAiBindingFormState()
|
||||
);
|
||||
const [astrbotForm, setAstrbotForm] = useState<AiBindingFormState>(() =>
|
||||
createAiBindingFormState()
|
||||
);
|
||||
|
||||
const bindingMap = useMemo(() => {
|
||||
const map = new Map<WebAiChannel, WebAiBindingSummary>();
|
||||
for (const binding of bindingsResponse?.bindings ?? []) {
|
||||
map.set(binding.channel, binding);
|
||||
}
|
||||
return map;
|
||||
}, [bindingsResponse]);
|
||||
|
||||
const loadBindings = useCallback(async (): Promise<void> => {
|
||||
setRefreshing(true);
|
||||
|
||||
try {
|
||||
const response = await listAiBindings(session);
|
||||
setBindingsResponse(response);
|
||||
setUserKeyForm(
|
||||
createAiBindingFormState(response.bindings.find((item) => item.channel === "USER_KEY"))
|
||||
);
|
||||
setAstrbotForm(
|
||||
createAiBindingFormState(response.bindings.find((item) => item.channel === "ASTRBOT"))
|
||||
);
|
||||
} catch (error) {
|
||||
setNotice({
|
||||
tone: "error",
|
||||
message: error instanceof Error ? error.message : "AI 配置加载失败"
|
||||
});
|
||||
} finally {
|
||||
setLoading(false);
|
||||
setRefreshing(false);
|
||||
}
|
||||
}, [session]);
|
||||
|
||||
useEffect(() => {
|
||||
void loadBindings();
|
||||
}, [loadBindings]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!notice) {
|
||||
return;
|
||||
}
|
||||
|
||||
const timer = window.setTimeout(() => {
|
||||
setNotice(null);
|
||||
}, 2800);
|
||||
|
||||
return () => {
|
||||
window.clearTimeout(timer);
|
||||
};
|
||||
}, [notice]);
|
||||
|
||||
async function handleSaveChannel(channel: Exclude<WebAiChannel, "PUBLIC_POOL">): Promise<void> {
|
||||
const formState = channel === "USER_KEY" ? userKeyForm : astrbotForm;
|
||||
const binding = bindingMap.get(channel) ?? null;
|
||||
const payload = buildAiBindingPayload(channel, formState, binding);
|
||||
|
||||
try {
|
||||
setSavingChannel(channel);
|
||||
setChannelNotices((current) => ({
|
||||
...current,
|
||||
[channel]: undefined
|
||||
}));
|
||||
if (payload.isEnabled) {
|
||||
const testResult = await testAiBinding(session, payload);
|
||||
if (!testResult.success) {
|
||||
setChannelNotices((current) => ({
|
||||
...current,
|
||||
[channel]: {
|
||||
tone: "error",
|
||||
message: `连通性测试未通过:${testResult.message}`,
|
||||
detail: binding
|
||||
? "测试的是你当前编辑中的草稿配置。由于未保存,系统仍会继续使用上一份已保存配置,所以聊天可能依然正常。"
|
||||
: "当前还没有已保存配置。请先修正表单中的地址、模型或密钥后再测试。"
|
||||
}
|
||||
}));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
await upsertAiBinding(session, payload);
|
||||
setChannelNotices((current) => ({
|
||||
...current,
|
||||
[channel]: {
|
||||
tone: "success",
|
||||
message:
|
||||
channel === "USER_KEY"
|
||||
? payload.isEnabled
|
||||
? "自备厂商连通性测试通过,配置已保存。"
|
||||
: "自备厂商配置草稿已保存。"
|
||||
: payload.isEnabled
|
||||
? "AstrBot 连通性测试通过,配置已保存。"
|
||||
: "AstrBot 配置草稿已保存。",
|
||||
detail: payload.isEnabled
|
||||
? "之后 AI 助手会使用这份刚保存的配置。"
|
||||
: "当前只是保存草稿,未启用时不会参与实际聊天。"
|
||||
}
|
||||
}));
|
||||
if (channel === "USER_KEY") {
|
||||
setUserKeyForm((current) => ({
|
||||
...current,
|
||||
apiKey: ""
|
||||
}));
|
||||
} else {
|
||||
setAstrbotForm((current) => ({
|
||||
...current,
|
||||
apiKey: ""
|
||||
}));
|
||||
}
|
||||
await loadBindings();
|
||||
} catch (error) {
|
||||
setChannelNotices((current) => ({
|
||||
...current,
|
||||
[channel]: {
|
||||
tone: "error",
|
||||
message: error instanceof Error ? error.message : "AI 配置保存失败"
|
||||
}
|
||||
}));
|
||||
} finally {
|
||||
setSavingChannel(null);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<section className="space-y-4">
|
||||
<div className="rounded-[2rem] border border-border/70 bg-card/92 p-6 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]">
|
||||
<div className="flex flex-col gap-4 lg:flex-row lg:items-end lg:justify-between">
|
||||
<div>
|
||||
<div className="flex items-center gap-2 text-sm font-medium text-primary">
|
||||
<Settings2 className="size-4" />
|
||||
系统设置
|
||||
</div>
|
||||
<h1 className="mt-2 text-2xl font-semibold tracking-tight text-foreground">
|
||||
统一管理 AI 配置与系统选项
|
||||
</h1>
|
||||
<p className="mt-2 text-sm leading-7 text-muted-foreground">
|
||||
你可以在这里维护自备厂商、AstrBot、公共 AI 的使用状态,后续也会扩展提醒偏好、
|
||||
界面设置与存储信息等系统能力。
|
||||
</p>
|
||||
<div className="mt-3 inline-flex items-center rounded-full border border-border/70 bg-background/80 px-3 py-1 text-xs font-medium text-muted-foreground">
|
||||
TodoList v{TODOLIST_VERSION}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={() => void loadBindings()}
|
||||
disabled={refreshing}
|
||||
>
|
||||
{refreshing ? (
|
||||
<>
|
||||
<LoaderCircle className="size-4 animate-spin" />
|
||||
刷新中
|
||||
</>
|
||||
) : (
|
||||
"刷新配置"
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-wrap gap-3">
|
||||
<Button
|
||||
type="button"
|
||||
variant={activeTab === "ai" ? "default" : "outline"}
|
||||
onClick={() => setActiveTab("ai")}
|
||||
>
|
||||
AI 配置
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
variant={activeTab === "general" ? "default" : "outline"}
|
||||
onClick={() => setActiveTab("general")}
|
||||
>
|
||||
其他设置
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{notice ? (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-start gap-2 rounded-2xl border px-3 py-2 text-sm",
|
||||
notice.tone === "success"
|
||||
? "border-emerald-500/20 bg-emerald-500/10 text-emerald-700 dark:text-emerald-300"
|
||||
: "border-destructive/20 bg-destructive/10 text-destructive"
|
||||
)}
|
||||
>
|
||||
<CheckCircle2 className="mt-0.5 size-4 shrink-0" />
|
||||
<span>{notice.message}</span>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{activeTab === "ai" ? (
|
||||
loading ? (
|
||||
<div className="rounded-[2rem] border border-border/70 bg-card/92 p-6 text-sm text-muted-foreground shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]">
|
||||
正在加载 AI 配置...
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-4">
|
||||
<AiConfigCard
|
||||
channel="USER_KEY"
|
||||
title="自备厂商"
|
||||
description="当前支持 OpenAI-Compatible 接口。Google 原生协议、阿里云原生协议将单独适配。"
|
||||
icon={KeyRound}
|
||||
formState={userKeyForm}
|
||||
onChange={setUserKeyForm}
|
||||
onSave={() => handleSaveChannel("USER_KEY")}
|
||||
saving={savingChannel === "USER_KEY"}
|
||||
binding={bindingMap.get("USER_KEY") ?? null}
|
||||
notice={channelNotices.USER_KEY ?? null}
|
||||
/>
|
||||
|
||||
<AiConfigCard
|
||||
channel="ASTRBOT"
|
||||
title="AstrBot"
|
||||
description="填写 AstrBot 地址与 API Key 后,即可在 AI 助手页面中使用你的 AstrBot 渠道。"
|
||||
icon={PlugZap}
|
||||
formState={astrbotForm}
|
||||
onChange={setAstrbotForm}
|
||||
onSave={() => handleSaveChannel("ASTRBOT")}
|
||||
saving={savingChannel === "ASTRBOT"}
|
||||
binding={bindingMap.get("ASTRBOT") ?? null}
|
||||
notice={channelNotices.ASTRBOT ?? null}
|
||||
/>
|
||||
|
||||
<section className="rounded-[2rem] border border-border/70 bg-card/92 p-5 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]">
|
||||
<div className="flex items-start justify-between gap-3">
|
||||
<div className="flex items-start gap-3">
|
||||
<span className="rounded-2xl bg-primary/10 p-3 text-primary">
|
||||
<Globe2 className="size-5" />
|
||||
</span>
|
||||
<div>
|
||||
<h2 className="text-lg font-semibold tracking-tight text-foreground">
|
||||
公共 AI
|
||||
</h2>
|
||||
<p className="mt-1 text-sm leading-6 text-muted-foreground">
|
||||
该渠道由管理员统一维护,普通用户仅可查看状态和使用,不可修改。
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<span
|
||||
className={cn(
|
||||
"rounded-full border px-2 py-0.5 text-[11px] font-medium",
|
||||
bindingsResponse?.publicPool?.enabled
|
||||
? "border-emerald-500/25 bg-emerald-500/10 text-emerald-700 dark:text-emerald-300"
|
||||
: "border-border bg-background text-muted-foreground"
|
||||
)}
|
||||
>
|
||||
{bindingsResponse?.publicPool?.enabled ? "已开放" : "未开放"}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="mt-4 rounded-2xl border border-border/70 bg-background/80 p-4 text-sm leading-7 text-muted-foreground">
|
||||
<div>
|
||||
提供商:
|
||||
<span className="ml-2 text-foreground">
|
||||
{bindingsResponse?.publicPool?.providerName || "未设置"}
|
||||
</span>
|
||||
</div>
|
||||
<div>
|
||||
模型:
|
||||
<span className="ml-2 text-foreground">
|
||||
{bindingsResponse?.publicPool?.model || "未设置"}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
)
|
||||
) : (
|
||||
<section className="rounded-[2rem] border border-border/70 bg-card/92 p-8 shadow-[0_24px_80px_-48px_rgba(15,23,42,0.55)]">
|
||||
<h2 className="text-xl font-semibold tracking-tight text-foreground">其他设置</h2>
|
||||
<p className="mt-3 text-sm leading-7 text-muted-foreground">
|
||||
这里后续会接入站点外观、提醒偏好、存储配额展示等系统设置项。当前先把 AI
|
||||
配置独立出来,避免继续堆在任务页面。
|
||||
</p>
|
||||
</section>
|
||||
)}
|
||||
</section>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
import type { WebSession } from "@/services/session-storage";
|
||||
|
||||
export type WebAiChannel = "USER_KEY" | "ASTRBOT" | "PUBLIC_POOL";
|
||||
|
||||
export type WebAiRouteAttempt = {
|
||||
channel: WebAiChannel;
|
||||
providerName: string | null;
|
||||
model: string | null;
|
||||
status: "skipped" | "failed" | "success";
|
||||
reasonCode: string | null;
|
||||
reasonMessage: string | null;
|
||||
};
|
||||
|
||||
export type WebAiBindingSummary = {
|
||||
id: string;
|
||||
channel: WebAiChannel;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
configId: string | null;
|
||||
configName: string | null;
|
||||
endpoint: string | null;
|
||||
isEnabled: boolean;
|
||||
hasApiKey: boolean;
|
||||
maskedApiKey: string | null;
|
||||
updatedAt: string;
|
||||
};
|
||||
|
||||
export type WebAiBindingsResponse = {
|
||||
routeOrder: WebAiChannel[];
|
||||
bindings: WebAiBindingSummary[];
|
||||
publicPool: {
|
||||
enabled: boolean;
|
||||
providerName: string | null;
|
||||
model: string | null;
|
||||
hasApiKey: boolean;
|
||||
} | null;
|
||||
};
|
||||
|
||||
export type UpsertWebAiBindingInput = {
|
||||
channel: Exclude<WebAiChannel, "PUBLIC_POOL">;
|
||||
providerName?: string;
|
||||
model?: string;
|
||||
configId?: string;
|
||||
configName?: string;
|
||||
endpoint?: string;
|
||||
apiKey?: string;
|
||||
isEnabled?: boolean;
|
||||
};
|
||||
|
||||
export type TestWebAiBindingResponse =
|
||||
| {
|
||||
success: true;
|
||||
channel: Exclude<WebAiChannel, "PUBLIC_POOL">;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
contentPreview: string;
|
||||
}
|
||||
| {
|
||||
success: false;
|
||||
channel: Exclude<WebAiChannel, "PUBLIC_POOL">;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
code: string;
|
||||
message: string;
|
||||
};
|
||||
|
||||
export type WebAiChatResponse = {
|
||||
channel: WebAiChannel;
|
||||
providerName: string;
|
||||
model: string | null;
|
||||
content: string;
|
||||
sessionId: string | null;
|
||||
attempts: WebAiRouteAttempt[];
|
||||
};
|
||||
|
||||
export type WebAiLocalTaskContextItem = {
|
||||
id: string;
|
||||
title: string;
|
||||
priority: "LOW" | "MEDIUM" | "HIGH" | "URGENT";
|
||||
status: "TODO" | "IN_PROGRESS" | "DONE" | "ARCHIVED";
|
||||
ddlAt: number | null;
|
||||
contentText: string | null;
|
||||
updatedAt: number;
|
||||
};
|
||||
|
||||
export class WebAiApiError extends Error {
|
||||
attempts: WebAiRouteAttempt[] | null;
|
||||
|
||||
constructor(message: string, attempts?: WebAiRouteAttempt[] | null) {
|
||||
super(message);
|
||||
this.name = "WebAiApiError";
|
||||
this.attempts = attempts ?? null;
|
||||
}
|
||||
}
|
||||
|
||||
const DEFAULT_API_BASE_URL = "http://localhost:3000";
|
||||
|
||||
function resolveApiBaseUrl(): string {
|
||||
const envBaseUrl = import.meta.env.VITE_API_BASE_URL as string | undefined;
|
||||
if (!envBaseUrl) {
|
||||
return DEFAULT_API_BASE_URL;
|
||||
}
|
||||
|
||||
return envBaseUrl.replace(/\/+$/, "");
|
||||
}
|
||||
|
||||
function createHeaders(session: WebSession): HeadersInit {
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${session.accessToken}`,
|
||||
"x-user-id": session.user.id
|
||||
};
|
||||
}
|
||||
|
||||
async function createApiError(response: Response): Promise<WebAiApiError> {
|
||||
try {
|
||||
const body = (await response.json()) as {
|
||||
message?: string | string[];
|
||||
attempts?: WebAiRouteAttempt[];
|
||||
};
|
||||
const message = Array.isArray(body.message)
|
||||
? body.message.join(";")
|
||||
: typeof body.message === "string" && body.message.trim().length > 0
|
||||
? body.message
|
||||
: `请求失败(${response.status})`;
|
||||
return new WebAiApiError(message, body.attempts ?? null);
|
||||
} catch {
|
||||
return new WebAiApiError(`请求失败(${response.status})`);
|
||||
}
|
||||
}
|
||||
export async function listAiBindings(session: WebSession): Promise<WebAiBindingsResponse> {
|
||||
const response = await fetch(`${resolveApiBaseUrl()}/ai/bindings`, {
|
||||
method: "GET",
|
||||
headers: createHeaders(session)
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw await createApiError(response);
|
||||
}
|
||||
|
||||
return (await response.json()) as WebAiBindingsResponse;
|
||||
}
|
||||
|
||||
export async function upsertAiBinding(
|
||||
session: WebSession,
|
||||
payload: UpsertWebAiBindingInput
|
||||
): Promise<WebAiBindingSummary> {
|
||||
const response = await fetch(`${resolveApiBaseUrl()}/ai/bindings`, {
|
||||
method: "POST",
|
||||
headers: createHeaders(session),
|
||||
body: JSON.stringify(payload)
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw await createApiError(response);
|
||||
}
|
||||
|
||||
return (await response.json()) as WebAiBindingSummary;
|
||||
}
|
||||
|
||||
export async function testAiBinding(
|
||||
session: WebSession,
|
||||
payload: UpsertWebAiBindingInput
|
||||
): Promise<TestWebAiBindingResponse> {
|
||||
const response = await fetch(`${resolveApiBaseUrl()}/ai/bindings/test`, {
|
||||
method: "POST",
|
||||
headers: createHeaders(session),
|
||||
body: JSON.stringify(payload)
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw await createApiError(response);
|
||||
}
|
||||
|
||||
return (await response.json()) as TestWebAiBindingResponse;
|
||||
}
|
||||
|
||||
export async function chatWithAi(
|
||||
session: WebSession,
|
||||
payload: {
|
||||
channel: WebAiChannel;
|
||||
message: string;
|
||||
sessionId?: string;
|
||||
localTasks?: WebAiLocalTaskContextItem[];
|
||||
}
|
||||
): Promise<WebAiChatResponse> {
|
||||
const response = await fetch(`${resolveApiBaseUrl()}/ai/chat`, {
|
||||
method: "POST",
|
||||
headers: createHeaders(session),
|
||||
body: JSON.stringify(payload)
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw await createApiError(response);
|
||||
}
|
||||
|
||||
return (await response.json()) as WebAiChatResponse;
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
import { localDb, type LocalAiChatSessionRecord } from "@/services/local-db";
|
||||
import type { WebAiChannel } from "@/services/ai-api";
|
||||
import {
|
||||
decryptAiChatSessionRecord,
|
||||
encryptAiChatSessionRecord
|
||||
} from "@/services/local-sensitive-codec";
|
||||
|
||||
export type LocalAiChatMessageRecord = {
|
||||
id: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
content: string;
|
||||
meta?: string;
|
||||
};
|
||||
|
||||
export type SaveLocalAiChatSessionInput = {
|
||||
userId: string;
|
||||
channel: WebAiChannel;
|
||||
sessionId: string | null;
|
||||
messages: LocalAiChatMessageRecord[];
|
||||
};
|
||||
|
||||
export type LocalAiChatSessionSnapshot = {
|
||||
channel: WebAiChannel;
|
||||
sessionId: string | null;
|
||||
messages: LocalAiChatMessageRecord[];
|
||||
};
|
||||
|
||||
function createSessionKey(userId: string, channel: WebAiChannel): string {
|
||||
return `${userId}:${channel}`;
|
||||
}
|
||||
|
||||
function parseMessages(messagesJson: string): LocalAiChatMessageRecord[] {
|
||||
try {
|
||||
const parsed = JSON.parse(messagesJson) as unknown;
|
||||
if (!Array.isArray(parsed)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return parsed.filter((item): item is LocalAiChatMessageRecord => {
|
||||
if (!item || typeof item !== "object") {
|
||||
return false;
|
||||
}
|
||||
|
||||
const record = item as Record<string, unknown>;
|
||||
return (
|
||||
typeof record["id"] === "string" &&
|
||||
(record["role"] === "user" ||
|
||||
record["role"] === "assistant" ||
|
||||
record["role"] === "system") &&
|
||||
typeof record["content"] === "string" &&
|
||||
(record["meta"] === undefined || typeof record["meta"] === "string")
|
||||
);
|
||||
});
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
function toSnapshot(record: LocalAiChatSessionRecord): LocalAiChatSessionSnapshot {
|
||||
return {
|
||||
channel: record.channel,
|
||||
sessionId: record.sessionId,
|
||||
messages: parseMessages(record.messagesJson)
|
||||
};
|
||||
}
|
||||
|
||||
export async function listLocalAiChatSessions(
|
||||
userId: string
|
||||
): Promise<LocalAiChatSessionSnapshot[]> {
|
||||
const records = await localDb.aiChatSessions.where("userId").equals(userId).toArray();
|
||||
const decryptedRecords = await Promise.all(
|
||||
records.map((record) => decryptAiChatSessionRecord(record))
|
||||
);
|
||||
return decryptedRecords.map(toSnapshot);
|
||||
}
|
||||
|
||||
export async function saveLocalAiChatSession(
|
||||
input: SaveLocalAiChatSessionInput
|
||||
): Promise<LocalAiChatSessionRecord> {
|
||||
const record = await encryptAiChatSessionRecord({
|
||||
key: createSessionKey(input.userId, input.channel),
|
||||
userId: input.userId,
|
||||
channel: input.channel,
|
||||
sessionId: input.sessionId,
|
||||
messagesJson: JSON.stringify(input.messages),
|
||||
updatedAt: Date.now()
|
||||
});
|
||||
|
||||
await localDb.aiChatSessions.put(record);
|
||||
return record;
|
||||
}
|
||||
|
||||
export async function deleteLocalAiChatSession(
|
||||
userId: string,
|
||||
channel: WebAiChannel
|
||||
): Promise<void> {
|
||||
await localDb.aiChatSessions.delete(createSessionKey(userId, channel));
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
const LOCAL_CRYPTO_KEY_STORAGE_KEY = "todolist.web.local-crypto-key";
|
||||
const LOCAL_CRYPTO_PREFIX = "locv1";
|
||||
const LOCAL_CRYPTO_IV_LENGTH = 12;
|
||||
const LOCAL_CRYPTO_KEY_LENGTH = 32;
|
||||
|
||||
let cachedLocalCryptoKeyPromise: Promise<CryptoKey> | null = null;
|
||||
|
||||
function toArrayBuffer(bytes: Uint8Array): ArrayBuffer {
|
||||
return bytes.buffer.slice(bytes.byteOffset, bytes.byteOffset + bytes.byteLength) as ArrayBuffer;
|
||||
}
|
||||
|
||||
function bytesToBase64Url(bytes: Uint8Array): string {
|
||||
let binary = "";
|
||||
const chunkSize = 0x8000;
|
||||
|
||||
for (let index = 0; index < bytes.length; index += chunkSize) {
|
||||
const chunk = bytes.subarray(index, index + chunkSize);
|
||||
binary += String.fromCharCode(...chunk);
|
||||
}
|
||||
|
||||
return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/u, "");
|
||||
}
|
||||
|
||||
function base64UrlToBytes(value: string): Uint8Array {
|
||||
const normalizedValue = value.replace(/-/g, "+").replace(/_/g, "/");
|
||||
const paddedValue = normalizedValue + "=".repeat((4 - (normalizedValue.length % 4 || 4)) % 4);
|
||||
const binary = atob(paddedValue);
|
||||
const bytes = new Uint8Array(binary.length);
|
||||
|
||||
for (let index = 0; index < binary.length; index += 1) {
|
||||
bytes[index] = binary.charCodeAt(index);
|
||||
}
|
||||
|
||||
return bytes;
|
||||
}
|
||||
|
||||
function createRandomKeyBytes(): Uint8Array {
|
||||
const bytes = new Uint8Array(LOCAL_CRYPTO_KEY_LENGTH);
|
||||
crypto.getRandomValues(bytes);
|
||||
return bytes;
|
||||
}
|
||||
|
||||
async function resolveLocalCryptoKey(): Promise<CryptoKey> {
|
||||
if (cachedLocalCryptoKeyPromise) {
|
||||
return cachedLocalCryptoKeyPromise;
|
||||
}
|
||||
|
||||
cachedLocalCryptoKeyPromise = (async () => {
|
||||
const savedKey = window.localStorage.getItem(LOCAL_CRYPTO_KEY_STORAGE_KEY);
|
||||
const keyBytes = savedKey ? base64UrlToBytes(savedKey) : createRandomKeyBytes();
|
||||
|
||||
if (!savedKey) {
|
||||
window.localStorage.setItem(LOCAL_CRYPTO_KEY_STORAGE_KEY, bytesToBase64Url(keyBytes));
|
||||
}
|
||||
|
||||
return crypto.subtle.importKey("raw", toArrayBuffer(keyBytes), "AES-GCM", false, [
|
||||
"encrypt",
|
||||
"decrypt"
|
||||
]);
|
||||
})();
|
||||
|
||||
return cachedLocalCryptoKeyPromise;
|
||||
}
|
||||
|
||||
export function isLocalEncryptedString(value: string): boolean {
|
||||
return value.startsWith(`${LOCAL_CRYPTO_PREFIX}:`);
|
||||
}
|
||||
|
||||
export async function encryptLocalString(
|
||||
value: string | null | undefined
|
||||
): Promise<string | null | undefined> {
|
||||
if (value === undefined || value === null) {
|
||||
return value;
|
||||
}
|
||||
|
||||
if (isLocalEncryptedString(value)) {
|
||||
return value;
|
||||
}
|
||||
|
||||
const key = await resolveLocalCryptoKey();
|
||||
const iv = crypto.getRandomValues(new Uint8Array(LOCAL_CRYPTO_IV_LENGTH));
|
||||
const plaintext = new TextEncoder().encode(value);
|
||||
const encryptedBuffer = await crypto.subtle.encrypt(
|
||||
{
|
||||
name: "AES-GCM",
|
||||
iv
|
||||
},
|
||||
key,
|
||||
plaintext
|
||||
);
|
||||
|
||||
return `${LOCAL_CRYPTO_PREFIX}:${bytesToBase64Url(iv)}:${bytesToBase64Url(new Uint8Array(encryptedBuffer))}`;
|
||||
}
|
||||
|
||||
export async function decryptLocalString(
|
||||
value: string | null | undefined
|
||||
): Promise<string | null | undefined> {
|
||||
if (value === undefined || value === null) {
|
||||
return value;
|
||||
}
|
||||
|
||||
if (!isLocalEncryptedString(value)) {
|
||||
return value;
|
||||
}
|
||||
|
||||
const [prefix, ivText, encryptedText] = value.split(":");
|
||||
if (prefix !== LOCAL_CRYPTO_PREFIX || !ivText || !encryptedText) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const key = await resolveLocalCryptoKey();
|
||||
const decryptedBuffer = await crypto.subtle.decrypt(
|
||||
{
|
||||
name: "AES-GCM",
|
||||
iv: toArrayBuffer(base64UrlToBytes(ivText))
|
||||
},
|
||||
key,
|
||||
toArrayBuffer(base64UrlToBytes(encryptedText))
|
||||
);
|
||||
|
||||
return new TextDecoder().decode(decryptedBuffer);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -69,12 +69,22 @@ export type LocalSyncInboxRecord = {
|
||||
appliedAt: number | null;
|
||||
};
|
||||
|
||||
export type LocalAiChatSessionRecord = {
|
||||
key: string;
|
||||
userId: string;
|
||||
channel: "USER_KEY" | "ASTRBOT" | "PUBLIC_POOL";
|
||||
sessionId: string | null;
|
||||
messagesJson: string;
|
||||
updatedAt: number;
|
||||
};
|
||||
|
||||
class TodoLocalDb extends Dexie {
|
||||
declare tasks: Table<LocalTaskRecord, string>;
|
||||
declare opLogs: Table<LocalOpLogRecord, string>;
|
||||
declare taskDrafts: Table<LocalTaskDraftRecord, string>;
|
||||
declare syncStates: Table<LocalSyncStateRecord, string>;
|
||||
declare syncInbox: Table<LocalSyncInboxRecord, string>;
|
||||
declare aiChatSessions: Table<LocalAiChatSessionRecord, string>;
|
||||
|
||||
constructor() {
|
||||
super("todolist-web-db");
|
||||
@@ -117,11 +127,21 @@ class TodoLocalDb extends Dexie {
|
||||
});
|
||||
});
|
||||
|
||||
this.version(5).stores({
|
||||
tasks: "&id,userId,status,priority,ddlAt,updatedAt,deletedAt",
|
||||
op_logs: "&opId,entityId,entityType,action,clientTs,syncedAt",
|
||||
task_drafts: "&taskId,userId,updatedAt",
|
||||
sync_states: "&userId,updatedAt,lastSyncedAt",
|
||||
sync_inbox: "&opId,userId,entityId,serverTs,appliedAt",
|
||||
ai_chat_sessions: "&key,userId,channel,updatedAt"
|
||||
});
|
||||
|
||||
this.tasks = this.table("tasks");
|
||||
this.opLogs = this.table("op_logs");
|
||||
this.taskDrafts = this.table("task_drafts");
|
||||
this.syncStates = this.table("sync_states");
|
||||
this.syncInbox = this.table("sync_inbox");
|
||||
this.aiChatSessions = this.table("ai_chat_sessions");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
import type {
|
||||
LocalAiChatSessionRecord,
|
||||
LocalOpLogRecord,
|
||||
LocalSyncInboxRecord,
|
||||
LocalTaskDraftRecord,
|
||||
LocalTaskRecord
|
||||
} from "@/services/local-db";
|
||||
import {
|
||||
decryptLocalString,
|
||||
encryptLocalString,
|
||||
isLocalEncryptedString
|
||||
} from "@/services/local-crypto";
|
||||
|
||||
export function shouldEncryptTaskRecord(record: LocalTaskRecord): boolean {
|
||||
return (
|
||||
!isLocalEncryptedString(record.title) ||
|
||||
(typeof record.contentJson === "string" && !isLocalEncryptedString(record.contentJson)) ||
|
||||
(typeof record.contentText === "string" && !isLocalEncryptedString(record.contentText))
|
||||
);
|
||||
}
|
||||
|
||||
export async function encryptTaskRecord(record: LocalTaskRecord): Promise<LocalTaskRecord> {
|
||||
return {
|
||||
...record,
|
||||
title: (await encryptLocalString(record.title)) ?? record.title,
|
||||
contentJson: (await encryptLocalString(record.contentJson)) ?? null,
|
||||
contentText: (await encryptLocalString(record.contentText)) ?? null
|
||||
};
|
||||
}
|
||||
|
||||
export async function decryptTaskRecord(record: LocalTaskRecord): Promise<LocalTaskRecord> {
|
||||
const title = await decryptLocalString(record.title);
|
||||
const contentJson = await decryptLocalString(record.contentJson);
|
||||
const contentText = await decryptLocalString(record.contentText);
|
||||
|
||||
return {
|
||||
...record,
|
||||
title: typeof title === "string" && title.trim().length > 0 ? title : "未命名任务",
|
||||
contentJson: typeof contentJson === "string" ? contentJson : null,
|
||||
contentText: typeof contentText === "string" ? contentText : null
|
||||
};
|
||||
}
|
||||
|
||||
export function shouldEncryptTaskDraft(record: LocalTaskDraftRecord): boolean {
|
||||
return (
|
||||
!isLocalEncryptedString(record.title) ||
|
||||
(typeof record.contentJson === "string" && !isLocalEncryptedString(record.contentJson)) ||
|
||||
!isLocalEncryptedString(record.contentText)
|
||||
);
|
||||
}
|
||||
|
||||
export async function encryptTaskDraftRecord(
|
||||
record: LocalTaskDraftRecord
|
||||
): Promise<LocalTaskDraftRecord> {
|
||||
return {
|
||||
...record,
|
||||
title: (await encryptLocalString(record.title)) ?? record.title,
|
||||
contentJson: (await encryptLocalString(record.contentJson)) ?? null,
|
||||
contentText: (await encryptLocalString(record.contentText)) ?? ""
|
||||
};
|
||||
}
|
||||
|
||||
export async function decryptTaskDraftRecord(
|
||||
record: LocalTaskDraftRecord
|
||||
): Promise<LocalTaskDraftRecord> {
|
||||
const title = await decryptLocalString(record.title);
|
||||
const contentJson = await decryptLocalString(record.contentJson);
|
||||
const contentText = await decryptLocalString(record.contentText);
|
||||
|
||||
return {
|
||||
...record,
|
||||
title: typeof title === "string" ? title : "",
|
||||
contentJson: typeof contentJson === "string" ? contentJson : null,
|
||||
contentText: typeof contentText === "string" ? contentText : ""
|
||||
};
|
||||
}
|
||||
|
||||
export function shouldEncryptOpLogRecord(record: LocalOpLogRecord): boolean {
|
||||
return !isLocalEncryptedString(record.payload);
|
||||
}
|
||||
|
||||
export async function encryptOpLogRecord(record: LocalOpLogRecord): Promise<LocalOpLogRecord> {
|
||||
return {
|
||||
...record,
|
||||
payload: (await encryptLocalString(record.payload)) ?? record.payload
|
||||
};
|
||||
}
|
||||
|
||||
export async function decryptOpLogRecord(record: LocalOpLogRecord): Promise<LocalOpLogRecord> {
|
||||
const payload = await decryptLocalString(record.payload);
|
||||
|
||||
return {
|
||||
...record,
|
||||
payload: typeof payload === "string" ? payload : record.payload
|
||||
};
|
||||
}
|
||||
|
||||
export function shouldEncryptSyncInboxRecord(record: LocalSyncInboxRecord): boolean {
|
||||
return typeof record.payload === "string" && !isLocalEncryptedString(record.payload);
|
||||
}
|
||||
|
||||
export async function encryptSyncInboxRecord(
|
||||
record: LocalSyncInboxRecord
|
||||
): Promise<LocalSyncInboxRecord> {
|
||||
return {
|
||||
...record,
|
||||
payload: (await encryptLocalString(record.payload)) ?? null
|
||||
};
|
||||
}
|
||||
|
||||
export async function decryptSyncInboxRecord(
|
||||
record: LocalSyncInboxRecord
|
||||
): Promise<LocalSyncInboxRecord> {
|
||||
const payload = await decryptLocalString(record.payload);
|
||||
|
||||
return {
|
||||
...record,
|
||||
payload: typeof payload === "string" ? payload : null
|
||||
};
|
||||
}
|
||||
|
||||
export function shouldEncryptAiChatSessionRecord(record: LocalAiChatSessionRecord): boolean {
|
||||
return (
|
||||
!isLocalEncryptedString(record.messagesJson) ||
|
||||
(typeof record.sessionId === "string" && !isLocalEncryptedString(record.sessionId))
|
||||
);
|
||||
}
|
||||
|
||||
export async function encryptAiChatSessionRecord(
|
||||
record: LocalAiChatSessionRecord
|
||||
): Promise<LocalAiChatSessionRecord> {
|
||||
return {
|
||||
...record,
|
||||
sessionId: (await encryptLocalString(record.sessionId)) ?? null,
|
||||
messagesJson: (await encryptLocalString(record.messagesJson)) ?? "[]"
|
||||
};
|
||||
}
|
||||
|
||||
export async function decryptAiChatSessionRecord(
|
||||
record: LocalAiChatSessionRecord
|
||||
): Promise<LocalAiChatSessionRecord> {
|
||||
const sessionId = await decryptLocalString(record.sessionId);
|
||||
const messagesJson = await decryptLocalString(record.messagesJson);
|
||||
|
||||
return {
|
||||
...record,
|
||||
sessionId: typeof sessionId === "string" ? sessionId : null,
|
||||
messagesJson: typeof messagesJson === "string" ? messagesJson : "[]"
|
||||
};
|
||||
}
|
||||
@@ -1,19 +1,25 @@
|
||||
import {
|
||||
import {
|
||||
localDb,
|
||||
type LocalOpLogRecord,
|
||||
type LocalSyncInboxRecord,
|
||||
type LocalSyncStateRecord
|
||||
} from "@/services/local-db";
|
||||
import {
|
||||
decryptOpLogRecord,
|
||||
decryptSyncInboxRecord,
|
||||
encryptSyncInboxRecord
|
||||
} from "@/services/local-sensitive-codec";
|
||||
import type { SyncPullItem } from "@/services/sync-api";
|
||||
|
||||
export const MAX_SYNC_RETRY_COUNT = 5;
|
||||
|
||||
export async function listPendingSyncOperations(limit = 20): Promise<LocalOpLogRecord[]> {
|
||||
const records = await localDb.opLogs.orderBy("clientTs").toArray();
|
||||
|
||||
return records
|
||||
const pendingRecords = records
|
||||
.filter((record) => record.syncedAt === null && record.retryCount < MAX_SYNC_RETRY_COUNT)
|
||||
.slice(0, limit);
|
||||
|
||||
return Promise.all(pendingRecords.map((record) => decryptOpLogRecord(record)));
|
||||
}
|
||||
|
||||
export async function countPendingSyncOperations(): Promise<number> {
|
||||
@@ -100,7 +106,9 @@ export async function enqueueRemoteSyncOperations(
|
||||
}
|
||||
|
||||
const receivedAt = Date.now();
|
||||
const records: LocalSyncInboxRecord[] = operations.map((operation) => ({
|
||||
const records = await Promise.all(
|
||||
operations.map(async (operation) =>
|
||||
encryptSyncInboxRecord({
|
||||
opId: operation.opId,
|
||||
userId,
|
||||
entityId: operation.entityId,
|
||||
@@ -112,7 +120,9 @@ export async function enqueueRemoteSyncOperations(
|
||||
serverTs: new Date(operation.serverTs).getTime(),
|
||||
receivedAt,
|
||||
appliedAt: null
|
||||
}));
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
await localDb.syncInbox.bulkPut(records);
|
||||
return records.length;
|
||||
@@ -123,8 +133,7 @@ export async function listPendingRemoteOperations(
|
||||
limit = 100
|
||||
): Promise<LocalSyncInboxRecord[]> {
|
||||
const records = await localDb.syncInbox.where("userId").equals(userId).toArray();
|
||||
|
||||
return records
|
||||
const pendingRecords = records
|
||||
.filter((record) => record.appliedAt === null)
|
||||
.sort((left, right) => {
|
||||
if (left.serverTs !== right.serverTs) {
|
||||
@@ -138,6 +147,8 @@ export async function listPendingRemoteOperations(
|
||||
return left.opId.localeCompare(right.opId);
|
||||
})
|
||||
.slice(0, limit);
|
||||
|
||||
return Promise.all(pendingRecords.map((record) => decryptSyncInboxRecord(record)));
|
||||
}
|
||||
|
||||
export async function markRemoteOperationsApplied(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { localDb, type LocalTaskDraftRecord } from "@/services/local-db";
|
||||
import { decryptTaskDraftRecord, encryptTaskDraftRecord } from "@/services/local-sensitive-codec";
|
||||
|
||||
export type SaveLocalTaskDraftInput = {
|
||||
taskId: string;
|
||||
@@ -12,7 +13,12 @@ export type SaveLocalTaskDraftInput = {
|
||||
};
|
||||
|
||||
export async function getLocalTaskDraft(taskId: string): Promise<LocalTaskDraftRecord | undefined> {
|
||||
return localDb.taskDrafts.get(taskId);
|
||||
const draft = await localDb.taskDrafts.get(taskId);
|
||||
if (!draft) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return decryptTaskDraftRecord(draft);
|
||||
}
|
||||
|
||||
export async function saveLocalTaskDraft(
|
||||
@@ -23,7 +29,7 @@ export async function saveLocalTaskDraft(
|
||||
updatedAt: Date.now()
|
||||
};
|
||||
|
||||
await localDb.taskDrafts.put(draft);
|
||||
await localDb.taskDrafts.put(await encryptTaskDraftRecord(draft));
|
||||
return draft;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,11 @@
|
||||
type LocalTaskStatus,
|
||||
type SyncActionType
|
||||
} from "@/services/local-db";
|
||||
import {
|
||||
decryptTaskRecord,
|
||||
encryptOpLogRecord,
|
||||
encryptTaskRecord
|
||||
} from "@/services/local-sensitive-codec";
|
||||
|
||||
const DEVICE_ID_STORAGE_KEY = "todolist.web.device-id";
|
||||
|
||||
@@ -83,7 +88,8 @@ function createSyncTaskPayload(payload: SyncTaskPayload): string {
|
||||
|
||||
export async function listLocalTasksByUser(userId: string): Promise<LocalTaskRecord[]> {
|
||||
const tasks = await localDb.tasks.where("userId").equals(userId).toArray();
|
||||
return tasks
|
||||
const decryptedTasks = await Promise.all(tasks.map((task) => decryptTaskRecord(task)));
|
||||
return decryptedTasks
|
||||
.filter((task) => task.deletedAt === null)
|
||||
.sort((left, right) => right.updatedAt - left.updatedAt);
|
||||
}
|
||||
@@ -94,7 +100,7 @@ export async function getLocalTaskById(id: string): Promise<LocalTaskRecord | un
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return task;
|
||||
return decryptTaskRecord(task);
|
||||
}
|
||||
|
||||
export async function createLocalTask(input: CreateLocalTaskInput): Promise<LocalTaskRecord> {
|
||||
@@ -134,8 +140,8 @@ export async function createLocalTask(input: CreateLocalTaskInput): Promise<Loca
|
||||
);
|
||||
|
||||
await localDb.transaction("rw", localDb.tasks, localDb.opLogs, async () => {
|
||||
await localDb.tasks.add(task);
|
||||
await localDb.opLogs.add(opLog);
|
||||
await localDb.tasks.add(await encryptTaskRecord(task));
|
||||
await localDb.opLogs.add(await encryptOpLogRecord(opLog));
|
||||
});
|
||||
|
||||
return task;
|
||||
@@ -178,8 +184,8 @@ export async function updateLocalTask(
|
||||
);
|
||||
|
||||
await localDb.transaction("rw", localDb.tasks, localDb.opLogs, async () => {
|
||||
await localDb.tasks.put(nextTask);
|
||||
await localDb.opLogs.add(opLog);
|
||||
await localDb.tasks.put(await encryptTaskRecord(nextTask));
|
||||
await localDb.opLogs.add(await encryptOpLogRecord(opLog));
|
||||
});
|
||||
|
||||
return nextTask;
|
||||
@@ -211,8 +217,8 @@ export async function deleteLocalTask(id: string): Promise<boolean> {
|
||||
);
|
||||
|
||||
await localDb.transaction("rw", localDb.tasks, localDb.opLogs, async () => {
|
||||
await localDb.tasks.put(nextTask);
|
||||
await localDb.opLogs.add(opLog);
|
||||
await localDb.tasks.put(await encryptTaskRecord(nextTask));
|
||||
await localDb.opLogs.add(await encryptOpLogRecord(opLog));
|
||||
});
|
||||
|
||||
return true;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { localDb } from "@/services/local-db";
|
||||
import { listLocalTasksByUser } from "@/services/local-task-repo";
|
||||
|
||||
export const DEFAULT_CLOUD_QUOTA_BYTES = 100 * 1024 * 1024;
|
||||
|
||||
@@ -18,13 +18,9 @@ function measureTextBytes(value: string | null): number {
|
||||
}
|
||||
|
||||
export async function getStorageQuotaSnapshot(userId: string): Promise<StorageQuotaSnapshot> {
|
||||
const tasks = await localDb.tasks.where("userId").equals(userId).toArray();
|
||||
const tasks = await listLocalTasksByUser(userId);
|
||||
|
||||
const usedBytes = tasks.reduce((total, task) => {
|
||||
if (task.deletedAt !== null) {
|
||||
return total;
|
||||
}
|
||||
|
||||
return (
|
||||
total +
|
||||
measureTextBytes(task.title) +
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import {
|
||||
import {
|
||||
localDb,
|
||||
type LocalSyncInboxRecord,
|
||||
type LocalTaskPriority,
|
||||
@@ -6,6 +6,11 @@ import {
|
||||
type LocalTaskStatus
|
||||
} from "@/services/local-db";
|
||||
import { listPendingRemoteOperations } from "@/services/local-sync-repo";
|
||||
import {
|
||||
decryptTaskRecord,
|
||||
encryptTaskRecord,
|
||||
shouldEncryptTaskRecord
|
||||
} from "@/services/local-sensitive-codec";
|
||||
|
||||
const TASK_PRIORITY_VALUES: LocalTaskPriority[] = ["LOW", "MEDIUM", "HIGH", "URGENT"];
|
||||
const TASK_STATUS_VALUES: LocalTaskStatus[] = ["TODO", "IN_PROGRESS", "DONE", "ARCHIVED"];
|
||||
@@ -246,11 +251,14 @@ export async function applyPendingRemoteOperations(userId: string): Promise<numb
|
||||
continue;
|
||||
}
|
||||
|
||||
const currentTask = await localDb.tasks.get(operation.entityId);
|
||||
const storedTask = await localDb.tasks.get(operation.entityId);
|
||||
const currentTask = storedTask ? await decryptTaskRecord(storedTask) : undefined;
|
||||
const incomingTask = buildIncomingTaskRecord(operation, currentTask);
|
||||
|
||||
if (shouldApplyIncomingTask(currentTask, incomingTask, operation)) {
|
||||
await localDb.tasks.put(incomingTask);
|
||||
await localDb.tasks.put(await encryptTaskRecord(incomingTask));
|
||||
} else if (storedTask && currentTask && shouldEncryptTaskRecord(storedTask)) {
|
||||
await localDb.tasks.put(await encryptTaskRecord(currentTask));
|
||||
}
|
||||
|
||||
await localDb.syncInbox.update(operation.opId, { appliedAt });
|
||||
|
||||
Reference in New Issue
Block a user