package middlewares import ( "encoding/json" "errors" configs "goGin/config" database "goGin/app/database/config" "goGin/app/database/models" "log" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" ) const ( corsWhitelistActiveCacheKey = "cors:active:whitelist" corsBlacklistActiveCacheKey = "cors:active:blacklist" corsCacheTTLSeconds = 60 ) var ( allowedMethods = "GET,POST,PUT,PATCH,DELETE,OPTIONS" allowedHeaders = "Authorization,Content-Type,Accept,Origin,X-Requested-With" ) // DynamicCORS validates request Origin using DB-backed whitelist/blacklist with Redis caching. func DynamicCORS() gin.HandlerFunc { return func(c *gin.Context) { origin := strings.TrimSpace(c.GetHeader("Origin")) if origin == "" { c.Next() return } if database.DB == nil { corsLogf("[cors][skip] database unavailable origin=%s path=%s", origin, c.Request.URL.Path) c.Next() return } originKey := strings.ToLower(origin) // Keep same-origin requests working even if DB entries are missing. if origin == requestBaseURL(c) { corsLogf("[cors][allow] same-origin origin=%s path=%s", origin, c.Request.URL.Path) setCORSHeaders(c, origin) if c.Request.Method == http.MethodOptions { c.AbortWithStatus(http.StatusNoContent) return } c.Next() return } blacklist, err := loadActiveOriginSet(corsBlacklistActiveCacheKey, true) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "cors blacklist lookup failed"}) return } if blacklist[originKey] { log.Printf("[cors][blocked] blacklist origin=%s path=%s", origin, c.Request.URL.Path) c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "origin is blocked by CORS policy"}) return } whitelist, err := loadActiveOriginSet(corsWhitelistActiveCacheKey, false) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "cors whitelist lookup failed"}) return } if !whitelist[originKey] { log.Printf("[cors][blocked] not-whitelisted origin=%s path=%s", origin, c.Request.URL.Path) c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "origin is not allowed by CORS policy"}) return } corsLogf("[cors][allow] origin=%s path=%s", origin, c.Request.URL.Path) setCORSHeaders(c, origin) if c.Request.Method == http.MethodOptions { c.AbortWithStatus(http.StatusNoContent) return } c.Next() } } func setCORSHeaders(c *gin.Context, origin string) { c.Header("Vary", "Origin") c.Header("Access-Control-Allow-Origin", origin) c.Header("Access-Control-Allow-Methods", allowedMethods) c.Header("Access-Control-Allow-Headers", allowedHeaders) c.Header("Access-Control-Allow-Credentials", "true") c.Header("Access-Control-Max-Age", "600") } func requestBaseURL(c *gin.Context) string { scheme := c.Request.Header.Get("X-Forwarded-Proto") if scheme == "" { if c.Request.TLS != nil { scheme = "https" } else { scheme = "http" } } return scheme + "://" + c.Request.Host } func loadActiveOriginSet(cacheKey string, isBlacklist bool) (map[string]bool, error) { out := make(map[string]bool) if cached, err := database.Get(cacheKey); err == nil { corsLogf("[cors][cache-hit] key=%s", cacheKey) var origins []string if jsonErr := json.Unmarshal([]byte(cached), &origins); jsonErr == nil { for _, origin := range origins { out[strings.ToLower(strings.TrimSpace(origin))] = true } return out, nil } } else if !errors.Is(err, redis.Nil) { return nil, err } corsLogf("[cors][cache-miss] key=%s", cacheKey) var origins []string var dbErr error if isBlacklist { dbErr = database.DB.Model(&models.CorsBlacklist{}). Where("is_active = ?", true). Pluck("origin", &origins).Error } else { dbErr = database.DB.Model(&models.CorsWhitelist{}). Where("is_active = ?", true). Pluck("origin", &origins).Error } if dbErr != nil { return nil, dbErr } for _, origin := range origins { out[strings.ToLower(strings.TrimSpace(origin))] = true } cacheBytes, _ := json.Marshal(origins) _ = database.SetEx(cacheKey, string(cacheBytes), corsCacheTTLSeconds) return out, nil } func corsLogf(format string, args ...interface{}) { if configs.AppConfig != nil && (configs.AppConfig.Debug || configs.AppConfig.CorsDebug) { log.Printf(format, args...) } }