first commit
This commit is contained in:
265
app/middleware/auth.go
Normal file
265
app/middleware/auth.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ginimageApi/app/accounts/models"
|
||||
"ginimageApi/configs"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type accessTokenPayload struct {
|
||||
UserID uint `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
|
||||
type accessTokenClaims struct {
|
||||
TokenType string `json:"token_type"`
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type refreshTokenClaims struct {
|
||||
TokenType string `json:"token_type"`
|
||||
UserID string `json:"user_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func jwtIssuer() string {
|
||||
issuer := os.Getenv("JWT_ISSUER")
|
||||
if issuer == "" {
|
||||
issuer = "ginimageApi"
|
||||
}
|
||||
return issuer
|
||||
}
|
||||
|
||||
func jwtAudience() string {
|
||||
audience := os.Getenv("JWT_AUDIENCE")
|
||||
if audience == "" {
|
||||
audience = "ginimageApi-client"
|
||||
}
|
||||
return audience
|
||||
}
|
||||
|
||||
func jwtSecret() string {
|
||||
secret := os.Getenv("JWT_SECRET")
|
||||
if secret == "" {
|
||||
secret = "dev-secret-change-me"
|
||||
}
|
||||
return secret
|
||||
}
|
||||
|
||||
func randomTokenID() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func GenerateAccessToken(userID uint, email, username string, ttl time.Duration) (string, error) {
|
||||
now := time.Now()
|
||||
tokenID, err := randomTokenID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims := accessTokenClaims{
|
||||
TokenType: "access",
|
||||
UserID: strconv.FormatUint(uint64(userID), 10),
|
||||
Email: email,
|
||||
Username: username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ID: tokenID,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(jwtSecret()))
|
||||
}
|
||||
|
||||
func GenerateRefreshToken(userID uint, ttl time.Duration) (string, string, error) {
|
||||
now := time.Now()
|
||||
tokenID, err := randomTokenID()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
claims := refreshTokenClaims{
|
||||
TokenType: "refresh",
|
||||
UserID: strconv.FormatUint(uint64(userID), 10),
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ID: tokenID,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signed, err := token.SignedString([]byte(jwtSecret()))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return signed, tokenID, nil
|
||||
}
|
||||
|
||||
func parseAccessToken(token string) (accessTokenPayload, error) {
|
||||
parsed, err := jwt.ParseWithClaims(
|
||||
token,
|
||||
&accessTokenClaims{},
|
||||
func(t *jwt.Token) (any, error) {
|
||||
return []byte(jwtSecret()), nil
|
||||
},
|
||||
jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Alg()}),
|
||||
jwt.WithExpirationRequired(),
|
||||
)
|
||||
if err != nil {
|
||||
return accessTokenPayload{}, errors.New("token gecersiz")
|
||||
}
|
||||
|
||||
claims, ok := parsed.Claims.(*accessTokenClaims)
|
||||
if !ok || !parsed.Valid {
|
||||
return accessTokenPayload{}, errors.New("token gecersiz")
|
||||
}
|
||||
if claims.TokenType != "access" {
|
||||
return accessTokenPayload{}, errors.New("token type gecersiz")
|
||||
}
|
||||
|
||||
uid64, err := strconv.ParseUint(claims.UserID, 10, 64)
|
||||
if err != nil {
|
||||
return accessTokenPayload{}, errors.New("user_id claim gecersiz")
|
||||
}
|
||||
|
||||
exp := int64(0)
|
||||
if claims.ExpiresAt != nil {
|
||||
exp = claims.ExpiresAt.Time.Unix()
|
||||
}
|
||||
|
||||
return accessTokenPayload{
|
||||
UserID: uint(uid64),
|
||||
Email: claims.Email,
|
||||
Username: claims.Username,
|
||||
Exp: exp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func bearerToken(c *gin.Context) (string, error) {
|
||||
header := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||
if header == "" {
|
||||
return "", errors.New("authorization basligi yok")
|
||||
}
|
||||
|
||||
parts := strings.SplitN(header, " ", 2)
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||
return "", errors.New("authorization formati gecersiz")
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(parts[1])
|
||||
if token == "" {
|
||||
return "", errors.New("authorization formati gecersiz")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// AuthRequired access token dogrular ve kullanici bilgisini context'e yazar.
|
||||
func AuthRequired() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token, err := bearerToken(c)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
payload, err := parseAccessToken(token)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("user_id", payload.UserID)
|
||||
c.Set("email", payload.Email)
|
||||
c.Set("username", payload.Username)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AdminRequired mutating endpointlerde kullanicinin admin oldugunu dogrular.
|
||||
func AdminRequired() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
userIDAny, ok := c.Get("user_id")
|
||||
if !ok {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "kullanici bulunamadi"})
|
||||
return
|
||||
}
|
||||
|
||||
var userID uint
|
||||
switch v := userIDAny.(type) {
|
||||
case uint:
|
||||
userID = v
|
||||
case int:
|
||||
if v < 0 {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "gecersiz kullanici"})
|
||||
return
|
||||
}
|
||||
userID = uint(v)
|
||||
case string:
|
||||
parsed, err := strconv.ParseUint(v, 10, 64)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "gecersiz kullanici"})
|
||||
return
|
||||
}
|
||||
userID = uint(parsed)
|
||||
default:
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "gecersiz kullanici"})
|
||||
return
|
||||
}
|
||||
|
||||
var user models.User
|
||||
if err := configs.DB.First(&user, userID).Error; err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "admin yetkisi gerekli"})
|
||||
return
|
||||
}
|
||||
|
||||
if user.IsAdmin == nil || !*user.IsAdmin {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "admin yetkisi gerekli"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func BuildAccessTokenForUser(user models.User) (string, error) {
|
||||
return GenerateAccessToken(user.ID, user.Email, user.UserName, 15*time.Minute)
|
||||
}
|
||||
|
||||
func RefreshTokenExpiry() time.Duration {
|
||||
return 7 * 24 * time.Hour
|
||||
}
|
||||
|
||||
func AccessTokenTTL() time.Duration {
|
||||
return 15 * time.Minute
|
||||
}
|
||||
|
||||
func TokenPayloadDebug(token string) string {
|
||||
payload, err := parseAccessToken(token)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
return fmt.Sprintf("uid=%d email=%s username=%s exp=%d", payload.UserID, payload.Email, payload.Username, payload.Exp)
|
||||
}
|
||||
231
app/middleware/auth_test.go
Normal file
231
app/middleware/auth_test.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"ginimageApi/app/accounts/models"
|
||||
"ginimageApi/configs"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func setupMiddlewareTestDB(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
prev := configs.DB
|
||||
dsn := "file:" + t.Name() + "?mode=memory&cache=shared"
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&models.User{}); err != nil {
|
||||
t.Fatalf("failed to migrate: %v", err)
|
||||
}
|
||||
configs.DB = db
|
||||
t.Cleanup(func() {
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
configs.DB = prev
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateAndParseAccessToken(t *testing.T) {
|
||||
t.Setenv("JWT_SECRET", "test-secret")
|
||||
|
||||
token, err := GenerateAccessToken(99, "u@example.com", "u1", time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAccessToken failed: %v", err)
|
||||
}
|
||||
if got := len(strings.Split(token, ".")); got != 3 {
|
||||
t.Fatalf("expected standard JWT with 3 segments, got %d", got)
|
||||
}
|
||||
|
||||
payload, err := parseAccessToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("parseAccessToken failed: %v", err)
|
||||
}
|
||||
|
||||
if payload.UserID != 99 || payload.Email != "u@example.com" || payload.Username != "u1" {
|
||||
t.Fatalf("unexpected payload: %+v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAccessTokenExpired(t *testing.T) {
|
||||
t.Setenv("JWT_SECRET", "test-secret")
|
||||
|
||||
token, err := GenerateAccessToken(1, "a@a.com", "a", -time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAccessToken failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := parseAccessToken(token); err == nil {
|
||||
t.Fatalf("expected parse error for expired token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAccessTokenRejectsRefreshToken(t *testing.T) {
|
||||
t.Setenv("JWT_SECRET", "test-secret")
|
||||
|
||||
token, _, err := GenerateRefreshToken(1, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRefreshToken failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := parseAccessToken(token); err == nil {
|
||||
t.Fatalf("expected parse error for refresh token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAccessTokenRequiresUserID(t *testing.T) {
|
||||
t.Setenv("JWT_SECRET", "test-secret")
|
||||
|
||||
claims := accessTokenClaims{
|
||||
TokenType: "access",
|
||||
Email: "a@a.com",
|
||||
Username: "a",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)),
|
||||
},
|
||||
}
|
||||
|
||||
token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign token: %v", err)
|
||||
}
|
||||
|
||||
if _, err := parseAccessToken(token); err == nil {
|
||||
t.Fatalf("expected parse error for missing user_id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthRequired(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("JWT_SECRET", "test-secret")
|
||||
|
||||
token, err := GenerateAccessToken(7, "mail@example.com", "user7", time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("token generate failed: %v", err)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/me", AuthRequired(), func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"user_id": c.GetUint("user_id"),
|
||||
"email": c.GetString("email"),
|
||||
"username": c.GetString("username"),
|
||||
})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/me", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("invalid json: %v", err)
|
||||
}
|
||||
if body["email"] != "mail@example.com" {
|
||||
t.Fatalf("expected email in context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthRequiredRejectsInvalidToken(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/me", AuthRequired(), func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/me", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthRequiredRejectsRawAuthorizationToken(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("JWT_SECRET", "test-secret")
|
||||
|
||||
token, err := GenerateAccessToken(11, "raw@example.com", "rawuser", time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("token generate failed: %v", err)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/me", AuthRequired(), func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/me", nil)
|
||||
req.Header.Set("Authorization", token)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401 for raw token without Bearer, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminRequired(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupMiddlewareTestDB(t)
|
||||
|
||||
isAdmin := true
|
||||
isUser := false
|
||||
admin := models.User{UserName: "admin", Email: "admin@example.com", Password: "x", IsAdmin: &isAdmin}
|
||||
user := models.User{UserName: "user", Email: "user@example.com", Password: "x", IsAdmin: &isUser}
|
||||
if err := configs.DB.Create(&admin).Error; err != nil {
|
||||
t.Fatalf("admin create failed: %v", err)
|
||||
}
|
||||
if err := configs.DB.Create(&user).Error; err != nil {
|
||||
t.Fatalf("user create failed: %v", err)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/admin", func(c *gin.Context) {
|
||||
c.Set("user_id", user.ID)
|
||||
c.Next()
|
||||
}, AdminRequired(), func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/admin", nil))
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403 for non-admin, got %d", w.Code)
|
||||
}
|
||||
|
||||
r2 := gin.New()
|
||||
r2.POST("/admin", func(c *gin.Context) {
|
||||
c.Set("user_id", admin.ID)
|
||||
c.Next()
|
||||
}, AdminRequired(), func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
r2.ServeHTTP(w2, httptest.NewRequest(http.MethodPost, "/admin", nil))
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for admin, got %d", w2.Code)
|
||||
}
|
||||
}
|
||||
79
app/middleware/security.go
Normal file
79
app/middleware/security.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// DynamicCORS CORS davranisini ortama gore dinamik ayarlar.
|
||||
func DynamicCORS() gin.HandlerFunc {
|
||||
allowOrigin := os.Getenv("CORS_ALLOW_ORIGIN")
|
||||
if allowOrigin == "" {
|
||||
allowOrigin = "*"
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", allowOrigin)
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
type clientWindow struct {
|
||||
count int
|
||||
windowEnds time.Time
|
||||
}
|
||||
|
||||
// DynamicRateLimit IP bazli basit bir dakika penceresi limiti uygular.
|
||||
func DynamicRateLimit() gin.HandlerFunc {
|
||||
limit := 120
|
||||
if v := os.Getenv("RATE_LIMIT_RPM"); v != "" {
|
||||
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
clients := map[string]*clientWindow{}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
now := time.Now()
|
||||
|
||||
mu.Lock()
|
||||
entry, ok := clients[ip]
|
||||
if !ok || now.After(entry.windowEnds) {
|
||||
entry = &clientWindow{count: 0, windowEnds: now.Add(time.Minute)}
|
||||
clients[ip] = entry
|
||||
}
|
||||
|
||||
entry.count++
|
||||
remaining := limit - entry.count
|
||||
resetIn := int(time.Until(entry.windowEnds).Seconds())
|
||||
mu.Unlock()
|
||||
|
||||
c.Header("X-RateLimit-Limit", strconv.Itoa(limit))
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining))
|
||||
c.Header("X-RateLimit-Reset", strconv.Itoa(resetIn))
|
||||
|
||||
if entry.count > limit {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "rate limit asildi"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
75
app/middleware/security_test.go
Normal file
75
app/middleware/security_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestDynamicCORS(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("CORS_ALLOW_ORIGIN", "http://example.com")
|
||||
|
||||
r := gin.New()
|
||||
r.Use(DynamicCORS())
|
||||
r.GET("/ping", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||
t.Fatalf("unexpected allow origin: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicCORSOptions(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("CORS_ALLOW_ORIGIN", "*")
|
||||
|
||||
r := gin.New()
|
||||
r.Use(DynamicCORS())
|
||||
r.OPTIONS("/ping", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/ping", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Fatalf("expected 204, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("RATE_LIMIT_RPM", "2")
|
||||
|
||||
r := gin.New()
|
||||
r.Use(DynamicRateLimit())
|
||||
r.GET("/limited", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
for i := 1; i <= 3; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/limited", nil)
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if i < 3 && w.Code != http.StatusOK {
|
||||
t.Fatalf("request %d expected 200, got %d", i, w.Code)
|
||||
}
|
||||
if i == 3 && w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("request %d expected 429, got %d", i, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user