feat(api-ai): persist usage logs
This commit is contained in:
@@ -15,7 +15,12 @@ import { PrismaService } from "../prisma/prisma.service";
|
||||
import { AiProviderRegistryService } from "./ai-provider-registry.service";
|
||||
import { AiChatDto } from "./dto/ai-chat.dto";
|
||||
import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto";
|
||||
import { AiResolvedRouteCandidate, AiRouteAttempt, AiRouteFailureError } from "./ai.types";
|
||||
import {
|
||||
AiResolvedRouteCandidate,
|
||||
AiRouteAttempt,
|
||||
AiRouteFailureError,
|
||||
AiUsageMetrics
|
||||
} from "./ai.types";
|
||||
|
||||
type AiBindingSummary = {
|
||||
id: string;
|
||||
@@ -198,6 +203,7 @@ export class AiService {
|
||||
}
|
||||
|
||||
const executor = this.aiProviderRegistryService.getExecutor(entry.candidate.channel);
|
||||
const startedAt = Date.now();
|
||||
|
||||
try {
|
||||
const result = await executor.execute(entry.candidate, {
|
||||
@@ -205,6 +211,7 @@ export class AiService {
|
||||
message: dto.message,
|
||||
sessionId: dto.sessionId ?? null
|
||||
});
|
||||
const latencyMs = Date.now() - startedAt;
|
||||
|
||||
attempts.push({
|
||||
channel: result.channel,
|
||||
@@ -214,6 +221,16 @@ export class AiService {
|
||||
reasonCode: null,
|
||||
reasonMessage: null
|
||||
});
|
||||
await this.recordUsageLog({
|
||||
userId,
|
||||
channel: result.channel,
|
||||
providerName: result.providerName,
|
||||
model: result.model,
|
||||
usage: result.usage,
|
||||
latencyMs,
|
||||
success: true,
|
||||
errorCode: null
|
||||
});
|
||||
|
||||
return {
|
||||
channel: result.channel,
|
||||
@@ -224,8 +241,19 @@ export class AiService {
|
||||
attempts
|
||||
};
|
||||
} catch (error) {
|
||||
const latencyMs = Date.now() - startedAt;
|
||||
const failureAttempt = this.toFailureAttempt(entry.candidate, error);
|
||||
attempts.push(failureAttempt);
|
||||
await this.recordUsageLog({
|
||||
userId,
|
||||
channel: failureAttempt.channel,
|
||||
providerName: failureAttempt.providerName,
|
||||
model: failureAttempt.model,
|
||||
usage: null,
|
||||
latencyMs,
|
||||
success: false,
|
||||
errorCode: failureAttempt.reasonCode
|
||||
});
|
||||
this.logger.warn(
|
||||
`AI 通道降级:channel=${failureAttempt.channel} provider=${failureAttempt.providerName ?? "unknown"} code=${failureAttempt.reasonCode ?? "UNKNOWN"} message=${failureAttempt.reasonMessage ?? "unknown"}`
|
||||
);
|
||||
@@ -464,4 +492,35 @@ export class AiService {
|
||||
|
||||
return `${secret.slice(0, 4)}***${secret.slice(-2)}`;
|
||||
}
|
||||
|
||||
private async recordUsageLog(input: {
|
||||
userId: string;
|
||||
channel: AiChannel;
|
||||
providerName: string | null;
|
||||
model: string | null;
|
||||
usage: AiUsageMetrics | null;
|
||||
latencyMs: number;
|
||||
success: boolean;
|
||||
errorCode: string | null;
|
||||
}): Promise<void> {
|
||||
try {
|
||||
await this.prismaService.aiUsageLog.create({
|
||||
data: {
|
||||
userId: input.userId,
|
||||
channel: input.channel,
|
||||
providerName: input.providerName,
|
||||
model: input.model,
|
||||
promptTokens: input.usage?.promptTokens ?? 0,
|
||||
completionTokens: input.usage?.completionTokens ?? 0,
|
||||
totalTokens: input.usage?.totalTokens ?? 0,
|
||||
latencyMs: input.latencyMs,
|
||||
success: input.success,
|
||||
errorCode: input.errorCode
|
||||
}
|
||||
});
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : "未知错误";
|
||||
this.logger.warn(`写入 AI 使用日志失败:${message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,9 +24,16 @@ export type AiChatResult = {
|
||||
model: string | null;
|
||||
content: string;
|
||||
sessionId: string | null;
|
||||
usage: AiUsageMetrics | null;
|
||||
raw: unknown;
|
||||
};
|
||||
|
||||
export type AiUsageMetrics = {
|
||||
promptTokens: number;
|
||||
completionTokens: number;
|
||||
totalTokens: number;
|
||||
};
|
||||
|
||||
export type AiRouteAttempt = {
|
||||
channel: AiChannel;
|
||||
providerName: string | null;
|
||||
|
||||
@@ -133,6 +133,7 @@ export class AstrbotProvider implements AiChannelExecutor {
|
||||
model: candidate.model,
|
||||
content,
|
||||
sessionId,
|
||||
usage: this.extractUsage(events),
|
||||
raw: events
|
||||
};
|
||||
}
|
||||
@@ -248,4 +249,39 @@ export class AstrbotProvider implements AiChannelExecutor {
|
||||
|
||||
return fallback;
|
||||
}
|
||||
|
||||
private extractUsage(events: Array<Record<string, unknown>>): AiChatResult["usage"] {
|
||||
for (const event of events) {
|
||||
if (this.readString(event["type"]) !== "agent_stats") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const data = this.asRecord(event["data"]);
|
||||
const tokenUsage = this.asRecord(data?.["token_usage"]);
|
||||
if (!tokenUsage) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const promptTokens =
|
||||
(this.readNumber(tokenUsage["input_other"]) ?? 0) +
|
||||
(this.readNumber(tokenUsage["input_cached"]) ?? 0);
|
||||
const completionTokens = this.readNumber(tokenUsage["output"]) ?? 0;
|
||||
|
||||
return {
|
||||
promptTokens,
|
||||
completionTokens,
|
||||
totalTokens: promptTokens + completionTokens
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private asRecord(value: unknown): Record<string, unknown> | null {
|
||||
return typeof value === "object" && value !== null ? (value as Record<string, unknown>) : null;
|
||||
}
|
||||
|
||||
private readNumber(value: unknown): number | null {
|
||||
return typeof value === "number" && Number.isFinite(value) ? value : null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,6 +105,7 @@ export class OpenAiCompatibleProvider implements AiChannelExecutor {
|
||||
model: this.extractModel(payload) ?? candidate.model,
|
||||
content,
|
||||
sessionId: input.sessionId,
|
||||
usage: this.extractUsage(payload),
|
||||
raw: payload
|
||||
};
|
||||
}
|
||||
@@ -176,6 +177,31 @@ export class OpenAiCompatibleProvider implements AiChannelExecutor {
|
||||
return payload["model"];
|
||||
}
|
||||
|
||||
private extractUsage(payload: unknown): AiChatResult["usage"] {
|
||||
if (!this.isRecord(payload)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const usage = payload["usage"];
|
||||
if (!this.isRecord(usage)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const promptTokens = this.readNumber(usage["prompt_tokens"]);
|
||||
const completionTokens = this.readNumber(usage["completion_tokens"]);
|
||||
const totalTokens = this.readNumber(usage["total_tokens"]);
|
||||
|
||||
if (promptTokens === null && completionTokens === null && totalTokens === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
promptTokens: promptTokens ?? 0,
|
||||
completionTokens: completionTokens ?? 0,
|
||||
totalTokens: totalTokens ?? (promptTokens ?? 0) + (completionTokens ?? 0)
|
||||
};
|
||||
}
|
||||
|
||||
private extractErrorMessage(payload: unknown, fallback: string): string {
|
||||
if (!this.isRecord(payload)) {
|
||||
return fallback;
|
||||
@@ -200,4 +226,8 @@ export class OpenAiCompatibleProvider implements AiChannelExecutor {
|
||||
|
||||
return fallback;
|
||||
}
|
||||
|
||||
private readNumber(value: unknown): number | null {
|
||||
return typeof value === "number" && Number.isFinite(value) ? value : null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,11 +12,25 @@ import {
|
||||
} from "../src/ai/ai.types";
|
||||
import { PrismaService } from "../src/prisma/prisma.service";
|
||||
|
||||
type AiUsageLogRecord = {
|
||||
userId: string | null;
|
||||
channel: AiChannel;
|
||||
providerName: string | null;
|
||||
model: string | null;
|
||||
promptTokens: number;
|
||||
completionTokens: number;
|
||||
totalTokens: number;
|
||||
latencyMs: number | null;
|
||||
success: boolean;
|
||||
errorCode: string | null;
|
||||
};
|
||||
|
||||
class InMemoryAiPrismaService {
|
||||
private bindingIdSequence = 1;
|
||||
private publicPoolIdSequence = 1;
|
||||
private bindings: AiProviderBinding[] = [];
|
||||
private publicPools: AiPublicPoolConfig[] = [];
|
||||
private usageLogs: AiUsageLogRecord[] = [];
|
||||
|
||||
readonly aiProviderBinding = {
|
||||
findMany: async (args: {
|
||||
@@ -164,6 +178,13 @@ class InMemoryAiPrismaService {
|
||||
}
|
||||
};
|
||||
|
||||
readonly aiUsageLog = {
|
||||
create: async (args: { data: AiUsageLogRecord }) => {
|
||||
this.usageLogs.push(args.data);
|
||||
return args.data;
|
||||
}
|
||||
};
|
||||
|
||||
async $transaction<T>(callback: (tx: InMemoryAiPrismaService) => Promise<T>): Promise<T> {
|
||||
return callback(this);
|
||||
}
|
||||
@@ -186,6 +207,10 @@ class InMemoryAiPrismaService {
|
||||
...publicPool
|
||||
});
|
||||
}
|
||||
|
||||
getUsageLogs(): AiUsageLogRecord[] {
|
||||
return [...this.usageLogs];
|
||||
}
|
||||
}
|
||||
|
||||
class StaticExecutor implements AiChannelExecutor {
|
||||
@@ -214,6 +239,11 @@ class StaticExecutor implements AiChannelExecutor {
|
||||
model: candidate.model,
|
||||
content: result.content ?? "",
|
||||
sessionId: "session_ai",
|
||||
usage: {
|
||||
promptTokens: 12,
|
||||
completionTokens: 8,
|
||||
totalTokens: 20
|
||||
},
|
||||
raw: null
|
||||
};
|
||||
}
|
||||
@@ -368,6 +398,32 @@ describe("AiController (integration)", () => {
|
||||
reasonMessage: null
|
||||
}
|
||||
]);
|
||||
expect(prismaService.getUsageLogs()).toEqual([
|
||||
{
|
||||
userId: "user_1",
|
||||
channel: AiChannel.USER_KEY,
|
||||
providerName: "openai",
|
||||
model: "gpt-4o-mini",
|
||||
promptTokens: 0,
|
||||
completionTokens: 0,
|
||||
totalTokens: 0,
|
||||
latencyMs: expect.any(Number),
|
||||
success: false,
|
||||
errorCode: "UPSTREAM_UNREACHABLE"
|
||||
},
|
||||
{
|
||||
userId: "user_1",
|
||||
channel: AiChannel.ASTRBOT,
|
||||
providerName: "default",
|
||||
model: null,
|
||||
promptTokens: 12,
|
||||
completionTokens: 8,
|
||||
totalTokens: 20,
|
||||
latencyMs: expect.any(Number),
|
||||
success: true,
|
||||
errorCode: null
|
||||
}
|
||||
]);
|
||||
});
|
||||
|
||||
it("should allow astrbot binding with config id only", async () => {
|
||||
@@ -428,5 +484,6 @@ describe("AiController (integration)", () => {
|
||||
reasonMessage: "公共 AI 通道未开启"
|
||||
}
|
||||
]);
|
||||
expect(prismaService.getUsageLogs()).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -30,6 +30,15 @@ describe("AstrbotProvider", () => {
|
||||
}
|
||||
|
||||
if (pullCount === 3) {
|
||||
controller.enqueue(
|
||||
encoder.encode(
|
||||
'data: {"type":"agent_stats","data":{"token_usage":{"input_other":12,"input_cached":30,"output":8}}}\n\n'
|
||||
)
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (pullCount === 4) {
|
||||
controller.enqueue(
|
||||
encoder.encode('data: {"type":"end","data":"","streaming":false}\n\n')
|
||||
);
|
||||
@@ -77,6 +86,11 @@ describe("AstrbotProvider", () => {
|
||||
|
||||
expect(result.content).toBe("TodoList AstrBot 已连接");
|
||||
expect(result.sessionId).toBe("session_1");
|
||||
expect(pullCount).toBeGreaterThanOrEqual(3);
|
||||
expect(result.usage).toEqual({
|
||||
promptTokens: 42,
|
||||
completionTokens: 8,
|
||||
totalTokens: 50
|
||||
});
|
||||
expect(pullCount).toBeGreaterThanOrEqual(4);
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user