Files
bifrost/core/providers/utils/utils.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

2780 lines
97 KiB
Go

// Package providers implements various LLM providers and their utility functions.
// This file contains common utility functions used across different provider implementations.
package utils
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/textproto"
"net/url"
"regexp"
"slices"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/network"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttpproxy"
)
// sortedAPI is a sonic encoder/decoder that sorts map keys during marshaling.
// This ensures deterministic JSON output for map[string]interface{} values,
// which is critical for LLM prompt caching (e.g., Anthropic cache keying).
var sortedAPI = sonic.Config{SortMapKeys: true}.Froze()
// MarshalSorted marshals v to JSON with map keys sorted alphabetically.
func MarshalSorted(v interface{}) ([]byte, error) {
return sortedAPI.Marshal(v)
}
// MarshalSortedIndent marshals v to indented JSON with map keys sorted alphabetically.
func MarshalSortedIndent(v interface{}, prefix, indent string) ([]byte, error) {
return sortedAPI.MarshalIndent(v, prefix, indent)
}
// SetJSONField sets a field in JSON bytes without disturbing other fields' ordering.
// Uses in-place byte manipulation for minimal allocations and preserves nested structure.
func SetJSONField(data []byte, path string, value interface{}) ([]byte, error) {
return sjson.SetBytes(data, path, value)
}
// DeleteJSONField deletes a field from JSON bytes without disturbing other fields' ordering.
// Uses in-place byte manipulation for minimal allocations and preserves nested structure.
func DeleteJSONField(data []byte, path string) ([]byte, error) {
return sjson.DeleteBytes(data, path)
}
// JSONFieldExists checks if a field exists in JSON bytes.
func JSONFieldExists(data []byte, path string) bool {
return gjson.GetBytes(data, path).Exists()
}
// GetJSONField retrieves a field value from JSON bytes without parsing the entire document.
func GetJSONField(data []byte, path string) gjson.Result {
return gjson.GetBytes(data, path)
}
// logger is the global logger for the provider utils (thread-safe via atomic.Pointer).
var logger atomic.Pointer[schemas.Logger]
// noopLogger is a no-op implementation of schemas.Logger.
type noopLogger struct{}
func (noopLogger) Debug(string, ...any) {}
func (noopLogger) Info(string, ...any) {}
func (noopLogger) Warn(string, ...any) {}
func (noopLogger) Error(string, ...any) {}
func (noopLogger) Fatal(string, ...any) {}
func (noopLogger) SetLevel(schemas.LogLevel) {}
func (noopLogger) SetOutputType(schemas.LoggerOutputType) {}
func (noopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
// Initialize with noop logger
func init() {
var noop schemas.Logger = &noopLogger{}
logger.Store(&noop)
}
// SetLogger sets the logger for the provider utils (thread-safe).
func SetLogger(l schemas.Logger) {
logger.Store(&l)
}
// getLogger returns the current logger (thread-safe).
func getLogger() schemas.Logger {
return *logger.Load()
}
var UnsupportedSpeechStreamModels = []string{"tts-1", "tts-1-hd"}
// noop is a reusable no-op function returned by MakeRequestWithContext on the normal path.
var noop = func() {}
// MakeRequestWithContext makes a request with a context and returns the latency, error, and a
// wait function. The wait function MUST be called (typically via defer) before releasing the
// request or response objects. On the normal path it is a no-op. On the context-cancellation
// path it blocks until the background client.Do goroutine finishes, preventing a data race
// between the still-running goroutine and the caller's release of req/resp.
//
// IMPORTANT: This function does NOT truly cancel the underlying fasthttp network request if the
// context is done. The fasthttp client call will continue in its goroutine until it completes
// or times out based on its own settings. This function merely stops *waiting* for the
// fasthttp call and returns an error related to the context.
func MakeRequestWithContext(ctx context.Context, client *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) (time.Duration, *schemas.BifrostError, func()) {
startTime := time.Now()
errChan := make(chan error, 1)
go func() {
// client.Do is a blocking call.
// It will send an error (or nil for success) to errChan when it completes.
errChan <- client.Do(req, resp)
}()
select {
case <-ctx.Done():
// Context was cancelled (e.g., deadline exceeded or manual cancellation).
// Calculate latency even for cancelled requests
latency := time.Since(startTime)
// Return a wait function that blocks until the background goroutine finishes.
// The caller MUST invoke this (via defer) before releasing req/resp to avoid
// a data race with the still-running client.Do goroutine.
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
statusCode := 504
errorType := schemas.RequestTimedOut
return latency, &schemas.BifrostError{
IsBifrostError: true,
StatusCode: &statusCode,
Error: &schemas.ErrorField{
Type: &errorType,
Message: fmt.Sprintf("Request timed out by context: %v", ctx.Err()),
Error: ctx.Err(),
},
}, func() { <-errChan }
}
statusCode := 499
errorType := schemas.RequestCancelled
return latency, &schemas.BifrostError{
IsBifrostError: true,
StatusCode: &statusCode,
Error: &schemas.ErrorField{
Type: &errorType,
Message: fmt.Sprintf("Request cancelled by context: %v", ctx.Err()),
Error: ctx.Err(),
},
}, func() { <-errChan }
case err := <-errChan:
// The fasthttp.Do call completed.
// Calculate latency for both successful and failed requests
latency := time.Since(startTime)
if err != nil {
if errors.Is(err, context.Canceled) {
return latency, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Type: schemas.Ptr(schemas.RequestCancelled),
Message: schemas.ErrRequestCancelled,
Error: err,
},
}, noop
}
// Check for timeout errors first before checking net.OpError to avoid misclassification
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), noop
}
// Check if error implements net.Error and has Timeout() == true
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), noop
}
// Check for DNS lookup and network errors after timeout checks
var opErr *net.OpError
var dnsErr *net.DNSError
if errors.As(err, &opErr) || errors.As(err, &dnsErr) {
return latency, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: schemas.ErrProviderNetworkError,
Error: err,
},
}, noop
}
// The HTTP request itself failed (e.g., connection error, fasthttp timeout).
return latency, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: schemas.ErrProviderDoRequest,
Error: err,
},
}, noop
}
// HTTP request was successful from fasthttp's perspective (err is nil).
// The caller should check resp.StatusCode() for HTTP-level errors (4xx, 5xx).
return latency, nil, noop
}
}
// Deprecated: ConfigureRetry is now handled internally by ConfigureDialer.
// This function is kept for backward compatibility but is no longer needed.
func ConfigureRetry(client *fasthttp.Client) *fasthttp.Client {
client.RetryIfErr = network.StaleConnectionRetryIfErr
return client
}
// ConfigureDialer configures the client's connection behavior:
// 1. Sets up the stale-connection retry policy (see network.StaleConnectionRetryIfErr).
// 2. Wraps the Dial function to enable TCP keepalive on all connections,
// proactively detecting dead connections before fasthttp tries to reuse them.
//
// Must be called AFTER ConfigureProxy (which may set client.Dial to a proxy
// dialer), so the keepalive wrapper composes on top of the proxy connection.
//
// Keepalive parameters:
// - Idle 10s: first probe after 10s of inactivity (well under the 30s MaxIdleConnDuration)
// - Interval 5s: subsequent probes every 5s
// - Count 3: close after 3 failed probes
//
// Dead connections are detected within ~25s (10 + 5*3), before the 30s
// MaxIdleConnDuration expires and the connection is reused.
func ConfigureDialer(client *fasthttp.Client) *fasthttp.Client {
// Configure stale-connection retry policy
client.RetryIfErr = network.StaleConnectionRetryIfErr
existingDial := client.Dial
existingDialTimeout := client.DialTimeout
keepAliveCfg := net.KeepAliveConfig{
Enable: true,
Idle: 10 * time.Second,
Interval: 5 * time.Second,
Count: 3,
}
client.Dial = func(addr string) (net.Conn, error) {
var conn net.Conn
var err error
switch {
case existingDial != nil:
// Proxy or custom dial function is set — use it, then enable keepalive
conn, err = existingDial(addr)
case existingDialTimeout != nil:
// Preserve dial-timeout behavior
conn, err = existingDialTimeout(addr, client.ReadTimeout)
default:
conn, err = (&net.Dialer{
Timeout: client.ReadTimeout,
KeepAliveConfig: keepAliveCfg,
}).Dial("tcp", addr)
}
if err != nil {
return nil, err
}
// Enable TCP keepalive on the connection
if tcpConn, ok := conn.(*net.TCPConn); ok {
_ = tcpConn.SetKeepAliveConfig(keepAliveCfg)
}
return conn, nil
}
return client
}
// ConfigureProxy sets up a proxy for the fasthttp client based on the provided configuration.
// It supports HTTP, SOCKS5, and environment-based proxy configurations.
// Returns the configured client or the original client if proxy configuration is invalid.
func ConfigureProxy(client *fasthttp.Client, proxyConfig *schemas.ProxyConfig, logger schemas.Logger) *fasthttp.Client {
if proxyConfig == nil {
return client
}
var dialFunc fasthttp.DialFunc
// Create the appropriate proxy based on type
switch proxyConfig.Type {
case schemas.NoProxy:
return client
case schemas.HTTPProxy:
if proxyConfig.URL == "" {
getLogger().Warn("Warning: HTTP proxy URL is required for setting up proxy")
return client
}
proxyURL := proxyConfig.URL
if proxyConfig.Username != "" && proxyConfig.Password != "" {
parsedURL, err := url.Parse(proxyConfig.URL)
if err != nil {
getLogger().Warn("Invalid proxy configuration: invalid HTTP proxy URL")
return client
}
// Set user and password in the parsed URL
parsedURL.User = url.UserPassword(proxyConfig.Username, proxyConfig.Password)
proxyURL = parsedURL.String()
}
dialFunc = fasthttpproxy.FasthttpHTTPDialer(proxyURL)
case schemas.Socks5Proxy:
if proxyConfig.URL == "" {
getLogger().Warn("Warning: SOCKS5 proxy URL is required for setting up proxy")
return client
}
proxyURL := proxyConfig.URL
// Add authentication if provided
if proxyConfig.Username != "" && proxyConfig.Password != "" {
parsedURL, err := url.Parse(proxyConfig.URL)
if err != nil {
getLogger().Warn("Invalid proxy configuration: invalid SOCKS5 proxy URL")
return client
}
// Set user and password in the parsed URL
parsedURL.User = url.UserPassword(proxyConfig.Username, proxyConfig.Password)
proxyURL = parsedURL.String()
}
dialFunc = fasthttpproxy.FasthttpSocksDialer(proxyURL)
case schemas.EnvProxy:
// Use environment variables for proxy configuration
dialFunc = fasthttpproxy.FasthttpProxyHTTPDialer()
default:
getLogger().Warn("Invalid proxy configuration: unsupported proxy type: %s", proxyConfig.Type)
return client
}
if dialFunc != nil {
client.Dial = dialFunc
}
// Configure custom CA certificate if provided
if proxyConfig.CACertPEM != "" {
tlsConfig, err := createTLSConfigWithCA(proxyConfig.CACertPEM)
if err != nil {
getLogger().Warn("Failed to configure custom CA certificate: %v", err)
} else {
client.TLSConfig = tlsConfig
}
}
return client
}
// createTLSConfigWithCA creates a TLS configuration with a custom CA certificate
// appended to the system root CA pool.
func createTLSConfigWithCA(caCertPEM string) (*tls.Config, error) {
// Get the system root CA pool
rootCAs, err := x509.SystemCertPool()
if err != nil {
// If we can't get system certs, create a new pool
rootCAs = x509.NewCertPool()
}
// Append the custom CA certificate
if !rootCAs.AppendCertsFromPEM([]byte(caCertPEM)) {
return nil, fmt.Errorf("failed to parse CA certificate PEM")
}
return &tls.Config{
RootCAs: rootCAs,
MinVersion: tls.VersionTLS12,
}, nil
}
// ConfigureTLS applies TLS settings from NetworkConfig to the fasthttp client.
// It merges with any existing TLSConfig (e.g., from ConfigureProxy).
func ConfigureTLS(client *fasthttp.Client, networkConfig schemas.NetworkConfig, logger schemas.Logger) *fasthttp.Client {
if !networkConfig.InsecureSkipVerify && networkConfig.CACertPEM == "" {
return client
}
tlsConfig := client.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{MinVersion: tls.VersionTLS12}
} else {
tlsConfig = tlsConfig.Clone()
}
if networkConfig.InsecureSkipVerify {
logger.Warn("insecure_skip_verify is enabled for provider — TLS certificate verification is disabled. Not recommended for production.")
tlsConfig.InsecureSkipVerify = true
}
if networkConfig.CACertPEM != "" {
caTLSConfig, err := createTLSConfigWithCA(networkConfig.CACertPEM)
if err != nil {
logger.Warn("Failed to configure custom CA certificate for provider: %v", err)
} else {
if tlsConfig.RootCAs != nil {
tlsConfig.RootCAs = tlsConfig.RootCAs.Clone()
// Merge: append network CA to existing pool (e.g. from proxy)
if !tlsConfig.RootCAs.AppendCertsFromPEM([]byte(networkConfig.CACertPEM)) {
logger.Warn("Failed to append CA certificate to existing TLS config")
}
} else {
tlsConfig.RootCAs = caTLSConfig.RootCAs
}
}
}
client.TLSConfig = tlsConfig
return client
}
// hopByHopHeaders are HTTP/1.1 headers that must not be forwarded by proxies.
var hopByHopHeaders = map[string]bool{
"connection": true,
"proxy-connection": true,
"keep-alive": true,
"proxy-authenticate": true,
"proxy-authorization": true,
"te": true,
"trailer": true,
"transfer-encoding": true,
"upgrade": true,
}
// filterHeaders filters out hop-by-hop headers and returns only the allowed headers.
func filterHeaders(headers map[string][]string) map[string][]string {
filtered := make(map[string][]string, len(headers))
for k, v := range headers {
if !hopByHopHeaders[strings.ToLower(k)] {
filtered[k] = v
}
}
return filtered
}
// providerResponseFilterHeaders are headers to exclude when forwarding provider response headers.
// These are transport-level headers that don't apply when re-serving the response.
var providerResponseFilterHeaders = map[string]bool{
"content-length": true,
"content-encoding": true,
"transfer-encoding": true,
"connection": true,
"keep-alive": true,
"proxy-connection": true,
"proxy-authenticate": true,
"proxy-authorization": true,
"authorization": true,
"cookie": true,
"set-cookie": true,
"set-cookie2": true,
"www-authenticate": true,
"te": true,
"trailer": true,
"upgrade": true,
"host": true,
"date": true,
"server": true,
"alt-svc": true,
"strict-transport-security": true,
"content-type": true,
"access-control-allow-origin": true,
"access-control-allow-methods": true,
"access-control-allow-headers": true,
"access-control-expose-headers": true,
"access-control-allow-credentials": true,
"access-control-max-age": true,
}
// ExtractProviderResponseHeaders extracts and filters response headers from a
// fasthttp response. Transport-level headers are excluded.
func ExtractProviderResponseHeaders(resp *fasthttp.Response) map[string]string {
if resp == nil {
return nil
}
headers := make(map[string]string)
resp.Header.VisitAll(func(key, value []byte) {
k := string(key)
if providerResponseFilterHeaders[strings.ToLower(k)] {
return
}
v := string(value)
if existing, ok := headers[k]; ok && existing != "" {
headers[k] = existing + ", " + v
} else {
headers[k] = v
}
})
if len(headers) == 0 {
return nil
}
return headers
}
// ExtractProviderResponseHeadersFromHTTP extracts and filters response headers
// from a standard net/http response. Transport-level headers are excluded.
// Used by providers like Bedrock that use net/http instead of fasthttp.
func ExtractProviderResponseHeadersFromHTTP(resp *http.Response) map[string]string {
if resp == nil {
return nil
}
headers := make(map[string]string)
for k, values := range resp.Header {
if !providerResponseFilterHeaders[strings.ToLower(k)] && len(values) > 0 {
headers[k] = strings.Join(values, ", ")
}
}
if len(headers) == 0 {
return nil
}
return headers
}
// SetExtraHeaders sets additional headers from NetworkConfig to the fasthttp request.
// This allows users to configure custom headers for their provider requests.
// Header keys are canonicalized using textproto.CanonicalMIMEHeaderKey to avoid duplicates.
// It accepts a list of headers (all canonicalized) to skip for security reasons.
// Headers are only set if they don't already exist on the request to avoid overwriting important headers.
func SetExtraHeaders(ctx context.Context, req *fasthttp.Request, extraHeaders map[string]string, skipHeaders []string) {
for key, value := range extraHeaders {
canonicalKey := textproto.CanonicalMIMEHeaderKey(key)
if skipHeaders != nil {
if slices.Contains(skipHeaders, key) {
continue
}
}
// Only set the header if it doesn't already exist to avoid overwriting important headers
if len(req.Header.Peek(canonicalKey)) == 0 {
req.Header.Set(canonicalKey, value)
}
}
// Give priority to extra headers in the context
if extraHeaders, ok := (ctx).Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok {
for k, values := range filterHeaders(extraHeaders) {
if skipHeaders != nil && slices.Contains(skipHeaders, strings.ToLower(k)) {
continue
}
for i, v := range values {
if i == 0 {
req.Header.Set(k, v)
} else {
req.Header.Add(k, v)
}
}
}
}
}
// GetPathFromContext gets the path from the context, if it exists, otherwise returns the default path.
func GetPathFromContext(ctx context.Context, defaultPath string) string {
if pathInContext, ok := ctx.Value(schemas.BifrostContextKeyURLPath).(string); ok {
return pathInContext
}
return defaultPath
}
// GetRequestPath gets the request path from the context, if it exists, checking for path overrides in the custom provider config.
// It returns the resolved value and a boolean indicating whether the value is a full absolute URL.
// If the boolean is false, the returned string is a path (leading slash ensured).
func GetRequestPath(ctx context.Context, defaultPath string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType) (string, bool) {
// If path/url set in context, return it.
if pathInContext, ok := ctx.Value(schemas.BifrostContextKeyURLPath).(string); ok {
trimmed := strings.TrimSpace(pathInContext)
if u, err := url.Parse(trimmed); err == nil && u != nil && u.IsAbs() && u.Host != "" {
return trimmed, true
}
return trimmed, false
}
// If path override set in custom provider config, return it.
if customProviderConfig != nil && customProviderConfig.RequestPathOverrides != nil {
if raw, ok := customProviderConfig.RequestPathOverrides[requestType]; ok {
override := strings.TrimSpace(raw)
if override == "" {
return defaultPath, false
}
// Treat absolute URLs with scheme+host as full URLs.
if u, err := url.Parse(override); err == nil && u != nil && u.IsAbs() && u.Host != "" {
return override, true
}
// Otherwise treat as a path override (ensure leading slash).
if !strings.HasPrefix(override, "/") {
override = "/" + override
}
return override, false
}
}
// Return default path.
return defaultPath, false
}
type RequestBodyGetter interface {
GetRawRequestBody() []byte
}
// CheckAndGetRawRequestBody checks if the raw request body should be used, and returns it if it exists.
func CheckAndGetRawRequestBody(ctx context.Context, request RequestBodyGetter) ([]byte, bool) {
if rawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && rawBody {
return request.GetRawRequestBody(), true
}
return nil, false
}
type RequestBodyWithExtraParams interface {
GetExtraParams() map[string]interface{}
}
type RequestBodyConverter func() (RequestBodyWithExtraParams, error)
// IsLargePayloadPassthroughEnabled returns true when large payload mode has already
// prepared an upstream body reader in context.
func IsLargePayloadPassthroughEnabled(ctx context.Context) bool {
if ctx == nil {
return false
}
isLargePayload, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool)
if !ok || !isLargePayload {
return false
}
reader, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadReader).(io.Reader)
return ok && reader != nil
}
// ApplyLargePayloadRequestBody applies the request body reader from context to the
// outgoing provider request. Returns true when a streaming body was applied.
func ApplyLargePayloadRequestBody(ctx context.Context, req *fasthttp.Request) bool {
return ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, "")
}
const largePayloadModelRewriteScanBytes = 256 * 1024
// ApplyLargePayloadRequestBodyWithModelNormalization applies the streaming body
// reader from context and optionally rewrites prefixed model values for JSON
// passthrough requests (for example "openai/gpt-5" -> "gpt-5").
// This preserves low-memory streaming while keeping large-payload behavior
// aligned with the normal parsed path that strips provider prefixes.
func ApplyLargePayloadRequestBodyWithModelNormalization(
ctx context.Context,
req *fasthttp.Request,
defaultProvider schemas.ModelProvider,
) bool {
if req == nil || !IsLargePayloadPassthroughEnabled(ctx) {
return false
}
bodyReader, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadReader).(io.Reader)
bodySize := -1
if contentLength, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadContentLength).(int); ok {
bodySize = contentLength
}
if contentType, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadContentType).(string); ok && contentType != "" {
ctLower := strings.ToLower(contentType)
if metadata, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadMetadata).(*schemas.LargePayloadMetadata); ok && metadata != nil {
if rawModel := strings.TrimSpace(metadata.Model); rawModel != "" && defaultProvider != "" {
_, normalizedModel := schemas.ParseModelString(rawModel, defaultProvider)
if normalizedModel != "" && normalizedModel != rawModel {
if strings.Contains(ctLower, "application/json") {
rewrittenReader, sizeDelta := RewriteLargePayloadModelInJSONPrefix(bodyReader, rawModel, normalizedModel)
bodyReader = rewrittenReader
if bodySize >= 0 {
bodySize += sizeDelta
}
} else if strings.Contains(ctLower, "multipart/form-data") {
rewrittenReader, sizeDelta := RewriteLargePayloadModelInMultipartPrefix(bodyReader, rawModel, normalizedModel)
bodyReader = rewrittenReader
if bodySize >= 0 {
bodySize += sizeDelta
}
}
}
}
}
req.Header.SetContentType(contentType)
}
req.SetBodyStream(bodyReader, bodySize)
return true
}
// RewriteLargePayloadModelInJSONPrefix reads the first 256KB of a streaming body,
// rewrites the "model" JSON value from fromModel to toModel, and returns a
// combined reader (rewritten prefix + remaining stream) with the size delta.
func RewriteLargePayloadModelInJSONPrefix(reader io.Reader, fromModel, toModel string) (io.Reader, int) {
if reader == nil || fromModel == "" || toModel == "" || fromModel == toModel {
return reader, 0
}
prefix := make([]byte, largePayloadModelRewriteScanBytes)
n, err := io.ReadFull(reader, prefix)
if n == 0 && err != nil {
return reader, 0
}
prefix = prefix[:n]
rewrittenPrefix, changed := rewriteJSONModelValue(prefix, fromModel, toModel)
if !changed {
return io.MultiReader(bytes.NewReader(prefix), reader), 0
}
return io.MultiReader(bytes.NewReader(rewrittenPrefix), reader), len(rewrittenPrefix) - len(prefix)
}
func rewriteJSONModelValue(data []byte, fromModel, toModel string) ([]byte, bool) {
if len(data) == 0 || fromModel == "" || toModel == "" || fromModel == toModel {
return data, false
}
pattern := []byte(`"model"`)
searchFrom := 0
for {
match := bytes.Index(data[searchFrom:], pattern)
if match < 0 {
return data, false
}
idx := searchFrom + match + len(pattern)
for idx < len(data) && (data[idx] == ' ' || data[idx] == '\t' || data[idx] == '\r' || data[idx] == '\n') {
idx++
}
if idx >= len(data) || data[idx] != ':' {
searchFrom += match + len(pattern)
continue
}
idx++
for idx < len(data) && (data[idx] == ' ' || data[idx] == '\t' || data[idx] == '\r' || data[idx] == '\n') {
idx++
}
if idx >= len(data) || data[idx] != '"' {
searchFrom += match + len(pattern)
continue
}
valueStart := idx + 1
valueEnd := valueStart
escaped := false
for valueEnd < len(data) {
ch := data[valueEnd]
if escaped {
escaped = false
valueEnd++
continue
}
if ch == '\\' {
escaped = true
valueEnd++
continue
}
if ch == '"' {
break
}
valueEnd++
}
if valueEnd >= len(data) {
return data, false
}
if string(data[valueStart:valueEnd]) != fromModel {
searchFrom = valueEnd + 1
continue
}
rewritten := make([]byte, 0, len(data)-len(fromModel)+len(toModel))
rewritten = append(rewritten, data[:valueStart]...)
rewritten = append(rewritten, toModel...)
rewritten = append(rewritten, data[valueEnd:]...)
return rewritten, true
}
}
// RewriteLargePayloadModelInMultipartPrefix reads the first 256KB of a streaming
// multipart body, finds the model form field value, and rewrites it from fromModel
// to toModel. The model field appears early in multipart bodies (typically the first
// form field), so scanning the prefix is sufficient.
func RewriteLargePayloadModelInMultipartPrefix(reader io.Reader, fromModel, toModel string) (io.Reader, int) {
if reader == nil || fromModel == "" || toModel == "" || fromModel == toModel {
return reader, 0
}
prefix := make([]byte, largePayloadModelRewriteScanBytes)
n, err := io.ReadFull(reader, prefix)
if n == 0 && err != nil {
return reader, 0
}
prefix = prefix[:n]
// In multipart, the model value appears as:
// ...name="model"\r\n\r\nopenai/whisper-1\r\n--boundary...
// A direct byte replacement of fromModel→toModel in the prefix is safe because
// the model string (e.g. "openai/whisper-1") is unique within the form metadata.
from := []byte(fromModel)
to := []byte(toModel)
if idx := bytes.Index(prefix, from); idx >= 0 {
rewritten := make([]byte, 0, len(prefix)-len(from)+len(to))
rewritten = append(rewritten, prefix[:idx]...)
rewritten = append(rewritten, to...)
rewritten = append(rewritten, prefix[idx+len(from):]...)
return io.MultiReader(bytes.NewReader(rewritten), reader), len(rewritten) - len(prefix)
}
return io.MultiReader(bytes.NewReader(prefix), reader), 0
}
// DrainLargePayloadRemainder drains any unread bytes from the large payload reader.
// This is useful for request types that may receive an upstream response before the
// incoming client upload is fully consumed (for example, lightweight preflight APIs).
// Example failure this prevents: fronting proxy returns 502/broken-pipe when backend
// responds early while client is still uploading a large body.
func DrainLargePayloadRemainder(ctx context.Context) {
if !IsLargePayloadPassthroughEnabled(ctx) {
return
}
bodyReader, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadReader).(io.Reader)
if bodyReader == nil {
return
}
_, _ = io.Copy(io.Discard, bodyReader)
}
// CloneFastHTTPClientConfig creates a fresh fasthttp.Client by copying only
// config fields from base.
// Never copy fasthttp.Client by value: it contains internal pools and locks.
// Example failure this prevents: parallel load regressions with unexpected buffering
// behavior after `cloned := *base` copies of active clients.
func CloneFastHTTPClientConfig(base *fasthttp.Client) *fasthttp.Client {
if base == nil {
return &fasthttp.Client{}
}
return &fasthttp.Client{
Transport: base.Transport,
DialTimeout: base.DialTimeout,
Dial: base.Dial,
TLSConfig: base.TLSConfig,
RetryIfErr: base.RetryIfErr,
ConfigureClient: base.ConfigureClient,
Name: base.Name,
MaxConnsPerHost: base.MaxConnsPerHost,
MaxIdleConnDuration: base.MaxIdleConnDuration,
MaxConnDuration: base.MaxConnDuration,
MaxIdemponentCallAttempts: base.MaxIdemponentCallAttempts,
ReadBufferSize: base.ReadBufferSize,
WriteBufferSize: base.WriteBufferSize,
ReadTimeout: base.ReadTimeout,
WriteTimeout: base.WriteTimeout,
MaxResponseBodySize: base.MaxResponseBodySize,
MaxConnWaitTimeout: base.MaxConnWaitTimeout,
ConnPoolStrategy: base.ConnPoolStrategy,
NoDefaultUserAgentHeader: base.NoDefaultUserAgentHeader,
DialDualStack: base.DialDualStack,
DisableHeaderNamesNormalizing: base.DisableHeaderNamesNormalizing,
DisablePathNormalizing: base.DisablePathNormalizing,
StreamResponseBody: base.StreamResponseBody,
}
}
// BuildStreamingClient returns a fasthttp.Client suitable for long-lived SSE
// or EventStream responses. It clones base's dialer/proxy/TLS/pool settings,
// then clears Read/Write timeouts and MaxConnDuration so fasthttp does not
// pre-empt a healthy stream. StreamResponseBody is forced on.
//
// Per-chunk idle detection is enforced at the application layer via
// NewIdleTimeoutReader (see GetStreamIdleTimeout / StreamIdleTimeoutInSeconds).
// The initial TCP/TLS dial still honors the base client's ReadTimeout because
// the Dial closure installed by ConfigureDialer reads client.ReadTimeout from
// the base client pointer captured at ConfigureDialer call time — cloning copies
// that closure verbatim, so zeroing the clone's ReadTimeout does not affect dial.
func BuildStreamingClient(base *fasthttp.Client) *fasthttp.Client {
c := CloneFastHTTPClientConfig(base)
c.ReadTimeout = 0
c.WriteTimeout = 0
c.MaxConnDuration = 0
c.StreamResponseBody = true
return c
}
// BuildStreamingHTTPClient returns an *http.Client for long-lived streaming
// responses over net/http (e.g. Bedrock EventStream). It reuses the base's
// Transport (safe for concurrent use by multiple clients) and sets Timeout=0
// so Client.Timeout does not cap the entire request lifecycle including body
// reads. The transport's ResponseHeaderTimeout still bounds the initial
// response-headers wait; per-chunk idle is enforced by NewIdleTimeoutReader.
func BuildStreamingHTTPClient(base *http.Client) *http.Client {
if base == nil {
return &http.Client{}
}
return &http.Client{
Transport: base.Transport,
CheckRedirect: base.CheckRedirect,
Jar: base.Jar,
}
}
// decompressBodyStreamIfGzip checks Content-Encoding for gzip and wraps the stream
// with on-the-fly decompression using a pooled gzip.Reader. Clears Content-Encoding
// header so downstream consumers don't double-decompress. Returns original reader
// unchanged if not gzip-encoded or if gzip reader creation fails.
func decompressBodyStreamIfGzip(resp *fasthttp.Response, stream io.Reader) (*gzip.Reader, io.Reader, bool) {
ce := strings.ToLower(strings.TrimSpace(string(resp.Header.Peek("Content-Encoding"))))
if !strings.Contains(ce, "gzip") {
return nil, stream, false
}
gz, err := AcquireGzipReader(stream)
if err != nil {
ReleaseGzipReader(gz)
return nil, stream, false
}
resp.Header.Del("Content-Encoding")
return gz, gz, true
}
// DecompressStreamBody returns a reader for consuming the response body, with
// on-the-fly gzip decompression when Content-Encoding indicates gzip. The response
// object is NOT modified (no SetBodyStream call), so the original requestStream
// remains live for proper cleanup by ReleaseStreamingResponse. Clears the
// Content-Encoding header to prevent double-decompression.
//
// Returns:
// - io.Reader: the reader to use for scanning (gzip reader if gzip-encoded,
// original body stream otherwise).
// - func(): cleanup function that releases the gzip reader back to the pool.
// Must be called (typically via defer) after streaming is complete.
func DecompressStreamBody(resp *fasthttp.Response) (io.Reader, func()) {
bodyStream := resp.BodyStream()
if bodyStream == nil {
// Return an empty reader instead of nil to prevent panics in callers
// that pass the reader to bufio.NewScanner without nil checks.
return bytes.NewReader(nil), func() {}
}
gz, decompressed, wasGzip := decompressBodyStreamIfGzip(resp, bodyStream)
if !wasGzip {
return bodyStream, func() {}
}
return decompressed, func() {
ReleaseGzipReader(gz)
}
}
// DrainNonSSEStreamResponse checks if the upstream response is a Server-Sent Events stream.
// If not SSE, drains the body to io.Discard to prevent bufio.Scanner buffer bloat on
// non-line-delimited data. Returns true if body was drained (caller should skip scanner).
// We intentionally do not touch valid SSE bodies here: callers must continue reading from
// the reader returned by DecompressStreamBody, and draining SSE in this helper would consume
// the stream before the scanner/manual event loop starts.
func DrainNonSSEStreamResponse(resp *fasthttp.Response) bool {
ct := strings.ToLower(string(resp.Header.ContentType()))
if strings.Contains(ct, "text/event-stream") {
return false
}
if bodyStream := resp.BodyStream(); bodyStream != nil {
_, _ = io.Copy(io.Discard, bodyStream)
}
return true
}
// MergeExtraParams merges extraParams into jsonMap, handling nested maps recursively.
func MergeExtraParams(jsonMap map[string]interface{}, extraParams map[string]interface{}) {
for k, v := range extraParams {
if existingVal, exists := jsonMap[k]; exists {
if existingMap, ok := existingVal.(map[string]interface{}); ok {
if newMap, ok := v.(map[string]interface{}); ok {
MergeExtraParams(existingMap, newMap)
continue
}
}
}
jsonMap[k] = v
}
}
// MergeExtraParamsIntoJSON merges extra params into serialized JSON while preserving
// the original key ordering. This avoids the order-destroying roundtrip through
// map[string]interface{} that would lose key ordering in tool schemas and other
// order-sensitive JSON structures.
func MergeExtraParamsIntoJSON(jsonBody []byte, extraParams map[string]interface{}) ([]byte, error) {
trimmed := bytes.TrimSpace(jsonBody)
if len(trimmed) < 2 || trimmed[0] != '{' {
return jsonBody, nil // not a JSON object, return as-is
}
// Parse existing JSON into ordered key-value pairs using encoding/json
// (not sonic) to preserve document key order via token-by-token parsing.
dec := json.NewDecoder(bytes.NewReader(trimmed))
dec.UseNumber()
if _, err := dec.Token(); err != nil { // '{'
return jsonBody, nil
}
type kvPair struct {
key string
val json.RawMessage
}
var pairs []kvPair
seen := make(map[string]int)
for dec.More() {
tok, err := dec.Token()
if err != nil {
return jsonBody, nil
}
key, ok := tok.(string)
if !ok {
return jsonBody, nil
}
var val json.RawMessage
if err := dec.Decode(&val); err != nil {
return jsonBody, nil
}
seen[key] = len(pairs)
pairs = append(pairs, kvPair{key, val})
}
// Add/merge extra params (deterministic order for new keys)
extraKeys := make([]string, 0, len(extraParams))
for k := range extraParams {
extraKeys = append(extraKeys, k)
}
sort.Strings(extraKeys)
for _, k := range extraKeys {
v := extraParams[k]
newValBytes, err := MarshalSorted(v)
if err != nil {
continue
}
if idx, exists := seen[k]; exists {
// If both existing and new are JSON objects, merge recursively
existingTrimmed := bytes.TrimSpace(pairs[idx].val)
newTrimmed := bytes.TrimSpace(newValBytes)
if len(existingTrimmed) > 0 && existingTrimmed[0] == '{' &&
len(newTrimmed) > 0 && newTrimmed[0] == '{' {
var existingMap, newMap map[string]interface{}
existingDec := json.NewDecoder(bytes.NewReader(existingTrimmed))
existingDec.UseNumber()
newDec := json.NewDecoder(bytes.NewReader(newTrimmed))
newDec.UseNumber()
if existingDec.Decode(&existingMap) == nil {
if newDec.Decode(&newMap) == nil {
MergeExtraParams(existingMap, newMap)
if merged, err := MarshalSorted(existingMap); err == nil {
pairs[idx].val = merged
continue
}
}
}
}
// Non-object or merge failed: overwrite in place (preserving position)
pairs[idx].val = newValBytes
} else {
// New key: append at end
pairs = append(pairs, kvPair{k, newValBytes})
}
}
// Rebuild compact JSON, then indent for consistent formatting
var compact bytes.Buffer
compact.WriteByte('{')
for i, kv := range pairs {
if i > 0 {
compact.WriteByte(',')
}
keyBytes, err := sonic.Marshal(kv.key)
if err != nil {
return jsonBody, err
}
compact.Write(keyBytes)
compact.WriteByte(':')
// Use trimmed value to remove any existing indentation
compact.Write(bytes.TrimSpace(kv.val))
}
compact.WriteByte('}')
// Re-indent to match the expected formatting
var indented bytes.Buffer
if err := json.Indent(&indented, compact.Bytes(), "", " "); err != nil {
return compact.Bytes(), nil
}
return indented.Bytes(), nil
}
// CheckContextAndGetRequestBody checks if the raw request body should be used, and returns it if it exists.
func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGetter, requestConverter RequestBodyConverter) ([]byte, *schemas.BifrostError) {
if IsLargePayloadPassthroughEnabled(ctx) {
return nil, nil
}
rawBody, ok := CheckAndGetRawRequestBody(ctx, request)
if !ok {
convertedBody, err := requestConverter()
if err != nil {
return nil, NewBifrostOperationError(schemas.ErrRequestBodyConversion, err)
}
if convertedBody == nil {
return nil, NewBifrostOperationError("request body is not provided", nil)
}
jsonBody, err := MarshalSortedIndent(convertedBody, "", " ")
if err != nil {
return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
// Merge ExtraParams into the JSON if passthrough is enabled
if ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) != nil && ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
extraParams := convertedBody.GetExtraParams()
if len(extraParams) > 0 {
// Use order-preserving merge to avoid destroying key ordering in
// tool schemas and other order-sensitive JSON structures.
jsonBody, err = MergeExtraParamsIntoJSON(jsonBody, extraParams)
if err != nil {
return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
}
return jsonBody, nil
} else {
return rawBody, nil
}
}
// SetExtraHeadersHTTP sets additional headers from NetworkConfig to the standard HTTP request.
// This allows users to configure custom headers for their provider requests.
// Header keys are canonicalized using textproto.CanonicalMIMEHeaderKey to avoid duplicates.
// It accepts a list of headers (all canonicalized) to skip for security reasons.
// Headers are only set if they don't already exist on the request to avoid overwriting important headers.
func SetExtraHeadersHTTP(ctx context.Context, req *http.Request, extraHeaders map[string]string, skipHeaders []string) {
for key, value := range extraHeaders {
canonicalKey := textproto.CanonicalMIMEHeaderKey(key)
if skipHeaders != nil {
if slices.Contains(skipHeaders, key) {
continue
}
}
// Only set the header if it doesn't already exist to avoid overwriting important headers
if req.Header.Get(canonicalKey) == "" {
req.Header.Set(canonicalKey, value)
}
}
// Give priority to extra headers in the context
if extraHeaders, ok := (ctx).Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok {
for k, values := range filterHeaders(extraHeaders) {
if skipHeaders != nil && slices.Contains(skipHeaders, strings.ToLower(k)) {
continue
}
for i, v := range values {
if i == 0 {
req.Header.Set(k, v)
} else {
req.Header.Add(k, v)
}
}
}
}
}
// HandleProviderAPIError processes error responses from provider APIs.
// It attempts to unmarshal the error response and returns a BifrostError
// with the appropriate status code and error information.
// HTML detection only runs if JSON parsing fails to avoid expensive regex operations
// on responses that are almost certainly valid JSON. errorResp must be a pointer to
// the target struct for unmarshaling.
func HandleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.BifrostError {
statusCode := resp.StatusCode()
// Decode body
decodedBody, err := CheckAndDecodeBody(resp)
if err != nil {
// Decode failed - still capture raw body for RawResponse
rawBody := resp.Body()
var rawErrorResponse interface{}
if len(rawBody) > 0 {
// Try to unmarshal, but if that fails, store as string
if unmarshalErr := sonic.Unmarshal(rawBody, &rawErrorResponse); unmarshalErr != nil {
rawErrorResponse = string(rawBody)
}
}
return &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &statusCode,
Error: &schemas.ErrorField{
Message: err.Error(),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RawResponse: rawErrorResponse,
},
}
}
// Try to unmarshal decoded body for RawResponse
var rawErrorResponse interface{}
if err := sonic.Unmarshal(decodedBody, &rawErrorResponse); err != nil {
// Store raw body as string for RawResponse when JSON parsing fails
// Continue to HTML detection and proper error handling below
rawErrorResponse = string(decodedBody)
}
// Check for empty response
trimmed := strings.TrimSpace(string(decodedBody))
if len(trimmed) == 0 {
// Provide a more descriptive error message based on HTTP status code
var errorMessage string
switch statusCode {
case 401:
errorMessage = "authentication failed: unauthorized (401) - check your API key"
case 403:
errorMessage = "access forbidden (403) - your API key may not have permission for this operation"
case 404:
errorMessage = "resource not found (404)"
case 429:
errorMessage = "rate limit exceeded (429)"
case 500, 502, 503, 504:
errorMessage = fmt.Sprintf("provider server error (%d)", statusCode)
default:
errorMessage = fmt.Sprintf("%s (HTTP %d)", schemas.ErrProviderResponseEmpty, statusCode)
}
return &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &statusCode,
Error: &schemas.ErrorField{
Message: errorMessage,
},
ExtraFields: schemas.BifrostErrorExtraFields{
RawResponse: rawErrorResponse,
},
}
}
// Try JSON parsing first
if err := sonic.Unmarshal(decodedBody, errorResp); err == nil {
// JSON parsing succeeded, return success
return &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &statusCode,
Error: &schemas.ErrorField{},
ExtraFields: schemas.BifrostErrorExtraFields{
RawResponse: rawErrorResponse,
},
}
}
// JSON parsing failed - now check if it's an HTML response (expensive operation)
if IsHTMLResponse(resp, decodedBody) {
return &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &statusCode,
Error: &schemas.ErrorField{
Message: schemas.ErrProviderResponseHTML,
Error: errors.New(string(decodedBody)),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RawResponse: rawErrorResponse,
},
}
}
// Not HTML either - return raw response as error message
message := fmt.Sprintf("provider API error: %s", string(decodedBody))
return &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &statusCode,
Error: &schemas.ErrorField{
Message: message,
},
ExtraFields: schemas.BifrostErrorExtraFields{
RawResponse: rawErrorResponse,
},
}
}
// EnrichError attaches the raw request and response to a BifrostError.
// Returns the request and response from provider embedded in BifrostError.ExtraFields.
func EnrichError(
ctx *schemas.BifrostContext,
bifrostErr *schemas.BifrostError,
requestBody []byte,
responseBody []byte,
sendBackRawRequest bool,
sendBackRawResponse bool,
) *schemas.BifrostError {
if bifrostErr == nil {
return bifrostErr
}
if ShouldSendBackRawRequest(ctx, sendBackRawRequest) && len(requestBody) > 0 {
// Store as json.RawMessage to preserve exact JSON bytes (including key ordering).
// Compact to remove insignificant whitespace that would break SSE framing.
bifrostErr.ExtraFields.RawRequest = compactRawJSON(requestBody)
} else {
bifrostErr.ExtraFields.RawRequest = nil
}
if ShouldSendBackRawResponse(ctx, sendBackRawResponse) {
if len(responseBody) > 0 {
bifrostErr.ExtraFields.RawResponse = compactRawJSON(responseBody)
}
} else {
bifrostErr.ExtraFields.RawResponse = nil
}
return bifrostErr
}
// HandleProviderResponse handles common response parsing logic for provider responses.
// It attempts to parse the response body into the provided response type
// and returns either the parsed response or a BifrostError if parsing fails.
// If sendBackRawResponse is true, it returns the raw response interface, otherwise nil.
// HTML detection only runs if JSON parsing fails to avoid expensive regex operations
// on responses that are almost certainly valid JSON.
func HandleProviderResponse[T any](responseBody []byte, response *T, requestBody []byte, sendBackRawRequest bool, sendBackRawResponse bool) (rawRequest interface{}, rawResponse interface{}, bifrostErr *schemas.BifrostError) {
// Check for empty response
trimmed := strings.TrimSpace(string(responseBody))
if len(trimmed) == 0 {
return nil, nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: schemas.ErrProviderResponseEmpty,
},
}
}
// Skip raw request capture if requestBody is nil (e.g., for GET requests)
shouldCaptureRawRequest := sendBackRawRequest && requestBody != nil
if shouldCaptureRawRequest {
// Store as json.RawMessage to preserve the exact JSON bytes (including key ordering).
// Previously this used sonic.Unmarshal into interface{} which created map[string]interface{}
// and destroyed key ordering in tool schemas and other order-sensitive structures.
// Compact to remove insignificant whitespace that would break SSE framing.
rawRequest = compactRawJSON(requestBody)
}
if sendBackRawResponse {
rawResponse = compactRawJSON(responseBody)
}
// Unmarshal the structured response
structuredErr := sonic.Unmarshal(responseBody, response)
if structuredErr != nil {
// JSON parsing failed - check if it's an HTML response (expensive operation)
if IsHTMLResponse(nil, responseBody) {
return nil, nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: schemas.ErrProviderResponseHTML,
Error: errors.New(string(responseBody)),
},
}
}
return nil, nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: schemas.ErrProviderResponseUnmarshal,
Error: structuredErr,
},
}
}
if shouldCaptureRawRequest || sendBackRawResponse {
return rawRequest, rawResponse, nil
}
return nil, nil, nil
}
// compactRawJSON removes insignificant whitespace from JSON bytes, returning a
// json.RawMessage safe for SSE streaming (no literal newlines). Falls back to
// the original bytes if compaction fails (e.g., invalid JSON).
func compactRawJSON(data []byte) json.RawMessage {
var buf bytes.Buffer
if err := schemas.Compact(&buf, data); err == nil {
return json.RawMessage(buf.Bytes())
}
return json.RawMessage(data)
}
// ParseAndSetRawRequest stores the raw request body in the extra fields.
// Uses json.RawMessage to preserve the exact JSON bytes (including key ordering).
// The body is compacted to remove insignificant whitespace, which prevents
// literal newlines from breaking SSE data-line framing during streaming.
func ParseAndSetRawRequest(extraFields *schemas.BifrostResponseExtraFields, jsonBody []byte) {
if len(jsonBody) > 0 {
extraFields.RawRequest = compactRawJSON(jsonBody)
}
}
// ParseAndSetRawRequestIfJSON parses the request body if it's JSON and sets the raw request in the extra fields.
func ParseAndSetRawRequestIfJSON(fasthttpReq *fasthttp.Request, extraFields *schemas.BifrostResponseExtraFields) {
extraFields.RawRequest = nil
contentType := strings.ToLower(string(fasthttpReq.Header.ContentType()))
if strings.Contains(contentType, "application/json") {
body := append([]byte(nil), fasthttpReq.Body()...)
ParseAndSetRawRequest(extraFields, body)
}
}
// NewUnsupportedOperationError creates a standardized error for unsupported operations.
// This helper reduces code duplication across providers that don't support certain operations.
func NewUnsupportedOperationError(requestType schemas.RequestType, providerName schemas.ModelProvider) *schemas.BifrostError {
return &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: fmt.Sprintf("%s is not supported by %s provider", requestType, providerName),
Code: schemas.Ptr("unsupported_operation"),
},
}
}
// CheckOperationAllowed enforces per-op gating using schemas.Operation.
// Behavior:
// - If no gating is configured (config == nil or AllowedRequests == nil), the operation is allowed.
// - If gating is configured, returns an error when the operation is not explicitly allowed.
func CheckOperationAllowed(defaultProvider schemas.ModelProvider, config *schemas.CustomProviderConfig, operation schemas.RequestType) *schemas.BifrostError {
// No gating configured => allowed
if config == nil || config.AllowedRequests == nil {
return nil
}
// Explicitly allowed?
if config.IsOperationAllowed(operation) {
return nil
}
// Gated and not allowed
resolved := GetProviderName(defaultProvider, config)
return NewUnsupportedOperationError(operation, resolved)
}
// CheckAndDecodeBody checks the content encoding and decodes the body accordingly.
// It returns a copy of the body to avoid race conditions when the response is released
// back to fasthttp's buffer pool. Uses pooled gzip readers to reduce GC pressure.
func CheckAndDecodeBody(resp *fasthttp.Response) ([]byte, error) {
contentEncoding := strings.ToLower(strings.TrimSpace(string(resp.Header.Peek("Content-Encoding"))))
if strings.Contains(contentEncoding, "gzip") {
body := resp.Body()
if len(body) == 0 {
return nil, nil
}
reader := bytes.NewReader(body)
gz, err := AcquireGzipReader(reader)
if err != nil {
return nil, err
}
defer ReleaseGzipReader(gz)
decompressed, err := io.ReadAll(gz)
if err != nil {
return nil, err
}
return decompressed, nil
}
// Copy the body to avoid race conditions when response is released back to pool
body := resp.Body()
result := make([]byte, len(body))
copy(result, body)
return result, nil
}
// IsHTMLResponse checks if the response is HTML by examining the Content-Type header
// and/or the response body for HTML indicators.
func IsHTMLResponse(resp *fasthttp.Response, body []byte) bool {
// Check Content-Type header first (most reliable indicator)
if resp != nil {
contentType := strings.ToLower(string(resp.Header.Peek("Content-Type")))
if strings.Contains(contentType, "text/html") {
return true
}
}
// If body is small, it's unlikely to be HTML
if len(body) < 20 {
return false
}
// Check for HTML indicators in body
bodyLower := strings.ToLower(string(body))
// Look for common HTML tags or DOCTYPE
htmlIndicators := []string{
"<!doctype html",
"<html",
"<head",
"<body",
"<title>",
"<h1>",
"<h2>",
"<h3>",
"<p>",
"<div",
}
for _, indicator := range htmlIndicators {
if strings.Contains(bodyLower, indicator) {
return true
}
}
return false
}
// Limit body size to prevent ReDoS on very large malicious responses
const maxBodySize = 32 * 1024 // 32KB
// ExtractHTMLErrorMessage extracts meaningful error information from an HTML response.
// It attempts to find error messages from title tags, headers, and visible text.
// UNUSED for now but could be useful in the future
func ExtractHTMLErrorMessage(body []byte) string {
if len(body) > maxBodySize {
body = body[:maxBodySize]
}
bodyStr := string(body)
bodyLower := strings.ToLower(bodyStr)
// Try to extract title first
if idx := strings.Index(bodyLower, "<title>"); idx != -1 {
endIdx := strings.Index(bodyLower[idx:], "</title>")
if endIdx != -1 {
title := strings.TrimSpace(bodyStr[idx+7 : idx+endIdx])
if title != "" && title != "Error" {
return title
}
}
}
// Try to extract from h1, h2, h3 tags (common for error pages)
for _, tag := range []string{"h1", "h2", "h3"} {
pattern := fmt.Sprintf("<%s[^>]*>([^<]+)</%s>", tag, tag)
re := regexp.MustCompile("(?i)" + pattern)
if matches := re.FindStringSubmatch(bodyStr); len(matches) > 1 {
msg := strings.TrimSpace(matches[1])
if msg != "" {
return msg
}
}
}
// Try to extract from meta description
pattern := `<meta\s+name="description"\s+content="([^"]+)"`
re := regexp.MustCompile("(?i)" + pattern)
if matches := re.FindStringSubmatch(bodyStr); len(matches) > 1 {
msg := strings.TrimSpace(matches[1])
if msg != "" {
return msg
}
}
// Extract visible text: remove script and style tags, then extract text
// Remove script and style tags and their content
re = regexp.MustCompile(`(?i)<script[^>]*>.*?</script>|<style[^>]*>.*?</style>`)
cleaned := re.ReplaceAllString(bodyStr, "")
// Remove HTML tags
re = regexp.MustCompile(`<[^>]+>`)
cleaned = re.ReplaceAllString(cleaned, " ")
// Clean up whitespace and get first meaningful sentence
sentences := strings.FieldsFunc(cleaned, func(r rune) bool {
return r == '\n' || r == '\r'
})
for _, sentence := range sentences {
trimmed := strings.TrimSpace(sentence)
if len(trimmed) > 10 && len(trimmed) < 500 {
// Limit to first 200 chars to avoid very long messages
if len(trimmed) > 200 {
trimmed = trimmed[:200] + "..."
}
return trimmed
}
}
// If all else fails, return a generic message with status code context
return "HTML error response received from provider"
}
// JSONLParseResult holds parsed items and any line-level errors encountered during parsing.
type JSONLParseResult struct {
Errors []schemas.BatchError
}
// ParseJSONL parses JSONL data line by line, calling the provided callback for each line.
// It collects parse errors with line numbers rather than silently skipping failed lines.
// The callback receives the line bytes and returns an error if parsing fails.
// This function operates directly on byte slices to avoid unnecessary string conversions.
func ParseJSONL(data []byte, parseLine func(line []byte) error) JSONLParseResult {
result := JSONLParseResult{}
lineNum := 0
start := 0
for i := 0; i <= len(data); i++ {
// Check for newline or end of data
if i == len(data) || data[i] == '\n' {
lineNum++
// Extract the line (excluding the newline character)
end := i
if end > start {
line := data[start:end]
// Trim trailing carriage return for Windows-style line endings
if len(line) > 0 && line[len(line)-1] == '\r' {
line = line[:len(line)-1]
}
// Skip empty lines
if len(line) > 0 {
if err := parseLine(line); err != nil {
lineNumCopy := lineNum
result.Errors = append(result.Errors, schemas.BatchError{
Code: "parse_error",
Message: err.Error(),
Line: &lineNumCopy,
})
}
}
}
start = i + 1
}
}
return result
}
// NewConfigurationError creates a standardized error for configuration errors.
// This helper reduces code duplication across providers that have configuration errors.
func NewConfigurationError(message string) *schemas.BifrostError {
return &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: message,
},
}
}
// NewBifrostOperationError creates a standardized error for bifrost operation errors.
// This helper reduces code duplication across providers that have bifrost operation errors.
func NewBifrostOperationError(message string, err error) *schemas.BifrostError {
return &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: message,
Error: err,
},
}
}
// NewBifrostTimeoutError creates a standardized error for provider request timeout errors.
// Sets StatusCode to 504 (Gateway Timeout) and Error.Type to RequestTimedOut,
// consistent with HandleStreamTimeout for streaming requests.
func NewBifrostTimeoutError(message string, err error) *schemas.BifrostError {
statusCode := 504
errorType := schemas.RequestTimedOut
return &schemas.BifrostError{
IsBifrostError: true,
StatusCode: &statusCode,
Error: &schemas.ErrorField{
Message: message,
Type: &errorType,
Error: err,
},
}
}
// NewProviderAPIError creates a standardized error for provider API errors.
// This helper reduces code duplication across providers that have provider API errors.
func NewProviderAPIError(message string, err error, statusCode int, errorType *string, eventID *string) *schemas.BifrostError {
return &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &statusCode,
Type: errorType,
EventID: eventID,
Error: &schemas.ErrorField{
Message: message,
Error: err,
Type: errorType,
},
}
}
// ShouldSendBackRawRequest checks if raw request bytes should be captured.
// bifrost.go always writes BifrostContextKeyCaptureRawRequest before provider dispatch,
// combining provider config, per-request overrides, and store_raw_request_response.
// The default parameter is a fallback for callers outside the normal bifrost dispatch path.
func ShouldSendBackRawRequest(ctx context.Context, defaultSendBackRawRequest bool) bool {
if capture, ok := ctx.Value(schemas.BifrostContextKeyCaptureRawRequest).(bool); ok {
return capture
}
return defaultSendBackRawRequest
}
// ShouldSendBackRawResponse checks if raw response bytes should be captured.
// bifrost.go always writes BifrostContextKeyCaptureRawResponse before provider dispatch,
// combining provider config, per-request overrides, and store_raw_request_response.
// The default parameter is a fallback for callers outside the normal bifrost dispatch path.
func ShouldSendBackRawResponse(ctx context.Context, defaultSendBackRawResponse bool) bool {
if capture, ok := ctx.Value(schemas.BifrostContextKeyCaptureRawResponse).(bool); ok {
return capture
}
return defaultSendBackRawResponse
}
// SendCreatedEventResponsesChunk sends a ResponsesStreamResponseTypeCreated event.
func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk, postHookSpanFinalizer func(context.Context)) {
firstChunk := &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeCreated,
SequenceNumber: 0,
Response: &schemas.BifrostResponsesResponse{},
ExtraFields: schemas.BifrostResponseExtraFields{
ChunkIndex: 0,
Latency: time.Since(startTime).Milliseconds(),
},
}
// TODO add bifrost response pooling here
bifrostResponse := &schemas.BifrostResponse{
ResponsesStreamResponse: firstChunk,
}
ProcessAndSendResponse(ctx, postHookRunner, bifrostResponse, responseChan, postHookSpanFinalizer)
}
// SendInProgressEventResponsesChunk sends a ResponsesStreamResponseTypeInProgress event
func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk, postHookSpanFinalizer func(context.Context)) {
chunk := &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeInProgress,
SequenceNumber: 1,
Response: &schemas.BifrostResponsesResponse{},
ExtraFields: schemas.BifrostResponseExtraFields{
ChunkIndex: 1,
Latency: time.Since(startTime).Milliseconds(),
},
}
// TODO add bifrost response pooling here
bifrostResponse := &schemas.BifrostResponse{
ResponsesStreamResponse: chunk,
}
ProcessAndSendResponse(ctx, postHookRunner, bifrostResponse, responseChan, postHookSpanFinalizer)
}
// BuildClientStreamChunk constructs a BifrostStreamChunk from post-hook results.
// It never mutates the shared processedResponse or processedError objects — when raw fields
// need to be stripped (captured for storage but not for send-back), it shallow-copies each
// inner response struct and nils only the appropriate per-side field on those copies.
// This is safe for concurrent PostLLMHook goroutines that still hold references to the originals.
func BuildClientStreamChunk(ctx context.Context, processedResponse *schemas.BifrostResponse, processedError *schemas.BifrostError) *schemas.BifrostStreamChunk {
dropReq, _ := ctx.Value(schemas.BifrostContextKeyDropRawRequestFromClient).(bool)
dropResp, _ := ctx.Value(schemas.BifrostContextKeyDropRawResponseFromClient).(bool)
drop := dropReq || dropResp
streamResponse := &schemas.BifrostStreamChunk{}
if processedResponse != nil {
streamResponse.BifrostTextCompletionResponse = processedResponse.TextCompletionResponse
streamResponse.BifrostChatResponse = processedResponse.ChatResponse
streamResponse.BifrostResponsesStreamResponse = processedResponse.ResponsesStreamResponse
streamResponse.BifrostSpeechStreamResponse = processedResponse.SpeechStreamResponse
streamResponse.BifrostTranscriptionStreamResponse = processedResponse.TranscriptionStreamResponse
streamResponse.BifrostImageGenerationStreamResponse = processedResponse.ImageGenerationStreamResponse
// Strip raw fields from client-facing copies without mutating the shared objects
// that PostLLMHook goroutines may still be reading.
if drop {
if streamResponse.BifrostTextCompletionResponse != nil {
cp := *streamResponse.BifrostTextCompletionResponse
if dropReq {
cp.ExtraFields.RawRequest = nil
}
if dropResp {
cp.ExtraFields.RawResponse = nil
}
streamResponse.BifrostTextCompletionResponse = &cp
}
if streamResponse.BifrostChatResponse != nil {
cp := *streamResponse.BifrostChatResponse
if dropReq {
cp.ExtraFields.RawRequest = nil
}
if dropResp {
cp.ExtraFields.RawResponse = nil
}
streamResponse.BifrostChatResponse = &cp
}
if streamResponse.BifrostResponsesStreamResponse != nil {
cp := *streamResponse.BifrostResponsesStreamResponse
if dropReq {
cp.ExtraFields.RawRequest = nil
}
if dropResp {
cp.ExtraFields.RawResponse = nil
}
streamResponse.BifrostResponsesStreamResponse = &cp
}
if streamResponse.BifrostSpeechStreamResponse != nil {
cp := *streamResponse.BifrostSpeechStreamResponse
if dropReq {
cp.ExtraFields.RawRequest = nil
}
if dropResp {
cp.ExtraFields.RawResponse = nil
}
streamResponse.BifrostSpeechStreamResponse = &cp
}
if streamResponse.BifrostTranscriptionStreamResponse != nil {
cp := *streamResponse.BifrostTranscriptionStreamResponse
if dropReq {
cp.ExtraFields.RawRequest = nil
}
if dropResp {
cp.ExtraFields.RawResponse = nil
}
streamResponse.BifrostTranscriptionStreamResponse = &cp
}
if streamResponse.BifrostImageGenerationStreamResponse != nil {
cp := *streamResponse.BifrostImageGenerationStreamResponse
if dropReq {
cp.ExtraFields.RawRequest = nil
}
if dropResp {
cp.ExtraFields.RawResponse = nil
}
streamResponse.BifrostImageGenerationStreamResponse = &cp
}
}
}
if processedError != nil {
if drop {
// Strip raw fields from a client-facing copy without mutating the shared error object.
errCopy := *processedError
if dropReq {
errCopy.ExtraFields.RawRequest = nil
}
if dropResp {
errCopy.ExtraFields.RawResponse = nil
}
streamResponse.BifrostError = &errCopy
} else {
streamResponse.BifrostError = processedError
}
}
return streamResponse
}
// ProcessAndSendResponse handles post-hook processing and sends the response to the channel.
// This utility reduces code duplication across streaming implementations by encapsulating
// the common pattern of running post hooks, handling errors, and sending responses with
// proper context cancellation handling.
// It also completes the deferred LLM span when the final chunk is sent (StreamEndIndicator is true).
func ProcessAndSendResponse(
ctx *schemas.BifrostContext,
postHookRunner schemas.PostHookRunner,
response *schemas.BifrostResponse,
responseChan chan *schemas.BifrostStreamChunk,
postHookSpanFinalizer func(context.Context),
) {
// Accumulate chunk for tracing (common for all providers)
if tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer); ok && tracer != nil {
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" {
tracer.AddStreamingChunk(traceID, response)
}
}
// Run post hooks on the response (note: accumulated chunks above contain pre-hook data)
processedResponse, processedError := postHookRunner(ctx, response, nil)
if HandleStreamControlSkip(processedError) {
// Even if skipping, complete the deferred span if this is the final chunk
if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil {
if final, ok := isFinalChunk.(bool); ok && final {
completeDeferredSpan(ctx, processedResponse, processedError, postHookSpanFinalizer)
}
}
return
}
streamResponse := BuildClientStreamChunk(ctx, processedResponse, processedError)
select {
case responseChan <- streamResponse:
case <-ctx.Done():
return
}
// Check if this is the final chunk and complete deferred span with post-processed data
if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil {
if final, ok := isFinalChunk.(bool); ok && final {
completeDeferredSpan(ctx, processedResponse, processedError, postHookSpanFinalizer)
}
}
}
// ProcessAndSendBifrostError handles post-hook processing and sends the bifrost error to the channel.
// This utility reduces code duplication across streaming implementations by encapsulating
// the common pattern of running post hooks, handling errors, and sending responses with
// proper context cancellation handling.
// It also completes the deferred LLM span when the final chunk is sent (StreamEndIndicator is true).
func ProcessAndSendBifrostError(
ctx *schemas.BifrostContext,
postHookRunner schemas.PostHookRunner,
bifrostErr *schemas.BifrostError,
responseChan chan *schemas.BifrostStreamChunk,
logger schemas.Logger,
postHookSpanFinalizer func(context.Context),
) {
// Run post hooks first so span reflects post-processed data
processedResponse, processedError := postHookRunner(ctx, nil, bifrostErr)
if HandleStreamControlSkip(processedError) {
// Even if skipping, complete the deferred span if this is the final chunk
if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil {
if final, ok := isFinalChunk.(bool); ok && final {
completeDeferredSpan(ctx, processedResponse, processedError, postHookSpanFinalizer)
}
}
return
}
streamResponse := BuildClientStreamChunk(ctx, processedResponse, processedError)
select {
case responseChan <- streamResponse:
case <-ctx.Done():
}
// Check if this is the final chunk and complete deferred span with post-processed data
if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil {
if final, ok := isFinalChunk.(bool); ok && final {
completeDeferredSpan(ctx, processedResponse, processedError, postHookSpanFinalizer)
}
}
}
// EnsureStreamFinalizerCalled invokes the post-hook span finalizer registered
// on ctx, if any. Designed to be deferred as the last line of defence in a
// provider's streaming goroutine (next to SetupStreamCancellation's cleanup):
//
// defer providerUtils.EnsureStreamFinalizerCalled(ctx)
//
// On a normal stream end the finalizer is already invoked when the final chunk
// is processed (via completeDeferredSpan). The registration wraps the closure
// in sync.Once, so this safety-net call is a noop in that case. It only does
// real work when the streaming goroutine exits without reaching the final-chunk
// path — e.g. a panic mid-stream — which would otherwise leak the plugin
// pipeline back-reference held by the finalizer closure.
//
// Panics inside the finalizer are recovered and logged so they never mask an
// in-flight panic that triggered the defer.
func EnsureStreamFinalizerCalled(ctx context.Context, finalizer func(context.Context)) {
defer func() {
if r := recover(); r != nil {
getLogger().Debug("recovered panic in deferred stream finalizer: %v", r)
}
}()
if finalizer == nil {
return
}
finalizer(ctx)
}
// SetupStreamCancellation spawns a goroutine that closes the body stream when
// the context is cancelled or deadline exceeded, unblocking any blocked Read/Scan operations.
// Returns a cleanup function that MUST be called when streaming is done to
// prevent the goroutine from closing the stream during normal operation.
// Works with both fasthttp's BodyStream() (io.Reader) and net/http's resp.Body (io.ReadCloser).
func SetupStreamCancellation(ctx context.Context, bodyStream io.Reader, logger schemas.Logger) (cleanup func()) {
done := make(chan struct{})
closed := make(chan struct{})
go func() {
defer close(closed)
select {
case <-ctx.Done():
// Context cancelled or deadline exceeded - close the body stream to unblock reads
if closer, ok := bodyStream.(io.Closer); ok {
if err := closer.Close(); err != nil {
getLogger().Debug(fmt.Sprintf("Error closing body stream on context done: %v", err))
}
}
case <-done:
// If context was also cancelled (race between done and ctx.Done),
// still close the body stream to unblock the drain in ReleaseStreamingResponse.
if ctx.Err() != nil {
if closer, ok := bodyStream.(io.Closer); ok {
if err := closer.Close(); err != nil {
getLogger().Debug(fmt.Sprintf("Error closing body stream on done with cancelled context: %v", err))
}
}
}
}
}()
return func() {
close(done)
<-closed // Wait for goroutine to finish closing the stream before ReleaseStreamingResponse drains
}
}
// DefaultStreamIdleTimeout is how long a stream read can block with zero data
// before bifrost considers the connection stalled and closes it. This protects
// against providers that stop sending data but keep the TCP connection open
// (e.g., Azure TPM throttling).
const DefaultStreamIdleTimeout = 60 * time.Second
// SetStreamIdleTimeoutIfEmpty sets the stream idle timeout on the context from
// the provider's network config, but only if no valid timeout is already present.
// This allows upstream layers (transport, headers) to set the timeout first,
// with the provider config acting as a fallback.
func SetStreamIdleTimeoutIfEmpty(ctx *schemas.BifrostContext, configSeconds int) {
if existing, ok := ctx.Value(schemas.BifrostContextKeyStreamIdleTimeout).(time.Duration); ok && existing > 0 {
return // already set from upstream (transport/header), respect it
}
if configSeconds > 0 {
ctx.SetValue(schemas.BifrostContextKeyStreamIdleTimeout, time.Duration(configSeconds)*time.Second)
}
}
// GetStreamIdleTimeout reads the per-chunk idle timeout from context,
// falling back to DefaultStreamIdleTimeout if not set.
func GetStreamIdleTimeout(ctx *schemas.BifrostContext) time.Duration {
if timeout, ok := ctx.Value(schemas.BifrostContextKeyStreamIdleTimeout).(time.Duration); ok && timeout > 0 {
return timeout
}
return DefaultStreamIdleTimeout
}
// idleTimeoutReader wraps an io.Reader and closes the underlying body stream
// if no data arrives within the configured timeout. This unblocks any pending
// Read() call on the wrapped reader.
type idleTimeoutReader struct {
reader io.Reader
bodyStream io.Reader // closed via type assertion to io.Closer on timeout
timeout time.Duration
timer *time.Timer
once sync.Once
}
// NewIdleTimeoutReader wraps reader with idle detection. If reader.Read() returns
// no data for the given timeout duration, bodyStream is closed to unblock the read.
// bodyStream must implement io.Closer for the timeout to take effect; if it does not,
// the wrapper still functions but cannot force-close the stream.
// Returns the wrapped reader and a cleanup function that MUST be called (via defer)
// when streaming is complete, to stop the timer and prevent premature closure.
func NewIdleTimeoutReader(reader io.Reader, bodyStream io.Reader, timeout time.Duration) (io.Reader, func()) {
if timeout <= 0 {
timeout = DefaultStreamIdleTimeout
}
r := &idleTimeoutReader{
reader: reader,
bodyStream: bodyStream,
timeout: timeout,
}
r.timer = time.AfterFunc(timeout, func() {
r.once.Do(func() {
if closer, ok := r.bodyStream.(io.Closer); ok {
closer.Close()
}
})
})
return r, func() { r.timer.Stop() }
}
func (r *idleTimeoutReader) Read(p []byte) (int, error) {
n, err := r.reader.Read(p)
if n > 0 {
r.timer.Reset(r.timeout)
}
return n, err
}
// HandleStreamCancellation should be called when a streaming goroutine exits
// due to context cancellation. It ensures proper cleanup by:
// 1. Checking if StreamEndIndicator was already set (to avoid duplicate handling)
// 2. Setting StreamEndIndicator to true
// 3. Sending a cancellation error through PostHook chain
//
// This is critical for the logging plugin to update log status from "processing" to "error"
// when a client disconnects mid-stream.
func HandleStreamCancellation(
ctx *schemas.BifrostContext,
postHookRunner schemas.PostHookRunner,
responseChan chan *schemas.BifrostStreamChunk,
logger schemas.Logger,
postHookSpanFinalizer func(context.Context),
) {
// Check if already handled (StreamEndIndicator already set)
if indicator := ctx.GetAndSetValue(schemas.BifrostContextKeyStreamEndIndicator, true); indicator != nil {
if set, ok := indicator.(bool); ok && set {
return // Already handled
}
}
// Create cancellation error
cancelErr := &schemas.BifrostError{
StatusCode: schemas.Ptr(499), // Client Closed Request
Error: &schemas.ErrorField{
Message: "Request cancelled: client disconnected",
Type: schemas.Ptr(schemas.RequestCancelled),
},
}
// Send through PostHook chain - this updates the log to "error" status
ProcessAndSendBifrostError(ctx, postHookRunner, cancelErr, responseChan, logger, postHookSpanFinalizer)
}
// HandleStreamTimeout should be called when a streaming goroutine exits
// due to context deadline exceeded. It ensures proper cleanup by:
// 1. Checking if StreamEndIndicator was already set (to avoid duplicate handling)
// 2. Setting StreamEndIndicator to true
// 3. Sending a timeout error through PostHook chain
//
// This is critical for the logging plugin to update log status from "processing" to "error"
// when a request times out mid-stream.
func HandleStreamTimeout(
ctx *schemas.BifrostContext,
postHookRunner schemas.PostHookRunner,
responseChan chan *schemas.BifrostStreamChunk,
logger schemas.Logger,
postHookSpanFinalizer func(context.Context),
) {
// Check if already handled (StreamEndIndicator already set)
if indicator := ctx.GetAndSetValue(schemas.BifrostContextKeyStreamEndIndicator, true); indicator != nil {
if set, ok := indicator.(bool); ok && set {
return // Already handled
}
}
// Create timeout error
timeoutErr := &schemas.BifrostError{
StatusCode: schemas.Ptr(504), // Gateway Timeout
Error: &schemas.ErrorField{
Message: "Request timed out: deadline exceeded",
Type: schemas.Ptr(schemas.RequestTimedOut),
},
}
// Send through PostHook chain - this updates the log to "error" status
ProcessAndSendBifrostError(ctx, postHookRunner, timeoutErr, responseChan, logger, postHookSpanFinalizer)
}
// ProcessAndSendError handles post-hook processing and sends the error to the channel.
// This utility reduces code duplication across streaming implementations by encapsulating
// the common pattern of running post hooks, handling errors, and sending responses with
// proper context cancellation handling.
func ProcessAndSendError(
ctx *schemas.BifrostContext,
postHookRunner schemas.PostHookRunner,
err error,
responseChan chan *schemas.BifrostStreamChunk,
logger schemas.Logger,
postHookSpanFinalizer func(context.Context),
) {
// Send scanner error through channel
bifrostError := &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: fmt.Sprintf("Error reading stream: %v", err),
Error: err,
},
}
processedResponse, processedError := postHookRunner(ctx, nil, bifrostError)
if HandleStreamControlSkip(processedError) {
return
}
streamResponse := &schemas.BifrostStreamChunk{}
if processedResponse != nil {
streamResponse.BifrostTextCompletionResponse = processedResponse.TextCompletionResponse
streamResponse.BifrostChatResponse = processedResponse.ChatResponse
streamResponse.BifrostResponsesStreamResponse = processedResponse.ResponsesStreamResponse
streamResponse.BifrostSpeechStreamResponse = processedResponse.SpeechStreamResponse
streamResponse.BifrostTranscriptionStreamResponse = processedResponse.TranscriptionStreamResponse
}
if processedError != nil {
streamResponse.BifrostError = processedError
}
select {
case responseChan <- streamResponse:
case <-ctx.Done():
}
}
// CreateBifrostTextCompletionChunkResponse creates a bifrost text completion chunk response.
func CreateBifrostTextCompletionChunkResponse(
id string,
usage *schemas.BifrostLLMUsage,
finishReason *string,
currentChunkIndex int,
requestType schemas.RequestType,
) *schemas.BifrostTextCompletionResponse {
response := &schemas.BifrostTextCompletionResponse{
ID: id,
Object: "text_completion",
Usage: usage,
Choices: []schemas.BifrostResponseChoice{
{
FinishReason: finishReason,
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{}, // empty delta
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
ChunkIndex: currentChunkIndex + 1,
},
}
return response
}
// CreateBifrostChatCompletionChunkResponse creates a bifrost chat completion chunk response.
func CreateBifrostChatCompletionChunkResponse(
id string,
usage *schemas.BifrostLLMUsage,
finishReason *string,
currentChunkIndex int,
model string,
created int,
) *schemas.BifrostChatResponse {
response := &schemas.BifrostChatResponse{
ID: id,
Model: model,
Created: created,
Object: "chat.completion.chunk",
Usage: usage,
Choices: []schemas.BifrostResponseChoice{
{
FinishReason: finishReason,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{}, // empty delta
},
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
ChunkIndex: currentChunkIndex + 1,
},
}
return response
}
// HandleStreamControlSkip checks if the stream control should be skipped.
func HandleStreamControlSkip(bifrostErr *schemas.BifrostError) bool {
if bifrostErr == nil || bifrostErr.StreamControl == nil {
return false
}
if bifrostErr.StreamControl.SkipStream != nil && *bifrostErr.StreamControl.SkipStream {
if bifrostErr.StreamControl.LogError != nil && *bifrostErr.StreamControl.LogError {
getLogger().Warn("Error in stream: " + bifrostErr.Error.Message)
}
return true
}
return false
}
// GetProviderName extracts the provider name from custom provider configuration.
// If a custom provider key is specified, it returns that; otherwise, it returns the default provider.
// Note: CustomProviderKey is internally set by Bifrost and should always match the provider name.
func GetProviderName(defaultProvider schemas.ModelProvider, customConfig *schemas.CustomProviderConfig) schemas.ModelProvider {
if customConfig != nil {
if key := strings.TrimSpace(customConfig.CustomProviderKey); key != "" {
return schemas.ModelProvider(key)
}
}
return defaultProvider
}
// ProviderSendsDoneMarker returns true if the provider sends the [DONE] marker in streaming responses.
// Some OpenAI-compatible providers (like Cerebras) don't send [DONE] and instead end the stream
// after sending the finish_reason. This function helps determine the correct stream termination logic.
func ProviderSendsDoneMarker(providerName schemas.ModelProvider) bool {
switch providerName {
case schemas.Cerebras, schemas.Perplexity, schemas.HuggingFace:
// Cerebras, Perplexity, and HuggingFace don't send [DONE] marker, ends stream after finish_reason
return false
default:
// Default to expecting [DONE] marker for safety
return true
}
}
func ProviderIsResponsesAPINative(providerName schemas.ModelProvider) bool {
switch providerName {
case schemas.OpenAI, schemas.OpenRouter, schemas.Azure:
return true
default:
return false
}
}
// ReleaseStreamingResponse releases a streaming response by draining the body stream and releasing the response.
func ReleaseStreamingResponse(resp *fasthttp.Response) {
defer func() {
if r := recover(); r != nil {
getLogger().Error("recovered panic in ReleaseStreamingResponse: %v", r)
}
// Always release the response to prevent leaks, even after a panic
fasthttp.ReleaseResponse(resp)
}()
// Drain any remaining data from the body stream before releasing.
// This prevents "whitespace in header" errors when the connection is reused
// (see: https://github.com/valyala/fasthttp/issues/1743).
if bodyStream := resp.BodyStream(); bodyStream != nil {
if _, err := io.Copy(io.Discard, bodyStream); err != nil {
getLogger().Warn("failed to drain streaming response body before release (may cause stale connection reuse): %v", err)
}
if closer, ok := bodyStream.(io.Closer); ok {
if err := closer.Close(); err != nil {
getLogger().Warn("failed to close streaming response body: %v", err)
}
}
}
}
// GetBifrostResponseForStreamResponse converts the provided responses to a bifrost response.
func GetBifrostResponseForStreamResponse(
textCompletionResponse *schemas.BifrostTextCompletionResponse,
chatResponse *schemas.BifrostChatResponse,
responsesStreamResponse *schemas.BifrostResponsesStreamResponse,
speechStreamResponse *schemas.BifrostSpeechStreamResponse,
transcriptionStreamResponse *schemas.BifrostTranscriptionStreamResponse,
imageGenerationStreamResponse *schemas.BifrostImageGenerationStreamResponse,
) *schemas.BifrostResponse {
// TODO add bifrost response pooling here
bifrostResponse := &schemas.BifrostResponse{}
switch {
case textCompletionResponse != nil:
bifrostResponse.TextCompletionResponse = textCompletionResponse
return bifrostResponse
case chatResponse != nil:
bifrostResponse.ChatResponse = chatResponse
return bifrostResponse
case responsesStreamResponse != nil:
bifrostResponse.ResponsesStreamResponse = responsesStreamResponse
return bifrostResponse
case speechStreamResponse != nil:
bifrostResponse.SpeechStreamResponse = speechStreamResponse
return bifrostResponse
case transcriptionStreamResponse != nil:
bifrostResponse.TranscriptionStreamResponse = transcriptionStreamResponse
return bifrostResponse
case imageGenerationStreamResponse != nil:
bifrostResponse.ImageGenerationStreamResponse = imageGenerationStreamResponse
return bifrostResponse
}
return nil
}
// aggregateListModelsResponses merges multiple BifrostListModelsResponse objects into a single response.
// It concatenates all model arrays, deduplicates based on model ID, sums up latencies across all responses,
// and concatenates raw responses into an array.
// When duplicate IDs are found, the first occurrence is kept to maintain the original ordering.
func aggregateListModelsResponses(responses []*schemas.BifrostListModelsResponse) *schemas.BifrostListModelsResponse {
if len(responses) == 0 {
return &schemas.BifrostListModelsResponse{
Data: []schemas.Model{},
}
}
// Always apply deduplication, even for single responses
// Use a map to track unique model IDs for efficient deduplication
seenIDs := make(map[string]struct{})
aggregated := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0),
}
// Aggregate all models with deduplication, and collect raw responses
var rawResponses []interface{}
for _, response := range responses {
if response == nil {
continue
}
// Add models, skipping duplicates based on ID
for _, model := range response.Data {
if _, exists := seenIDs[model.ID]; !exists {
seenIDs[model.ID] = struct{}{}
aggregated.Data = append(aggregated.Data, model)
}
}
// Collect raw response if present
if response.ExtraFields.RawResponse != nil {
rawResponses = append(rawResponses, response.ExtraFields.RawResponse)
}
}
// Sort models alphabetically by ID
sort.Slice(aggregated.Data, func(i, j int) bool {
return aggregated.Data[i].ID < aggregated.Data[j].ID
})
if len(rawResponses) > 0 {
aggregated.ExtraFields.RawResponse = rawResponses
}
return aggregated
}
// extractSuccessfulListModelsResponses extracts successful responses from a results channel
// and tracks per-key status information. This utility reduces code duplication across providers
// for handling multi-key ListModels requests.
func extractSuccessfulListModelsResponses(results chan schemas.ListModelsByKeyResult, provider schemas.ModelProvider) ([]*schemas.BifrostListModelsResponse, []schemas.KeyStatus, *schemas.BifrostError) {
var successfulResponses []*schemas.BifrostListModelsResponse
var keyStatuses []schemas.KeyStatus
var lastError *schemas.BifrostError
for result := range results {
if result.Err != nil {
errMsg := "unknown error"
if errorField := result.Err.Error; errorField != nil {
if errorField.Message != "" {
errMsg = errorField.Message
} else if errorField.Error != nil {
errMsg = errorField.Error.Error()
}
}
getLogger().Warn(fmt.Sprintf("failed to list models with key %s: %s", result.KeyID, errMsg))
keyStatuses = append(keyStatuses, schemas.KeyStatus{
KeyID: result.KeyID,
Provider: provider,
Status: schemas.KeyStatusListModelsFailed,
Error: result.Err,
})
lastError = result.Err
continue
}
keyStatuses = append(keyStatuses, schemas.KeyStatus{
KeyID: result.KeyID,
Provider: provider,
Status: schemas.KeyStatusSuccess,
})
successfulResponses = append(successfulResponses, result.Response)
}
if len(successfulResponses) == 0 {
if lastError != nil {
return nil, keyStatuses, lastError
}
return nil, keyStatuses, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "all keys failed to list models",
},
}
}
return successfulResponses, keyStatuses, nil
}
// HandleKeylessListModelsRequest wraps a list models request for keyless providers
// and automatically populates the KeyStatuses field with provider-level status tracking.
// This centralizes the status tracking logic for keyless providers.
func HandleKeylessListModelsRequest(
provider schemas.ModelProvider,
listFunc func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError),
) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
resp, bifrostErr := listFunc()
keyStatus := schemas.KeyStatus{
KeyID: "", // Empty for keyless providers
Provider: provider,
}
// If request failed, attach status to error
if bifrostErr != nil {
keyStatus.Status = schemas.KeyStatusListModelsFailed
keyStatus.Error = bifrostErr
bifrostErr.ExtraFields.KeyStatuses = []schemas.KeyStatus{keyStatus}
return nil, bifrostErr
}
// Success case
if resp != nil {
keyStatus.Status = schemas.KeyStatusSuccess
resp.KeyStatuses = []schemas.KeyStatus{keyStatus}
return resp, nil
}
return resp, bifrostErr
}
// HandleMultipleListModelsRequests handles multiple list models requests concurrently for different keys.
// It launches concurrent requests for all keys and waits for all goroutines to complete.
// It returns the aggregated response with per-key status information or an error if the request fails.
func HandleMultipleListModelsRequests(
ctx *schemas.BifrostContext,
keys []schemas.Key,
request *schemas.BifrostListModelsRequest,
listModelsByKey func(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError),
) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
startTime := time.Now()
results := make(chan schemas.ListModelsByKeyResult, len(keys))
var wg sync.WaitGroup
// Launch concurrent requests for all keys
for _, key := range keys {
wg.Add(1)
go func(k schemas.Key) {
defer wg.Done()
// Should never panic, but if it does, we need to handle it gracefully
defer func() {
if r := recover(); r != nil {
getLogger().Error("panic in listModelsByKey for key %s (%s): %v", k.Name, k.ID, r)
results <- schemas.ListModelsByKeyResult{
Err: &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: "internal error while listing models for key",
},
},
KeyID: k.ID,
}
}
}()
resp, bifrostErr := listModelsByKey(ctx, k, request)
results <- schemas.ListModelsByKeyResult{Response: resp, Err: bifrostErr, KeyID: k.ID}
}(key)
}
// Wait for all goroutines to complete
wg.Wait()
close(results)
successfulResponses, keyStatuses, err := extractSuccessfulListModelsResponses(results, request.Provider)
if err != nil {
// Attach key statuses to error's ExtraFields
err.ExtraFields.KeyStatuses = keyStatuses
return nil, err
}
// Aggregate all successful responses
response := aggregateListModelsResponses(successfulResponses)
response = response.ApplyPagination(request.PageSize, request.PageToken)
// Attach key statuses to response
response.KeyStatuses = keyStatuses
// Set ExtraFields
latency := time.Since(startTime)
response.ExtraFields.Latency = latency.Milliseconds()
return response, nil
}
// GetRandomString generates a random alphanumeric string of the given length.
func GetRandomString(length int) string {
if length <= 0 {
return ""
}
randomSource := rand.New(rand.NewSource(time.Now().UnixNano()))
letters := []rune("abcdef0123456789")
b := make([]rune, length)
for i := range b {
b[i] = letters[randomSource.Intn(len(letters))]
}
return string(b)
}
// GetReasoningEffortFromBudgetTokens maps a reasoning token budget to OpenAI reasoning effort.
// Valid values: none, low, medium, high
func GetReasoningEffortFromBudgetTokens(
budgetTokens int,
minBudgetTokens int,
maxTokens int,
) string {
if budgetTokens <= 0 {
return "none"
}
// Defensive defaults
if maxTokens <= 0 {
return "medium"
}
// Normalize budget
if budgetTokens < minBudgetTokens {
budgetTokens = minBudgetTokens
}
if budgetTokens > maxTokens {
budgetTokens = maxTokens
}
// Avoid division by zero
if maxTokens <= minBudgetTokens {
return "high"
}
ratio := float64(budgetTokens-minBudgetTokens) / float64(maxTokens-minBudgetTokens)
switch {
case ratio <= 0.25:
return "low"
case ratio <= 0.60:
return "medium"
default:
return "high"
}
}
// GetBudgetTokensFromReasoningEffort converts reasoning effort into a reasoning token budget.
// effort ∈ {"none", "minimal", "low", "medium", "high", "xhigh", "max"}
func GetBudgetTokensFromReasoningEffort(
effort string,
minBudgetTokens int,
maxTokens int,
) (int, error) {
if effort == "none" {
return 0, nil
}
if minBudgetTokens > maxTokens {
return 0, fmt.Errorf("max_tokens must be greater than %d for reasoning", minBudgetTokens)
}
// Defensive defaults
if maxTokens <= minBudgetTokens {
return minBudgetTokens, nil
}
var ratio float64
switch effort {
case "minimal":
ratio = 0.025
case "low":
ratio = 0.15
case "medium":
ratio = 0.425
case "high":
ratio = 0.80
case "xhigh":
ratio = 0.92
case "max":
ratio = 1.0
default:
// Unknown effort → safe default
ratio = 0.425
}
budget := minBudgetTokens + int(ratio*float64(maxTokens-minBudgetTokens))
return budget, nil
}
// completeDeferredSpan completes the deferred LLM span for streaming requests.
// This is called when the final chunk is processed (when StreamEndIndicator is true).
// It retrieves the deferred span handle from TraceStore using the trace ID from context,
// populates response attributes from accumulated chunks, and ends the span.
func completeDeferredSpan(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError, postHookSpanFinalizer func(context.Context)) {
if ctx == nil {
return
}
// Get the trace ID from context (this IS available in the provider's goroutine)
traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string)
if !ok || traceID == "" {
return
}
// Get the tracer from context
tracerVal := ctx.Value(schemas.BifrostContextKeyTracer)
if tracerVal == nil {
return
}
tracer, ok := tracerVal.(schemas.Tracer)
if !ok || tracer == nil {
return
}
// Get the deferred span handle from TraceStore using trace ID
handle := tracer.GetDeferredSpanHandle(traceID)
if handle == nil {
return
}
// Set total latency from the final chunk
if result != nil {
extraFields := result.GetExtraFields()
if extraFields.Latency > 0 {
tracer.SetAttribute(handle, "gen_ai.response.total_latency_ms", extraFields.Latency)
}
}
// Get accumulated response with full data (content, tool calls, reasoning, etc.)
// This builds a complete BifrostResponse from all the streaming chunks
accumulatedResp, ttftNs, chunkCount := tracer.GetAccumulatedChunks(traceID)
// Set TTFT and chunk count attributes regardless of accumulated response availability
// (GetAccumulatedChunks may return nil response while still providing valid metrics)
if ttftNs > 0 {
tracer.SetAttribute(handle, schemas.AttrTimeToFirstToken, ttftNs)
}
if chunkCount > 0 {
tracer.SetAttribute(handle, schemas.AttrTotalChunks, chunkCount)
}
if accumulatedResp != nil {
// Use accumulated response for attributes (includes full content, tool calls, etc.)
tracer.PopulateLLMResponseAttributes(ctx, handle, accumulatedResp, err)
} else if result != nil {
// Fall back to final chunk if no accumulated data (shouldn't happen normally)
tracer.PopulateLLMResponseAttributes(ctx, handle, result, err)
}
// Finalize aggregated post-hook spans before ending the LLM span
// This creates one span per plugin with average execution time
// We need to set the llm.call span ID in context so post-hook spans become its children
if postHookSpanFinalizer != nil {
// Get the deferred span ID (the llm.call span) to set as parent for post-hook spans
spanID := tracer.GetDeferredSpanID(traceID)
if spanID != "" {
finalizerCtx := context.WithValue(ctx, schemas.BifrostContextKeySpanID, spanID)
postHookSpanFinalizer(finalizerCtx)
} else {
postHookSpanFinalizer(ctx)
}
}
// End span with appropriate status
if err != nil {
if err.Error != nil {
tracer.SetAttribute(handle, "error", err.Error.Message)
}
if err.StatusCode != nil {
tracer.SetAttribute(handle, "status_code", *err.StatusCode)
}
tracer.EndSpan(handle, schemas.SpanStatusError, "streaming request failed")
} else {
tracer.EndSpan(handle, schemas.SpanStatusOk, "")
}
// Clear the deferred span from TraceStore
tracer.ClearDeferredSpan(traceID)
}
// CheckAndSetDefaultProvider checks if the default provider should be used based on the context.
// It returns the default provider if it should be used, otherwise it returns an empty string.
// Checks if the direct key is set in the context, or if key selection is skipped.
// Or if the available providers are set in the context and the default provider is in the list.
func CheckAndSetDefaultProvider(ctx *schemas.BifrostContext, defaultProvider schemas.ModelProvider) schemas.ModelProvider {
if ctx != nil {
if ctx.Value(schemas.BifrostContextKeyDirectKey) != nil || ctx.Value(schemas.BifrostContextKeySkipKeySelection) != nil {
return defaultProvider
}
if ctx.Value(schemas.BifrostContextKeyAvailableProviders) != nil {
availableProviders, ok := ctx.Value(schemas.BifrostContextKeyAvailableProviders).([]schemas.ModelProvider)
if !ok || len(availableProviders) == 0 {
return ""
}
getLogger().Debug("[Provider] Available providers: %v, checking %s", availableProviders, defaultProvider)
if slices.Contains(availableProviders, defaultProvider) {
return defaultProvider
}
return ""
}
return defaultProvider
}
return defaultProvider
}
// ModelMatchesDenylist reports whether any of the candidate model IDs matches
// an entry in denylist, using both exact and base-model (SameBaseModel) matching.
// Empty candidates are skipped. Returns false immediately if denylist is empty.
func ModelMatchesDenylist(denylist []string, candidates ...string) bool {
if len(denylist) == 0 {
return false
}
for _, c := range candidates {
if c == "" {
continue
}
if slices.Contains(denylist, c) {
return true
}
for _, d := range denylist {
if schemas.SameBaseModel(d, c) {
return true
}
}
}
return false
}