first commit
This commit is contained in:
63
middlewares/auth_middleware.go
Normal file
63
middlewares/auth_middleware.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"goFiber/services"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
)
|
||||
|
||||
const authClaimsKey = "auth_claims"
|
||||
|
||||
func RequireAuth(c fiber.Ctx) error {
|
||||
authHeader := strings.TrimSpace(c.Get("Authorization"))
|
||||
if authHeader == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "authorization header is required"})
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") || strings.TrimSpace(parts[1]) == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "invalid authorization format, expected: Bearer <token>"})
|
||||
}
|
||||
|
||||
jwtService := services.NewJWTService()
|
||||
claims, err := jwtService.ValidateToken(strings.TrimSpace(parts[1]))
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "invalid token"})
|
||||
}
|
||||
if claims.TokenType != services.TokenTypeAccess {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "access token required"})
|
||||
}
|
||||
|
||||
c.Locals(authClaimsKey, claims)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
func RequireAdmin(c fiber.Ctx) error {
|
||||
claims, ok := GetAuthClaims(c)
|
||||
if !ok {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "unauthorized"})
|
||||
}
|
||||
if !claims.IsAdmin {
|
||||
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "admin role required"})
|
||||
}
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
func RequireNormalUser(c fiber.Ctx) error {
|
||||
claims, ok := GetAuthClaims(c)
|
||||
if !ok {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "unauthorized"})
|
||||
}
|
||||
if claims.IsAdmin {
|
||||
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{"error": "only normal users can access this endpoint"})
|
||||
}
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
func GetAuthClaims(c fiber.Ctx) (*services.JWTClaim, bool) {
|
||||
raw := c.Locals(authClaimsKey)
|
||||
claims, ok := raw.(*services.JWTClaim)
|
||||
return claims, ok
|
||||
}
|
||||
137
middlewares/dynamic_cors.go
Normal file
137
middlewares/dynamic_cors.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
configs "goFiber/config"
|
||||
database "goFiber/database/config"
|
||||
"goFiber/database/models"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
corsWhitelistActiveCacheKey = "cors:active:whitelist"
|
||||
corsBlacklistActiveCacheKey = "cors:active:blacklist"
|
||||
corsCacheTTLSeconds = 60
|
||||
)
|
||||
|
||||
var (
|
||||
allowedMethods = "GET,POST,PUT,PATCH,DELETE,OPTIONS"
|
||||
allowedHeaders = "Authorization,Content-Type,Accept,Origin,X-Requested-With"
|
||||
)
|
||||
|
||||
// DynamicCORS validates request Origin using DB-backed whitelist/blacklist with Redis caching.
|
||||
func DynamicCORS() fiber.Handler {
|
||||
return func(c fiber.Ctx) error {
|
||||
origin := strings.TrimSpace(c.Get("Origin"))
|
||||
if origin == "" {
|
||||
return c.Next()
|
||||
}
|
||||
if database.DB == nil {
|
||||
corsLogf("[cors][skip] database unavailable origin=%s path=%s", origin, c.Path())
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
originKey := strings.ToLower(origin)
|
||||
// Keep same-origin requests working even if DB entries are missing.
|
||||
if origin == requestBaseURL(c) {
|
||||
corsLogf("[cors][allow] same-origin origin=%s path=%s", origin, c.Path())
|
||||
setCORSHeaders(c, origin)
|
||||
if c.Method() == http.MethodOptions {
|
||||
return c.SendStatus(http.StatusNoContent)
|
||||
}
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
blacklist, err := loadActiveOriginSet(corsBlacklistActiveCacheKey, true)
|
||||
if err != nil {
|
||||
return c.Status(http.StatusInternalServerError).JSON(fiber.Map{"error": "cors blacklist lookup failed"})
|
||||
}
|
||||
if blacklist[originKey] {
|
||||
log.Printf("[cors][blocked] blacklist origin=%s path=%s", origin, c.Path())
|
||||
return c.Status(http.StatusForbidden).JSON(fiber.Map{"error": "origin is blocked by CORS policy"})
|
||||
}
|
||||
|
||||
whitelist, err := loadActiveOriginSet(corsWhitelistActiveCacheKey, false)
|
||||
if err != nil {
|
||||
return c.Status(http.StatusInternalServerError).JSON(fiber.Map{"error": "cors whitelist lookup failed"})
|
||||
}
|
||||
if !whitelist[originKey] {
|
||||
log.Printf("[cors][blocked] not-whitelisted origin=%s path=%s", origin, c.Path())
|
||||
return c.Status(http.StatusForbidden).JSON(fiber.Map{"error": "origin is not allowed by CORS policy"})
|
||||
}
|
||||
|
||||
corsLogf("[cors][allow] origin=%s path=%s", origin, c.Path())
|
||||
setCORSHeaders(c, origin)
|
||||
if c.Method() == http.MethodOptions {
|
||||
return c.SendStatus(http.StatusNoContent)
|
||||
}
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func setCORSHeaders(c fiber.Ctx, origin string) {
|
||||
c.Set("Vary", "Origin")
|
||||
c.Set("Access-Control-Allow-Origin", origin)
|
||||
c.Set("Access-Control-Allow-Methods", allowedMethods)
|
||||
c.Set("Access-Control-Allow-Headers", allowedHeaders)
|
||||
c.Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Set("Access-Control-Max-Age", "600")
|
||||
}
|
||||
|
||||
func requestBaseURL(c fiber.Ctx) string {
|
||||
return c.Protocol() + "://" + c.Get("Host")
|
||||
}
|
||||
|
||||
func loadActiveOriginSet(cacheKey string, isBlacklist bool) (map[string]bool, error) {
|
||||
out := make(map[string]bool)
|
||||
|
||||
if cached, err := database.Get(cacheKey); err == nil {
|
||||
corsLogf("[cors][cache-hit] key=%s", cacheKey)
|
||||
var origins []string
|
||||
if jsonErr := json.Unmarshal([]byte(cached), &origins); jsonErr == nil {
|
||||
for _, origin := range origins {
|
||||
out[strings.ToLower(strings.TrimSpace(origin))] = true
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
} else if !errors.Is(err, redis.Nil) {
|
||||
return nil, err
|
||||
}
|
||||
corsLogf("[cors][cache-miss] key=%s", cacheKey)
|
||||
|
||||
var origins []string
|
||||
var dbErr error
|
||||
if isBlacklist {
|
||||
dbErr = database.DB.Model(&models.CorsBlacklist{}).
|
||||
Where("is_active = ?", true).
|
||||
Pluck("origin", &origins).Error
|
||||
} else {
|
||||
dbErr = database.DB.Model(&models.CorsWhitelist{}).
|
||||
Where("is_active = ?", true).
|
||||
Pluck("origin", &origins).Error
|
||||
}
|
||||
if dbErr != nil {
|
||||
return nil, dbErr
|
||||
}
|
||||
|
||||
for _, origin := range origins {
|
||||
out[strings.ToLower(strings.TrimSpace(origin))] = true
|
||||
}
|
||||
|
||||
cacheBytes, _ := json.Marshal(origins)
|
||||
_ = database.SetEx(cacheKey, string(cacheBytes), corsCacheTTLSeconds)
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func corsLogf(format string, args ...interface{}) {
|
||||
if configs.AppConfig != nil && configs.AppConfig.CorsDebug {
|
||||
log.Printf(format, args...)
|
||||
}
|
||||
}
|
||||
122
middlewares/rate_limit.go
Normal file
122
middlewares/rate_limit.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
configs "goFiber/config"
|
||||
database "goFiber/database/config"
|
||||
"goFiber/database/models"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v3"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type rateLimitRuntime struct {
|
||||
Name string `json:"name"`
|
||||
MaxRequests int64 `json:"max_requests"`
|
||||
WindowSeconds int `json:"window_seconds"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
||||
// RequireRateLimit applies Redis-backed per-IP rate limiting by setting name.
|
||||
func RequireRateLimit(name string, fallbackMax int64, fallbackWindowSeconds int) fiber.Handler {
|
||||
return func(c fiber.Ctx) error {
|
||||
if database.DB == nil {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
setting, err := loadRateLimitRuntime(name, fallbackMax, fallbackWindowSeconds)
|
||||
if err != nil {
|
||||
return c.Status(http.StatusInternalServerError).JSON(fiber.Map{"error": "rate limit configuration error"})
|
||||
}
|
||||
if !setting.IsActive {
|
||||
return c.Next()
|
||||
}
|
||||
if database.RedisClient == nil {
|
||||
rateLimitLogf("[rate-limit][warn] redis unavailable, skipping enforcement name=%s", setting.Name)
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
ip := strings.TrimSpace(c.IP())
|
||||
if ip == "" {
|
||||
ip = "unknown"
|
||||
}
|
||||
|
||||
counterKey := fmt.Sprintf("ratelimit:%s:%s", setting.Name, ip)
|
||||
count, err := database.RedisClient.Incr(context.Background(), counterKey).Result()
|
||||
if err != nil {
|
||||
return c.Status(http.StatusInternalServerError).JSON(fiber.Map{"error": "rate limit check failed"})
|
||||
}
|
||||
if count == 1 {
|
||||
_ = database.RedisClient.Expire(context.Background(), counterKey, time.Duration(setting.WindowSeconds)*time.Second).Err()
|
||||
}
|
||||
|
||||
if count > setting.MaxRequests {
|
||||
ttl, _ := database.RedisClient.TTL(context.Background(), counterKey).Result()
|
||||
retryAfter := int(ttl.Seconds())
|
||||
if retryAfter < 1 {
|
||||
retryAfter = setting.WindowSeconds
|
||||
}
|
||||
c.Set("Retry-After", strconv.Itoa(retryAfter))
|
||||
log.Printf("[rate-limit][blocked] name=%s ip=%s count=%d max=%d window=%ds", setting.Name, ip, count, setting.MaxRequests, setting.WindowSeconds)
|
||||
return c.Status(http.StatusTooManyRequests).JSON(fiber.Map{
|
||||
"error": "too many requests",
|
||||
"retry_after": retryAfter,
|
||||
})
|
||||
}
|
||||
|
||||
rateLimitLogf("[rate-limit][allow] name=%s ip=%s count=%d max=%d window=%ds", setting.Name, ip, count, setting.MaxRequests, setting.WindowSeconds)
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func loadRateLimitRuntime(name string, fallbackMax int64, fallbackWindowSeconds int) (*rateLimitRuntime, error) {
|
||||
cacheKey := "ratelimit:setting:" + name
|
||||
if cached, err := database.Get(cacheKey); err == nil {
|
||||
var s rateLimitRuntime
|
||||
if jsonErr := json.Unmarshal([]byte(cached), &s); jsonErr == nil {
|
||||
return &s, nil
|
||||
}
|
||||
} else if !errors.Is(err, redis.Nil) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setting := &rateLimitRuntime{
|
||||
Name: name,
|
||||
MaxRequests: fallbackMax,
|
||||
WindowSeconds: fallbackWindowSeconds,
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
var dbSetting models.RateLimitSetting
|
||||
if err := database.DB.Where("name = ?", name).First(&dbSetting).Error; err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
rateLimitLogf("[rate-limit][config] setting=%s not found, using fallback max=%d window=%ds", name, fallbackMax, fallbackWindowSeconds)
|
||||
} else {
|
||||
setting.MaxRequests = dbSetting.MaxRequests
|
||||
setting.WindowSeconds = dbSetting.WindowSeconds
|
||||
setting.IsActive = dbSetting.IsActive
|
||||
rateLimitLogf("[rate-limit][config] loaded from db name=%s active=%t max=%d window=%ds", name, setting.IsActive, setting.MaxRequests, setting.WindowSeconds)
|
||||
}
|
||||
|
||||
cacheJSON, _ := json.Marshal(setting)
|
||||
_ = database.SetEx(cacheKey, string(cacheJSON), 60)
|
||||
|
||||
return setting, nil
|
||||
}
|
||||
|
||||
func rateLimitLogf(format string, args ...interface{}) {
|
||||
if configs.AppConfig != nil && configs.AppConfig.CorsDebug {
|
||||
log.Printf(format, args...)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user