diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts index b5bf24e..2269d2d 100644 --- a/apps/api/src/ai/ai.service.ts +++ b/apps/api/src/ai/ai.service.ts @@ -15,7 +15,12 @@ import { PrismaService } from "../prisma/prisma.service"; import { AiProviderRegistryService } from "./ai-provider-registry.service"; import { AiChatDto } from "./dto/ai-chat.dto"; import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; -import { AiResolvedRouteCandidate, AiRouteAttempt, AiRouteFailureError } from "./ai.types"; +import { + AiResolvedRouteCandidate, + AiRouteAttempt, + AiRouteFailureError, + AiUsageMetrics +} from "./ai.types"; type AiBindingSummary = { id: string; @@ -198,6 +203,7 @@ export class AiService { } const executor = this.aiProviderRegistryService.getExecutor(entry.candidate.channel); + const startedAt = Date.now(); try { const result = await executor.execute(entry.candidate, { @@ -205,6 +211,7 @@ export class AiService { message: dto.message, sessionId: dto.sessionId ?? null }); + const latencyMs = Date.now() - startedAt; attempts.push({ channel: result.channel, @@ -214,6 +221,16 @@ export class AiService { reasonCode: null, reasonMessage: null }); + await this.recordUsageLog({ + userId, + channel: result.channel, + providerName: result.providerName, + model: result.model, + usage: result.usage, + latencyMs, + success: true, + errorCode: null + }); return { channel: result.channel, @@ -224,8 +241,19 @@ export class AiService { attempts }; } catch (error) { + const latencyMs = Date.now() - startedAt; const failureAttempt = this.toFailureAttempt(entry.candidate, error); attempts.push(failureAttempt); + await this.recordUsageLog({ + userId, + channel: failureAttempt.channel, + providerName: failureAttempt.providerName, + model: failureAttempt.model, + usage: null, + latencyMs, + success: false, + errorCode: failureAttempt.reasonCode + }); this.logger.warn( `AI 通道降级:channel=${failureAttempt.channel} provider=${failureAttempt.providerName ?? "unknown"} code=${failureAttempt.reasonCode ?? "UNKNOWN"} message=${failureAttempt.reasonMessage ?? "unknown"}` ); @@ -464,4 +492,35 @@ export class AiService { return `${secret.slice(0, 4)}***${secret.slice(-2)}`; } + + private async recordUsageLog(input: { + userId: string; + channel: AiChannel; + providerName: string | null; + model: string | null; + usage: AiUsageMetrics | null; + latencyMs: number; + success: boolean; + errorCode: string | null; + }): Promise { + try { + await this.prismaService.aiUsageLog.create({ + data: { + userId: input.userId, + channel: input.channel, + providerName: input.providerName, + model: input.model, + promptTokens: input.usage?.promptTokens ?? 0, + completionTokens: input.usage?.completionTokens ?? 0, + totalTokens: input.usage?.totalTokens ?? 0, + latencyMs: input.latencyMs, + success: input.success, + errorCode: input.errorCode + } + }); + } catch (error) { + const message = error instanceof Error ? error.message : "未知错误"; + this.logger.warn(`写入 AI 使用日志失败:${message}`); + } + } } diff --git a/apps/api/src/ai/ai.types.ts b/apps/api/src/ai/ai.types.ts index 5c52915..ccb8088 100644 --- a/apps/api/src/ai/ai.types.ts +++ b/apps/api/src/ai/ai.types.ts @@ -24,9 +24,16 @@ export type AiChatResult = { model: string | null; content: string; sessionId: string | null; + usage: AiUsageMetrics | null; raw: unknown; }; +export type AiUsageMetrics = { + promptTokens: number; + completionTokens: number; + totalTokens: number; +}; + export type AiRouteAttempt = { channel: AiChannel; providerName: string | null; diff --git a/apps/api/src/ai/providers/astrbot.provider.ts b/apps/api/src/ai/providers/astrbot.provider.ts index a82cb99..4f9936e 100644 --- a/apps/api/src/ai/providers/astrbot.provider.ts +++ b/apps/api/src/ai/providers/astrbot.provider.ts @@ -133,6 +133,7 @@ export class AstrbotProvider implements AiChannelExecutor { model: candidate.model, content, sessionId, + usage: this.extractUsage(events), raw: events }; } @@ -248,4 +249,39 @@ export class AstrbotProvider implements AiChannelExecutor { return fallback; } + + private extractUsage(events: Array>): AiChatResult["usage"] { + for (const event of events) { + if (this.readString(event["type"]) !== "agent_stats") { + continue; + } + + const data = this.asRecord(event["data"]); + const tokenUsage = this.asRecord(data?.["token_usage"]); + if (!tokenUsage) { + continue; + } + + const promptTokens = + (this.readNumber(tokenUsage["input_other"]) ?? 0) + + (this.readNumber(tokenUsage["input_cached"]) ?? 0); + const completionTokens = this.readNumber(tokenUsage["output"]) ?? 0; + + return { + promptTokens, + completionTokens, + totalTokens: promptTokens + completionTokens + }; + } + + return null; + } + + private asRecord(value: unknown): Record | null { + return typeof value === "object" && value !== null ? (value as Record) : null; + } + + private readNumber(value: unknown): number | null { + return typeof value === "number" && Number.isFinite(value) ? value : null; + } } diff --git a/apps/api/src/ai/providers/openai-compatible.provider.ts b/apps/api/src/ai/providers/openai-compatible.provider.ts index 2ca1723..0c52099 100644 --- a/apps/api/src/ai/providers/openai-compatible.provider.ts +++ b/apps/api/src/ai/providers/openai-compatible.provider.ts @@ -105,6 +105,7 @@ export class OpenAiCompatibleProvider implements AiChannelExecutor { model: this.extractModel(payload) ?? candidate.model, content, sessionId: input.sessionId, + usage: this.extractUsage(payload), raw: payload }; } @@ -176,6 +177,31 @@ export class OpenAiCompatibleProvider implements AiChannelExecutor { return payload["model"]; } + private extractUsage(payload: unknown): AiChatResult["usage"] { + if (!this.isRecord(payload)) { + return null; + } + + const usage = payload["usage"]; + if (!this.isRecord(usage)) { + return null; + } + + const promptTokens = this.readNumber(usage["prompt_tokens"]); + const completionTokens = this.readNumber(usage["completion_tokens"]); + const totalTokens = this.readNumber(usage["total_tokens"]); + + if (promptTokens === null && completionTokens === null && totalTokens === null) { + return null; + } + + return { + promptTokens: promptTokens ?? 0, + completionTokens: completionTokens ?? 0, + totalTokens: totalTokens ?? (promptTokens ?? 0) + (completionTokens ?? 0) + }; + } + private extractErrorMessage(payload: unknown, fallback: string): string { if (!this.isRecord(payload)) { return fallback; @@ -200,4 +226,8 @@ export class OpenAiCompatibleProvider implements AiChannelExecutor { return fallback; } + + private readNumber(value: unknown): number | null { + return typeof value === "number" && Number.isFinite(value) ? value : null; + } } diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts index 113d037..85705ef 100644 --- a/apps/api/test/ai.spec.ts +++ b/apps/api/test/ai.spec.ts @@ -12,11 +12,25 @@ import { } from "../src/ai/ai.types"; import { PrismaService } from "../src/prisma/prisma.service"; +type AiUsageLogRecord = { + userId: string | null; + channel: AiChannel; + providerName: string | null; + model: string | null; + promptTokens: number; + completionTokens: number; + totalTokens: number; + latencyMs: number | null; + success: boolean; + errorCode: string | null; +}; + class InMemoryAiPrismaService { private bindingIdSequence = 1; private publicPoolIdSequence = 1; private bindings: AiProviderBinding[] = []; private publicPools: AiPublicPoolConfig[] = []; + private usageLogs: AiUsageLogRecord[] = []; readonly aiProviderBinding = { findMany: async (args: { @@ -164,6 +178,13 @@ class InMemoryAiPrismaService { } }; + readonly aiUsageLog = { + create: async (args: { data: AiUsageLogRecord }) => { + this.usageLogs.push(args.data); + return args.data; + } + }; + async $transaction(callback: (tx: InMemoryAiPrismaService) => Promise): Promise { return callback(this); } @@ -186,6 +207,10 @@ class InMemoryAiPrismaService { ...publicPool }); } + + getUsageLogs(): AiUsageLogRecord[] { + return [...this.usageLogs]; + } } class StaticExecutor implements AiChannelExecutor { @@ -214,6 +239,11 @@ class StaticExecutor implements AiChannelExecutor { model: candidate.model, content: result.content ?? "", sessionId: "session_ai", + usage: { + promptTokens: 12, + completionTokens: 8, + totalTokens: 20 + }, raw: null }; } @@ -368,6 +398,32 @@ describe("AiController (integration)", () => { reasonMessage: null } ]); + expect(prismaService.getUsageLogs()).toEqual([ + { + userId: "user_1", + channel: AiChannel.USER_KEY, + providerName: "openai", + model: "gpt-4o-mini", + promptTokens: 0, + completionTokens: 0, + totalTokens: 0, + latencyMs: expect.any(Number), + success: false, + errorCode: "UPSTREAM_UNREACHABLE" + }, + { + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "default", + model: null, + promptTokens: 12, + completionTokens: 8, + totalTokens: 20, + latencyMs: expect.any(Number), + success: true, + errorCode: null + } + ]); }); it("should allow astrbot binding with config id only", async () => { @@ -428,5 +484,6 @@ describe("AiController (integration)", () => { reasonMessage: "公共 AI 通道未开启" } ]); + expect(prismaService.getUsageLogs()).toEqual([]); }); }); diff --git a/apps/api/test/astrbot-provider.spec.ts b/apps/api/test/astrbot-provider.spec.ts index a8fcee8..98f3862 100644 --- a/apps/api/test/astrbot-provider.spec.ts +++ b/apps/api/test/astrbot-provider.spec.ts @@ -30,6 +30,15 @@ describe("AstrbotProvider", () => { } if (pullCount === 3) { + controller.enqueue( + encoder.encode( + 'data: {"type":"agent_stats","data":{"token_usage":{"input_other":12,"input_cached":30,"output":8}}}\n\n' + ) + ); + return; + } + + if (pullCount === 4) { controller.enqueue( encoder.encode('data: {"type":"end","data":"","streaming":false}\n\n') ); @@ -77,6 +86,11 @@ describe("AstrbotProvider", () => { expect(result.content).toBe("TodoList AstrBot 已连接"); expect(result.sessionId).toBe("session_1"); - expect(pullCount).toBeGreaterThanOrEqual(3); + expect(result.usage).toEqual({ + promptTokens: 42, + completionTokens: 8, + totalTokens: 50 + }); + expect(pullCount).toBeGreaterThanOrEqual(4); }); });