package configs import ( "fmt" "net/url" "os" "strconv" "strings" "gorm.io/driver/mysql" "gorm.io/gorm" ) var DB *gorm.DB // ConnectDB opens a MySQL connection via GORM. func ConnectDB() error { dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", os.Getenv("DB_USER"), os.Getenv("DB_PASSWORD"), os.Getenv("DB_HOST"), os.Getenv("DB_PORT"), os.Getenv("DB_NAME"), ) db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) if err != nil { return fmt.Errorf("database connection failed: %w", err) } DB = db return nil } // SeedSecurityDefaults inserts known origins into CORS tables if missing. // Existing rows are left untouched. /*func SeedSecurityDefaults() error { if DB == nil { return fmt.Errorf("database is not connected") } for _, origin := range bootstrapWhitelistOrigins() { item := settingsModels.CorsWhitelist{ Origin: origin, IsActive: true, } if err := DB.Where("origin = ?", origin).FirstOrCreate(&item).Error; err != nil { return fmt.Errorf("seed cors whitelist (%s): %w", origin, err) } } for _, origin := range parseOriginList(os.Getenv("CORS_BOOTSTRAP_BLACKLIST_ORIGINS")) { item := settingsModels.CorsBlacklist{ Origin: origin, IsActive: true, } if err := DB.Where("origin = ?", origin).FirstOrCreate(&item).Error; err != nil { return fmt.Errorf("seed cors blacklist (%s): %w", origin, err) } } for _, rule := range bootstrapRateLimitRules() { item := settingsModels.RateLimitSetting{ Name: rule.Name, Description: rule.Description, MaxRequests: rule.MaxRequests, WindowSeconds: rule.WindowSeconds, IsActive: true, } if err := DB.Where("name = ?", rule.Name).FirstOrCreate(&item).Error; err != nil { return fmt.Errorf("seed rate limit (%s): %w", rule.Name, err) } } return nil }*/ type rateLimitSeedRule struct { Name string Description string MaxRequests int64 WindowSeconds int } func bootstrapRateLimitRules() []rateLimitSeedRule { return []rateLimitSeedRule{ { Name: "api/v1/auth/login", Description: "Bootstrap login rate limit", MaxRequests: envInt64Or("RL_BOOTSTRAP_LOGIN_MAX_REQUESTS", 10), WindowSeconds: envIntOr("RL_BOOTSTRAP_LOGIN_WINDOW_SECONDS", 60), }, { Name: "api/v1/auth/register", Description: "Bootstrap register rate limit", MaxRequests: envInt64Or("RL_BOOTSTRAP_REGISTER_MAX_REQUESTS", 5), WindowSeconds: envIntOr("RL_BOOTSTRAP_REGISTER_WINDOW_SECONDS", 60), }, { Name: "api", Description: "Bootstrap default API rate limit", MaxRequests: envInt64Or("RL_BOOTSTRAP_API_MAX_REQUESTS", 120), WindowSeconds: envIntOr("RL_BOOTSTRAP_API_WINDOW_SECONDS", 60), }, } } func bootstrapWhitelistOrigins() []string { uniq := map[string]struct{}{} var out []string add := func(v string) { origin := normalizeOrigin(v) if origin == "" { return } if _, ok := uniq[origin]; ok { return } uniq[origin] = struct{}{} out = append(out, origin) } // Explicit bootstrap origins from env (comma separated). for _, v := range parseOriginList(os.Getenv("CORS_BOOTSTRAP_WHITELIST_ORIGINS")) { add(v) } // Derive common origins from existing app URLs. add(os.Getenv("APP_BASE_URL")) add(os.Getenv("SOCIAL_AUTH_GOOGLE_REDIRECT_URL")) add(os.Getenv("SOCIAL_AUTH_GITHUB_REDIRECT_URL")) // Safe local defaults for development. add("http://localhost:8080") add("http://localhost:3000") add("http://localhost:5173") add("http://127.0.0.1:8080") add("http://127.0.0.1:3000") add("http://127.0.0.1:5173") return out } func parseOriginList(raw string) []string { parts := strings.Split(raw, ",") out := make([]string, 0, len(parts)) for _, p := range parts { v := normalizeOrigin(p) if v != "" { out = append(out, v) } } return out } func normalizeOrigin(v string) string { s := strings.TrimSpace(strings.Trim(v, `'"`)) if s == "" { return "" } u, err := url.Parse(s) if err != nil || u.Scheme == "" || u.Host == "" { return "" } return strings.ToLower(u.Scheme + "://" + u.Host) } func envIntOr(key string, fallback int) int { raw := strings.TrimSpace(os.Getenv(key)) if raw == "" { return fallback } n, err := strconv.Atoi(raw) if err != nil || n < 1 { return fallback } return n } func envInt64Or(key string, fallback int64) int64 { raw := strings.TrimSpace(os.Getenv(key)) if raw == "" { return fallback } n, err := strconv.ParseInt(raw, 10, 64) if err != nil || n < 1 { return fallback } return n }