feat(api-ai): add user and ip rate limiting
This commit is contained in:
@@ -0,0 +1,123 @@
|
|||||||
|
import { Injectable } from "@nestjs/common";
|
||||||
|
import { ConfigService } from "@nestjs/config";
|
||||||
|
|
||||||
|
type AiRateLimitBucket = {
|
||||||
|
count: number;
|
||||||
|
resetAt: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type AiRateLimitResult =
|
||||||
|
| {
|
||||||
|
allowed: true;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
allowed: false;
|
||||||
|
reason: "USER" | "IP";
|
||||||
|
retryAfterMs: number;
|
||||||
|
limit: number;
|
||||||
|
windowMs: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
@Injectable()
|
||||||
|
export class AiRateLimitService {
|
||||||
|
private readonly userBuckets = new Map<string, AiRateLimitBucket>();
|
||||||
|
private readonly ipBuckets = new Map<string, AiRateLimitBucket>();
|
||||||
|
private readonly windowMs: number;
|
||||||
|
private readonly userLimit: number;
|
||||||
|
private readonly ipLimit: number;
|
||||||
|
|
||||||
|
constructor(private readonly configService: ConfigService) {
|
||||||
|
this.windowMs = this.readPositiveInt("AI_RATE_LIMIT_WINDOW_MS", 60_000);
|
||||||
|
this.userLimit = this.readPositiveInt("AI_RATE_LIMIT_USER_MAX", 20);
|
||||||
|
this.ipLimit = this.readPositiveInt("AI_RATE_LIMIT_IP_MAX", 60);
|
||||||
|
}
|
||||||
|
|
||||||
|
consume(userId: string, clientIp: string | null): AiRateLimitResult {
|
||||||
|
const now = Date.now();
|
||||||
|
const userBucket = this.getBucket(this.userBuckets, userId, now);
|
||||||
|
if (userBucket.count >= this.userLimit) {
|
||||||
|
return {
|
||||||
|
allowed: false,
|
||||||
|
reason: "USER",
|
||||||
|
retryAfterMs: Math.max(0, userBucket.resetAt - now),
|
||||||
|
limit: this.userLimit,
|
||||||
|
windowMs: this.windowMs
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const normalizedIp = this.normalizeIp(clientIp);
|
||||||
|
const ipBucket = normalizedIp ? this.getBucket(this.ipBuckets, normalizedIp, now) : null;
|
||||||
|
if (ipBucket && ipBucket.count >= this.ipLimit) {
|
||||||
|
return {
|
||||||
|
allowed: false,
|
||||||
|
reason: "IP",
|
||||||
|
retryAfterMs: Math.max(0, ipBucket.resetAt - now),
|
||||||
|
limit: this.ipLimit,
|
||||||
|
windowMs: this.windowMs
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
userBucket.count += 1;
|
||||||
|
if (ipBucket) {
|
||||||
|
ipBucket.count += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.cleanupExpiredBuckets(this.userBuckets, now);
|
||||||
|
this.cleanupExpiredBuckets(this.ipBuckets, now);
|
||||||
|
|
||||||
|
return {
|
||||||
|
allowed: true
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private getBucket(
|
||||||
|
buckets: Map<string, AiRateLimitBucket>,
|
||||||
|
key: string,
|
||||||
|
now: number
|
||||||
|
): AiRateLimitBucket {
|
||||||
|
const currentBucket = buckets.get(key);
|
||||||
|
if (!currentBucket || now >= currentBucket.resetAt) {
|
||||||
|
const nextBucket: AiRateLimitBucket = {
|
||||||
|
count: 0,
|
||||||
|
resetAt: now + this.windowMs
|
||||||
|
};
|
||||||
|
buckets.set(key, nextBucket);
|
||||||
|
return nextBucket;
|
||||||
|
}
|
||||||
|
|
||||||
|
return currentBucket;
|
||||||
|
}
|
||||||
|
|
||||||
|
private cleanupExpiredBuckets(buckets: Map<string, AiRateLimitBucket>, now: number): void {
|
||||||
|
if (buckets.size <= 256) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const [key, bucket] of buckets.entries()) {
|
||||||
|
if (now >= bucket.resetAt) {
|
||||||
|
buckets.delete(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private normalizeIp(clientIp: string | null): string | null {
|
||||||
|
if (!clientIp) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const normalizedIp = clientIp.trim();
|
||||||
|
return normalizedIp.length > 0 ? normalizedIp : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private readPositiveInt(key: string, fallbackValue: number): number {
|
||||||
|
const rawValue = this.configService.get<string | number | undefined>(key);
|
||||||
|
const parsedValue =
|
||||||
|
typeof rawValue === "number" ? rawValue : Number.parseInt(String(rawValue ?? ""), 10);
|
||||||
|
|
||||||
|
if (!Number.isFinite(parsedValue) || parsedValue <= 0) {
|
||||||
|
return fallbackValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsedValue;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,13 @@
|
|||||||
import { Body, Controller, Get, Headers, Post, Query, UnauthorizedException } from "@nestjs/common";
|
import {
|
||||||
|
Body,
|
||||||
|
Controller,
|
||||||
|
Get,
|
||||||
|
Headers,
|
||||||
|
Ip,
|
||||||
|
Post,
|
||||||
|
Query,
|
||||||
|
UnauthorizedException
|
||||||
|
} from "@nestjs/common";
|
||||||
import { AiChatDto } from "./dto/ai-chat.dto";
|
import { AiChatDto } from "./dto/ai-chat.dto";
|
||||||
import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto";
|
import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto";
|
||||||
import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto";
|
import { UpsertAiProviderBindingDto } from "./dto/upsert-ai-provider-binding.dto";
|
||||||
@@ -39,9 +48,10 @@ export class AiController {
|
|||||||
@Post("chat")
|
@Post("chat")
|
||||||
async chat(
|
async chat(
|
||||||
@Headers("x-user-id") userIdHeader: string | string[] | undefined,
|
@Headers("x-user-id") userIdHeader: string | string[] | undefined,
|
||||||
|
@Ip() clientIp: string,
|
||||||
@Body() body: AiChatDto
|
@Body() body: AiChatDto
|
||||||
): Promise<AiChatResponse> {
|
): Promise<AiChatResponse> {
|
||||||
return this.aiService.chat(this.resolveUserId(userIdHeader), body);
|
return this.aiService.chat(this.resolveUserId(userIdHeader), body, clientIp);
|
||||||
}
|
}
|
||||||
|
|
||||||
private resolveUserId(userIdHeader: string | string[] | undefined): string {
|
private resolveUserId(userIdHeader: string | string[] | undefined): string {
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import { Module } from "@nestjs/common";
|
import { Module } from "@nestjs/common";
|
||||||
import { PrismaModule } from "../prisma/prisma.module";
|
import { PrismaModule } from "../prisma/prisma.module";
|
||||||
|
import { AiRateLimitService } from "./ai-rate-limit.service";
|
||||||
import { AiController } from "./ai.controller";
|
import { AiController } from "./ai.controller";
|
||||||
import { AiProviderRegistryService } from "./ai-provider-registry.service";
|
import { AiProviderRegistryService } from "./ai-provider-registry.service";
|
||||||
import { AiService } from "./ai.service";
|
import { AiService } from "./ai.service";
|
||||||
@@ -9,6 +10,12 @@ import { OpenAiCompatibleProvider } from "./providers/openai-compatible.provider
|
|||||||
@Module({
|
@Module({
|
||||||
imports: [PrismaModule],
|
imports: [PrismaModule],
|
||||||
controllers: [AiController],
|
controllers: [AiController],
|
||||||
providers: [AiService, AiProviderRegistryService, OpenAiCompatibleProvider, AstrbotProvider]
|
providers: [
|
||||||
|
AiService,
|
||||||
|
AiRateLimitService,
|
||||||
|
AiProviderRegistryService,
|
||||||
|
OpenAiCompatibleProvider,
|
||||||
|
AstrbotProvider
|
||||||
|
]
|
||||||
})
|
})
|
||||||
export class AiModule {}
|
export class AiModule {}
|
||||||
|
|||||||
@@ -1,4 +1,11 @@
|
|||||||
import { BadGatewayException, BadRequestException, Injectable, Logger } from "@nestjs/common";
|
import {
|
||||||
|
BadGatewayException,
|
||||||
|
BadRequestException,
|
||||||
|
HttpException,
|
||||||
|
HttpStatus,
|
||||||
|
Injectable,
|
||||||
|
Logger
|
||||||
|
} from "@nestjs/common";
|
||||||
import {
|
import {
|
||||||
AiChannel,
|
AiChannel,
|
||||||
AiUsageLog,
|
AiUsageLog,
|
||||||
@@ -10,6 +17,7 @@ import {
|
|||||||
} from "../../generated/prisma/client";
|
} from "../../generated/prisma/client";
|
||||||
import { PrismaService } from "../prisma/prisma.service";
|
import { PrismaService } from "../prisma/prisma.service";
|
||||||
import { DataEncryptionService } from "../security/data-encryption.service";
|
import { DataEncryptionService } from "../security/data-encryption.service";
|
||||||
|
import { AiRateLimitService } from "./ai-rate-limit.service";
|
||||||
import { AiProviderRegistryService } from "./ai-provider-registry.service";
|
import { AiProviderRegistryService } from "./ai-provider-registry.service";
|
||||||
import { AiChatDto } from "./dto/ai-chat.dto";
|
import { AiChatDto } from "./dto/ai-chat.dto";
|
||||||
import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto";
|
import { ListAiUsageLogsQueryDto } from "./dto/list-ai-usage-logs-query.dto";
|
||||||
@@ -105,7 +113,8 @@ export class AiService {
|
|||||||
constructor(
|
constructor(
|
||||||
private readonly prismaService: PrismaService,
|
private readonly prismaService: PrismaService,
|
||||||
private readonly aiProviderRegistryService: AiProviderRegistryService,
|
private readonly aiProviderRegistryService: AiProviderRegistryService,
|
||||||
private readonly dataEncryptionService: DataEncryptionService
|
private readonly dataEncryptionService: DataEncryptionService,
|
||||||
|
private readonly aiRateLimitService: AiRateLimitService
|
||||||
) {}
|
) {}
|
||||||
|
|
||||||
async listBindings(userId: string): Promise<ListAiBindingsResponse> {
|
async listBindings(userId: string): Promise<ListAiBindingsResponse> {
|
||||||
@@ -242,7 +251,26 @@ export class AiService {
|
|||||||
return this.serializeBinding(result);
|
return this.serializeBinding(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
async chat(userId: string, dto: AiChatDto): Promise<AiChatResponse> {
|
async chat(
|
||||||
|
userId: string,
|
||||||
|
dto: AiChatDto,
|
||||||
|
clientIp: string | null = null
|
||||||
|
): Promise<AiChatResponse> {
|
||||||
|
const rateLimitResult = this.aiRateLimitService.consume(userId, clientIp);
|
||||||
|
if (!rateLimitResult.allowed) {
|
||||||
|
throw new HttpException(
|
||||||
|
{
|
||||||
|
message: "AI 请求过于频繁,请稍后再试",
|
||||||
|
code: "AI_RATE_LIMITED",
|
||||||
|
dimension: rateLimitResult.reason === "USER" ? "user" : "ip",
|
||||||
|
retryAfterMs: rateLimitResult.retryAfterMs,
|
||||||
|
limit: rateLimitResult.limit,
|
||||||
|
windowMs: rateLimitResult.windowMs
|
||||||
|
},
|
||||||
|
HttpStatus.TOO_MANY_REQUESTS
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
const attempts: AiRouteAttempt[] = [];
|
const attempts: AiRouteAttempt[] = [];
|
||||||
const plan = await this.buildRoutePlan(userId, dto.channel ?? null);
|
const plan = await this.buildRoutePlan(userId, dto.channel ?? null);
|
||||||
const promptMessage = await this.buildPromptMessage(userId, dto.message, dto.localTasks ?? []);
|
const promptMessage = await this.buildPromptMessage(userId, dto.message, dto.localTasks ?? []);
|
||||||
|
|||||||
+151
-2
@@ -12,6 +12,7 @@ import {
|
|||||||
} from "../generated/prisma/client";
|
} from "../generated/prisma/client";
|
||||||
import { AiController } from "../src/ai/ai.controller";
|
import { AiController } from "../src/ai/ai.controller";
|
||||||
import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service";
|
import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service";
|
||||||
|
import { AiRateLimitService } from "../src/ai/ai-rate-limit.service";
|
||||||
import { AiService } from "../src/ai/ai.service";
|
import { AiService } from "../src/ai/ai.service";
|
||||||
import {
|
import {
|
||||||
AiChatInput,
|
AiChatInput,
|
||||||
@@ -410,6 +411,7 @@ describe("AiController (integration)", () => {
|
|||||||
controllers: [AiController],
|
controllers: [AiController],
|
||||||
providers: [
|
providers: [
|
||||||
AiService,
|
AiService,
|
||||||
|
AiRateLimitService,
|
||||||
DataEncryptionService,
|
DataEncryptionService,
|
||||||
{
|
{
|
||||||
provide: PrismaService,
|
provide: PrismaService,
|
||||||
@@ -418,8 +420,22 @@ describe("AiController (integration)", () => {
|
|||||||
{
|
{
|
||||||
provide: ConfigService,
|
provide: ConfigService,
|
||||||
useValue: {
|
useValue: {
|
||||||
get: (key: string) =>
|
get: (key: string) => {
|
||||||
key === "DATA_ENCRYPTION_SECRET" ? "test-data-encryption-secret" : undefined
|
if (key === "DATA_ENCRYPTION_SECRET") {
|
||||||
|
return "test-data-encryption-secret";
|
||||||
|
}
|
||||||
|
if (key === "AI_RATE_LIMIT_WINDOW_MS") {
|
||||||
|
return 60_000;
|
||||||
|
}
|
||||||
|
if (key === "AI_RATE_LIMIT_USER_MAX") {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
if (key === "AI_RATE_LIMIT_IP_MAX") {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -911,6 +927,139 @@ describe("AiController (integration)", () => {
|
|||||||
]);
|
]);
|
||||||
expect(prismaService.getUsageLogs()).toEqual([]);
|
expect(prismaService.getUsageLogs()).toEqual([]);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should rate limit ai chat by user in the same window", async () => {
|
||||||
|
prismaService.seedBinding({
|
||||||
|
id: "binding_astrbot_rate_limit_user",
|
||||||
|
userId: "user_1",
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
providerName: "",
|
||||||
|
model: null,
|
||||||
|
configId: "default",
|
||||||
|
configName: null,
|
||||||
|
encryptedApiKey: "abk_astrbot",
|
||||||
|
endpoint: "http://127.0.0.1:6185",
|
||||||
|
isDefault: true,
|
||||||
|
isEnabled: true
|
||||||
|
});
|
||||||
|
|
||||||
|
await request(app.getHttpServer())
|
||||||
|
.post("/ai/chat")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.set("x-forwarded-for", "203.0.113.10")
|
||||||
|
.send({
|
||||||
|
message: "第一条"
|
||||||
|
})
|
||||||
|
.expect(201);
|
||||||
|
|
||||||
|
await request(app.getHttpServer())
|
||||||
|
.post("/ai/chat")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.set("x-forwarded-for", "203.0.113.10")
|
||||||
|
.send({
|
||||||
|
message: "第二条"
|
||||||
|
})
|
||||||
|
.expect(201);
|
||||||
|
|
||||||
|
const response = await request(app.getHttpServer())
|
||||||
|
.post("/ai/chat")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.set("x-forwarded-for", "203.0.113.10")
|
||||||
|
.send({
|
||||||
|
message: "第三条"
|
||||||
|
})
|
||||||
|
.expect(429);
|
||||||
|
|
||||||
|
expect(response.body).toMatchObject({
|
||||||
|
message: "AI 请求过于频繁,请稍后再试",
|
||||||
|
code: "AI_RATE_LIMITED",
|
||||||
|
dimension: "user",
|
||||||
|
limit: 2,
|
||||||
|
windowMs: 60000
|
||||||
|
});
|
||||||
|
expect(response.body.retryAfterMs).toEqual(expect.any(Number));
|
||||||
|
expect(astrbotExecutor.inputs).toHaveLength(2);
|
||||||
|
expect(prismaService.getUsageLogs()).toHaveLength(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should rate limit ai chat by ip across different users", async () => {
|
||||||
|
prismaService.seedBinding({
|
||||||
|
id: "binding_astrbot_rate_limit_ip_user_1",
|
||||||
|
userId: "user_1",
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
providerName: "",
|
||||||
|
model: null,
|
||||||
|
configId: "default",
|
||||||
|
configName: null,
|
||||||
|
encryptedApiKey: "abk_astrbot",
|
||||||
|
endpoint: "http://127.0.0.1:6185",
|
||||||
|
isDefault: true,
|
||||||
|
isEnabled: true
|
||||||
|
});
|
||||||
|
prismaService.seedBinding({
|
||||||
|
id: "binding_astrbot_rate_limit_ip_user_2",
|
||||||
|
userId: "user_2",
|
||||||
|
channel: AiChannel.ASTRBOT,
|
||||||
|
providerName: "",
|
||||||
|
model: null,
|
||||||
|
configId: "default",
|
||||||
|
configName: null,
|
||||||
|
encryptedApiKey: "abk_astrbot",
|
||||||
|
endpoint: "http://127.0.0.1:6185",
|
||||||
|
isDefault: true,
|
||||||
|
isEnabled: true
|
||||||
|
});
|
||||||
|
|
||||||
|
const sharedIp = "198.51.100.7";
|
||||||
|
|
||||||
|
await request(app.getHttpServer())
|
||||||
|
.post("/ai/chat")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.set("x-forwarded-for", sharedIp)
|
||||||
|
.send({
|
||||||
|
message: "用户一第一条"
|
||||||
|
})
|
||||||
|
.expect(201);
|
||||||
|
|
||||||
|
await request(app.getHttpServer())
|
||||||
|
.post("/ai/chat")
|
||||||
|
.set("x-user-id", "user_2")
|
||||||
|
.set("x-forwarded-for", sharedIp)
|
||||||
|
.send({
|
||||||
|
message: "用户二第一条"
|
||||||
|
})
|
||||||
|
.expect(201);
|
||||||
|
|
||||||
|
await request(app.getHttpServer())
|
||||||
|
.post("/ai/chat")
|
||||||
|
.set("x-user-id", "user_1")
|
||||||
|
.set("x-forwarded-for", sharedIp)
|
||||||
|
.send({
|
||||||
|
message: "用户一第二条"
|
||||||
|
})
|
||||||
|
.expect(201);
|
||||||
|
|
||||||
|
const response = await request(app.getHttpServer())
|
||||||
|
.post("/ai/chat")
|
||||||
|
.set("x-user-id", "user_2")
|
||||||
|
.set("x-forwarded-for", sharedIp)
|
||||||
|
.send({
|
||||||
|
message: "用户二第二条"
|
||||||
|
})
|
||||||
|
.expect(429);
|
||||||
|
|
||||||
|
expect(response.body).toMatchObject({
|
||||||
|
message: "AI 请求过于频繁,请稍后再试",
|
||||||
|
code: "AI_RATE_LIMITED",
|
||||||
|
dimension: "ip",
|
||||||
|
limit: 3,
|
||||||
|
windowMs: 60000
|
||||||
|
});
|
||||||
|
expect(response.body.retryAfterMs).toEqual(expect.any(Number));
|
||||||
|
expect(astrbotExecutor.inputs).toHaveLength(3);
|
||||||
|
expect(prismaService.getUsageLogs()).toHaveLength(3);
|
||||||
|
});
|
||||||
|
|
||||||
it("should list usage logs with pagination and filters", async () => {
|
it("should list usage logs with pagination and filters", async () => {
|
||||||
prismaService.seedUsageLog({
|
prismaService.seedUsageLog({
|
||||||
id: "usage_log_1",
|
id: "usage_log_1",
|
||||||
|
|||||||
Reference in New Issue
Block a user