feat(api-ai): add user and ip rate limiting

This commit is contained in:
2026-04-07 22:56:22 +08:00
parent ce72892dc8
commit 1f8b539b68
5 changed files with 325 additions and 8 deletions
+123
View File
@@ -0,0 +1,123 @@
import { Injectable } from "@nestjs/common";
import { ConfigService } from "@nestjs/config";
type AiRateLimitBucket = {
count: number;
resetAt: number;
};
export type AiRateLimitResult =
| {
allowed: true;
}
| {
allowed: false;
reason: "USER" | "IP";
retryAfterMs: number;
limit: number;
windowMs: number;
};
@Injectable()
export class AiRateLimitService {
private readonly userBuckets = new Map<string, AiRateLimitBucket>();
private readonly ipBuckets = new Map<string, AiRateLimitBucket>();
private readonly windowMs: number;
private readonly userLimit: number;
private readonly ipLimit: number;
constructor(private readonly configService: ConfigService) {
this.windowMs = this.readPositiveInt("AI_RATE_LIMIT_WINDOW_MS", 60_000);
this.userLimit = this.readPositiveInt("AI_RATE_LIMIT_USER_MAX", 20);
this.ipLimit = this.readPositiveInt("AI_RATE_LIMIT_IP_MAX", 60);
}
consume(userId: string, clientIp: string | null): AiRateLimitResult {
const now = Date.now();
const userBucket = this.getBucket(this.userBuckets, userId, now);
if (userBucket.count >= this.userLimit) {
return {
allowed: false,
reason: "USER",
retryAfterMs: Math.max(0, userBucket.resetAt - now),
limit: this.userLimit,
windowMs: this.windowMs
};
}
const normalizedIp = this.normalizeIp(clientIp);
const ipBucket = normalizedIp ? this.getBucket(this.ipBuckets, normalizedIp, now) : null;
if (ipBucket && ipBucket.count >= this.ipLimit) {
return {
allowed: false,
reason: "IP",
retryAfterMs: Math.max(0, ipBucket.resetAt - now),
limit: this.ipLimit,
windowMs: this.windowMs
};
}
userBucket.count += 1;
if (ipBucket) {
ipBucket.count += 1;
}
this.cleanupExpiredBuckets(this.userBuckets, now);
this.cleanupExpiredBuckets(this.ipBuckets, now);
return {
allowed: true
};
}
private getBucket(
buckets: Map<string, AiRateLimitBucket>,
key: string,
now: number
): AiRateLimitBucket {
const currentBucket = buckets.get(key);
if (!currentBucket || now >= currentBucket.resetAt) {
const nextBucket: AiRateLimitBucket = {
count: 0,
resetAt: now + this.windowMs
};
buckets.set(key, nextBucket);
return nextBucket;
}
return currentBucket;
}
private cleanupExpiredBuckets(buckets: Map<string, AiRateLimitBucket>, now: number): void {
if (buckets.size <= 256) {
return;
}
for (const [key, bucket] of buckets.entries()) {
if (now >= bucket.resetAt) {
buckets.delete(key);
}
}
}
private normalizeIp(clientIp: string | null): string | null {
if (!clientIp) {
return null;
}
const normalizedIp = clientIp.trim();
return normalizedIp.length > 0 ? normalizedIp : null;
}
private readPositiveInt(key: string, fallbackValue: number): number {
const rawValue = this.configService.get<string | number | undefined>(key);
const parsedValue =
typeof rawValue === "number" ? rawValue : Number.parseInt(String(rawValue ?? ""), 10);
if (!Number.isFinite(parsedValue) || parsedValue <= 0) {
return fallbackValue;
}
return parsedValue;
}
}
+12 -2
View File
@@ -1,4 +1,13 @@
import { Body, Controller, Get, Headers, Post, Query, UnauthorizedException } from "@nestjs/common"; import {
Body,
Controller,
Get,
Headers,
Ip,
Post,
Query,
UnauthorizedException
} from "@nestjs/common";
import { AiChatDto } from "./dto/ai-chat.dto"; import { AiChatDto } from "./dto/ai-chat.dto";
import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto"; import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto";
import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto";
@@ -39,9 +48,10 @@ export class AiController {
@Post("chat") @Post("chat")
async chat( async chat(
@Headers("x-user-id") userIdHeader: string | string[] | undefined, @Headers("x-user-id") userIdHeader: string | string[] | undefined,
@Ip() clientIp: string,
@Body() body: AiChatDto @Body() body: AiChatDto
): Promise<AiChatResponse> { ): Promise<AiChatResponse> {
return this.aiService.chat(this.resolveUserId(userIdHeader), body); return this.aiService.chat(this.resolveUserId(userIdHeader), body, clientIp);
} }
private resolveUserId(userIdHeader: string | string[] | undefined): string { private resolveUserId(userIdHeader: string | string[] | undefined): string {
+8 -1
View File
@@ -1,5 +1,6 @@
import { Module } from "@nestjs/common"; import { Module } from "@nestjs/common";
import { PrismaModule } from "../prisma/prisma.module"; import { PrismaModule } from "../prisma/prisma.module";
import { AiRateLimitService } from "./ai-rate-limit.service";
import { AiController } from "./ai.controller"; import { AiController } from "./ai.controller";
import { AiProviderRegistryService } from "./ai-provider-registry.service"; import { AiProviderRegistryService } from "./ai-provider-registry.service";
import { AiService } from "./ai.service"; import { AiService } from "./ai.service";
@@ -9,6 +10,12 @@ import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider
@Module({ @Module({
imports: [PrismaModule], imports: [PrismaModule],
controllers: [AiController], controllers: [AiController],
providers: [AiService, AiProviderRegistryService, OpenAiCompatibleProvider, AstrbotProvider] providers: [
AiService,
AiRateLimitService,
AiProviderRegistryService,
OpenAiCompatibleProvider,
AstrbotProvider
]
}) })
export class AiModule {} export class AiModule {}
+31 -3
View File
@@ -1,4 +1,11 @@
import { BadGatewayException, BadRequestException, Injectable, Logger } from "@nestjs/common"; import {
BadGatewayException,
BadRequestException,
HttpException,
HttpStatus,
Injectable,
Logger
} from "@nestjs/common";
import { import {
AiChannel, AiChannel,
AiUsageLog, AiUsageLog,
@@ -10,6 +17,7 @@ import {
} from "../../generated/prisma/client"; } from "../../generated/prisma/client";
import { PrismaService } from "../prisma/prisma.service"; import { PrismaService } from "../prisma/prisma.service";
import { DataEncryptionService } from "../security/data-encryption.service"; import { DataEncryptionService } from "../security/data-encryption.service";
import { AiRateLimitService } from "./ai-rate-limit.service";
import { AiProviderRegistryService } from "./ai-provider-registry.service"; import { AiProviderRegistryService } from "./ai-provider-registry.service";
import { AiChatDto } from "./dto/ai-chat.dto"; import { AiChatDto } from "./dto/ai-chat.dto";
import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto"; import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto";
@@ -105,7 +113,8 @@ export class AiService {
constructor( constructor(
private readonly prismaService: PrismaService, private readonly prismaService: PrismaService,
private readonly aiProviderRegistryService: AiProviderRegistryService, private readonly aiProviderRegistryService: AiProviderRegistryService,
private readonly dataEncryptionService: DataEncryptionService private readonly dataEncryptionService: DataEncryptionService,
private readonly aiRateLimitService: AiRateLimitService
) {} ) {}
async listBindings(userId: string): Promise<ListAiBindingsResponse> { async listBindings(userId: string): Promise<ListAiBindingsResponse> {
@@ -242,7 +251,26 @@ export class AiService {
return this.serializeBinding(result); return this.serializeBinding(result);
} }
async chat(userId: string, dto: AiChatDto): Promise<AiChatResponse> { async chat(
userId: string,
dto: AiChatDto,
clientIp: string | null = null
): Promise<AiChatResponse> {
const rateLimitResult = this.aiRateLimitService.consume(userId, clientIp);
if (!rateLimitResult.allowed) {
throw new HttpException(
{
message: "AI 请求过于频繁,请稍后再试",
code: "AI_RATE_LIMITED",
dimension: rateLimitResult.reason === "USER" ? "user" : "ip",
retryAfterMs: rateLimitResult.retryAfterMs,
limit: rateLimitResult.limit,
windowMs: rateLimitResult.windowMs
},
HttpStatus.TOO_MANY_REQUESTS
);
}
const attempts: AiRouteAttempt[] = []; const attempts: AiRouteAttempt[] = [];
const plan = await this.buildRoutePlan(userId, dto.channel ?? null); const plan = await this.buildRoutePlan(userId, dto.channel ?? null);
const promptMessage = await this.buildPromptMessage(userId, dto.message, dto.localTasks ?? []); const promptMessage = await this.buildPromptMessage(userId, dto.message, dto.localTasks ?? []);
+151 -2
View File
@@ -12,6 +12,7 @@ import {
} from "../generated/prisma/client"; } from "../generated/prisma/client";
import { AiController } from "../src/ai/ai.controller"; import { AiController } from "../src/ai/ai.controller";
import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service"; import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service";
import { AiRateLimitService } from "../src/ai/ai-rate-limit.service";
import { AiService } from "../src/ai/ai.service"; import { AiService } from "../src/ai/ai.service";
import { import {
AiChatInput, AiChatInput,
@@ -410,6 +411,7 @@ describe("AiController (integration)", () => {
controllers: [AiController], controllers: [AiController],
providers: [ providers: [
AiService, AiService,
AiRateLimitService,
DataEncryptionService, DataEncryptionService,
{ {
provide: PrismaService, provide: PrismaService,
@@ -418,8 +420,22 @@ describe("AiController (integration)", () => {
{ {
provide: ConfigService, provide: ConfigService,
useValue: { useValue: {
get: (key: string) => get: (key: string) => {
key === "DATA_ENCRYPTION_SECRET" ? "test-data-encryption-secret" : undefined if (key === "DATA_ENCRYPTION_SECRET") {
return "test-data-encryption-secret";
}
if (key === "AI_RATE_LIMIT_WINDOW_MS") {
return 60_000;
}
if (key === "AI_RATE_LIMIT_USER_MAX") {
return 2;
}
if (key === "AI_RATE_LIMIT_IP_MAX") {
return 3;
}
return undefined;
}
} }
}, },
{ {
@@ -911,6 +927,139 @@ describe("AiController (integration)", () => {
]); ]);
expect(prismaService.getUsageLogs()).toEqual([]); expect(prismaService.getUsageLogs()).toEqual([]);
}); });
it("should rate limit ai chat by user in the same window", async () => {
prismaService.seedBinding({
id: "binding_astrbot_rate_limit_user",
userId: "user_1",
channel: AiChannel.ASTRBOT,
providerName: "",
model: null,
configId: "default",
configName: null,
encryptedApiKey: "abk_astrbot",
endpoint: "http://127.0.0.1:6185",
isDefault: true,
isEnabled: true
});
await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.set("x-forwarded-for", "203.0.113.10")
.send({
message: "第一条"
})
.expect(201);
await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.set("x-forwarded-for", "203.0.113.10")
.send({
message: "第二条"
})
.expect(201);
const response = await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.set("x-forwarded-for", "203.0.113.10")
.send({
message: "第三条"
})
.expect(429);
expect(response.body).toMatchObject({
message: "AI 请求过于频繁,请稍后再试",
code: "AI_RATE_LIMITED",
dimension: "user",
limit: 2,
windowMs: 60000
});
expect(response.body.retryAfterMs).toEqual(expect.any(Number));
expect(astrbotExecutor.inputs).toHaveLength(2);
expect(prismaService.getUsageLogs()).toHaveLength(2);
});
it("should rate limit ai chat by ip across different users", async () => {
prismaService.seedBinding({
id: "binding_astrbot_rate_limit_ip_user_1",
userId: "user_1",
channel: AiChannel.ASTRBOT,
providerName: "",
model: null,
configId: "default",
configName: null,
encryptedApiKey: "abk_astrbot",
endpoint: "http://127.0.0.1:6185",
isDefault: true,
isEnabled: true
});
prismaService.seedBinding({
id: "binding_astrbot_rate_limit_ip_user_2",
userId: "user_2",
channel: AiChannel.ASTRBOT,
providerName: "",
model: null,
configId: "default",
configName: null,
encryptedApiKey: "abk_astrbot",
endpoint: "http://127.0.0.1:6185",
isDefault: true,
isEnabled: true
});
const sharedIp = "198.51.100.7";
await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.set("x-forwarded-for", sharedIp)
.send({
message: "用户一第一条"
})
.expect(201);
await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_2")
.set("x-forwarded-for", sharedIp)
.send({
message: "用户二第一条"
})
.expect(201);
await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.set("x-forwarded-for", sharedIp)
.send({
message: "用户一第二条"
})
.expect(201);
const response = await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_2")
.set("x-forwarded-for", sharedIp)
.send({
message: "用户二第二条"
})
.expect(429);
expect(response.body).toMatchObject({
message: "AI 请求过于频繁,请稍后再试",
code: "AI_RATE_LIMITED",
dimension: "ip",
limit: 3,
windowMs: 60000
});
expect(response.body.retryAfterMs).toEqual(expect.any(Number));
expect(astrbotExecutor.inputs).toHaveLength(3);
expect(prismaService.getUsageLogs()).toHaveLength(3);
});
it("should list usage logs with pagination and filters", async () => { it("should list usage logs with pagination and filters", async () => {
prismaService.seedUsageLog({ prismaService.seedUsageLog({
id: "usage_log_1", id: "usage_log_1",