first commit
This commit is contained in:
433
core/providers/bedrock/signer.go
Normal file
433
core/providers/bedrock/signer.go
Normal file
@@ -0,0 +1,433 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/smithy-go/encoding/httpbinding"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
signingAlgorithm = "AWS4-HMAC-SHA256"
|
||||
amzDateKey = "X-Amz-Date"
|
||||
amzSecurityToken = "X-Amz-Security-Token"
|
||||
timeFormat = "20060102T150405Z"
|
||||
shortTimeFormat = "20060102"
|
||||
)
|
||||
|
||||
// Headers to ignore during signing
|
||||
var ignoredHeaders = map[string]struct{}{
|
||||
"authorization": {},
|
||||
"user-agent": {},
|
||||
"x-amzn-trace-id": {},
|
||||
"expect": {},
|
||||
"transfer-encoding": {},
|
||||
}
|
||||
|
||||
// signingKeyCache caches derived signing keys to avoid recomputation
|
||||
type signingKeyCache struct {
|
||||
cache map[string]cachedKey
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type cachedKey struct {
|
||||
key []byte
|
||||
date string // YYYYMMDD format
|
||||
accessKey string
|
||||
}
|
||||
|
||||
var keyCache = &signingKeyCache{
|
||||
cache: make(map[string]cachedKey),
|
||||
}
|
||||
|
||||
// hmacSHA256 computes HMAC-SHA256
|
||||
func hmacSHA256(key, data []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// deriveSigningKey derives the AWS signing key
|
||||
func deriveSigningKey(secret, dateStamp, region, service string) []byte {
|
||||
kDate := hmacSHA256([]byte("AWS4"+secret), []byte(dateStamp))
|
||||
kRegion := hmacSHA256(kDate, []byte(region))
|
||||
kService := hmacSHA256(kRegion, []byte(service))
|
||||
kSigning := hmacSHA256(kService, []byte("aws4_request"))
|
||||
return kSigning
|
||||
}
|
||||
|
||||
// getSigningKey retrieves or computes the signing key with caching
|
||||
func getSigningKey(accessKey, secretKey, dateStamp, region, service string) []byte {
|
||||
cacheKey := fmt.Sprintf("%s/%s/%s/%s", accessKey, dateStamp, region, service)
|
||||
|
||||
keyCache.mu.RLock()
|
||||
if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp {
|
||||
keyCache.mu.RUnlock()
|
||||
return cached.key
|
||||
}
|
||||
keyCache.mu.RUnlock()
|
||||
|
||||
keyCache.mu.Lock()
|
||||
defer keyCache.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp {
|
||||
return cached.key
|
||||
}
|
||||
|
||||
key := deriveSigningKey(secretKey, dateStamp, region, service)
|
||||
keyCache.cache[cacheKey] = cachedKey{
|
||||
key: key,
|
||||
date: dateStamp,
|
||||
accessKey: accessKey,
|
||||
}
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
// stripExcessSpaces removes excess spaces from a string
|
||||
func stripExcessSpaces(str string) string {
|
||||
str = strings.TrimSpace(str)
|
||||
if !strings.Contains(str, " ") {
|
||||
return str
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.Grow(len(str))
|
||||
prevWasSpace := false
|
||||
|
||||
for _, ch := range str {
|
||||
if ch == ' ' {
|
||||
if !prevWasSpace {
|
||||
result.WriteRune(ch)
|
||||
}
|
||||
prevWasSpace = true
|
||||
} else {
|
||||
result.WriteRune(ch)
|
||||
prevWasSpace = false
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// percentEncodeRFC3986 encodes a string per RFC 3986
|
||||
// Keep unreserved characters (A-Z, a-z, 0-9, -, _, ., ~) as-is
|
||||
// Percent-encode everything else as %HH using uppercase hex
|
||||
func percentEncodeRFC3986(s string) string {
|
||||
var result strings.Builder
|
||||
result.Grow(len(s))
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
b := s[i]
|
||||
// RFC 3986 unreserved characters
|
||||
if (b >= 'A' && b <= 'Z') ||
|
||||
(b >= 'a' && b <= 'z') ||
|
||||
(b >= '0' && b <= '9') ||
|
||||
b == '-' || b == '_' || b == '.' || b == '~' {
|
||||
result.WriteByte(b)
|
||||
} else {
|
||||
// Percent-encode with uppercase hex
|
||||
result.WriteByte('%')
|
||||
result.WriteByte(uppercaseHex(b >> 4))
|
||||
result.WriteByte(uppercaseHex(b & 0x0F))
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// uppercaseHex returns the uppercase hex character for a nibble (0-15)
|
||||
func uppercaseHex(b byte) byte {
|
||||
if b < 10 {
|
||||
return '0' + b
|
||||
}
|
||||
return 'A' + (b - 10)
|
||||
}
|
||||
|
||||
// percentDecode decodes percent-encoded sequences in a string without treating + as space
|
||||
// This differs from url.QueryUnescape which uses form encoding (+ becomes space)
|
||||
func percentDecode(s string) string {
|
||||
// Quick check if there are any percent signs
|
||||
if !strings.Contains(s, "%") {
|
||||
return s
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.Grow(len(s))
|
||||
|
||||
for i := 0; i < len(s); {
|
||||
if s[i] == '%' && i+2 < len(s) {
|
||||
// Try to decode the hex sequence
|
||||
if h1 := unhex(s[i+1]); h1 >= 0 {
|
||||
if h2 := unhex(s[i+2]); h2 >= 0 {
|
||||
result.WriteByte(byte(h1<<4 | h2))
|
||||
i += 3
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
result.WriteByte(s[i])
|
||||
i++
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// unhex converts a hex character to its value, or -1 if not a hex char
|
||||
func unhex(c byte) int {
|
||||
switch {
|
||||
case '0' <= c && c <= '9':
|
||||
return int(c - '0')
|
||||
case 'a' <= c && c <= 'f':
|
||||
return int(c - 'a' + 10)
|
||||
case 'A' <= c && c <= 'F':
|
||||
return int(c - 'A' + 10)
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// queryPair represents a query parameter name-value pair
|
||||
type queryPair struct {
|
||||
encodedName string
|
||||
encodedValue string
|
||||
}
|
||||
|
||||
// buildCanonicalQueryString builds a canonical query string per AWS SigV4 spec
|
||||
// using proper RFC 3986 percent-encoding
|
||||
func buildCanonicalQueryString(queryString string) string {
|
||||
if queryString == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Split the raw query string on '&' into pairs
|
||||
rawPairs := strings.Split(queryString, "&")
|
||||
pairs := make([]queryPair, 0, len(rawPairs))
|
||||
|
||||
for _, rawPair := range rawPairs {
|
||||
if rawPair == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Split on the first '=' to get name and value
|
||||
var name, value string
|
||||
if idx := strings.IndexByte(rawPair, '='); idx >= 0 {
|
||||
name = rawPair[:idx]
|
||||
value = rawPair[idx+1:]
|
||||
} else {
|
||||
// No '=' means name only, empty value
|
||||
name = rawPair
|
||||
value = ""
|
||||
}
|
||||
|
||||
// Decode percent-encoded sequences first to normalize (handles already-encoded values)
|
||||
// then encode per RFC 3986 to ensure consistent encoding
|
||||
// Note: We use percentDecode instead of url.QueryUnescape because the latter
|
||||
// treats + as space (form encoding), but we need + to encode as %2B
|
||||
decodedName := percentDecode(name)
|
||||
decodedValue := percentDecode(value)
|
||||
|
||||
// Percent-encode name and value per RFC 3986
|
||||
encodedName := percentEncodeRFC3986(decodedName)
|
||||
encodedValue := percentEncodeRFC3986(decodedValue)
|
||||
|
||||
pairs = append(pairs, queryPair{
|
||||
encodedName: encodedName,
|
||||
encodedValue: encodedValue,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort pairs lexicographically by encoded name, then by encoded value
|
||||
sort.Slice(pairs, func(i, j int) bool {
|
||||
if pairs[i].encodedName != pairs[j].encodedName {
|
||||
return pairs[i].encodedName < pairs[j].encodedName
|
||||
}
|
||||
return pairs[i].encodedValue < pairs[j].encodedValue
|
||||
})
|
||||
|
||||
// Join encoded pairs with '&'
|
||||
var result strings.Builder
|
||||
for i, pair := range pairs {
|
||||
if i > 0 {
|
||||
result.WriteByte('&')
|
||||
}
|
||||
result.WriteString(pair.encodedName)
|
||||
result.WriteByte('=')
|
||||
result.WriteString(pair.encodedValue)
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// signAWSRequestFastHTTP signs a fasthttp request using AWS Signature Version 4
|
||||
// This is a native implementation that avoids allocating http.Request
|
||||
func signAWSRequestFastHTTP(
|
||||
ctx context.Context,
|
||||
req *fasthttp.Request,
|
||||
body []byte,
|
||||
accessKey, secretKey string,
|
||||
sessionToken *string,
|
||||
region, service string,
|
||||
) *schemas.BifrostError {
|
||||
// Get AWS credentials if not provided
|
||||
if accessKey == "" && secretKey == "" {
|
||||
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to load aws config", err)
|
||||
}
|
||||
creds, err := cfg.Credentials.Retrieve(ctx)
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err)
|
||||
}
|
||||
accessKey = creds.AccessKeyID
|
||||
secretKey = creds.SecretAccessKey
|
||||
if creds.SessionToken != "" {
|
||||
st := creds.SessionToken
|
||||
sessionToken = &st
|
||||
}
|
||||
}
|
||||
|
||||
// Get current time
|
||||
now := time.Now().UTC()
|
||||
amzDate := now.Format(timeFormat)
|
||||
dateStamp := now.Format(shortTimeFormat)
|
||||
|
||||
// Parse URI
|
||||
uri := req.URI()
|
||||
host := string(uri.Host())
|
||||
path := string(uri.Path())
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
queryString := string(uri.QueryString())
|
||||
|
||||
// Escape path for canonical URI (Bedrock doesn't disable escaping)
|
||||
canonicalURI := httpbinding.EscapePath(path, false)
|
||||
|
||||
// Calculate payload hash
|
||||
hash := sha256.Sum256(body)
|
||||
payloadHash := hex.EncodeToString(hash[:])
|
||||
|
||||
// Set required headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set(amzDateKey, amzDate)
|
||||
if sessionToken != nil && *sessionToken != "" {
|
||||
req.Header.Set(amzSecurityToken, *sessionToken)
|
||||
}
|
||||
|
||||
// Build canonical headers
|
||||
var headerNames []string
|
||||
headerMap := make(map[string][]string)
|
||||
|
||||
// Always include host
|
||||
headerNames = append(headerNames, "host")
|
||||
headerMap["host"] = []string{host}
|
||||
|
||||
// Include content-length if body is present
|
||||
if cl := req.Header.ContentLength(); cl >= 0 {
|
||||
headerNames = append(headerNames, "content-length")
|
||||
headerMap["content-length"] = []string{strconv.Itoa(cl)}
|
||||
}
|
||||
|
||||
// Collect other headers
|
||||
for key, value := range req.Header.All() {
|
||||
keyStr := strings.ToLower(string(key))
|
||||
|
||||
// Skip ignored headers
|
||||
if _, ignore := ignoredHeaders[keyStr]; ignore {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if already handled
|
||||
if keyStr == "host" || keyStr == "content-length" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := headerMap[keyStr]; !exists {
|
||||
headerNames = append(headerNames, keyStr)
|
||||
}
|
||||
headerMap[keyStr] = append(headerMap[keyStr], string(value))
|
||||
}
|
||||
|
||||
// Sort header names
|
||||
sort.Strings(headerNames)
|
||||
|
||||
// Build canonical headers string
|
||||
var canonicalHeaders strings.Builder
|
||||
for _, name := range headerNames {
|
||||
canonicalHeaders.WriteString(name)
|
||||
canonicalHeaders.WriteRune(':')
|
||||
|
||||
values := headerMap[name]
|
||||
for i, v := range values {
|
||||
cleanedValue := stripExcessSpaces(v)
|
||||
canonicalHeaders.WriteString(cleanedValue)
|
||||
if i < len(values)-1 {
|
||||
canonicalHeaders.WriteRune(',')
|
||||
}
|
||||
}
|
||||
canonicalHeaders.WriteRune('\n')
|
||||
}
|
||||
|
||||
signedHeaders := strings.Join(headerNames, ";")
|
||||
|
||||
// Build canonical query string using RFC 3986 encoding
|
||||
canonicalQueryString := buildCanonicalQueryString(queryString)
|
||||
|
||||
// Build canonical request
|
||||
canonicalRequest := strings.Join([]string{
|
||||
string(req.Header.Method()),
|
||||
canonicalURI,
|
||||
canonicalQueryString,
|
||||
canonicalHeaders.String(),
|
||||
signedHeaders,
|
||||
payloadHash,
|
||||
}, "\n")
|
||||
|
||||
// Build credential scope
|
||||
credentialScope := strings.Join([]string{
|
||||
dateStamp,
|
||||
region,
|
||||
service,
|
||||
"aws4_request",
|
||||
}, "/")
|
||||
|
||||
// Build string to sign
|
||||
canonicalRequestHash := sha256.Sum256([]byte(canonicalRequest))
|
||||
stringToSign := strings.Join([]string{
|
||||
signingAlgorithm,
|
||||
amzDate,
|
||||
credentialScope,
|
||||
hex.EncodeToString(canonicalRequestHash[:]),
|
||||
}, "\n")
|
||||
|
||||
// Calculate signature
|
||||
signingKey := getSigningKey(accessKey, secretKey, dateStamp, region, service)
|
||||
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
|
||||
|
||||
// Build authorization header
|
||||
authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
signingAlgorithm,
|
||||
accessKey,
|
||||
credentialScope,
|
||||
signedHeaders,
|
||||
signature,
|
||||
)
|
||||
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user