From 1f8b539b68ce2be27f1cacfa0cfb0e83bf18d527 Mon Sep 17 00:00:00 2001 From: Yaosanqi137 Date: Tue, 7 Apr 2026 22:56:22 +0800 Subject: [PATCH] feat(api-ai): add user and ip rate limiting --- apps/api/src/ai/ai-rate-limit.service.ts | 123 ++++++++++++++++++ apps/api/src/ai/ai.controller.ts | 14 ++- apps/api/src/ai/ai.module.ts | 9 +- apps/api/src/ai/ai.service.ts | 34 ++++- apps/api/test/ai.spec.ts | 153 ++++++++++++++++++++++- 5 files changed, 325 insertions(+), 8 deletions(-) create mode 100644 apps/api/src/ai/ai-rate-limit.service.ts diff --git a/apps/api/src/ai/ai-rate-limit.service.ts b/apps/api/src/ai/ai-rate-limit.service.ts new file mode 100644 index 0000000..e84b2de --- /dev/null +++ b/apps/api/src/ai/ai-rate-limit.service.ts @@ -0,0 +1,123 @@ +import { Injectable } from "@nestjs/common"; +import { ConfigService } from "@nestjs/config"; + +type AiRateLimitBucket = { + count: number; + resetAt: number; +}; + +export type AiRateLimitResult = + | { + allowed: true; + } + | { + allowed: false; + reason: "USER" | "IP"; + retryAfterMs: number; + limit: number; + windowMs: number; + }; + +@Injectable() +export class AiRateLimitService { + private readonly userBuckets = new Map(); + private readonly ipBuckets = new Map(); + private readonly windowMs: number; + private readonly userLimit: number; + private readonly ipLimit: number; + + constructor(private readonly configService: ConfigService) { + this.windowMs = this.readPositiveInt("AI_RATE_LIMIT_WINDOW_MS", 60_000); + this.userLimit = this.readPositiveInt("AI_RATE_LIMIT_USER_MAX", 20); + this.ipLimit = this.readPositiveInt("AI_RATE_LIMIT_IP_MAX", 60); + } + + consume(userId: string, clientIp: string | null): AiRateLimitResult { + const now = Date.now(); + const userBucket = this.getBucket(this.userBuckets, userId, now); + if (userBucket.count >= this.userLimit) { + return { + allowed: false, + reason: "USER", + retryAfterMs: Math.max(0, userBucket.resetAt - now), + limit: this.userLimit, + windowMs: this.windowMs + }; + } + + const normalizedIp = this.normalizeIp(clientIp); + const ipBucket = normalizedIp ? this.getBucket(this.ipBuckets, normalizedIp, now) : null; + if (ipBucket && ipBucket.count >= this.ipLimit) { + return { + allowed: false, + reason: "IP", + retryAfterMs: Math.max(0, ipBucket.resetAt - now), + limit: this.ipLimit, + windowMs: this.windowMs + }; + } + + userBucket.count += 1; + if (ipBucket) { + ipBucket.count += 1; + } + + this.cleanupExpiredBuckets(this.userBuckets, now); + this.cleanupExpiredBuckets(this.ipBuckets, now); + + return { + allowed: true + }; + } + + private getBucket( + buckets: Map, + key: string, + now: number + ): AiRateLimitBucket { + const currentBucket = buckets.get(key); + if (!currentBucket || now >= currentBucket.resetAt) { + const nextBucket: AiRateLimitBucket = { + count: 0, + resetAt: now + this.windowMs + }; + buckets.set(key, nextBucket); + return nextBucket; + } + + return currentBucket; + } + + private cleanupExpiredBuckets(buckets: Map, now: number): void { + if (buckets.size <= 256) { + return; + } + + for (const [key, bucket] of buckets.entries()) { + if (now >= bucket.resetAt) { + buckets.delete(key); + } + } + } + + private normalizeIp(clientIp: string | null): string | null { + if (!clientIp) { + return null; + } + + const normalizedIp = clientIp.trim(); + return normalizedIp.length > 0 ? normalizedIp : null; + } + + private readPositiveInt(key: string, fallbackValue: number): number { + const rawValue = this.configService.get(key); + const parsedValue = + typeof rawValue === "number" ? rawValue : Number.parseInt(String(rawValue ?? ""), 10); + + if (!Number.isFinite(parsedValue) || parsedValue <= 0) { + return fallbackValue; + } + + return parsedValue; + } +} diff --git a/apps/api/src/ai/ai.controller.ts b/apps/api/src/ai/ai.controller.ts index f9f0c16..99afb38 100644 --- a/apps/api/src/ai/ai.controller.ts +++ b/apps/api/src/ai/ai.controller.ts @@ -1,4 +1,13 @@ -import { Body, Controller, Get, Headers, Post, Query, UnauthorizedException } from "@nestjs/common"; +import { + Body, + Controller, + Get, + Headers, + Ip, + Post, + Query, + UnauthorizedException +} from "@nestjs/common"; import { AiChatDto } from "./dto/ai-chat.dto"; import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto"; import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto"; @@ -39,9 +48,10 @@ export class AiController { @Post("chat") async chat( @Headers("x-user-id") userIdHeader: string | string[] | undefined, + @Ip() clientIp: string, @Body() body: AiChatDto ): Promise { - return this.aiService.chat(this.resolveUserId(userIdHeader), body); + return this.aiService.chat(this.resolveUserId(userIdHeader), body, clientIp); } private resolveUserId(userIdHeader: string | string[] | undefined): string { diff --git a/apps/api/src/ai/ai.module.ts b/apps/api/src/ai/ai.module.ts index 2655525..a17544a 100644 --- a/apps/api/src/ai/ai.module.ts +++ b/apps/api/src/ai/ai.module.ts @@ -1,5 +1,6 @@ import { Module } from "@nestjs/common"; import { PrismaModule } from "../prisma/prisma.module"; +import { AiRateLimitService } from "./ai-rate-limit.service"; import { AiController } from "./ai.controller"; import { AiProviderRegistryService } from "./ai-provider-registry.service"; import { AiService } from "./ai.service"; @@ -9,6 +10,12 @@ import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider @Module({ imports: [PrismaModule], controllers: [AiController], - providers: [AiService, AiProviderRegistryService, OpenAiCompatibleProvider, AstrbotProvider] + providers: [ + AiService, + AiRateLimitService, + AiProviderRegistryService, + OpenAiCompatibleProvider, + AstrbotProvider + ] }) export class AiModule {} diff --git a/apps/api/src/ai/ai.service.ts b/apps/api/src/ai/ai.service.ts index 92cc275..5e7e354 100644 --- a/apps/api/src/ai/ai.service.ts +++ b/apps/api/src/ai/ai.service.ts @@ -1,4 +1,11 @@ -import { BadGatewayException, BadRequestException, Injectable, Logger } from "@nestjs/common"; +import { + BadGatewayException, + BadRequestException, + HttpException, + HttpStatus, + Injectable, + Logger +} from "@nestjs/common"; import { AiChannel, AiUsageLog, @@ -10,6 +17,7 @@ import { } from "../../generated/prisma/client"; import { PrismaService } from "../prisma/prisma.service"; import { DataEncryptionService } from "../security/data-encryption.service"; +import { AiRateLimitService } from "./ai-rate-limit.service"; import { AiProviderRegistryService } from "./ai-provider-registry.service"; import { AiChatDto } from "./dto/ai-chat.dto"; import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto"; @@ -105,7 +113,8 @@ export class AiService { constructor( private readonly prismaService: PrismaService, private readonly aiProviderRegistryService: AiProviderRegistryService, - private readonly dataEncryptionService: DataEncryptionService + private readonly dataEncryptionService: DataEncryptionService, + private readonly aiRateLimitService: AiRateLimitService ) {} async listBindings(userId: string): Promise { @@ -242,7 +251,26 @@ export class AiService { return this.serializeBinding(result); } - async chat(userId: string, dto: AiChatDto): Promise { + async chat( + userId: string, + dto: AiChatDto, + clientIp: string | null = null + ): Promise { + const rateLimitResult = this.aiRateLimitService.consume(userId, clientIp); + if (!rateLimitResult.allowed) { + throw new HttpException( + { + message: "AI 请求过于频繁,请稍后再试", + code: "AI_RATE_LIMITED", + dimension: rateLimitResult.reason === "USER" ? "user" : "ip", + retryAfterMs: rateLimitResult.retryAfterMs, + limit: rateLimitResult.limit, + windowMs: rateLimitResult.windowMs + }, + HttpStatus.TOO_MANY_REQUESTS + ); + } + const attempts: AiRouteAttempt[] = []; const plan = await this.buildRoutePlan(userId, dto.channel ?? null); const promptMessage = await this.buildPromptMessage(userId, dto.message, dto.localTasks ?? []); diff --git a/apps/api/test/ai.spec.ts b/apps/api/test/ai.spec.ts index be8dcdc..486d016 100644 --- a/apps/api/test/ai.spec.ts +++ b/apps/api/test/ai.spec.ts @@ -12,6 +12,7 @@ import { } from "../generated/prisma/client"; import { AiController } from "../src/ai/ai.controller"; import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service"; +import { AiRateLimitService } from "../src/ai/ai-rate-limit.service"; import { AiService } from "../src/ai/ai.service"; import { AiChatInput, @@ -410,6 +411,7 @@ describe("AiController (integration)", () => { controllers: [AiController], providers: [ AiService, + AiRateLimitService, DataEncryptionService, { provide: PrismaService, @@ -418,8 +420,22 @@ describe("AiController (integration)", () => { { provide: ConfigService, useValue: { - get: (key: string) => - key === "DATA_ENCRYPTION_SECRET" ? "test-data-encryption-secret" : undefined + get: (key: string) => { + if (key === "DATA_ENCRYPTION_SECRET") { + return "test-data-encryption-secret"; + } + if (key === "AI_RATE_LIMIT_WINDOW_MS") { + return 60_000; + } + if (key === "AI_RATE_LIMIT_USER_MAX") { + return 2; + } + if (key === "AI_RATE_LIMIT_IP_MAX") { + return 3; + } + + return undefined; + } } }, { @@ -911,6 +927,139 @@ describe("AiController (integration)", () => { ]); expect(prismaService.getUsageLogs()).toEqual([]); }); + + it("should rate limit ai chat by user in the same window", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_rate_limit_user", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", "203.0.113.10") + .send({ + message: "第一条" + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", "203.0.113.10") + .send({ + message: "第二条" + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", "203.0.113.10") + .send({ + message: "第三条" + }) + .expect(429); + + expect(response.body).toMatchObject({ + message: "AI 请求过于频繁,请稍后再试", + code: "AI_RATE_LIMITED", + dimension: "user", + limit: 2, + windowMs: 60000 + }); + expect(response.body.retryAfterMs).toEqual(expect.any(Number)); + expect(astrbotExecutor.inputs).toHaveLength(2); + expect(prismaService.getUsageLogs()).toHaveLength(2); + }); + + it("should rate limit ai chat by ip across different users", async () => { + prismaService.seedBinding({ + id: "binding_astrbot_rate_limit_ip_user_1", + userId: "user_1", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + prismaService.seedBinding({ + id: "binding_astrbot_rate_limit_ip_user_2", + userId: "user_2", + channel: AiChannel.ASTRBOT, + providerName: "", + model: null, + configId: "default", + configName: null, + encryptedApiKey: "abk_astrbot", + endpoint: "http://127.0.0.1:6185", + isDefault: true, + isEnabled: true + }); + + const sharedIp = "198.51.100.7"; + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", sharedIp) + .send({ + message: "用户一第一条" + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_2") + .set("x-forwarded-for", sharedIp) + .send({ + message: "用户二第一条" + }) + .expect(201); + + await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_1") + .set("x-forwarded-for", sharedIp) + .send({ + message: "用户一第二条" + }) + .expect(201); + + const response = await request(app.getHttpServer()) + .post("/ai/chat") + .set("x-user-id", "user_2") + .set("x-forwarded-for", sharedIp) + .send({ + message: "用户二第二条" + }) + .expect(429); + + expect(response.body).toMatchObject({ + message: "AI 请求过于频繁,请稍后再试", + code: "AI_RATE_LIMITED", + dimension: "ip", + limit: 3, + windowMs: 60000 + }); + expect(response.body.retryAfterMs).toEqual(expect.any(Number)); + expect(astrbotExecutor.inputs).toHaveLength(3); + expect(prismaService.getUsageLogs()).toHaveLength(3); + }); + it("should list usage logs with pagination and filters", async () => { prismaService.seedUsageLog({ id: "usage_log_1",