package middleware import ( "net/http" "strconv" "strings" "sync" "time" "github.com/gin-gonic/gin" "goaresv3/app/settings/models" "goaresv3/config" ) type rateLimitBucket struct { Count int64 WindowEnds time.Time } var ( rateLimitMu sync.Mutex rateLimitBuckets = map[string]rateLimitBucket{} ) // DynamicRateLimit enforces DB-backed rate-limit settings. // Rule selection order: // 1) Exact route path key (without leading slash), e.g. api/v1/auth/login // 2) "api" fallback func DynamicRateLimit() gin.HandlerFunc { debug := envBool("RATE_LIMIT_DEBUG", false) return func(c *gin.Context) { pathKey := strings.TrimPrefix(c.FullPath(), "/") if pathKey == "" { pathKey = strings.TrimPrefix(c.Request.URL.Path, "/") } if shouldSkipRateLimit(pathKey) { policyLogf(debug, "[rate-limit] skip path=%s reason=skip-list", pathKey) c.Next() return } if isWhitelisted, ok := c.Get("origin_whitelisted"); ok { if v, castOK := isWhitelisted.(bool); castOK && v { // Whitelisted origins are excluded from rate limiting. policyLogf(debug, "[rate-limit] skip path=%s reason=origin-whitelisted", pathKey) c.Next() return } } if config.DB == nil { c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "database is not connected"}) return } rule, ok := resolveRateLimitRule(pathKey) if !ok { policyLogf(debug, "[rate-limit] no-rule path=%s", pathKey) c.Next() return } clientIP := c.ClientIP() windowDur := time.Duration(rule.WindowSeconds) * time.Second bucketKey := rule.Name + ":" + clientIP now := time.Now() rateLimitMu.Lock() bucket, ok := rateLimitBuckets[bucketKey] if !ok || now.After(bucket.WindowEnds) { bucket = rateLimitBucket{ Count: 0, WindowEnds: now.Add(windowDur), } } bucket.Count++ rateLimitBuckets[bucketKey] = bucket remaining := rule.MaxRequests - bucket.Count resetIn := int(time.Until(bucket.WindowEnds).Seconds()) rateLimitMu.Unlock() if remaining < 0 { policyLogf(debug, "[rate-limit] blocked path=%s rule=%s ip=%s limit=%d window=%d", pathKey, rule.Name, clientIP, rule.MaxRequests, rule.WindowSeconds) c.Header("Retry-After", strconvItoa(maxInt(resetIn, 1))) c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{ "error": "rate limit exceeded", "limit": rule.MaxRequests, "window_sec": rule.WindowSeconds, "retry_after": maxInt(resetIn, 1), "rule_name": rule.Name, "client_ip": clientIP, "request_path": pathKey, }) return } c.Header("X-RateLimit-Limit", strconvI64(rule.MaxRequests)) c.Header("X-RateLimit-Remaining", strconvI64(maxI64(remaining, 0))) c.Header("X-RateLimit-Reset", strconvItoa(maxInt(resetIn, 0))) policyLogf(debug, "[rate-limit] pass path=%s rule=%s ip=%s remaining=%d", pathKey, rule.Name, clientIP, maxI64(remaining, 0)) c.Next() } } func shouldSkipRateLimit(pathKey string) bool { return strings.HasPrefix(pathKey, "swagger") } func resolveRateLimitRule(pathKey string) (models.RateLimitSetting, bool) { var rule models.RateLimitSetting res := config.DB.Where("name = ? AND is_active = ?", pathKey, true).Limit(1).Find(&rule) if res.Error == nil && res.RowsAffected > 0 { return rule, true } res = config.DB.Where("name = ? AND is_active = ?", "api", true).Limit(1).Find(&rule) if res.Error == nil && res.RowsAffected > 0 { return rule, true } return models.RateLimitSetting{}, false } func strconvItoa(v int) string { return strconv.FormatInt(int64(v), 10) } func strconvI64(v int64) string { return strconv.FormatInt(v, 10) } func maxInt(a, b int) int { if a > b { return a } return b } func maxI64(a, b int64) int64 { if a > b { return a } return b }