first commit
This commit is contained in:
154
app/middlewares/dynamic_cors.go
Normal file
154
app/middlewares/dynamic_cors.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
configs "goGin/config"
|
||||
database "goGin/app/database/config"
|
||||
"goGin/app/database/models"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
corsWhitelistActiveCacheKey = "cors:active:whitelist"
|
||||
corsBlacklistActiveCacheKey = "cors:active:blacklist"
|
||||
corsCacheTTLSeconds = 60
|
||||
)
|
||||
|
||||
var (
|
||||
allowedMethods = "GET,POST,PUT,PATCH,DELETE,OPTIONS"
|
||||
allowedHeaders = "Authorization,Content-Type,Accept,Origin,X-Requested-With"
|
||||
)
|
||||
|
||||
// DynamicCORS validates request Origin using DB-backed whitelist/blacklist with Redis caching.
|
||||
func DynamicCORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
||||
if origin == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
if database.DB == nil {
|
||||
corsLogf("[cors][skip] database unavailable origin=%s path=%s", origin, c.Request.URL.Path)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
originKey := strings.ToLower(origin)
|
||||
// Keep same-origin requests working even if DB entries are missing.
|
||||
if origin == requestBaseURL(c) {
|
||||
corsLogf("[cors][allow] same-origin origin=%s path=%s", origin, c.Request.URL.Path)
|
||||
setCORSHeaders(c, origin)
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
blacklist, err := loadActiveOriginSet(corsBlacklistActiveCacheKey, true)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "cors blacklist lookup failed"})
|
||||
return
|
||||
}
|
||||
if blacklist[originKey] {
|
||||
log.Printf("[cors][blocked] blacklist origin=%s path=%s", origin, c.Request.URL.Path)
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "origin is blocked by CORS policy"})
|
||||
return
|
||||
}
|
||||
|
||||
whitelist, err := loadActiveOriginSet(corsWhitelistActiveCacheKey, false)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "cors whitelist lookup failed"})
|
||||
return
|
||||
}
|
||||
if !whitelist[originKey] {
|
||||
log.Printf("[cors][blocked] not-whitelisted origin=%s path=%s", origin, c.Request.URL.Path)
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "origin is not allowed by CORS policy"})
|
||||
return
|
||||
}
|
||||
|
||||
corsLogf("[cors][allow] origin=%s path=%s", origin, c.Request.URL.Path)
|
||||
setCORSHeaders(c, origin)
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func setCORSHeaders(c *gin.Context, origin string) {
|
||||
c.Header("Vary", "Origin")
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
c.Header("Access-Control-Allow-Methods", allowedMethods)
|
||||
c.Header("Access-Control-Allow-Headers", allowedHeaders)
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Max-Age", "600")
|
||||
}
|
||||
|
||||
func requestBaseURL(c *gin.Context) string {
|
||||
scheme := c.Request.Header.Get("X-Forwarded-Proto")
|
||||
if scheme == "" {
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
return scheme + "://" + c.Request.Host
|
||||
}
|
||||
|
||||
func loadActiveOriginSet(cacheKey string, isBlacklist bool) (map[string]bool, error) {
|
||||
out := make(map[string]bool)
|
||||
|
||||
if cached, err := database.Get(cacheKey); err == nil {
|
||||
corsLogf("[cors][cache-hit] key=%s", cacheKey)
|
||||
var origins []string
|
||||
if jsonErr := json.Unmarshal([]byte(cached), &origins); jsonErr == nil {
|
||||
for _, origin := range origins {
|
||||
out[strings.ToLower(strings.TrimSpace(origin))] = true
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
} else if !errors.Is(err, redis.Nil) {
|
||||
return nil, err
|
||||
}
|
||||
corsLogf("[cors][cache-miss] key=%s", cacheKey)
|
||||
|
||||
var origins []string
|
||||
var dbErr error
|
||||
if isBlacklist {
|
||||
dbErr = database.DB.Model(&models.CorsBlacklist{}).
|
||||
Where("is_active = ?", true).
|
||||
Pluck("origin", &origins).Error
|
||||
} else {
|
||||
dbErr = database.DB.Model(&models.CorsWhitelist{}).
|
||||
Where("is_active = ?", true).
|
||||
Pluck("origin", &origins).Error
|
||||
}
|
||||
if dbErr != nil {
|
||||
return nil, dbErr
|
||||
}
|
||||
|
||||
for _, origin := range origins {
|
||||
out[strings.ToLower(strings.TrimSpace(origin))] = true
|
||||
}
|
||||
|
||||
cacheBytes, _ := json.Marshal(origins)
|
||||
_ = database.SetEx(cacheKey, string(cacheBytes), corsCacheTTLSeconds)
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func corsLogf(format string, args ...interface{}) {
|
||||
if configs.AppConfig != nil && (configs.AppConfig.Debug || configs.AppConfig.CorsDebug) {
|
||||
log.Printf(format, args...)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user