first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,433 @@
package bedrock
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/smithy-go/encoding/httpbinding"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
const (
signingAlgorithm = "AWS4-HMAC-SHA256"
amzDateKey = "X-Amz-Date"
amzSecurityToken = "X-Amz-Security-Token"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
)
// Headers to ignore during signing
var ignoredHeaders = map[string]struct{}{
"authorization": {},
"user-agent": {},
"x-amzn-trace-id": {},
"expect": {},
"transfer-encoding": {},
}
// signingKeyCache caches derived signing keys to avoid recomputation
type signingKeyCache struct {
cache map[string]cachedKey
mu sync.RWMutex
}
type cachedKey struct {
key []byte
date string // YYYYMMDD format
accessKey string
}
var keyCache = &signingKeyCache{
cache: make(map[string]cachedKey),
}
// hmacSHA256 computes HMAC-SHA256
func hmacSHA256(key, data []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(data)
return h.Sum(nil)
}
// deriveSigningKey derives the AWS signing key
func deriveSigningKey(secret, dateStamp, region, service string) []byte {
kDate := hmacSHA256([]byte("AWS4"+secret), []byte(dateStamp))
kRegion := hmacSHA256(kDate, []byte(region))
kService := hmacSHA256(kRegion, []byte(service))
kSigning := hmacSHA256(kService, []byte("aws4_request"))
return kSigning
}
// getSigningKey retrieves or computes the signing key with caching
func getSigningKey(accessKey, secretKey, dateStamp, region, service string) []byte {
cacheKey := fmt.Sprintf("%s/%s/%s/%s", accessKey, dateStamp, region, service)
keyCache.mu.RLock()
if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp {
keyCache.mu.RUnlock()
return cached.key
}
keyCache.mu.RUnlock()
keyCache.mu.Lock()
defer keyCache.mu.Unlock()
// Double-check after acquiring write lock
if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp {
return cached.key
}
key := deriveSigningKey(secretKey, dateStamp, region, service)
keyCache.cache[cacheKey] = cachedKey{
key: key,
date: dateStamp,
accessKey: accessKey,
}
return key
}
// stripExcessSpaces removes excess spaces from a string
func stripExcessSpaces(str string) string {
str = strings.TrimSpace(str)
if !strings.Contains(str, " ") {
return str
}
var result strings.Builder
result.Grow(len(str))
prevWasSpace := false
for _, ch := range str {
if ch == ' ' {
if !prevWasSpace {
result.WriteRune(ch)
}
prevWasSpace = true
} else {
result.WriteRune(ch)
prevWasSpace = false
}
}
return result.String()
}
// percentEncodeRFC3986 encodes a string per RFC 3986
// Keep unreserved characters (A-Z, a-z, 0-9, -, _, ., ~) as-is
// Percent-encode everything else as %HH using uppercase hex
func percentEncodeRFC3986(s string) string {
var result strings.Builder
result.Grow(len(s))
for i := 0; i < len(s); i++ {
b := s[i]
// RFC 3986 unreserved characters
if (b >= 'A' && b <= 'Z') ||
(b >= 'a' && b <= 'z') ||
(b >= '0' && b <= '9') ||
b == '-' || b == '_' || b == '.' || b == '~' {
result.WriteByte(b)
} else {
// Percent-encode with uppercase hex
result.WriteByte('%')
result.WriteByte(uppercaseHex(b >> 4))
result.WriteByte(uppercaseHex(b & 0x0F))
}
}
return result.String()
}
// uppercaseHex returns the uppercase hex character for a nibble (0-15)
func uppercaseHex(b byte) byte {
if b < 10 {
return '0' + b
}
return 'A' + (b - 10)
}
// percentDecode decodes percent-encoded sequences in a string without treating + as space
// This differs from url.QueryUnescape which uses form encoding (+ becomes space)
func percentDecode(s string) string {
// Quick check if there are any percent signs
if !strings.Contains(s, "%") {
return s
}
var result strings.Builder
result.Grow(len(s))
for i := 0; i < len(s); {
if s[i] == '%' && i+2 < len(s) {
// Try to decode the hex sequence
if h1 := unhex(s[i+1]); h1 >= 0 {
if h2 := unhex(s[i+2]); h2 >= 0 {
result.WriteByte(byte(h1<<4 | h2))
i += 3
continue
}
}
}
result.WriteByte(s[i])
i++
}
return result.String()
}
// unhex converts a hex character to its value, or -1 if not a hex char
func unhex(c byte) int {
switch {
case '0' <= c && c <= '9':
return int(c - '0')
case 'a' <= c && c <= 'f':
return int(c - 'a' + 10)
case 'A' <= c && c <= 'F':
return int(c - 'A' + 10)
}
return -1
}
// queryPair represents a query parameter name-value pair
type queryPair struct {
encodedName string
encodedValue string
}
// buildCanonicalQueryString builds a canonical query string per AWS SigV4 spec
// using proper RFC 3986 percent-encoding
func buildCanonicalQueryString(queryString string) string {
if queryString == "" {
return ""
}
// Split the raw query string on '&' into pairs
rawPairs := strings.Split(queryString, "&")
pairs := make([]queryPair, 0, len(rawPairs))
for _, rawPair := range rawPairs {
if rawPair == "" {
continue
}
// Split on the first '=' to get name and value
var name, value string
if idx := strings.IndexByte(rawPair, '='); idx >= 0 {
name = rawPair[:idx]
value = rawPair[idx+1:]
} else {
// No '=' means name only, empty value
name = rawPair
value = ""
}
// Decode percent-encoded sequences first to normalize (handles already-encoded values)
// then encode per RFC 3986 to ensure consistent encoding
// Note: We use percentDecode instead of url.QueryUnescape because the latter
// treats + as space (form encoding), but we need + to encode as %2B
decodedName := percentDecode(name)
decodedValue := percentDecode(value)
// Percent-encode name and value per RFC 3986
encodedName := percentEncodeRFC3986(decodedName)
encodedValue := percentEncodeRFC3986(decodedValue)
pairs = append(pairs, queryPair{
encodedName: encodedName,
encodedValue: encodedValue,
})
}
// Sort pairs lexicographically by encoded name, then by encoded value
sort.Slice(pairs, func(i, j int) bool {
if pairs[i].encodedName != pairs[j].encodedName {
return pairs[i].encodedName < pairs[j].encodedName
}
return pairs[i].encodedValue < pairs[j].encodedValue
})
// Join encoded pairs with '&'
var result strings.Builder
for i, pair := range pairs {
if i > 0 {
result.WriteByte('&')
}
result.WriteString(pair.encodedName)
result.WriteByte('=')
result.WriteString(pair.encodedValue)
}
return result.String()
}
// signAWSRequestFastHTTP signs a fasthttp request using AWS Signature Version 4
// This is a native implementation that avoids allocating http.Request
func signAWSRequestFastHTTP(
ctx context.Context,
req *fasthttp.Request,
body []byte,
accessKey, secretKey string,
sessionToken *string,
region, service string,
) *schemas.BifrostError {
// Get AWS credentials if not provided
if accessKey == "" && secretKey == "" {
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return providerUtils.NewBifrostOperationError("failed to load aws config", err)
}
creds, err := cfg.Credentials.Retrieve(ctx)
if err != nil {
return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err)
}
accessKey = creds.AccessKeyID
secretKey = creds.SecretAccessKey
if creds.SessionToken != "" {
st := creds.SessionToken
sessionToken = &st
}
}
// Get current time
now := time.Now().UTC()
amzDate := now.Format(timeFormat)
dateStamp := now.Format(shortTimeFormat)
// Parse URI
uri := req.URI()
host := string(uri.Host())
path := string(uri.Path())
if path == "" {
path = "/"
}
queryString := string(uri.QueryString())
// Escape path for canonical URI (Bedrock doesn't disable escaping)
canonicalURI := httpbinding.EscapePath(path, false)
// Calculate payload hash
hash := sha256.Sum256(body)
payloadHash := hex.EncodeToString(hash[:])
// Set required headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set(amzDateKey, amzDate)
if sessionToken != nil && *sessionToken != "" {
req.Header.Set(amzSecurityToken, *sessionToken)
}
// Build canonical headers
var headerNames []string
headerMap := make(map[string][]string)
// Always include host
headerNames = append(headerNames, "host")
headerMap["host"] = []string{host}
// Include content-length if body is present
if cl := req.Header.ContentLength(); cl >= 0 {
headerNames = append(headerNames, "content-length")
headerMap["content-length"] = []string{strconv.Itoa(cl)}
}
// Collect other headers
for key, value := range req.Header.All() {
keyStr := strings.ToLower(string(key))
// Skip ignored headers
if _, ignore := ignoredHeaders[keyStr]; ignore {
continue
}
// Skip if already handled
if keyStr == "host" || keyStr == "content-length" {
continue
}
if _, exists := headerMap[keyStr]; !exists {
headerNames = append(headerNames, keyStr)
}
headerMap[keyStr] = append(headerMap[keyStr], string(value))
}
// Sort header names
sort.Strings(headerNames)
// Build canonical headers string
var canonicalHeaders strings.Builder
for _, name := range headerNames {
canonicalHeaders.WriteString(name)
canonicalHeaders.WriteRune(':')
values := headerMap[name]
for i, v := range values {
cleanedValue := stripExcessSpaces(v)
canonicalHeaders.WriteString(cleanedValue)
if i < len(values)-1 {
canonicalHeaders.WriteRune(',')
}
}
canonicalHeaders.WriteRune('\n')
}
signedHeaders := strings.Join(headerNames, ";")
// Build canonical query string using RFC 3986 encoding
canonicalQueryString := buildCanonicalQueryString(queryString)
// Build canonical request
canonicalRequest := strings.Join([]string{
string(req.Header.Method()),
canonicalURI,
canonicalQueryString,
canonicalHeaders.String(),
signedHeaders,
payloadHash,
}, "\n")
// Build credential scope
credentialScope := strings.Join([]string{
dateStamp,
region,
service,
"aws4_request",
}, "/")
// Build string to sign
canonicalRequestHash := sha256.Sum256([]byte(canonicalRequest))
stringToSign := strings.Join([]string{
signingAlgorithm,
amzDate,
credentialScope,
hex.EncodeToString(canonicalRequestHash[:]),
}, "\n")
// Calculate signature
signingKey := getSigningKey(accessKey, secretKey, dateStamp, region, service)
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
// Build authorization header
authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
signingAlgorithm,
accessKey,
credentialScope,
signedHeaders,
signature,
)
req.Header.Set("Authorization", authHeader)
return nil
}