package middleware import ( "encoding/json" "net/http" "net/http/httptest" "strings" "testing" "time" "ginimageApi/app/accounts/models" "ginimageApi/configs" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "gorm.io/driver/sqlite" "gorm.io/gorm" ) func setupMiddlewareTestDB(t *testing.T) { t.Helper() prev := configs.DB dsn := "file:" + t.Name() + "?mode=memory&cache=shared" db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{}) if err != nil { t.Fatalf("failed to open sqlite: %v", err) } if err := db.AutoMigrate(&models.User{}); err != nil { t.Fatalf("failed to migrate: %v", err) } configs.DB = db t.Cleanup(func() { if sqlDB, err := db.DB(); err == nil { _ = sqlDB.Close() } configs.DB = prev }) } func TestGenerateAndParseAccessToken(t *testing.T) { t.Setenv("JWT_SECRET", "test-secret") token, err := GenerateAccessToken(99, "u@example.com", "u1", time.Minute) if err != nil { t.Fatalf("GenerateAccessToken failed: %v", err) } if got := len(strings.Split(token, ".")); got != 3 { t.Fatalf("expected standard JWT with 3 segments, got %d", got) } payload, err := parseAccessToken(token) if err != nil { t.Fatalf("parseAccessToken failed: %v", err) } if payload.UserID != 99 || payload.Email != "u@example.com" || payload.Username != "u1" { t.Fatalf("unexpected payload: %+v", payload) } } func TestParseAccessTokenExpired(t *testing.T) { t.Setenv("JWT_SECRET", "test-secret") token, err := GenerateAccessToken(1, "a@a.com", "a", -time.Second) if err != nil { t.Fatalf("GenerateAccessToken failed: %v", err) } if _, err := parseAccessToken(token); err == nil { t.Fatalf("expected parse error for expired token") } } func TestParseAccessTokenRejectsRefreshToken(t *testing.T) { t.Setenv("JWT_SECRET", "test-secret") token, _, err := GenerateRefreshToken(1, time.Minute) if err != nil { t.Fatalf("GenerateRefreshToken failed: %v", err) } if _, err := parseAccessToken(token); err == nil { t.Fatalf("expected parse error for refresh token") } } func TestParseAccessTokenRequiresUserID(t *testing.T) { t.Setenv("JWT_SECRET", "test-secret") claims := accessTokenClaims{ TokenType: "access", Email: "a@a.com", Username: "a", RegisteredClaims: jwt.RegisteredClaims{ IssuedAt: jwt.NewNumericDate(time.Now()), ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Minute)), }, } token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte("test-secret")) if err != nil { t.Fatalf("failed to sign token: %v", err) } if _, err := parseAccessToken(token); err == nil { t.Fatalf("expected parse error for missing user_id") } } func TestAuthRequired(t *testing.T) { gin.SetMode(gin.TestMode) t.Setenv("JWT_SECRET", "test-secret") token, err := GenerateAccessToken(7, "mail@example.com", "user7", time.Minute) if err != nil { t.Fatalf("token generate failed: %v", err) } r := gin.New() r.GET("/me", AuthRequired(), func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "user_id": c.GetUint("user_id"), "email": c.GetString("email"), "username": c.GetString("username"), }) }) req := httptest.NewRequest(http.MethodGet, "/me", nil) req.Header.Set("Authorization", "Bearer "+token) w := httptest.NewRecorder() r.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("expected 200, got %d", w.Code) } var body map[string]any if err := json.Unmarshal(w.Body.Bytes(), &body); err != nil { t.Fatalf("invalid json: %v", err) } if body["email"] != "mail@example.com" { t.Fatalf("expected email in context") } } func TestAuthRequiredRejectsInvalidToken(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() r.GET("/me", AuthRequired(), func(c *gin.Context) { c.Status(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/me", nil) req.Header.Set("Authorization", "Bearer invalid") w := httptest.NewRecorder() r.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Fatalf("expected 401, got %d", w.Code) } } func TestAuthRequiredRejectsRawAuthorizationToken(t *testing.T) { gin.SetMode(gin.TestMode) t.Setenv("JWT_SECRET", "test-secret") token, err := GenerateAccessToken(11, "raw@example.com", "rawuser", time.Minute) if err != nil { t.Fatalf("token generate failed: %v", err) } r := gin.New() r.GET("/me", AuthRequired(), func(c *gin.Context) { c.Status(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/me", nil) req.Header.Set("Authorization", token) w := httptest.NewRecorder() r.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Fatalf("expected 401 for raw token without Bearer, got %d", w.Code) } } func TestAdminRequired(t *testing.T) { gin.SetMode(gin.TestMode) setupMiddlewareTestDB(t) isAdmin := true isUser := false admin := models.User{UserName: "admin", Email: "admin@example.com", Password: "x", IsAdmin: &isAdmin} user := models.User{UserName: "user", Email: "user@example.com", Password: "x", IsAdmin: &isUser} if err := configs.DB.Create(&admin).Error; err != nil { t.Fatalf("admin create failed: %v", err) } if err := configs.DB.Create(&user).Error; err != nil { t.Fatalf("user create failed: %v", err) } r := gin.New() r.POST("/admin", func(c *gin.Context) { c.Set("user_id", user.ID) c.Next() }, AdminRequired(), func(c *gin.Context) { c.Status(http.StatusOK) }) w := httptest.NewRecorder() r.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/admin", nil)) if w.Code != http.StatusForbidden { t.Fatalf("expected 403 for non-admin, got %d", w.Code) } r2 := gin.New() r2.POST("/admin", func(c *gin.Context) { c.Set("user_id", admin.ID) c.Next() }, AdminRequired(), func(c *gin.Context) { c.Status(http.StatusOK) }) w2 := httptest.NewRecorder() r2.ServeHTTP(w2, httptest.NewRequest(http.MethodPost, "/admin", nil)) if w2.Code != http.StatusOK { t.Fatalf("expected 200 for admin, got %d", w2.Code) } }