first commit
This commit is contained in:
64
pkg/jwt/jwt.go
Normal file
64
pkg/jwt/jwt.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
UserID uint `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
UserName string `json:"username,omitempty"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateAccessToken creates a short-lived access token (15 minutes).
|
||||
func GenerateAccessToken(userID uint, email, userName string) (string, error) {
|
||||
secret := os.Getenv("JWT_SECRET")
|
||||
claims := Claims{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
UserName: userName,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(15 * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(secret))
|
||||
}
|
||||
|
||||
// GenerateRefreshToken creates a long-lived refresh token (7 days).
|
||||
func GenerateRefreshToken(userID uint, email, userName string) (string, error) {
|
||||
secret := os.Getenv("JWT_REFRESH_SECRET")
|
||||
claims := Claims{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
UserName: userName,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(7 * 24 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(secret))
|
||||
}
|
||||
|
||||
// ValidateToken parses and validates a token string using the provided secret.
|
||||
func ValidateToken(tokenStr, secret string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("unexpected signing method")
|
||||
}
|
||||
return []byte(secret), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claims, ok := token.Claims.(*Claims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
42
pkg/jwt/jwt_test.go
Normal file
42
pkg/jwt/jwt_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateAndValidateAccessToken(t *testing.T) {
|
||||
t.Setenv("JWT_SECRET", "test-secret-1234567890")
|
||||
|
||||
token, err := GenerateAccessToken(42, "user@example.com", "tester")
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAccessToken returned error: %v", err)
|
||||
}
|
||||
|
||||
claims, err := ValidateToken(token, "test-secret-1234567890")
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken returned error: %v", err)
|
||||
}
|
||||
|
||||
if claims.UserID != 42 {
|
||||
t.Fatalf("expected user_id=42, got %d", claims.UserID)
|
||||
}
|
||||
if claims.Email != "user@example.com" {
|
||||
t.Fatalf("expected email=user@example.com, got %s", claims.Email)
|
||||
}
|
||||
if claims.UserName != "tester" {
|
||||
t.Fatalf("expected username=tester, got %s", claims.UserName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenWrongSecretFails(t *testing.T) {
|
||||
t.Setenv("JWT_SECRET", "test-secret-1234567890")
|
||||
|
||||
token, err := GenerateAccessToken(1, "user@example.com", "tester")
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAccessToken returned error: %v", err)
|
||||
}
|
||||
|
||||
if _, err := ValidateToken(token, "wrong-secret"); err == nil {
|
||||
t.Fatal("expected ValidateToken to fail with wrong secret")
|
||||
}
|
||||
}
|
||||
123
pkg/mailer/mailer.go
Normal file
123
pkg/mailer/mailer.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package mailer
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func envBool(key string) bool {
|
||||
v, err := strconv.ParseBool(strings.TrimSpace(os.Getenv(key)))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func Send(to, subject, body string) error {
|
||||
host := strings.TrimSpace(os.Getenv("EMAIL_HOST"))
|
||||
port := strings.TrimSpace(os.Getenv("EMAIL_PORT"))
|
||||
from := strings.TrimSpace(os.Getenv("EMAIL_FROM"))
|
||||
username := strings.TrimSpace(os.Getenv("EMAIL_HOST_USER"))
|
||||
password := strings.TrimSpace(os.Getenv("EMAIL_HOST_PASSWORD"))
|
||||
|
||||
if host == "" || port == "" || from == "" {
|
||||
return fmt.Errorf("email configuration is incomplete")
|
||||
}
|
||||
|
||||
addr := host + ":" + port
|
||||
msg := strings.Join([]string{
|
||||
"From: " + from,
|
||||
"To: " + to,
|
||||
"Subject: " + subject,
|
||||
"MIME-Version: 1.0",
|
||||
"Content-Type: text/plain; charset=UTF-8",
|
||||
"",
|
||||
body,
|
||||
}, "\r\n")
|
||||
|
||||
var auth smtp.Auth
|
||||
if username != "" {
|
||||
auth = smtp.PlainAuth("", username, password, host)
|
||||
}
|
||||
|
||||
useSSL := envBool("EMAIL_USE_SSL")
|
||||
useTLS := envBool("EMAIL_USE_TLS")
|
||||
|
||||
if useSSL {
|
||||
conn, err := tls.Dial("tcp", addr, &tls.Config{ServerName: host})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client, err := smtp.NewClient(conn, host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Quit()
|
||||
|
||||
if auth != nil {
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.Mail(from); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := client.Rcpt(to); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w, err := client.Data()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write([]byte(msg)); err != nil {
|
||||
_ = w.Close()
|
||||
return err
|
||||
}
|
||||
return w.Close()
|
||||
}
|
||||
|
||||
if !useTLS {
|
||||
return smtp.SendMail(addr, auth, from, []string{to}, []byte(msg))
|
||||
}
|
||||
|
||||
client, err := smtp.Dial(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Quit()
|
||||
|
||||
if err := client.StartTLS(&tls.Config{ServerName: host}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if auth != nil {
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.Mail(from); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := client.Rcpt(to); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w, err := client.Data()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write([]byte(msg)); err != nil {
|
||||
_ = w.Close()
|
||||
return err
|
||||
}
|
||||
return w.Close()
|
||||
}
|
||||
130
pkg/mailer/mailer_test.go
Normal file
130
pkg/mailer/mailer_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package mailer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSendFailsWhenConfigMissing(t *testing.T) {
|
||||
t.Setenv("EMAIL_HOST", "")
|
||||
t.Setenv("EMAIL_PORT", "")
|
||||
t.Setenv("EMAIL_FROM", "")
|
||||
|
||||
if err := Send("user@example.com", "subj", "body"); err == nil {
|
||||
t.Fatal("expected error for incomplete email config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendPlainSMTP(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
r := bufio.NewReader(conn)
|
||||
w := bufio.NewWriter(conn)
|
||||
write := func(s string) error {
|
||||
if _, err := w.WriteString(s); err != nil {
|
||||
return err
|
||||
}
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
if err := write("220 localhost Simple Mail Transfer Service Ready\r\n"); err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
line, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
cmd := strings.TrimSpace(line)
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(strings.ToUpper(cmd), "EHLO") || strings.HasPrefix(strings.ToUpper(cmd), "HELO"):
|
||||
if err := write("250-localhost\r\n250 AUTH PLAIN\r\n"); err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
case strings.HasPrefix(strings.ToUpper(cmd), "MAIL FROM"):
|
||||
if err := write("250 OK\r\n"); err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
case strings.HasPrefix(strings.ToUpper(cmd), "RCPT TO"):
|
||||
if err := write("250 OK\r\n"); err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
case strings.HasPrefix(strings.ToUpper(cmd), "DATA"):
|
||||
if err := write("354 End data with <CR><LF>.<CR><LF>\r\n"); err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
for {
|
||||
d, err := r.ReadString('\n')
|
||||
if err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(d) == "." {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := write("250 OK queued\r\n"); err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
case strings.HasPrefix(strings.ToUpper(cmd), "QUIT"):
|
||||
_ = write("221 Bye\r\n")
|
||||
done <- nil
|
||||
return
|
||||
default:
|
||||
if err := write("250 OK\r\n"); err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse listener address: %v", err)
|
||||
}
|
||||
if _, err := strconv.Atoi(portStr); err != nil {
|
||||
t.Fatalf("invalid test smtp port: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("EMAIL_HOST", host)
|
||||
t.Setenv("EMAIL_PORT", portStr)
|
||||
t.Setenv("EMAIL_FROM", "noreply@example.com")
|
||||
t.Setenv("EMAIL_HOST_USER", "")
|
||||
t.Setenv("EMAIL_HOST_PASSWORD", "")
|
||||
t.Setenv("EMAIL_USE_TLS", "false")
|
||||
t.Setenv("EMAIL_USE_SSL", "false")
|
||||
|
||||
if err := Send("user@example.com", "Verify", "Hello"); err != nil {
|
||||
t.Fatalf("expected send success, got error: %v", err)
|
||||
}
|
||||
|
||||
if err := <-done; err != nil {
|
||||
t.Fatalf("smtp server finished with error: %v", err)
|
||||
}
|
||||
}
|
||||
63
pkg/middleware/auth.go
Normal file
63
pkg/middleware/auth.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
accountModels "goaresv3/app/accounts/models"
|
||||
"goaresv3/config"
|
||||
jwtHelper "goaresv3/pkg/jwt"
|
||||
)
|
||||
|
||||
// AuthRequired validates the Bearer access token and injects claims into context.
|
||||
func AuthRequired() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
header := c.GetHeader("Authorization")
|
||||
if !strings.HasPrefix(header, "Bearer ") {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "authorization header missing or malformed"})
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr := strings.TrimPrefix(header, "Bearer ")
|
||||
claims, err := jwtHelper.ValidateToken(tokenStr, os.Getenv("JWT_SECRET"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired access token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("username", claims.UserName)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AdminRequired checks whether the authenticated user has admin privileges.
|
||||
func AdminRequired() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
userID := c.GetUint("user_id")
|
||||
if userID == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
if config.DB == nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "database is not connected"})
|
||||
return
|
||||
}
|
||||
|
||||
var user accountModels.User
|
||||
if err := config.DB.Select("id", "is_admin").First(&user, userID).Error; err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid user"})
|
||||
return
|
||||
}
|
||||
if user.IsAdmin == nil || !*user.IsAdmin {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "admin role required"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
59
pkg/middleware/auth_test.go
Normal file
59
pkg/middleware/auth_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
jwtHelper "goaresv3/pkg/jwt"
|
||||
)
|
||||
|
||||
func TestAuthRequiredValidBearerPasses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("JWT_SECRET", "test-secret-1234567890")
|
||||
|
||||
token, err := jwtHelper.GenerateAccessToken(7, "u@example.com", "user7")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/protected", AuthRequired(), func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthRequiredRawTokenRejected(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
t.Setenv("JWT_SECRET", "test-secret-1234567890")
|
||||
|
||||
token, err := jwtHelper.GenerateAccessToken(7, "u@example.com", "user7")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate token: %v", err)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.GET("/protected", AuthRequired(), func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
req.Header.Set("Authorization", token)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
68
pkg/middleware/cors_dynamic.go
Normal file
68
pkg/middleware/cors_dynamic.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"goaresv3/app/settings/models"
|
||||
"goaresv3/config"
|
||||
)
|
||||
|
||||
// DynamicCORS applies CORS policy from DB-backed whitelist/blacklist tables.
|
||||
func DynamicCORS() gin.HandlerFunc {
|
||||
debug := envBool("CORS_DEBUG", false)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// Defaults for downstream middlewares (e.g. rate limit)
|
||||
c.Set("origin_whitelisted", false)
|
||||
c.Set("origin_blacklisted", false)
|
||||
|
||||
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
||||
if origin == "" {
|
||||
policyLogf(debug, "[cors] skip: no origin method=%s path=%s", c.Request.Method, c.Request.URL.Path)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if config.DB == nil {
|
||||
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "database is not connected"})
|
||||
return
|
||||
}
|
||||
|
||||
var blocked models.CorsBlacklist
|
||||
if err := config.DB.
|
||||
Where("origin = ? AND is_active = ?", origin, true).
|
||||
First(&blocked).Error; err == nil {
|
||||
c.Set("origin_blacklisted", true)
|
||||
policyLogf(debug, "[cors] blocked origin=%s method=%s path=%s", origin, c.Request.Method, c.Request.URL.Path)
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "origin is blocked"})
|
||||
return
|
||||
}
|
||||
|
||||
var whitelisted models.CorsWhitelist
|
||||
if err := config.DB.
|
||||
Where("origin = ? AND is_active = ?", origin, true).
|
||||
First(&whitelisted).Error; err == nil {
|
||||
c.Set("origin_whitelisted", true)
|
||||
policyLogf(debug, "[cors] whitelisted origin=%s method=%s path=%s", origin, c.Request.Method, c.Request.URL.Path)
|
||||
} else {
|
||||
policyLogf(debug, "[cors] pass(non-listed) origin=%s method=%s path=%s", origin, c.Request.Method, c.Request.URL.Path)
|
||||
}
|
||||
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
c.Header("Vary", "Origin")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Type, Accept")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
policyLogf(debug, "[cors] preflight origin=%s path=%s", origin, c.Request.URL.Path)
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
187
pkg/middleware/dynamic_policies_test.go
Normal file
187
pkg/middleware/dynamic_policies_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
settingsModels "goaresv3/app/settings/models"
|
||||
"goaresv3/config"
|
||||
)
|
||||
|
||||
func setupMiddlewareDB(t *testing.T) {
|
||||
t.Helper()
|
||||
dbName := strings.ReplaceAll(strings.ToLower(t.Name()), "/", "_")
|
||||
dsn := "file:" + dbName + "?mode=memory&cache=shared"
|
||||
db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open sqlite db: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(
|
||||
&settingsModels.CorsWhitelist{},
|
||||
&settingsModels.CorsBlacklist{},
|
||||
&settingsModels.RateLimitSetting{},
|
||||
); err != nil {
|
||||
t.Fatalf("failed to migrate middleware models: %v", err)
|
||||
}
|
||||
config.DB = db
|
||||
}
|
||||
|
||||
func resetRateLimitState() {
|
||||
rateLimitMu.Lock()
|
||||
defer rateLimitMu.Unlock()
|
||||
rateLimitBuckets = map[string]rateLimitBucket{}
|
||||
}
|
||||
|
||||
func TestDynamicCORSBlacklistBlocks(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupMiddlewareDB(t)
|
||||
resetRateLimitState()
|
||||
|
||||
if err := config.DB.Create(&settingsModels.CorsBlacklist{
|
||||
Origin: "https://blocked.example.com",
|
||||
IsActive: true,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to seed blacklist: %v", err)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(DynamicCORS(), DynamicRateLimit())
|
||||
r.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
req.Header.Set("Origin", "https://blocked.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403 for blacklisted origin, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicRateLimitSkipsWhitelistedOrigin(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupMiddlewareDB(t)
|
||||
resetRateLimitState()
|
||||
|
||||
if err := config.DB.Create(&settingsModels.CorsWhitelist{
|
||||
Origin: "https://trusted.example.com",
|
||||
IsActive: true,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to seed whitelist: %v", err)
|
||||
}
|
||||
if err := config.DB.Create(&settingsModels.RateLimitSetting{
|
||||
Name: "api",
|
||||
MaxRequests: 1,
|
||||
WindowSeconds: 60,
|
||||
IsActive: true,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to seed api rate limit: %v", err)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(DynamicCORS(), DynamicRateLimit())
|
||||
r.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) })
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
req.Header.Set("Origin", "https://trusted.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 for whitelisted origin request #%d, got %d", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicRateLimitAppliesToNonListedOrigin(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupMiddlewareDB(t)
|
||||
resetRateLimitState()
|
||||
|
||||
if err := config.DB.Create(&settingsModels.RateLimitSetting{
|
||||
Name: "api",
|
||||
MaxRequests: 2,
|
||||
WindowSeconds: 60,
|
||||
IsActive: true,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("failed to seed api rate limit: %v", err)
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(DynamicCORS(), DynamicRateLimit())
|
||||
r.GET("/ping", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) })
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
req.Header.Set("Origin", "https://unknown.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 on allowed request #%d, got %d", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
req.Header.Set("Origin", "https://unknown.example.com")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected 429 for over-limit request, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicRateLimitLoginAndRegisterSeparateRules(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
setupMiddlewareDB(t)
|
||||
resetRateLimitState()
|
||||
|
||||
seed := []settingsModels.RateLimitSetting{
|
||||
{Name: "api/v1/auth/login", MaxRequests: 1, WindowSeconds: 60, IsActive: true},
|
||||
{Name: "api/v1/auth/register", MaxRequests: 2, WindowSeconds: 60, IsActive: true},
|
||||
{Name: "api", MaxRequests: 10, WindowSeconds: 60, IsActive: true},
|
||||
}
|
||||
for _, item := range seed {
|
||||
it := item
|
||||
if err := config.DB.Create(&it).Error; err != nil {
|
||||
t.Fatalf("failed to seed rate limit %s: %v", it.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(DynamicCORS(), DynamicRateLimit())
|
||||
r.POST("/api/v1/auth/login", func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
r.POST("/api/v1/auth/register", func(c *gin.Context) { c.Status(http.StatusOK) })
|
||||
|
||||
login1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(login1, httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil))
|
||||
if login1.Code != http.StatusOK {
|
||||
t.Fatalf("expected login first request 200, got %d", login1.Code)
|
||||
}
|
||||
login2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(login2, httptest.NewRequest(http.MethodPost, "/api/v1/auth/login", nil))
|
||||
if login2.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected login second request 429, got %d", login2.Code)
|
||||
}
|
||||
|
||||
reg1 := httptest.NewRecorder()
|
||||
r.ServeHTTP(reg1, httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", nil))
|
||||
if reg1.Code != http.StatusOK {
|
||||
t.Fatalf("expected register first request 200, got %d", reg1.Code)
|
||||
}
|
||||
reg2 := httptest.NewRecorder()
|
||||
r.ServeHTTP(reg2, httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", nil))
|
||||
if reg2.Code != http.StatusOK {
|
||||
t.Fatalf("expected register second request 200, got %d", reg2.Code)
|
||||
}
|
||||
reg3 := httptest.NewRecorder()
|
||||
r.ServeHTTP(reg3, httptest.NewRequest(http.MethodPost, "/api/v1/auth/register", nil))
|
||||
if reg3.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("expected register third request 429, got %d", reg3.Code)
|
||||
}
|
||||
}
|
||||
29
pkg/middleware/log_flags.go
Normal file
29
pkg/middleware/log_flags.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func envBool(key string, fallback bool) bool {
|
||||
raw := strings.TrimSpace(strings.ToLower(os.Getenv(key)))
|
||||
if raw == "" {
|
||||
return fallback
|
||||
}
|
||||
switch raw {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return fallback
|
||||
}
|
||||
}
|
||||
|
||||
func policyLogf(enabled bool, format string, args ...any) {
|
||||
if !enabled {
|
||||
return
|
||||
}
|
||||
log.Printf(format, args...)
|
||||
}
|
||||
145
pkg/middleware/rate_limit_dynamic.go
Normal file
145
pkg/middleware/rate_limit_dynamic.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"goaresv3/app/settings/models"
|
||||
"goaresv3/config"
|
||||
)
|
||||
|
||||
type rateLimitBucket struct {
|
||||
Count int64
|
||||
WindowEnds time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
rateLimitMu sync.Mutex
|
||||
rateLimitBuckets = map[string]rateLimitBucket{}
|
||||
)
|
||||
|
||||
// DynamicRateLimit enforces DB-backed rate-limit settings.
|
||||
// Rule selection order:
|
||||
// 1) Exact route path key (without leading slash), e.g. api/v1/auth/login
|
||||
// 2) "api" fallback
|
||||
func DynamicRateLimit() gin.HandlerFunc {
|
||||
debug := envBool("RATE_LIMIT_DEBUG", false)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
pathKey := strings.TrimPrefix(c.FullPath(), "/")
|
||||
if pathKey == "" {
|
||||
pathKey = strings.TrimPrefix(c.Request.URL.Path, "/")
|
||||
}
|
||||
if shouldSkipRateLimit(pathKey) {
|
||||
policyLogf(debug, "[rate-limit] skip path=%s reason=skip-list", pathKey)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if isWhitelisted, ok := c.Get("origin_whitelisted"); ok {
|
||||
if v, castOK := isWhitelisted.(bool); castOK && v {
|
||||
// Whitelisted origins are excluded from rate limiting.
|
||||
policyLogf(debug, "[rate-limit] skip path=%s reason=origin-whitelisted", pathKey)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if config.DB == nil {
|
||||
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "database is not connected"})
|
||||
return
|
||||
}
|
||||
|
||||
rule, ok := resolveRateLimitRule(pathKey)
|
||||
if !ok {
|
||||
policyLogf(debug, "[rate-limit] no-rule path=%s", pathKey)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
windowDur := time.Duration(rule.WindowSeconds) * time.Second
|
||||
bucketKey := rule.Name + ":" + clientIP
|
||||
|
||||
now := time.Now()
|
||||
rateLimitMu.Lock()
|
||||
bucket, ok := rateLimitBuckets[bucketKey]
|
||||
if !ok || now.After(bucket.WindowEnds) {
|
||||
bucket = rateLimitBucket{
|
||||
Count: 0,
|
||||
WindowEnds: now.Add(windowDur),
|
||||
}
|
||||
}
|
||||
|
||||
bucket.Count++
|
||||
rateLimitBuckets[bucketKey] = bucket
|
||||
remaining := rule.MaxRequests - bucket.Count
|
||||
resetIn := int(time.Until(bucket.WindowEnds).Seconds())
|
||||
rateLimitMu.Unlock()
|
||||
|
||||
if remaining < 0 {
|
||||
policyLogf(debug, "[rate-limit] blocked path=%s rule=%s ip=%s limit=%d window=%d", pathKey, rule.Name, clientIP, rule.MaxRequests, rule.WindowSeconds)
|
||||
c.Header("Retry-After", strconvItoa(maxInt(resetIn, 1)))
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded",
|
||||
"limit": rule.MaxRequests,
|
||||
"window_sec": rule.WindowSeconds,
|
||||
"retry_after": maxInt(resetIn, 1),
|
||||
"rule_name": rule.Name,
|
||||
"client_ip": clientIP,
|
||||
"request_path": pathKey,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("X-RateLimit-Limit", strconvI64(rule.MaxRequests))
|
||||
c.Header("X-RateLimit-Remaining", strconvI64(maxI64(remaining, 0)))
|
||||
c.Header("X-RateLimit-Reset", strconvItoa(maxInt(resetIn, 0)))
|
||||
policyLogf(debug, "[rate-limit] pass path=%s rule=%s ip=%s remaining=%d", pathKey, rule.Name, clientIP, maxI64(remaining, 0))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func shouldSkipRateLimit(pathKey string) bool {
|
||||
return strings.HasPrefix(pathKey, "swagger")
|
||||
}
|
||||
|
||||
func resolveRateLimitRule(pathKey string) (models.RateLimitSetting, bool) {
|
||||
var rule models.RateLimitSetting
|
||||
res := config.DB.Where("name = ? AND is_active = ?", pathKey, true).Limit(1).Find(&rule)
|
||||
if res.Error == nil && res.RowsAffected > 0 {
|
||||
return rule, true
|
||||
}
|
||||
res = config.DB.Where("name = ? AND is_active = ?", "api", true).Limit(1).Find(&rule)
|
||||
if res.Error == nil && res.RowsAffected > 0 {
|
||||
return rule, true
|
||||
}
|
||||
return models.RateLimitSetting{}, false
|
||||
}
|
||||
|
||||
func strconvItoa(v int) string {
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
}
|
||||
|
||||
func strconvI64(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
|
||||
func maxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func maxI64(a, b int64) int64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
60
pkg/swaggerui/initializer.go
Normal file
60
pkg/swaggerui/initializer.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package swaggerui
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
)
|
||||
|
||||
const initializerJS = `window.onload = function() {
|
||||
const ui = SwaggerUIBundle({
|
||||
url: "doc.json",
|
||||
dom_id: '#swagger-ui',
|
||||
validatorUrl: null,
|
||||
persistAuthorization: true,
|
||||
presets: [
|
||||
SwaggerUIBundle.presets.apis,
|
||||
SwaggerUIStandalonePreset
|
||||
],
|
||||
plugins: [
|
||||
SwaggerUIBundle.plugins.DownloadUrl
|
||||
],
|
||||
layout: "StandaloneLayout",
|
||||
docExpansion: "list",
|
||||
deepLinking: true,
|
||||
defaultModelsExpandDepth: 1,
|
||||
requestInterceptor: function(request) {
|
||||
const auth = request.headers.Authorization || request.headers.authorization
|
||||
|
||||
if (typeof auth === 'string') {
|
||||
const trimmed = auth.trim()
|
||||
|
||||
if (trimmed !== '' && !/^Bearer\s+/i.test(trimmed)) {
|
||||
request.headers.Authorization = 'Bearer ' + trimmed
|
||||
}
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
})
|
||||
|
||||
window.ui = ui
|
||||
}
|
||||
`
|
||||
|
||||
// Handler serves Swagger UI and overrides the initializer script to prefix raw tokens with Bearer.
|
||||
func Handler() gin.HandlerFunc {
|
||||
defaultHandler := ginSwagger.WrapHandler(swaggerFiles.Handler, ginSwagger.PersistAuthorization(true))
|
||||
|
||||
return func(c *gin.Context) {
|
||||
if strings.HasSuffix(c.Request.URL.Path, "/swagger-initializer.js") {
|
||||
c.Data(http.StatusOK, "application/javascript", []byte(initializerJS))
|
||||
return
|
||||
}
|
||||
|
||||
defaultHandler(c)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user