first commit
This commit is contained in:
187
pkg/middleware/dynamic_policies_test.go
Normal file
187
pkg/middleware/dynamic_policies_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
settingsModels "goaresv3/app/settings/models"
|
||||
"goaresv3/config"
|
||||
)
|
||||
|
||||
func setupMiddlewareDB(t *testing.T) {
|
||||
t.Helper()
|
||||
dbName := strings.ReplaceAll(strings.ToLower(t.Name()), "/", "_")
|
||||
dsn := "file:" + dbName + "?mode=memory&cache=shared"
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite db: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(
|
||||
&settingsModels.CorsWhitelist{},
|
||||
&settingsModels.CorsBlacklist{},
|
||||
&settingsModels.RateLimitSetting{},
|
||||
); err != nil {
|
||||
t.Fatalf("failed to migrate middleware models: %v", err)
|
||||
}
|
||||
config.DB = db
|
||||
}
|
||||
|
||||
func resetRateLimitState() {
|
||||
rateLimitMu.Lock()
|
||||
defer rateLimitMu.Unlock()
|
||||
rateLimitBuckets = map[string]rateLimitBucket{}
|
||||
}
|
||||
|
||||
func TestDynamicCORSBlacklistBlocks(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupMiddlewareDB(t)
|
||||
resetRateLimitState()
|
||||
|
||||
if err := config.DB.Create(&settingsModels.CorsBlacklist{
|
||||
Origin: "https://blocked.example.com",
|
||||
IsActive: true,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to seed blacklist: %v", err)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(DynamicCORS(), DynamicRateLimit())
|
||||
r.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
req.Header.Set("Origin", "https://blocked.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403 for blacklisted origin, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicRateLimitSkipsWhitelistedOrigin(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupMiddlewareDB(t)
|
||||
resetRateLimitState()
|
||||
|
||||
if err := config.DB.Create(&settingsModels.CorsWhitelist{
|
||||
Origin: "https://trusted.example.com",
|
||||
IsActive: true,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to seed whitelist: %v", err)
|
||||
}
|
||||
if err := config.DB.Create(&settingsModels.RateLimitSetting{
|
||||
Name: "api",
|
||||
MaxRequests: 1,
|
||||
WindowSeconds: 60,
|
||||
IsActive: true,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to seed api rate limit: %v", err)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(DynamicCORS(), DynamicRateLimit())
|
||||
r.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) })
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
req.Header.Set("Origin", "https://trusted.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for whitelisted origin request #%d, got %d", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicRateLimitAppliesToNonListedOrigin(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupMiddlewareDB(t)
|
||||
resetRateLimitState()
|
||||
|
||||
if err := config.DB.Create(&settingsModels.RateLimitSetting{
|
||||
Name: "api",
|
||||
MaxRequests: 2,
|
||||
WindowSeconds: 60,
|
||||
IsActive: true,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to seed api rate limit: %v", err)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(DynamicCORS(), DynamicRateLimit())
|
||||
r.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) })
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
req.Header.Set("Origin", "https://unknown.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 on allowed request #%d, got %d", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
req.Header.Set("Origin", "https://unknown.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429 for over-limit request, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicRateLimitLoginAndRegisterSeparateRules(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupMiddlewareDB(t)
|
||||
resetRateLimitState()
|
||||
|
||||
seed := []settingsModels.RateLimitSetting{
|
||||
{Name: "api/v1/auth/login", MaxRequests: 1, WindowSeconds: 60, IsActive: true},
|
||||
{Name: "api/v1/auth/register", MaxRequests: 2, WindowSeconds: 60, IsActive: true},
|
||||
{Name: "api", MaxRequests: 10, WindowSeconds: 60, IsActive: true},
|
||||
}
|
||||
for _, item := range seed {
|
||||
it := item
|
||||
if err := config.DB.Create(&it).Error; err != nil {
|
||||
t.Fatalf("failed to seed rate limit %s: %v", it.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(DynamicCORS(), DynamicRateLimit())
|
||||
r.POST("/api/v1/auth/login", func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
r.POST("/api/v1/auth/register", func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
|
||||
login1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(login1, httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil))
|
||||
if login1.Code != http.StatusOK {
|
||||
t.Fatalf("expected login first request 200, got %d", login1.Code)
|
||||
}
|
||||
login2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(login2, httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil))
|
||||
if login2.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected login second request 429, got %d", login2.Code)
|
||||
}
|
||||
|
||||
reg1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(reg1, httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", nil))
|
||||
if reg1.Code != http.StatusOK {
|
||||
t.Fatalf("expected register first request 200, got %d", reg1.Code)
|
||||
}
|
||||
reg2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(reg2, httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", nil))
|
||||
if reg2.Code != http.StatusOK {
|
||||
t.Fatalf("expected register second request 200, got %d", reg2.Code)
|
||||
}
|
||||
reg3 := httptest.NewRecorder()
|
||||
r.ServeHTTP(reg3, httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", nil))
|
||||
if reg3.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected register third request 429, got %d", reg3.Code)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user