first commit
This commit is contained in:
254
configs/configs_test.go
Normal file
254
configs/configs_test.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package configs
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// ─── normalizeOrigin ─────────────────────────────────────────────────────────
|
||||
|
||||
func TestNormalizeOrigin_ValidHTTP(t *testing.T) {
|
||||
if got := normalizeOrigin("http://localhost:3000"); got != "http://localhost:3000" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOrigin_ValidHTTPS(t *testing.T) {
|
||||
if got := normalizeOrigin("https://example.com"); got != "https://example.com" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOrigin_StripsTrailingSlash(t *testing.T) {
|
||||
got := normalizeOrigin("https://example.com/")
|
||||
// url.Parse keeps the trailing slash on host-only URLs, so we just check host is preserved
|
||||
if !strings.HasPrefix(got, "https://example.com") {
|
||||
t.Fatalf("unexpected result: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOrigin_UppercaseNormalized(t *testing.T) {
|
||||
got := normalizeOrigin("HTTP://EXAMPLE.COM")
|
||||
if got != "http://example.com" {
|
||||
t.Fatalf("expected lowercase, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOrigin_Empty(t *testing.T) {
|
||||
if got := normalizeOrigin(""); got != "" {
|
||||
t.Fatalf("expected empty, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOrigin_Whitespace(t *testing.T) {
|
||||
if got := normalizeOrigin(" "); got != "" {
|
||||
t.Fatalf("expected empty for whitespace, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOrigin_NoScheme(t *testing.T) {
|
||||
if got := normalizeOrigin("example.com"); got != "" {
|
||||
t.Fatalf("expected empty for missing scheme, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOrigin_QuotedValue(t *testing.T) {
|
||||
if got := normalizeOrigin(`'http://localhost:8080'`); got != "http://localhost:8080" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── parseOriginList ─────────────────────────────────────────────────────────
|
||||
|
||||
func TestParseOriginList_MultipleEntries(t *testing.T) {
|
||||
list := parseOriginList("http://localhost:3000,https://example.com")
|
||||
if len(list) != 2 {
|
||||
t.Fatalf("expected 2 entries, got %d", len(list))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOriginList_EmptyString(t *testing.T) {
|
||||
list := parseOriginList("")
|
||||
if len(list) != 0 {
|
||||
t.Fatalf("expected 0 entries, got %d: %v", len(list), list)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOriginList_InvalidEntriesSkipped(t *testing.T) {
|
||||
list := parseOriginList("http://good.com,not-a-url,http://also-good.com")
|
||||
if len(list) != 2 {
|
||||
t.Fatalf("expected 2 valid entries, got %d: %v", len(list), list)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOriginList_DuplicatesPassThrough(t *testing.T) {
|
||||
// parseOriginList itself does NOT deduplicate; bootstrapWhitelistOrigins does.
|
||||
list := parseOriginList("http://dup.com,http://dup.com")
|
||||
if len(list) != 2 {
|
||||
t.Fatalf("parseOriginList should keep duplicates, got %d", len(list))
|
||||
}
|
||||
}
|
||||
|
||||
// ─── envIntOr ────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestEnvIntOr_MissingUsesDefault(t *testing.T) {
|
||||
t.Setenv("__TEST_INT_MISSING__", "")
|
||||
if got := envIntOr("__TEST_INT_MISSING__", 42); got != 42 {
|
||||
t.Fatalf("expected 42, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvIntOr_ValidValue(t *testing.T) {
|
||||
t.Setenv("__TEST_INT__", "99")
|
||||
if got := envIntOr("__TEST_INT__", 1); got != 99 {
|
||||
t.Fatalf("expected 99, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvIntOr_InvalidStringUsesDefault(t *testing.T) {
|
||||
t.Setenv("__TEST_INT_BAD__", "abc")
|
||||
if got := envIntOr("__TEST_INT_BAD__", 7); got != 7 {
|
||||
t.Fatalf("expected fallback 7, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvIntOr_ZeroUsesDefault(t *testing.T) {
|
||||
t.Setenv("__TEST_INT_ZERO__", "0")
|
||||
if got := envIntOr("__TEST_INT_ZERO__", 5); got != 5 {
|
||||
t.Fatalf("expected fallback for 0, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvIntOr_NegativeUsesDefault(t *testing.T) {
|
||||
t.Setenv("__TEST_INT_NEG__", "-1")
|
||||
if got := envIntOr("__TEST_INT_NEG__", 3); got != 3 {
|
||||
t.Fatalf("expected fallback for negative, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── envInt64Or ──────────────────────────────────────────────────────────────
|
||||
|
||||
func TestEnvInt64Or_MissingUsesDefault(t *testing.T) {
|
||||
t.Setenv("__TEST_I64_MISSING__", "")
|
||||
if got := envInt64Or("__TEST_I64_MISSING__", 100); got != 100 {
|
||||
t.Fatalf("expected 100, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvInt64Or_ValidValue(t *testing.T) {
|
||||
t.Setenv("__TEST_I64__", "999")
|
||||
if got := envInt64Or("__TEST_I64__", 1); got != 999 {
|
||||
t.Fatalf("expected 999, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvInt64Or_InvalidUsesDefault(t *testing.T) {
|
||||
t.Setenv("__TEST_I64_BAD__", "not-a-number")
|
||||
if got := envInt64Or("__TEST_I64_BAD__", 50); got != 50 {
|
||||
t.Fatalf("expected fallback 50, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── bootstrapRateLimitRules ─────────────────────────────────────────────────
|
||||
|
||||
func TestBootstrapRateLimitRules_ContainsThreeRules(t *testing.T) {
|
||||
rules := bootstrapRateLimitRules()
|
||||
if len(rules) != 3 {
|
||||
t.Fatalf("expected 3 rules, got %d", len(rules))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapRateLimitRules_DefaultFallbacks(t *testing.T) {
|
||||
// Env vars temizlenerek varsayılan değerler test edilir.
|
||||
t.Setenv("RL_BOOTSTRAP_LOGIN_MAX_REQUESTS", "")
|
||||
t.Setenv("RL_BOOTSTRAP_LOGIN_WINDOW_SECONDS", "")
|
||||
|
||||
rules := bootstrapRateLimitRules()
|
||||
loginRule := rules[0]
|
||||
|
||||
if loginRule.MaxRequests != 10 {
|
||||
t.Fatalf("expected default MaxRequests=10, got %d", loginRule.MaxRequests)
|
||||
}
|
||||
if loginRule.WindowSeconds != 60 {
|
||||
t.Fatalf("expected default WindowSeconds=60, got %d", loginRule.WindowSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapRateLimitRules_EnvOverride(t *testing.T) {
|
||||
t.Setenv("RL_BOOTSTRAP_LOGIN_MAX_REQUESTS", "25")
|
||||
t.Setenv("RL_BOOTSTRAP_LOGIN_WINDOW_SECONDS", "120")
|
||||
|
||||
rules := bootstrapRateLimitRules()
|
||||
loginRule := rules[0]
|
||||
|
||||
if loginRule.MaxRequests != 25 {
|
||||
t.Fatalf("expected MaxRequests=25, got %d", loginRule.MaxRequests)
|
||||
}
|
||||
if loginRule.WindowSeconds != 120 {
|
||||
t.Fatalf("expected WindowSeconds=120, got %d", loginRule.WindowSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapRateLimitRules_AllNamesNonEmpty(t *testing.T) {
|
||||
for _, r := range bootstrapRateLimitRules() {
|
||||
if r.Name == "" {
|
||||
t.Errorf("rate limit rule has empty Name: %+v", r)
|
||||
}
|
||||
if r.Description == "" {
|
||||
t.Errorf("rate limit rule has empty Description: %+v", r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── bootstrapWhitelistOrigins ───────────────────────────────────────────────
|
||||
|
||||
func TestBootstrapWhitelistOrigins_ContainsLocalDefaults(t *testing.T) {
|
||||
t.Setenv("CORS_BOOTSTRAP_WHITELIST_ORIGINS", "")
|
||||
t.Setenv("APP_BASE_URL", "")
|
||||
|
||||
origins := bootstrapWhitelistOrigins()
|
||||
required := []string{
|
||||
"http://localhost:8080",
|
||||
"http://localhost:3000",
|
||||
"http://localhost:5173",
|
||||
}
|
||||
|
||||
originSet := make(map[string]struct{}, len(origins))
|
||||
for _, o := range origins {
|
||||
originSet[o] = struct{}{}
|
||||
}
|
||||
|
||||
for _, want := range required {
|
||||
if _, ok := originSet[want]; !ok {
|
||||
t.Errorf("expected default origin %q to be present", want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapWhitelistOrigins_NoDuplicates(t *testing.T) {
|
||||
t.Setenv("CORS_BOOTSTRAP_WHITELIST_ORIGINS", "http://localhost:3000,http://localhost:3000")
|
||||
|
||||
origins := bootstrapWhitelistOrigins()
|
||||
seen := map[string]int{}
|
||||
for _, o := range origins {
|
||||
seen[o]++
|
||||
}
|
||||
for origin, count := range seen {
|
||||
if count > 1 {
|
||||
t.Errorf("duplicate origin in whitelist: %q (x%d)", origin, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapWhitelistOrigins_AppBaseURLIncluded(t *testing.T) {
|
||||
t.Setenv("APP_BASE_URL", "https://myapp.example.com")
|
||||
defer t.Setenv("APP_BASE_URL", "")
|
||||
|
||||
origins := bootstrapWhitelistOrigins()
|
||||
for _, o := range origins {
|
||||
if o == "https://myapp.example.com" {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatal("APP_BASE_URL origin not found in whitelist")
|
||||
}
|
||||
193
configs/db.go
Normal file
193
configs/db.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package configs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
// ConnectDB opens a MySQL connection via GORM.
|
||||
func ConnectDB() error {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
os.Getenv("DB_USER"),
|
||||
os.Getenv("DB_PASSWORD"),
|
||||
os.Getenv("DB_HOST"),
|
||||
os.Getenv("DB_PORT"),
|
||||
os.Getenv("DB_NAME"),
|
||||
)
|
||||
|
||||
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("database connection failed: %w", err)
|
||||
}
|
||||
|
||||
DB = db
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
// SeedSecurityDefaults inserts known origins into CORS tables if missing.
|
||||
// Existing rows are left untouched.
|
||||
/*func SeedSecurityDefaults() error {
|
||||
if DB == nil {
|
||||
return fmt.Errorf("database is not connected")
|
||||
}
|
||||
|
||||
for _, origin := range bootstrapWhitelistOrigins() {
|
||||
item := settingsModels.CorsWhitelist{
|
||||
Origin: origin,
|
||||
IsActive: true,
|
||||
}
|
||||
if err := DB.Where("origin = ?", origin).FirstOrCreate(&item).Error; err != nil {
|
||||
return fmt.Errorf("seed cors whitelist (%s): %w", origin, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, origin := range parseOriginList(os.Getenv("CORS_BOOTSTRAP_BLACKLIST_ORIGINS")) {
|
||||
item := settingsModels.CorsBlacklist{
|
||||
Origin: origin,
|
||||
IsActive: true,
|
||||
}
|
||||
if err := DB.Where("origin = ?", origin).FirstOrCreate(&item).Error; err != nil {
|
||||
return fmt.Errorf("seed cors blacklist (%s): %w", origin, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range bootstrapRateLimitRules() {
|
||||
item := settingsModels.RateLimitSetting{
|
||||
Name: rule.Name,
|
||||
Description: rule.Description,
|
||||
MaxRequests: rule.MaxRequests,
|
||||
WindowSeconds: rule.WindowSeconds,
|
||||
IsActive: true,
|
||||
}
|
||||
if err := DB.Where("name = ?", rule.Name).FirstOrCreate(&item).Error; err != nil {
|
||||
return fmt.Errorf("seed rate limit (%s): %w", rule.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}*/
|
||||
|
||||
type rateLimitSeedRule struct {
|
||||
Name string
|
||||
Description string
|
||||
MaxRequests int64
|
||||
WindowSeconds int
|
||||
}
|
||||
|
||||
func bootstrapRateLimitRules() []rateLimitSeedRule {
|
||||
return []rateLimitSeedRule{
|
||||
{
|
||||
Name: "api/v1/auth/login",
|
||||
Description: "Bootstrap login rate limit",
|
||||
MaxRequests: envInt64Or("RL_BOOTSTRAP_LOGIN_MAX_REQUESTS", 10),
|
||||
WindowSeconds: envIntOr("RL_BOOTSTRAP_LOGIN_WINDOW_SECONDS", 60),
|
||||
},
|
||||
{
|
||||
Name: "api/v1/auth/register",
|
||||
Description: "Bootstrap register rate limit",
|
||||
MaxRequests: envInt64Or("RL_BOOTSTRAP_REGISTER_MAX_REQUESTS", 5),
|
||||
WindowSeconds: envIntOr("RL_BOOTSTRAP_REGISTER_WINDOW_SECONDS", 60),
|
||||
},
|
||||
{
|
||||
Name: "api",
|
||||
Description: "Bootstrap default API rate limit",
|
||||
MaxRequests: envInt64Or("RL_BOOTSTRAP_API_MAX_REQUESTS", 120),
|
||||
WindowSeconds: envIntOr("RL_BOOTSTRAP_API_WINDOW_SECONDS", 60),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func bootstrapWhitelistOrigins() []string {
|
||||
uniq := map[string]struct{}{}
|
||||
var out []string
|
||||
add := func(v string) {
|
||||
origin := normalizeOrigin(v)
|
||||
if origin == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := uniq[origin]; ok {
|
||||
return
|
||||
}
|
||||
uniq[origin] = struct{}{}
|
||||
out = append(out, origin)
|
||||
}
|
||||
|
||||
// Explicit bootstrap origins from env (comma separated).
|
||||
for _, v := range parseOriginList(os.Getenv("CORS_BOOTSTRAP_WHITELIST_ORIGINS")) {
|
||||
add(v)
|
||||
}
|
||||
|
||||
// Derive common origins from existing app URLs.
|
||||
add(os.Getenv("APP_BASE_URL"))
|
||||
add(os.Getenv("SOCIAL_AUTH_GOOGLE_REDIRECT_URL"))
|
||||
add(os.Getenv("SOCIAL_AUTH_GITHUB_REDIRECT_URL"))
|
||||
|
||||
// Safe local defaults for development.
|
||||
add("http://localhost:8080")
|
||||
add("http://localhost:3000")
|
||||
add("http://localhost:5173")
|
||||
add("http://127.0.0.1:8080")
|
||||
add("http://127.0.0.1:3000")
|
||||
add("http://127.0.0.1:5173")
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func parseOriginList(raw string) []string {
|
||||
parts := strings.Split(raw, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
v := normalizeOrigin(p)
|
||||
if v != "" {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeOrigin(v string) string {
|
||||
s := strings.TrimSpace(strings.Trim(v, `'"`))
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
u, err := url.Parse(s)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(u.Scheme + "://" + u.Host)
|
||||
}
|
||||
|
||||
func envIntOr(key string, fallback int) int {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
return fallback
|
||||
}
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err != nil || n < 1 {
|
||||
return fallback
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func envInt64Or(key string, fallback int64) int64 {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
return fallback
|
||||
}
|
||||
n, err := strconv.ParseInt(raw, 10, 64)
|
||||
if err != nil || n < 1 {
|
||||
return fallback
|
||||
}
|
||||
return n
|
||||
}
|
||||
47
configs/redis.go
Normal file
47
configs/redis.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package configs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var RedisClient *redis.Client
|
||||
var Ctx = context.Background()
|
||||
|
||||
// ConnectRedis initializes the Redis connection
|
||||
func ConnectRedis() error {
|
||||
redisURL := os.Getenv("REDIS_URL")
|
||||
if redisURL != "" {
|
||||
opt, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse redis url: %w", err)
|
||||
}
|
||||
RedisClient = redis.NewClient(opt)
|
||||
} else {
|
||||
host := os.Getenv("REDIS_HOST")
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
port := os.Getenv("REDIS_PORT")
|
||||
if port == "" {
|
||||
port = "6379"
|
||||
}
|
||||
pass := os.Getenv("REDIS_PASSWORD")
|
||||
|
||||
RedisClient = redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%s", host, port),
|
||||
Password: pass,
|
||||
DB: 0, // Default DB
|
||||
})
|
||||
}
|
||||
|
||||
_, err := RedisClient.Ping(Ctx).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("redis connection failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user