package middleware import ( "net/http" "net/http/httptest" "strings" "testing" "github.com/gin-gonic/gin" "gorm.io/driver/sqlite" "gorm.io/gorm" settingsModels "goaresv3/app/settings/models" "goaresv3/config" ) func setupMiddlewareDB(t *testing.T) { t.Helper() dbName := strings.ReplaceAll(strings.ToLower(t.Name()), "/", "_") dsn := "file:" + dbName + "?mode=memory&cache=shared" db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) if err != nil { t.Fatalf("failed to open sqlite db: %v", err) } if err := db.AutoMigrate( &settingsModels.CorsWhitelist{}, &settingsModels.CorsBlacklist{}, &settingsModels.RateLimitSetting{}, ); err != nil { t.Fatalf("failed to migrate middleware models: %v", err) } config.DB = db } func resetRateLimitState() { rateLimitMu.Lock() defer rateLimitMu.Unlock() rateLimitBuckets = map[string]rateLimitBucket{} } func TestDynamicCORSBlacklistBlocks(t *testing.T) { gin.SetMode(gin.TestMode) setupMiddlewareDB(t) resetRateLimitState() if err := config.DB.Create(&settingsModels.CorsBlacklist{ Origin: "https://blocked.example.com", IsActive: true, }).Error; err != nil { t.Fatalf("failed to seed blacklist: %v", err) } r := gin.New() r.Use(DynamicCORS(), DynamicRateLimit()) r.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) req := httptest.NewRequest(http.MethodGet, "/ping", nil) req.Header.Set("Origin", "https://blocked.example.com") w := httptest.NewRecorder() r.ServeHTTP(w, req) if w.Code != http.StatusForbidden { t.Fatalf("expected 403 for blacklisted origin, got %d", w.Code) } } func TestDynamicRateLimitSkipsWhitelistedOrigin(t *testing.T) { gin.SetMode(gin.TestMode) setupMiddlewareDB(t) resetRateLimitState() if err := config.DB.Create(&settingsModels.CorsWhitelist{ Origin: "https://trusted.example.com", IsActive: true, }).Error; err != nil { t.Fatalf("failed to seed whitelist: %v", err) } if err := config.DB.Create(&settingsModels.RateLimitSetting{ Name: "api", MaxRequests: 1, WindowSeconds: 60, IsActive: true, }).Error; err != nil { t.Fatalf("failed to seed api rate limit: %v", err) } r := gin.New() r.Use(DynamicCORS(), DynamicRateLimit()) r.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) for i := 0; i < 3; i++ { req := httptest.NewRequest(http.MethodGet, "/ping", nil) req.Header.Set("Origin", "https://trusted.example.com") w := httptest.NewRecorder() r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200 for whitelisted origin request #%d, got %d", i+1, w.Code) } } } func TestDynamicRateLimitAppliesToNonListedOrigin(t *testing.T) { gin.SetMode(gin.TestMode) setupMiddlewareDB(t) resetRateLimitState() if err := config.DB.Create(&settingsModels.RateLimitSetting{ Name: "api", MaxRequests: 2, WindowSeconds: 60, IsActive: true, }).Error; err != nil { t.Fatalf("failed to seed api rate limit: %v", err) } r := gin.New() r.Use(DynamicCORS(), DynamicRateLimit()) r.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) for i := 0; i < 2; i++ { req := httptest.NewRequest(http.MethodGet, "/ping", nil) req.Header.Set("Origin", "https://unknown.example.com") w := httptest.NewRecorder() r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200 on allowed request #%d, got %d", i+1, w.Code) } } req := httptest.NewRequest(http.MethodGet, "/ping", nil) req.Header.Set("Origin", "https://unknown.example.com") w := httptest.NewRecorder() r.ServeHTTP(w, req) if w.Code != http.StatusTooManyRequests { t.Fatalf("expected 429 for over-limit request, got %d", w.Code) } } func TestDynamicRateLimitLoginAndRegisterSeparateRules(t *testing.T) { gin.SetMode(gin.TestMode) setupMiddlewareDB(t) resetRateLimitState() seed := []settingsModels.RateLimitSetting{ {Name: "api/v1/auth/login", MaxRequests: 1, WindowSeconds: 60, IsActive: true}, {Name: "api/v1/auth/register", MaxRequests: 2, WindowSeconds: 60, IsActive: true}, {Name: "api", MaxRequests: 10, WindowSeconds: 60, IsActive: true}, } for _, item := range seed { it := item if err := config.DB.Create(&it).Error; err != nil { t.Fatalf("failed to seed rate limit %s: %v", it.Name, err) } } r := gin.New() r.Use(DynamicCORS(), DynamicRateLimit()) r.POST("/api/v1/auth/login", func(c *gin.Context) { c.Status(http.StatusOK) }) r.POST("/api/v1/auth/register", func(c *gin.Context) { c.Status(http.StatusOK) }) login1 := httptest.NewRecorder() r.ServeHTTP(login1, httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil)) if login1.Code != http.StatusOK { t.Fatalf("expected login first request 200, got %d", login1.Code) } login2 := httptest.NewRecorder() r.ServeHTTP(login2, httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil)) if login2.Code != http.StatusTooManyRequests { t.Fatalf("expected login second request 429, got %d", login2.Code) } reg1 := httptest.NewRecorder() r.ServeHTTP(reg1, httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", nil)) if reg1.Code != http.StatusOK { t.Fatalf("expected register first request 200, got %d", reg1.Code) } reg2 := httptest.NewRecorder() r.ServeHTTP(reg2, httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", nil)) if reg2.Code != http.StatusOK { t.Fatalf("expected register second request 200, got %d", reg2.Code) } reg3 := httptest.NewRecorder() r.ServeHTTP(reg3, httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", nil)) if reg3.Code != http.StatusTooManyRequests { t.Fatalf("expected register third request 429, got %d", reg3.Code) } }