feat(api-ai): add user and ip rate limiting

This commit is contained in:
2026-04-07 22:56:22 +08:00
parent ce72892dc8
commit 1f8b539b68
5 changed files with 325 additions and 8 deletions
+151 -2
View File
@@ -12,6 +12,7 @@ import {
} from "../generated/prisma/client";
import { AiController } from "../src/ai/ai.controller";
import { AiProviderRegistryService } from "../src/ai/ai-provider-registry.service";
import { AiRateLimitService } from "../src/ai/ai-rate-limit.service";
import { AiService } from "../src/ai/ai.service";
import {
AiChatInput,
@@ -410,6 +411,7 @@ describe("AiController (integration)", () => {
controllers: [AiController],
providers: [
AiService,
AiRateLimitService,
DataEncryptionService,
{
provide: PrismaService,
@@ -418,8 +420,22 @@ describe("AiController (integration)", () => {
{
provide: ConfigService,
useValue: {
get: (key: string) =>
key === "DATA_ENCRYPTION_SECRET" ? "test-data-encryption-secret" : undefined
get: (key: string) => {
if (key === "DATA_ENCRYPTION_SECRET") {
return "test-data-encryption-secret";
}
if (key === "AI_RATE_LIMIT_WINDOW_MS") {
return 60_000;
}
if (key === "AI_RATE_LIMIT_USER_MAX") {
return 2;
}
if (key === "AI_RATE_LIMIT_IP_MAX") {
return 3;
}
return undefined;
}
}
},
{
@@ -911,6 +927,139 @@ describe("AiController (integration)", () => {
]);
expect(prismaService.getUsageLogs()).toEqual([]);
});
it("should rate limit ai chat by user in the same window", async () => {
prismaService.seedBinding({
id: "binding_astrbot_rate_limit_user",
userId: "user_1",
channel: AiChannel.ASTRBOT,
providerName: "",
model: null,
configId: "default",
configName: null,
encryptedApiKey: "abk_astrbot",
endpoint: "http://127.0.0.1:6185",
isDefault: true,
isEnabled: true
});
await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.set("x-forwarded-for", "203.0.113.10")
.send({
message: "第一条"
})
.expect(201);
await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.set("x-forwarded-for", "203.0.113.10")
.send({
message: "第二条"
})
.expect(201);
const response = await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.set("x-forwarded-for", "203.0.113.10")
.send({
message: "第三条"
})
.expect(429);
expect(response.body).toMatchObject({
message: "AI 请求过于频繁,请稍后再试",
code: "AI_RATE_LIMITED",
dimension: "user",
limit: 2,
windowMs: 60000
});
expect(response.body.retryAfterMs).toEqual(expect.any(Number));
expect(astrbotExecutor.inputs).toHaveLength(2);
expect(prismaService.getUsageLogs()).toHaveLength(2);
});
it("should rate limit ai chat by ip across different users", async () => {
prismaService.seedBinding({
id: "binding_astrbot_rate_limit_ip_user_1",
userId: "user_1",
channel: AiChannel.ASTRBOT,
providerName: "",
model: null,
configId: "default",
configName: null,
encryptedApiKey: "abk_astrbot",
endpoint: "http://127.0.0.1:6185",
isDefault: true,
isEnabled: true
});
prismaService.seedBinding({
id: "binding_astrbot_rate_limit_ip_user_2",
userId: "user_2",
channel: AiChannel.ASTRBOT,
providerName: "",
model: null,
configId: "default",
configName: null,
encryptedApiKey: "abk_astrbot",
endpoint: "http://127.0.0.1:6185",
isDefault: true,
isEnabled: true
});
const sharedIp = "198.51.100.7";
await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.set("x-forwarded-for", sharedIp)
.send({
message: "用户一第一条"
})
.expect(201);
await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_2")
.set("x-forwarded-for", sharedIp)
.send({
message: "用户二第一条"
})
.expect(201);
await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_1")
.set("x-forwarded-for", sharedIp)
.send({
message: "用户一第二条"
})
.expect(201);
const response = await request(app.getHttpServer())
.post("/ai/chat")
.set("x-user-id", "user_2")
.set("x-forwarded-for", sharedIp)
.send({
message: "用户二第二条"
})
.expect(429);
expect(response.body).toMatchObject({
message: "AI 请求过于频繁,请稍后再试",
code: "AI_RATE_LIMITED",
dimension: "ip",
limit: 3,
windowMs: 60000
});
expect(response.body.retryAfterMs).toEqual(expect.any(Number));
expect(astrbotExecutor.inputs).toHaveLength(3);
expect(prismaService.getUsageLogs()).toHaveLength(3);
});
it("should list usage logs with pagination and filters", async () => {
prismaService.seedUsageLog({
id: "usage_log_1",