first commit
This commit is contained in:
145
pkg/middleware/rate_limit_dynamic.go
Normal file
145
pkg/middleware/rate_limit_dynamic.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"goaresv3/app/settings/models"
|
||||
"goaresv3/config"
|
||||
)
|
||||
|
||||
type rateLimitBucket struct {
|
||||
Count int64
|
||||
WindowEnds time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
rateLimitMu sync.Mutex
|
||||
rateLimitBuckets = map[string]rateLimitBucket{}
|
||||
)
|
||||
|
||||
// DynamicRateLimit enforces DB-backed rate-limit settings.
|
||||
// Rule selection order:
|
||||
// 1) Exact route path key (without leading slash), e.g. api/v1/auth/login
|
||||
// 2) "api" fallback
|
||||
func DynamicRateLimit() gin.HandlerFunc {
|
||||
debug := envBool("RATE_LIMIT_DEBUG", false)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
pathKey := strings.TrimPrefix(c.FullPath(), "/")
|
||||
if pathKey == "" {
|
||||
pathKey = strings.TrimPrefix(c.Request.URL.Path, "/")
|
||||
}
|
||||
if shouldSkipRateLimit(pathKey) {
|
||||
policyLogf(debug, "[rate-limit] skip path=%s reason=skip-list", pathKey)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if isWhitelisted, ok := c.Get("origin_whitelisted"); ok {
|
||||
if v, castOK := isWhitelisted.(bool); castOK && v {
|
||||
// Whitelisted origins are excluded from rate limiting.
|
||||
policyLogf(debug, "[rate-limit] skip path=%s reason=origin-whitelisted", pathKey)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if config.DB == nil {
|
||||
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "database is not connected"})
|
||||
return
|
||||
}
|
||||
|
||||
rule, ok := resolveRateLimitRule(pathKey)
|
||||
if !ok {
|
||||
policyLogf(debug, "[rate-limit] no-rule path=%s", pathKey)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
windowDur := time.Duration(rule.WindowSeconds) * time.Second
|
||||
bucketKey := rule.Name + ":" + clientIP
|
||||
|
||||
now := time.Now()
|
||||
rateLimitMu.Lock()
|
||||
bucket, ok := rateLimitBuckets[bucketKey]
|
||||
if !ok || now.After(bucket.WindowEnds) {
|
||||
bucket = rateLimitBucket{
|
||||
Count: 0,
|
||||
WindowEnds: now.Add(windowDur),
|
||||
}
|
||||
}
|
||||
|
||||
bucket.Count++
|
||||
rateLimitBuckets[bucketKey] = bucket
|
||||
remaining := rule.MaxRequests - bucket.Count
|
||||
resetIn := int(time.Until(bucket.WindowEnds).Seconds())
|
||||
rateLimitMu.Unlock()
|
||||
|
||||
if remaining < 0 {
|
||||
policyLogf(debug, "[rate-limit] blocked path=%s rule=%s ip=%s limit=%d window=%d", pathKey, rule.Name, clientIP, rule.MaxRequests, rule.WindowSeconds)
|
||||
c.Header("Retry-After", strconvItoa(maxInt(resetIn, 1)))
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded",
|
||||
"limit": rule.MaxRequests,
|
||||
"window_sec": rule.WindowSeconds,
|
||||
"retry_after": maxInt(resetIn, 1),
|
||||
"rule_name": rule.Name,
|
||||
"client_ip": clientIP,
|
||||
"request_path": pathKey,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("X-RateLimit-Limit", strconvI64(rule.MaxRequests))
|
||||
c.Header("X-RateLimit-Remaining", strconvI64(maxI64(remaining, 0)))
|
||||
c.Header("X-RateLimit-Reset", strconvItoa(maxInt(resetIn, 0)))
|
||||
policyLogf(debug, "[rate-limit] pass path=%s rule=%s ip=%s remaining=%d", pathKey, rule.Name, clientIP, maxI64(remaining, 0))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func shouldSkipRateLimit(pathKey string) bool {
|
||||
return strings.HasPrefix(pathKey, "swagger")
|
||||
}
|
||||
|
||||
func resolveRateLimitRule(pathKey string) (models.RateLimitSetting, bool) {
|
||||
var rule models.RateLimitSetting
|
||||
res := config.DB.Where("name = ? AND is_active = ?", pathKey, true).Limit(1).Find(&rule)
|
||||
if res.Error == nil && res.RowsAffected > 0 {
|
||||
return rule, true
|
||||
}
|
||||
res = config.DB.Where("name = ? AND is_active = ?", "api", true).Limit(1).Find(&rule)
|
||||
if res.Error == nil && res.RowsAffected > 0 {
|
||||
return rule, true
|
||||
}
|
||||
return models.RateLimitSetting{}, false
|
||||
}
|
||||
|
||||
func strconvItoa(v int) string {
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
}
|
||||
|
||||
func strconvI64(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
|
||||
func maxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func maxI64(a, b int64) int64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
Reference in New Issue
Block a user