137 lines
4.1 KiB
Go
137 lines
4.1 KiB
Go
package middlewares
|
|
|
|
import (
|
|
configs "ares/config"
|
|
database "ares/database/config"
|
|
"ares/database/models"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
"github.com/redis/go-redis/v9"
|
|
"go.uber.org/zap"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type rateLimitRuntime struct {
|
|
Name string `json:"name"`
|
|
MaxRequests int64 `json:"max_requests"`
|
|
WindowSeconds int `json:"window_seconds"`
|
|
IsActive bool `json:"is_active"`
|
|
}
|
|
|
|
// RequireRateLimit applies Redis-backed per-IP rate limiting by setting name.
|
|
func RequireRateLimit(name string, fallbackMax int64, fallbackWindowSeconds int) fiber.Handler {
|
|
return func(c fiber.Ctx) error {
|
|
if database.DB == nil {
|
|
return c.Next()
|
|
}
|
|
|
|
setting, err := loadRateLimitRuntime(name, fallbackMax, fallbackWindowSeconds)
|
|
if err != nil {
|
|
return c.Status(http.StatusInternalServerError).JSON(fiber.Map{"error": "rate limit configuration error"})
|
|
}
|
|
if !setting.IsActive {
|
|
return c.Next()
|
|
}
|
|
if database.RedisClient == nil {
|
|
rateLimitLogf("[rate-limit][warn] redis unavailable, skipping enforcement name=%s", setting.Name)
|
|
return c.Next()
|
|
}
|
|
|
|
ip := strings.TrimSpace(c.IP())
|
|
if ip == "" {
|
|
ip = "unknown"
|
|
}
|
|
|
|
counterKey := fmt.Sprintf("ratelimit:%s:%s", setting.Name, ip)
|
|
count, err := database.RedisClient.Incr(context.Background(), counterKey).Result()
|
|
if err != nil {
|
|
return c.Status(http.StatusInternalServerError).JSON(fiber.Map{"error": "rate limit check failed"})
|
|
}
|
|
if count == 1 {
|
|
_ = database.RedisClient.Expire(context.Background(), counterKey, time.Duration(setting.WindowSeconds)*time.Second).Err()
|
|
}
|
|
|
|
if count > setting.MaxRequests {
|
|
ttl, _ := database.RedisClient.TTL(context.Background(), counterKey).Result()
|
|
retryAfter := int(ttl.Seconds())
|
|
if retryAfter < 1 {
|
|
retryAfter = setting.WindowSeconds
|
|
}
|
|
c.Set("Retry-After", strconv.Itoa(retryAfter))
|
|
if configs.Logger != nil {
|
|
configs.Logger.Warn("rate-limit blocked", zapFieldsForRateLimit(setting.Name, ip, count, setting.MaxRequests, setting.WindowSeconds)...)
|
|
}
|
|
return c.Status(http.StatusTooManyRequests).JSON(fiber.Map{
|
|
"error": "too many requests",
|
|
"retry_after": retryAfter,
|
|
})
|
|
}
|
|
|
|
rateLimitLogf("[rate-limit][allow] name=%s ip=%s count=%d max=%d window=%ds", setting.Name, ip, count, setting.MaxRequests, setting.WindowSeconds)
|
|
return c.Next()
|
|
}
|
|
}
|
|
|
|
func loadRateLimitRuntime(name string, fallbackMax int64, fallbackWindowSeconds int) (*rateLimitRuntime, error) {
|
|
cacheKey := "ratelimit:setting:" + name
|
|
if cached, err := database.Get(cacheKey); err == nil {
|
|
var s rateLimitRuntime
|
|
if jsonErr := json.Unmarshal([]byte(cached), &s); jsonErr == nil {
|
|
return &s, nil
|
|
}
|
|
} else if !errors.Is(err, redis.Nil) {
|
|
return nil, err
|
|
}
|
|
|
|
setting := &rateLimitRuntime{
|
|
Name: name,
|
|
MaxRequests: fallbackMax,
|
|
WindowSeconds: fallbackWindowSeconds,
|
|
IsActive: true,
|
|
}
|
|
|
|
var dbSetting models.RateLimitSetting
|
|
if err := database.DB.Where("name = ?", name).First(&dbSetting).Error; err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, err
|
|
}
|
|
rateLimitLogf("[rate-limit][config] setting=%s not found, using fallback max=%d window=%ds", name, fallbackMax, fallbackWindowSeconds)
|
|
} else {
|
|
setting.MaxRequests = dbSetting.MaxRequests
|
|
setting.WindowSeconds = dbSetting.WindowSeconds
|
|
setting.IsActive = dbSetting.IsActive
|
|
rateLimitLogf("[rate-limit][config] loaded from db name=%s active=%t max=%d window=%ds", name, setting.IsActive, setting.MaxRequests, setting.WindowSeconds)
|
|
}
|
|
|
|
cacheJSON, _ := json.Marshal(setting)
|
|
_ = database.SetEx(cacheKey, string(cacheJSON), 60)
|
|
|
|
return setting, nil
|
|
}
|
|
|
|
func zapFieldsForRateLimit(name, ip string, count, max int64, window int) []zap.Field {
|
|
return []zap.Field{
|
|
zap.String("name", name),
|
|
zap.String("ip", ip),
|
|
zap.Int64("count", count),
|
|
zap.Int64("max", max),
|
|
zap.Int("window_seconds", window),
|
|
}
|
|
}
|
|
|
|
func rateLimitLogf(format string, args ...interface{}) {
|
|
if configs.AppConfig != nil && configs.AppConfig.CorsDebug {
|
|
if configs.Logger != nil {
|
|
configs.Logger.Sugar().Infof(format, args...)
|
|
}
|
|
}
|
|
}
|