first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:41:46 +03:00
commit b6e74bd024
56 changed files with 16114 additions and 0 deletions

63
pkg/middleware/auth.go Normal file
View File

@@ -0,0 +1,63 @@
package middleware
import (
"net/http"
"os"
"strings"
"github.com/gin-gonic/gin"
accountModels "goaresv3/app/accounts/models"
"goaresv3/config"
jwtHelper "goaresv3/pkg/jwt"
)
// AuthRequired validates the Bearer access token and injects claims into context.
func AuthRequired() gin.HandlerFunc {
return func(c *gin.Context) {
header := c.GetHeader("Authorization")
if !strings.HasPrefix(header, "Bearer ") {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "authorization header missing or malformed"})
return
}
tokenStr := strings.TrimPrefix(header, "Bearer ")
claims, err := jwtHelper.ValidateToken(tokenStr, os.Getenv("JWT_SECRET"))
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired access token"})
return
}
c.Set("user_id", claims.UserID)
c.Set("email", claims.Email)
c.Set("username", claims.UserName)
c.Next()
}
}
// AdminRequired checks whether the authenticated user has admin privileges.
func AdminRequired() gin.HandlerFunc {
return func(c *gin.Context) {
userID := c.GetUint("user_id")
if userID == 0 {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
if config.DB == nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "database is not connected"})
return
}
var user accountModels.User
if err := config.DB.Select("id", "is_admin").First(&user, userID).Error; err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid user"})
return
}
if user.IsAdmin == nil || !*user.IsAdmin {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "admin role required"})
return
}
c.Next()
}
}

View File

@@ -0,0 +1,59 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
jwtHelper "goaresv3/pkg/jwt"
)
func TestAuthRequiredValidBearerPasses(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("JWT_SECRET", "test-secret-1234567890")
token, err := jwtHelper.GenerateAccessToken(7, "u@example.com", "user7")
if err != nil {
t.Fatalf("failed to generate token: %v", err)
}
r := gin.New()
r.GET("/protected", AuthRequired(), func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", "Bearer "+token)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", w.Code)
}
}
func TestAuthRequiredRawTokenRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("JWT_SECRET", "test-secret-1234567890")
token, err := jwtHelper.GenerateAccessToken(7, "u@example.com", "user7")
if err != nil {
t.Fatalf("failed to generate token: %v", err)
}
r := gin.New()
r.GET("/protected", AuthRequired(), func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
req.Header.Set("Authorization", token)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("expected 401, got %d", w.Code)
}
}

View File

@@ -0,0 +1,68 @@
package middleware
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"goaresv3/app/settings/models"
"goaresv3/config"
)
// DynamicCORS applies CORS policy from DB-backed whitelist/blacklist tables.
func DynamicCORS() gin.HandlerFunc {
debug := envBool("CORS_DEBUG", false)
return func(c *gin.Context) {
// Defaults for downstream middlewares (e.g. rate limit)
c.Set("origin_whitelisted", false)
c.Set("origin_blacklisted", false)
origin := strings.TrimSpace(c.GetHeader("Origin"))
if origin == "" {
policyLogf(debug, "[cors] skip: no origin method=%s path=%s", c.Request.Method, c.Request.URL.Path)
c.Next()
return
}
if config.DB == nil {
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "database is not connected"})
return
}
var blocked models.CorsBlacklist
if err := config.DB.
Where("origin = ? AND is_active = ?", origin, true).
First(&blocked).Error; err == nil {
c.Set("origin_blacklisted", true)
policyLogf(debug, "[cors] blocked origin=%s method=%s path=%s", origin, c.Request.Method, c.Request.URL.Path)
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "origin is blocked"})
return
}
var whitelisted models.CorsWhitelist
if err := config.DB.
Where("origin = ? AND is_active = ?", origin, true).
First(&whitelisted).Error; err == nil {
c.Set("origin_whitelisted", true)
policyLogf(debug, "[cors] whitelisted origin=%s method=%s path=%s", origin, c.Request.Method, c.Request.URL.Path)
} else {
policyLogf(debug, "[cors] pass(non-listed) origin=%s method=%s path=%s", origin, c.Request.Method, c.Request.URL.Path)
}
c.Header("Access-Control-Allow-Origin", origin)
c.Header("Vary", "Origin")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Type, Accept")
c.Header("Access-Control-Allow-Credentials", "true")
if c.Request.Method == http.MethodOptions {
policyLogf(debug, "[cors] preflight origin=%s path=%s", origin, c.Request.URL.Path)
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}

View File

@@ -0,0 +1,187 @@
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
settingsModels "goaresv3/app/settings/models"
"goaresv3/config"
)
func setupMiddlewareDB(t *testing.T) {
t.Helper()
dbName := strings.ReplaceAll(strings.ToLower(t.Name()), "/", "_")
dsn := "file:" + dbName + "?mode=memory&cache=shared"
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
if err != nil {
t.Fatalf("failed to open sqlite db: %v", err)
}
if err := db.AutoMigrate(
&settingsModels.CorsWhitelist{},
&settingsModels.CorsBlacklist{},
&settingsModels.RateLimitSetting{},
); err != nil {
t.Fatalf("failed to migrate middleware models: %v", err)
}
config.DB = db
}
func resetRateLimitState() {
rateLimitMu.Lock()
defer rateLimitMu.Unlock()
rateLimitBuckets = map[string]rateLimitBucket{}
}
func TestDynamicCORSBlacklistBlocks(t *testing.T) {
gin.SetMode(gin.TestMode)
setupMiddlewareDB(t)
resetRateLimitState()
if err := config.DB.Create(&settingsModels.CorsBlacklist{
Origin: "https://blocked.example.com",
IsActive: true,
}).Error; err != nil {
t.Fatalf("failed to seed blacklist: %v", err)
}
r := gin.New()
r.Use(DynamicCORS(), DynamicRateLimit())
r.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
req.Header.Set("Origin", "https://blocked.example.com")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Fatalf("expected 403 for blacklisted origin, got %d", w.Code)
}
}
func TestDynamicRateLimitSkipsWhitelistedOrigin(t *testing.T) {
gin.SetMode(gin.TestMode)
setupMiddlewareDB(t)
resetRateLimitState()
if err := config.DB.Create(&settingsModels.CorsWhitelist{
Origin: "https://trusted.example.com",
IsActive: true,
}).Error; err != nil {
t.Fatalf("failed to seed whitelist: %v", err)
}
if err := config.DB.Create(&settingsModels.RateLimitSetting{
Name: "api",
MaxRequests: 1,
WindowSeconds: 60,
IsActive: true,
}).Error; err != nil {
t.Fatalf("failed to seed api rate limit: %v", err)
}
r := gin.New()
r.Use(DynamicCORS(), DynamicRateLimit())
r.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) })
for i := 0; i < 3; i++ {
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
req.Header.Set("Origin", "https://trusted.example.com")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 for whitelisted origin request #%d, got %d", i+1, w.Code)
}
}
}
func TestDynamicRateLimitAppliesToNonListedOrigin(t *testing.T) {
gin.SetMode(gin.TestMode)
setupMiddlewareDB(t)
resetRateLimitState()
if err := config.DB.Create(&settingsModels.RateLimitSetting{
Name: "api",
MaxRequests: 2,
WindowSeconds: 60,
IsActive: true,
}).Error; err != nil {
t.Fatalf("failed to seed api rate limit: %v", err)
}
r := gin.New()
r.Use(DynamicCORS(), DynamicRateLimit())
r.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) })
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
req.Header.Set("Origin", "https://unknown.example.com")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200 on allowed request #%d, got %d", i+1, w.Code)
}
}
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
req.Header.Set("Origin", "https://unknown.example.com")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429 for over-limit request, got %d", w.Code)
}
}
func TestDynamicRateLimitLoginAndRegisterSeparateRules(t *testing.T) {
gin.SetMode(gin.TestMode)
setupMiddlewareDB(t)
resetRateLimitState()
seed := []settingsModels.RateLimitSetting{
{Name: "api/v1/auth/login", MaxRequests: 1, WindowSeconds: 60, IsActive: true},
{Name: "api/v1/auth/register", MaxRequests: 2, WindowSeconds: 60, IsActive: true},
{Name: "api", MaxRequests: 10, WindowSeconds: 60, IsActive: true},
}
for _, item := range seed {
it := item
if err := config.DB.Create(&it).Error; err != nil {
t.Fatalf("failed to seed rate limit %s: %v", it.Name, err)
}
}
r := gin.New()
r.Use(DynamicCORS(), DynamicRateLimit())
r.POST("/api/v1/auth/login", func(c *gin.Context) { c.Status(http.StatusOK) })
r.POST("/api/v1/auth/register", func(c *gin.Context) { c.Status(http.StatusOK) })
login1 := httptest.NewRecorder()
r.ServeHTTP(login1, httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil))
if login1.Code != http.StatusOK {
t.Fatalf("expected login first request 200, got %d", login1.Code)
}
login2 := httptest.NewRecorder()
r.ServeHTTP(login2, httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil))
if login2.Code != http.StatusTooManyRequests {
t.Fatalf("expected login second request 429, got %d", login2.Code)
}
reg1 := httptest.NewRecorder()
r.ServeHTTP(reg1, httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", nil))
if reg1.Code != http.StatusOK {
t.Fatalf("expected register first request 200, got %d", reg1.Code)
}
reg2 := httptest.NewRecorder()
r.ServeHTTP(reg2, httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", nil))
if reg2.Code != http.StatusOK {
t.Fatalf("expected register second request 200, got %d", reg2.Code)
}
reg3 := httptest.NewRecorder()
r.ServeHTTP(reg3, httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", nil))
if reg3.Code != http.StatusTooManyRequests {
t.Fatalf("expected register third request 429, got %d", reg3.Code)
}
}

View File

@@ -0,0 +1,29 @@
package middleware
import (
"log"
"os"
"strings"
)
func envBool(key string, fallback bool) bool {
raw := strings.TrimSpace(strings.ToLower(os.Getenv(key)))
if raw == "" {
return fallback
}
switch raw {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
default:
return fallback
}
}
func policyLogf(enabled bool, format string, args ...any) {
if !enabled {
return
}
log.Printf(format, args...)
}

View File

@@ -0,0 +1,145 @@
package middleware
import (
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"goaresv3/app/settings/models"
"goaresv3/config"
)
type rateLimitBucket struct {
Count int64
WindowEnds time.Time
}
var (
rateLimitMu sync.Mutex
rateLimitBuckets = map[string]rateLimitBucket{}
)
// DynamicRateLimit enforces DB-backed rate-limit settings.
// Rule selection order:
// 1) Exact route path key (without leading slash), e.g. api/v1/auth/login
// 2) "api" fallback
func DynamicRateLimit() gin.HandlerFunc {
debug := envBool("RATE_LIMIT_DEBUG", false)
return func(c *gin.Context) {
pathKey := strings.TrimPrefix(c.FullPath(), "/")
if pathKey == "" {
pathKey = strings.TrimPrefix(c.Request.URL.Path, "/")
}
if shouldSkipRateLimit(pathKey) {
policyLogf(debug, "[rate-limit] skip path=%s reason=skip-list", pathKey)
c.Next()
return
}
if isWhitelisted, ok := c.Get("origin_whitelisted"); ok {
if v, castOK := isWhitelisted.(bool); castOK && v {
// Whitelisted origins are excluded from rate limiting.
policyLogf(debug, "[rate-limit] skip path=%s reason=origin-whitelisted", pathKey)
c.Next()
return
}
}
if config.DB == nil {
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "database is not connected"})
return
}
rule, ok := resolveRateLimitRule(pathKey)
if !ok {
policyLogf(debug, "[rate-limit] no-rule path=%s", pathKey)
c.Next()
return
}
clientIP := c.ClientIP()
windowDur := time.Duration(rule.WindowSeconds) * time.Second
bucketKey := rule.Name + ":" + clientIP
now := time.Now()
rateLimitMu.Lock()
bucket, ok := rateLimitBuckets[bucketKey]
if !ok || now.After(bucket.WindowEnds) {
bucket = rateLimitBucket{
Count: 0,
WindowEnds: now.Add(windowDur),
}
}
bucket.Count++
rateLimitBuckets[bucketKey] = bucket
remaining := rule.MaxRequests - bucket.Count
resetIn := int(time.Until(bucket.WindowEnds).Seconds())
rateLimitMu.Unlock()
if remaining < 0 {
policyLogf(debug, "[rate-limit] blocked path=%s rule=%s ip=%s limit=%d window=%d", pathKey, rule.Name, clientIP, rule.MaxRequests, rule.WindowSeconds)
c.Header("Retry-After", strconvItoa(maxInt(resetIn, 1)))
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "rate limit exceeded",
"limit": rule.MaxRequests,
"window_sec": rule.WindowSeconds,
"retry_after": maxInt(resetIn, 1),
"rule_name": rule.Name,
"client_ip": clientIP,
"request_path": pathKey,
})
return
}
c.Header("X-RateLimit-Limit", strconvI64(rule.MaxRequests))
c.Header("X-RateLimit-Remaining", strconvI64(maxI64(remaining, 0)))
c.Header("X-RateLimit-Reset", strconvItoa(maxInt(resetIn, 0)))
policyLogf(debug, "[rate-limit] pass path=%s rule=%s ip=%s remaining=%d", pathKey, rule.Name, clientIP, maxI64(remaining, 0))
c.Next()
}
}
func shouldSkipRateLimit(pathKey string) bool {
return strings.HasPrefix(pathKey, "swagger")
}
func resolveRateLimitRule(pathKey string) (models.RateLimitSetting, bool) {
var rule models.RateLimitSetting
res := config.DB.Where("name = ? AND is_active = ?", pathKey, true).Limit(1).Find(&rule)
if res.Error == nil && res.RowsAffected > 0 {
return rule, true
}
res = config.DB.Where("name = ? AND is_active = ?", "api", true).Limit(1).Find(&rule)
if res.Error == nil && res.RowsAffected > 0 {
return rule, true
}
return models.RateLimitSetting{}, false
}
func strconvItoa(v int) string {
return strconv.FormatInt(int64(v), 10)
}
func strconvI64(v int64) string {
return strconv.FormatInt(v, 10)
}
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}
func maxI64(a, b int64) int64 {
if a > b {
return a
}
return b
}