package middlewares import ( "encoding/json" "errors" configs "goFiber/config" database "goFiber/database/config" "goFiber/database/models" "log" "net/http" "strings" "github.com/gofiber/fiber/v3" "github.com/redis/go-redis/v9" ) const ( corsWhitelistActiveCacheKey = "cors:active:whitelist" corsBlacklistActiveCacheKey = "cors:active:blacklist" corsCacheTTLSeconds = 60 ) var ( allowedMethods = "GET,POST,PUT,PATCH,DELETE,OPTIONS" allowedHeaders = "Authorization,Content-Type,Accept,Origin,X-Requested-With" ) // DynamicCORS validates request Origin using DB-backed whitelist/blacklist with Redis caching. func DynamicCORS() fiber.Handler { return func(c fiber.Ctx) error { origin := strings.TrimSpace(c.Get("Origin")) if origin == "" { return c.Next() } if database.DB == nil { corsLogf("[cors][skip] database unavailable origin=%s path=%s", origin, c.Path()) return c.Next() } originKey := strings.ToLower(origin) // Keep same-origin requests working even if DB entries are missing. if origin == requestBaseURL(c) { corsLogf("[cors][allow] same-origin origin=%s path=%s", origin, c.Path()) setCORSHeaders(c, origin) if c.Method() == http.MethodOptions { return c.SendStatus(http.StatusNoContent) } return c.Next() } blacklist, err := loadActiveOriginSet(corsBlacklistActiveCacheKey, true) if err != nil { return c.Status(http.StatusInternalServerError).JSON(fiber.Map{"error": "cors blacklist lookup failed"}) } if blacklist[originKey] { log.Printf("[cors][blocked] blacklist origin=%s path=%s", origin, c.Path()) return c.Status(http.StatusForbidden).JSON(fiber.Map{"error": "origin is blocked by CORS policy"}) } whitelist, err := loadActiveOriginSet(corsWhitelistActiveCacheKey, false) if err != nil { return c.Status(http.StatusInternalServerError).JSON(fiber.Map{"error": "cors whitelist lookup failed"}) } if !whitelist[originKey] { log.Printf("[cors][blocked] not-whitelisted origin=%s path=%s", origin, c.Path()) return c.Status(http.StatusForbidden).JSON(fiber.Map{"error": "origin is not allowed by CORS policy"}) } corsLogf("[cors][allow] origin=%s path=%s", origin, c.Path()) setCORSHeaders(c, origin) if c.Method() == http.MethodOptions { return c.SendStatus(http.StatusNoContent) } return c.Next() } } func setCORSHeaders(c fiber.Ctx, origin string) { c.Set("Vary", "Origin") c.Set("Access-Control-Allow-Origin", origin) c.Set("Access-Control-Allow-Methods", allowedMethods) c.Set("Access-Control-Allow-Headers", allowedHeaders) c.Set("Access-Control-Allow-Credentials", "true") c.Set("Access-Control-Max-Age", "600") } func requestBaseURL(c fiber.Ctx) string { return c.Protocol() + "://" + c.Get("Host") } func loadActiveOriginSet(cacheKey string, isBlacklist bool) (map[string]bool, error) { out := make(map[string]bool) if cached, err := database.Get(cacheKey); err == nil { corsLogf("[cors][cache-hit] key=%s", cacheKey) var origins []string if jsonErr := json.Unmarshal([]byte(cached), &origins); jsonErr == nil { for _, origin := range origins { out[strings.ToLower(strings.TrimSpace(origin))] = true } return out, nil } } else if !errors.Is(err, redis.Nil) { return nil, err } corsLogf("[cors][cache-miss] key=%s", cacheKey) var origins []string var dbErr error if isBlacklist { dbErr = database.DB.Model(&models.CorsBlacklist{}). Where("is_active = ?", true). Pluck("origin", &origins).Error } else { dbErr = database.DB.Model(&models.CorsWhitelist{}). Where("is_active = ?", true). Pluck("origin", &origins).Error } if dbErr != nil { return nil, dbErr } for _, origin := range origins { out[strings.ToLower(strings.TrimSpace(origin))] = true } cacheBytes, _ := json.Marshal(origins) _ = database.SetEx(cacheKey, string(cacheBytes), corsCacheTTLSeconds) return out, nil } func corsLogf(format string, args ...interface{}) { if configs.AppConfig != nil && configs.AppConfig.CorsDebug { log.Printf(format, args...) } }