Files
goaresv3/config/database.go
Beyhan Oğur b6e74bd024 first commit
2026-04-26 21:41:46 +03:00

231 lines
5.5 KiB
Go

package config
import (
"fmt"
"net/url"
"os"
"strconv"
"strings"
accountModels "goaresv3/app/accounts/models"
blogModels "goaresv3/app/blog/models"
settingsModels "goaresv3/app/settings/models"
shopModels "goaresv3/app/shop/models"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
var DB *gorm.DB
// ConnectDB opens a MySQL connection via GORM.
func ConnectDB() error {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
os.Getenv("DB_USER"),
os.Getenv("DB_PASSWORD"),
os.Getenv("DB_HOST"),
os.Getenv("DB_PORT"),
os.Getenv("DB_NAME"),
)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
return fmt.Errorf("database connection failed: %w", err)
}
DB = db
return nil
}
// RunAutoMigrate applies schema changes for account models.
func RunAutoMigrate() error {
if DB == nil {
return fmt.Errorf("database is not connected")
}
if err := DB.AutoMigrate(
&accountModels.User{},
&accountModels.SocialAccount{},
&accountModels.Profile{},
&settingsModels.Setting{},
&settingsModels.Hero{},
&settingsModels.CorsWhitelist{},
&settingsModels.CorsBlacklist{},
&settingsModels.RateLimitSetting{},
&shopModels.ProductCategory{},
&shopModels.ProductTag{},
&shopModels.Product{},
&shopModels.ProductCategoryView{},
&shopModels.ProductComment{},
&shopModels.Cart{},
&shopModels.CartItem{},
&blogModels.Category{},
&blogModels.Tag{},
&blogModels.Post{},
&blogModels.CategoryView{},
&blogModels.Comment{},
); err != nil {
return fmt.Errorf("auto-migrate failed: %w", err)
}
return nil
}
// SeedSecurityDefaults inserts known origins into CORS tables if missing.
// Existing rows are left untouched.
func SeedSecurityDefaults() error {
if DB == nil {
return fmt.Errorf("database is not connected")
}
for _, origin := range bootstrapWhitelistOrigins() {
item := settingsModels.CorsWhitelist{
Origin: origin,
IsActive: true,
}
if err := DB.Where("origin = ?", origin).FirstOrCreate(&item).Error; err != nil {
return fmt.Errorf("seed cors whitelist (%s): %w", origin, err)
}
}
for _, origin := range parseOriginList(os.Getenv("CORS_BOOTSTRAP_BLACKLIST_ORIGINS")) {
item := settingsModels.CorsBlacklist{
Origin: origin,
IsActive: true,
}
if err := DB.Where("origin = ?", origin).FirstOrCreate(&item).Error; err != nil {
return fmt.Errorf("seed cors blacklist (%s): %w", origin, err)
}
}
for _, rule := range bootstrapRateLimitRules() {
item := settingsModels.RateLimitSetting{
Name: rule.Name,
Description: rule.Description,
MaxRequests: rule.MaxRequests,
WindowSeconds: rule.WindowSeconds,
IsActive: true,
}
if err := DB.Where("name = ?", rule.Name).FirstOrCreate(&item).Error; err != nil {
return fmt.Errorf("seed rate limit (%s): %w", rule.Name, err)
}
}
return nil
}
type rateLimitSeedRule struct {
Name string
Description string
MaxRequests int64
WindowSeconds int
}
func bootstrapRateLimitRules() []rateLimitSeedRule {
return []rateLimitSeedRule{
{
Name: "api/v1/auth/login",
Description: "Bootstrap login rate limit",
MaxRequests: envInt64Or("RL_BOOTSTRAP_LOGIN_MAX_REQUESTS", 10),
WindowSeconds: envIntOr("RL_BOOTSTRAP_LOGIN_WINDOW_SECONDS", 60),
},
{
Name: "api/v1/auth/register",
Description: "Bootstrap register rate limit",
MaxRequests: envInt64Or("RL_BOOTSTRAP_REGISTER_MAX_REQUESTS", 5),
WindowSeconds: envIntOr("RL_BOOTSTRAP_REGISTER_WINDOW_SECONDS", 60),
},
{
Name: "api",
Description: "Bootstrap default API rate limit",
MaxRequests: envInt64Or("RL_BOOTSTRAP_API_MAX_REQUESTS", 120),
WindowSeconds: envIntOr("RL_BOOTSTRAP_API_WINDOW_SECONDS", 60),
},
}
}
func bootstrapWhitelistOrigins() []string {
uniq := map[string]struct{}{}
var out []string
add := func(v string) {
origin := normalizeOrigin(v)
if origin == "" {
return
}
if _, ok := uniq[origin]; ok {
return
}
uniq[origin] = struct{}{}
out = append(out, origin)
}
// Explicit bootstrap origins from env (comma separated).
for _, v := range parseOriginList(os.Getenv("CORS_BOOTSTRAP_WHITELIST_ORIGINS")) {
add(v)
}
// Derive common origins from existing app URLs.
add(os.Getenv("APP_BASE_URL"))
add(os.Getenv("SOCIAL_AUTH_GOOGLE_REDIRECT_URL"))
add(os.Getenv("SOCIAL_AUTH_GITHUB_REDIRECT_URL"))
// Safe local defaults for development.
add("http://localhost:8080")
add("http://localhost:3000")
add("http://localhost:5173")
add("http://127.0.0.1:8080")
add("http://127.0.0.1:3000")
add("http://127.0.0.1:5173")
return out
}
func parseOriginList(raw string) []string {
parts := strings.Split(raw, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
v := normalizeOrigin(p)
if v != "" {
out = append(out, v)
}
}
return out
}
func normalizeOrigin(v string) string {
s := strings.TrimSpace(strings.Trim(v, `'"`))
if s == "" {
return ""
}
u, err := url.Parse(s)
if err != nil || u.Scheme == "" || u.Host == "" {
return ""
}
return strings.ToLower(u.Scheme + "://" + u.Host)
}
func envIntOr(key string, fallback int) int {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
return fallback
}
n, err := strconv.Atoi(raw)
if err != nil || n < 1 {
return fallback
}
return n
}
func envInt64Or(key string, fallback int64) int64 {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
return fallback
}
n, err := strconv.ParseInt(raw, 10, 64)
if err != nil || n < 1 {
return fallback
}
return n
}