feat(api-ai): persist usage logs

This commit is contained in:
2026-04-06 12:42:56 +08:00
parent 2ca790abf9
commit 45177e9fad
6 changed files with 205 additions and 2 deletions
+60 -1
View File
@@ -15,7 +15,12 @@ import { PrismaService } from "../prisma/prisma.service";
import { AiProviderRegistryService } from "./ai-provider-registry.service";
import { AiChatDto } from "./dto/ai-chat.dto";
import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto";
import { AiResolvedRouteCandidate, AiRouteAttempt, AiRouteFailureError } from "./ai.types";
import {
AiResolvedRouteCandidate,
AiRouteAttempt,
AiRouteFailureError,
AiUsageMetrics
} from "./ai.types";
type AiBindingSummary = {
id: string;
@@ -198,6 +203,7 @@ export class AiService {
}
const executor = this.aiProviderRegistryService.getExecutor(entry.candidate.channel);
const startedAt = Date.now();
try {
const result = await executor.execute(entry.candidate, {
@@ -205,6 +211,7 @@ export class AiService {
message: dto.message,
sessionId: dto.sessionId ?? null
});
const latencyMs = Date.now() - startedAt;
attempts.push({
channel: result.channel,
@@ -214,6 +221,16 @@ export class AiService {
reasonCode: null,
reasonMessage: null
});
await this.recordUsageLog({
userId,
channel: result.channel,
providerName: result.providerName,
model: result.model,
usage: result.usage,
latencyMs,
success: true,
errorCode: null
});
return {
channel: result.channel,
@@ -224,8 +241,19 @@ export class AiService {
attempts
};
} catch (error) {
const latencyMs = Date.now() - startedAt;
const failureAttempt = this.toFailureAttempt(entry.candidate, error);
attempts.push(failureAttempt);
await this.recordUsageLog({
userId,
channel: failureAttempt.channel,
providerName: failureAttempt.providerName,
model: failureAttempt.model,
usage: null,
latencyMs,
success: false,
errorCode: failureAttempt.reasonCode
});
this.logger.warn(
`AI 通道降级:channel=${failureAttempt.channel} provider=${failureAttempt.providerName ?? "unknown"} code=${failureAttempt.reasonCode ?? "UNKNOWN"} message=${failureAttempt.reasonMessage ?? "unknown"}`
);
@@ -464,4 +492,35 @@ export class AiService {
return `${secret.slice(0, 4)}***${secret.slice(-2)}`;
}
private async recordUsageLog(input: {
userId: string;
channel: AiChannel;
providerName: string | null;
model: string | null;
usage: AiUsageMetrics | null;
latencyMs: number;
success: boolean;
errorCode: string | null;
}): Promise<void> {
try {
await this.prismaService.aiUsageLog.create({
data: {
userId: input.userId,
channel: input.channel,
providerName: input.providerName,
model: input.model,
promptTokens: input.usage?.promptTokens ?? 0,
completionTokens: input.usage?.completionTokens ?? 0,
totalTokens: input.usage?.totalTokens ?? 0,
latencyMs: input.latencyMs,
success: input.success,
errorCode: input.errorCode
}
});
} catch (error) {
const message = error instanceof Error ? error.message : "未知错误";
this.logger.warn(`写入 AI 使用日志失败:${message}`);
}
}
}
+7
View File
@@ -24,9 +24,16 @@ export type AiChatResult = {
model: string | null;
content: string;
sessionId: string | null;
usage: AiUsageMetrics | null;
raw: unknown;
};
export type AiUsageMetrics = {
promptTokens: number;
completionTokens: number;
totalTokens: number;
};
export type AiRouteAttempt = {
channel: AiChannel;
providerName: string | null;
@@ -133,6 +133,7 @@ export class AstrbotProvider implements AiChannelExecutor {
model: candidate.model,
content,
sessionId,
usage: this.extractUsage(events),
raw: events
};
}
@@ -248,4 +249,39 @@ export class AstrbotProvider implements AiChannelExecutor {
return fallback;
}
private extractUsage(events: Array<Record<string, unknown>>): AiChatResult["usage"] {
for (const event of events) {
if (this.readString(event["type"]) !== "agent_stats") {
continue;
}
const data = this.asRecord(event["data"]);
const tokenUsage = this.asRecord(data?.["token_usage"]);
if (!tokenUsage) {
continue;
}
const promptTokens =
(this.readNumber(tokenUsage["input_other"]) ?? 0) +
(this.readNumber(tokenUsage["input_cached"]) ?? 0);
const completionTokens = this.readNumber(tokenUsage["output"]) ?? 0;
return {
promptTokens,
completionTokens,
totalTokens: promptTokens + completionTokens
};
}
return null;
}
private asRecord(value: unknown): Record<string, unknown> | null {
return typeof value === "object" && value !== null ? (value as Record<string, unknown>) : null;
}
private readNumber(value: unknown): number | null {
return typeof value === "number" && Number.isFinite(value) ? value : null;
}
}
@@ -105,6 +105,7 @@ export class OpenAiCompatibleProvider implements AiChannelExecutor {
model: this.extractModel(payload) ?? candidate.model,
content,
sessionId: input.sessionId,
usage: this.extractUsage(payload),
raw: payload
};
}
@@ -176,6 +177,31 @@ export class OpenAiCompatibleProvider implements AiChannelExecutor {
return payload["model"];
}
private extractUsage(payload: unknown): AiChatResult["usage"] {
if (!this.isRecord(payload)) {
return null;
}
const usage = payload["usage"];
if (!this.isRecord(usage)) {
return null;
}
const promptTokens = this.readNumber(usage["prompt_tokens"]);
const completionTokens = this.readNumber(usage["completion_tokens"]);
const totalTokens = this.readNumber(usage["total_tokens"]);
if (promptTokens === null && completionTokens === null && totalTokens === null) {
return null;
}
return {
promptTokens: promptTokens ?? 0,
completionTokens: completionTokens ?? 0,
totalTokens: totalTokens ?? (promptTokens ?? 0) + (completionTokens ?? 0)
};
}
private extractErrorMessage(payload: unknown, fallback: string): string {
if (!this.isRecord(payload)) {
return fallback;
@@ -200,4 +226,8 @@ export class OpenAiCompatibleProvider implements AiChannelExecutor {
return fallback;
}
private readNumber(value: unknown): number | null {
return typeof value === "number" && Number.isFinite(value) ? value : null;
}
}
+57
View File
@@ -12,11 +12,25 @@ import {
} from "../src/ai/ai.types";
import { PrismaService } from "../src/prisma/prisma.service";
type AiUsageLogRecord = {
userId: string | null;
channel: AiChannel;
providerName: string | null;
model: string | null;
promptTokens: number;
completionTokens: number;
totalTokens: number;
latencyMs: number | null;
success: boolean;
errorCode: string | null;
};
class InMemoryAiPrismaService {
private bindingIdSequence = 1;
private publicPoolIdSequence = 1;
private bindings: AiProviderBinding[] = [];
private publicPools: AiPublicPoolConfig[] = [];
private usageLogs: AiUsageLogRecord[] = [];
readonly aiProviderBinding = {
findMany: async (args: {
@@ -164,6 +178,13 @@ class InMemoryAiPrismaService {
}
};
readonly aiUsageLog = {
create: async (args: { data: AiUsageLogRecord }) => {
this.usageLogs.push(args.data);
return args.data;
}
};
async $transaction<T>(callback: (tx: InMemoryAiPrismaService) => Promise<T>): Promise<T> {
return callback(this);
}
@@ -186,6 +207,10 @@ class InMemoryAiPrismaService {
...publicPool
});
}
getUsageLogs(): AiUsageLogRecord[] {
return [...this.usageLogs];
}
}
class StaticExecutor implements AiChannelExecutor {
@@ -214,6 +239,11 @@ class StaticExecutor implements AiChannelExecutor {
model: candidate.model,
content: result.content ?? "",
sessionId: "session_ai",
usage: {
promptTokens: 12,
completionTokens: 8,
totalTokens: 20
},
raw: null
};
}
@@ -368,6 +398,32 @@ describe("AiController (integration)", () => {
reasonMessage: null
}
]);
expect(prismaService.getUsageLogs()).toEqual([
{
userId: "user_1",
channel: AiChannel.USER_KEY,
providerName: "openai",
model: "gpt-4o-mini",
promptTokens: 0,
completionTokens: 0,
totalTokens: 0,
latencyMs: expect.any(Number),
success: false,
errorCode: "UPSTREAM_UNREACHABLE"
},
{
userId: "user_1",
channel: AiChannel.ASTRBOT,
providerName: "default",
model: null,
promptTokens: 12,
completionTokens: 8,
totalTokens: 20,
latencyMs: expect.any(Number),
success: true,
errorCode: null
}
]);
});
it("should allow astrbot binding with config id only", async () => {
@@ -428,5 +484,6 @@ describe("AiController (integration)", () => {
reasonMessage: "公共 AI 通道未开启"
}
]);
expect(prismaService.getUsageLogs()).toEqual([]);
});
});
+15 -1
View File
@@ -30,6 +30,15 @@ describe("AstrbotProvider", () => {
}
if (pullCount === 3) {
controller.enqueue(
encoder.encode(
'data: {"type":"agent_stats","data":{"token_usage":{"input_other":12,"input_cached":30,"output":8}}}\n\n'
)
);
return;
}
if (pullCount === 4) {
controller.enqueue(
encoder.encode('data: {"type":"end","data":"","streaming":false}\n\n')
);
@@ -77,6 +86,11 @@ describe("AstrbotProvider", () => {
expect(result.content).toBe("TodoList AstrBot 已连接");
expect(result.sessionId).toBe("session_1");
expect(pullCount).toBeGreaterThanOrEqual(3);
expect(result.usage).toEqual({
promptTokens: 42,
completionTokens: 8,
totalTokens: 50
});
expect(pullCount).toBeGreaterThanOrEqual(4);
});
});