package bedrock import ( "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "fmt" "sort" "strconv" "strings" "sync" "time" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/smithy-go/encoding/httpbinding" providerUtils "github.com/maximhq/bifrost/core/providers/utils" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) const ( signingAlgorithm = "AWS4-HMAC-SHA256" amzDateKey = "X-Amz-Date" amzSecurityToken = "X-Amz-Security-Token" timeFormat = "20060102T150405Z" shortTimeFormat = "20060102" ) // Headers to ignore during signing var ignoredHeaders = map[string]struct{}{ "authorization": {}, "user-agent": {}, "x-amzn-trace-id": {}, "expect": {}, "transfer-encoding": {}, } // signingKeyCache caches derived signing keys to avoid recomputation type signingKeyCache struct { cache map[string]cachedKey mu sync.RWMutex } type cachedKey struct { key []byte date string // YYYYMMDD format accessKey string } var keyCache = &signingKeyCache{ cache: make(map[string]cachedKey), } // hmacSHA256 computes HMAC-SHA256 func hmacSHA256(key, data []byte) []byte { h := hmac.New(sha256.New, key) h.Write(data) return h.Sum(nil) } // deriveSigningKey derives the AWS signing key func deriveSigningKey(secret, dateStamp, region, service string) []byte { kDate := hmacSHA256([]byte("AWS4"+secret), []byte(dateStamp)) kRegion := hmacSHA256(kDate, []byte(region)) kService := hmacSHA256(kRegion, []byte(service)) kSigning := hmacSHA256(kService, []byte("aws4_request")) return kSigning } // getSigningKey retrieves or computes the signing key with caching func getSigningKey(accessKey, secretKey, dateStamp, region, service string) []byte { cacheKey := fmt.Sprintf("%s/%s/%s/%s", accessKey, dateStamp, region, service) keyCache.mu.RLock() if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp { keyCache.mu.RUnlock() return cached.key } keyCache.mu.RUnlock() keyCache.mu.Lock() defer keyCache.mu.Unlock() // Double-check after acquiring write lock if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp { return cached.key } key := deriveSigningKey(secretKey, dateStamp, region, service) keyCache.cache[cacheKey] = cachedKey{ key: key, date: dateStamp, accessKey: accessKey, } return key } // stripExcessSpaces removes excess spaces from a string func stripExcessSpaces(str string) string { str = strings.TrimSpace(str) if !strings.Contains(str, " ") { return str } var result strings.Builder result.Grow(len(str)) prevWasSpace := false for _, ch := range str { if ch == ' ' { if !prevWasSpace { result.WriteRune(ch) } prevWasSpace = true } else { result.WriteRune(ch) prevWasSpace = false } } return result.String() } // percentEncodeRFC3986 encodes a string per RFC 3986 // Keep unreserved characters (A-Z, a-z, 0-9, -, _, ., ~) as-is // Percent-encode everything else as %HH using uppercase hex func percentEncodeRFC3986(s string) string { var result strings.Builder result.Grow(len(s)) for i := 0; i < len(s); i++ { b := s[i] // RFC 3986 unreserved characters if (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') || b == '-' || b == '_' || b == '.' || b == '~' { result.WriteByte(b) } else { // Percent-encode with uppercase hex result.WriteByte('%') result.WriteByte(uppercaseHex(b >> 4)) result.WriteByte(uppercaseHex(b & 0x0F)) } } return result.String() } // uppercaseHex returns the uppercase hex character for a nibble (0-15) func uppercaseHex(b byte) byte { if b < 10 { return '0' + b } return 'A' + (b - 10) } // percentDecode decodes percent-encoded sequences in a string without treating + as space // This differs from url.QueryUnescape which uses form encoding (+ becomes space) func percentDecode(s string) string { // Quick check if there are any percent signs if !strings.Contains(s, "%") { return s } var result strings.Builder result.Grow(len(s)) for i := 0; i < len(s); { if s[i] == '%' && i+2 < len(s) { // Try to decode the hex sequence if h1 := unhex(s[i+1]); h1 >= 0 { if h2 := unhex(s[i+2]); h2 >= 0 { result.WriteByte(byte(h1<<4 | h2)) i += 3 continue } } } result.WriteByte(s[i]) i++ } return result.String() } // unhex converts a hex character to its value, or -1 if not a hex char func unhex(c byte) int { switch { case '0' <= c && c <= '9': return int(c - '0') case 'a' <= c && c <= 'f': return int(c - 'a' + 10) case 'A' <= c && c <= 'F': return int(c - 'A' + 10) } return -1 } // queryPair represents a query parameter name-value pair type queryPair struct { encodedName string encodedValue string } // buildCanonicalQueryString builds a canonical query string per AWS SigV4 spec // using proper RFC 3986 percent-encoding func buildCanonicalQueryString(queryString string) string { if queryString == "" { return "" } // Split the raw query string on '&' into pairs rawPairs := strings.Split(queryString, "&") pairs := make([]queryPair, 0, len(rawPairs)) for _, rawPair := range rawPairs { if rawPair == "" { continue } // Split on the first '=' to get name and value var name, value string if idx := strings.IndexByte(rawPair, '='); idx >= 0 { name = rawPair[:idx] value = rawPair[idx+1:] } else { // No '=' means name only, empty value name = rawPair value = "" } // Decode percent-encoded sequences first to normalize (handles already-encoded values) // then encode per RFC 3986 to ensure consistent encoding // Note: We use percentDecode instead of url.QueryUnescape because the latter // treats + as space (form encoding), but we need + to encode as %2B decodedName := percentDecode(name) decodedValue := percentDecode(value) // Percent-encode name and value per RFC 3986 encodedName := percentEncodeRFC3986(decodedName) encodedValue := percentEncodeRFC3986(decodedValue) pairs = append(pairs, queryPair{ encodedName: encodedName, encodedValue: encodedValue, }) } // Sort pairs lexicographically by encoded name, then by encoded value sort.Slice(pairs, func(i, j int) bool { if pairs[i].encodedName != pairs[j].encodedName { return pairs[i].encodedName < pairs[j].encodedName } return pairs[i].encodedValue < pairs[j].encodedValue }) // Join encoded pairs with '&' var result strings.Builder for i, pair := range pairs { if i > 0 { result.WriteByte('&') } result.WriteString(pair.encodedName) result.WriteByte('=') result.WriteString(pair.encodedValue) } return result.String() } // signAWSRequestFastHTTP signs a fasthttp request using AWS Signature Version 4 // This is a native implementation that avoids allocating http.Request func signAWSRequestFastHTTP( ctx context.Context, req *fasthttp.Request, body []byte, accessKey, secretKey string, sessionToken *string, region, service string, ) *schemas.BifrostError { // Get AWS credentials if not provided if accessKey == "" && secretKey == "" { cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) if err != nil { return providerUtils.NewBifrostOperationError("failed to load aws config", err) } creds, err := cfg.Credentials.Retrieve(ctx) if err != nil { return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err) } accessKey = creds.AccessKeyID secretKey = creds.SecretAccessKey if creds.SessionToken != "" { st := creds.SessionToken sessionToken = &st } } // Get current time now := time.Now().UTC() amzDate := now.Format(timeFormat) dateStamp := now.Format(shortTimeFormat) // Parse URI uri := req.URI() host := string(uri.Host()) path := string(uri.Path()) if path == "" { path = "/" } queryString := string(uri.QueryString()) // Escape path for canonical URI (Bedrock doesn't disable escaping) canonicalURI := httpbinding.EscapePath(path, false) // Calculate payload hash hash := sha256.Sum256(body) payloadHash := hex.EncodeToString(hash[:]) // Set required headers req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set(amzDateKey, amzDate) if sessionToken != nil && *sessionToken != "" { req.Header.Set(amzSecurityToken, *sessionToken) } // Build canonical headers var headerNames []string headerMap := make(map[string][]string) // Always include host headerNames = append(headerNames, "host") headerMap["host"] = []string{host} // Include content-length if body is present if cl := req.Header.ContentLength(); cl >= 0 { headerNames = append(headerNames, "content-length") headerMap["content-length"] = []string{strconv.Itoa(cl)} } // Collect other headers for key, value := range req.Header.All() { keyStr := strings.ToLower(string(key)) // Skip ignored headers if _, ignore := ignoredHeaders[keyStr]; ignore { continue } // Skip if already handled if keyStr == "host" || keyStr == "content-length" { continue } if _, exists := headerMap[keyStr]; !exists { headerNames = append(headerNames, keyStr) } headerMap[keyStr] = append(headerMap[keyStr], string(value)) } // Sort header names sort.Strings(headerNames) // Build canonical headers string var canonicalHeaders strings.Builder for _, name := range headerNames { canonicalHeaders.WriteString(name) canonicalHeaders.WriteRune(':') values := headerMap[name] for i, v := range values { cleanedValue := stripExcessSpaces(v) canonicalHeaders.WriteString(cleanedValue) if i < len(values)-1 { canonicalHeaders.WriteRune(',') } } canonicalHeaders.WriteRune('\n') } signedHeaders := strings.Join(headerNames, ";") // Build canonical query string using RFC 3986 encoding canonicalQueryString := buildCanonicalQueryString(queryString) // Build canonical request canonicalRequest := strings.Join([]string{ string(req.Header.Method()), canonicalURI, canonicalQueryString, canonicalHeaders.String(), signedHeaders, payloadHash, }, "\n") // Build credential scope credentialScope := strings.Join([]string{ dateStamp, region, service, "aws4_request", }, "/") // Build string to sign canonicalRequestHash := sha256.Sum256([]byte(canonicalRequest)) stringToSign := strings.Join([]string{ signingAlgorithm, amzDate, credentialScope, hex.EncodeToString(canonicalRequestHash[:]), }, "\n") // Calculate signature signingKey := getSigningKey(accessKey, secretKey, dateStamp, region, service) signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign))) // Build authorization header authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", signingAlgorithm, accessKey, credentialScope, signedHeaders, signature, ) req.Header.Set("Authorization", authHeader) return nil }