// Package providers implements various LLM providers and their utility functions. // This file contains common utility functions used across different provider implementations. package utils import ( "bytes" "compress/gzip" "context" "crypto/tls" "crypto/x509" "encoding/json" "errors" "fmt" "io" "math/rand" "net" "net/http" "net/textproto" "net/url" "regexp" "slices" "sort" "strings" "sync" "sync/atomic" "time" "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/network" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpproxy" ) // sortedAPI is a sonic encoder/decoder that sorts map keys during marshaling. // This ensures deterministic JSON output for map[string]interface{} values, // which is critical for LLM prompt caching (e.g., Anthropic cache keying). var sortedAPI = sonic.Config{SortMapKeys: true}.Froze() // MarshalSorted marshals v to JSON with map keys sorted alphabetically. func MarshalSorted(v interface{}) ([]byte, error) { return sortedAPI.Marshal(v) } // MarshalSortedIndent marshals v to indented JSON with map keys sorted alphabetically. func MarshalSortedIndent(v interface{}, prefix, indent string) ([]byte, error) { return sortedAPI.MarshalIndent(v, prefix, indent) } // SetJSONField sets a field in JSON bytes without disturbing other fields' ordering. // Uses in-place byte manipulation for minimal allocations and preserves nested structure. func SetJSONField(data []byte, path string, value interface{}) ([]byte, error) { return sjson.SetBytes(data, path, value) } // DeleteJSONField deletes a field from JSON bytes without disturbing other fields' ordering. // Uses in-place byte manipulation for minimal allocations and preserves nested structure. func DeleteJSONField(data []byte, path string) ([]byte, error) { return sjson.DeleteBytes(data, path) } // JSONFieldExists checks if a field exists in JSON bytes. func JSONFieldExists(data []byte, path string) bool { return gjson.GetBytes(data, path).Exists() } // GetJSONField retrieves a field value from JSON bytes without parsing the entire document. func GetJSONField(data []byte, path string) gjson.Result { return gjson.GetBytes(data, path) } // logger is the global logger for the provider utils (thread-safe via atomic.Pointer). var logger atomic.Pointer[schemas.Logger] // noopLogger is a no-op implementation of schemas.Logger. type noopLogger struct{} func (noopLogger) Debug(string, ...any) {} func (noopLogger) Info(string, ...any) {} func (noopLogger) Warn(string, ...any) {} func (noopLogger) Error(string, ...any) {} func (noopLogger) Fatal(string, ...any) {} func (noopLogger) SetLevel(schemas.LogLevel) {} func (noopLogger) SetOutputType(schemas.LoggerOutputType) {} func (noopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder { return schemas.NoopLogEvent } // Initialize with noop logger func init() { var noop schemas.Logger = &noopLogger{} logger.Store(&noop) } // SetLogger sets the logger for the provider utils (thread-safe). func SetLogger(l schemas.Logger) { logger.Store(&l) } // getLogger returns the current logger (thread-safe). func getLogger() schemas.Logger { return *logger.Load() } var UnsupportedSpeechStreamModels = []string{"tts-1", "tts-1-hd"} // noop is a reusable no-op function returned by MakeRequestWithContext on the normal path. var noop = func() {} // MakeRequestWithContext makes a request with a context and returns the latency, error, and a // wait function. The wait function MUST be called (typically via defer) before releasing the // request or response objects. On the normal path it is a no-op. On the context-cancellation // path it blocks until the background client.Do goroutine finishes, preventing a data race // between the still-running goroutine and the caller's release of req/resp. // // IMPORTANT: This function does NOT truly cancel the underlying fasthttp network request if the // context is done. The fasthttp client call will continue in its goroutine until it completes // or times out based on its own settings. This function merely stops *waiting* for the // fasthttp call and returns an error related to the context. func MakeRequestWithContext(ctx context.Context, client *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) (time.Duration, *schemas.BifrostError, func()) { startTime := time.Now() errChan := make(chan error, 1) go func() { // client.Do is a blocking call. // It will send an error (or nil for success) to errChan when it completes. errChan <- client.Do(req, resp) }() select { case <-ctx.Done(): // Context was cancelled (e.g., deadline exceeded or manual cancellation). // Calculate latency even for cancelled requests latency := time.Since(startTime) // Return a wait function that blocks until the background goroutine finishes. // The caller MUST invoke this (via defer) before releasing req/resp to avoid // a data race with the still-running client.Do goroutine. if errors.Is(ctx.Err(), context.DeadlineExceeded) { statusCode := 504 errorType := schemas.RequestTimedOut return latency, &schemas.BifrostError{ IsBifrostError: true, StatusCode: &statusCode, Error: &schemas.ErrorField{ Type: &errorType, Message: fmt.Sprintf("Request timed out by context: %v", ctx.Err()), Error: ctx.Err(), }, }, func() { <-errChan } } statusCode := 499 errorType := schemas.RequestCancelled return latency, &schemas.BifrostError{ IsBifrostError: true, StatusCode: &statusCode, Error: &schemas.ErrorField{ Type: &errorType, Message: fmt.Sprintf("Request cancelled by context: %v", ctx.Err()), Error: ctx.Err(), }, }, func() { <-errChan } case err := <-errChan: // The fasthttp.Do call completed. // Calculate latency for both successful and failed requests latency := time.Since(startTime) if err != nil { if errors.Is(err, context.Canceled) { return latency, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Type: schemas.Ptr(schemas.RequestCancelled), Message: schemas.ErrRequestCancelled, Error: err, }, }, noop } // Check for timeout errors first before checking net.OpError to avoid misclassification if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), noop } // Check if error implements net.Error and has Timeout() == true var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { return latency, NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), noop } // Check for DNS lookup and network errors after timeout checks var opErr *net.OpError var dnsErr *net.DNSError if errors.As(err, &opErr) || errors.As(err, &dnsErr) { return latency, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: schemas.ErrProviderNetworkError, Error: err, }, }, noop } // The HTTP request itself failed (e.g., connection error, fasthttp timeout). return latency, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: schemas.ErrProviderDoRequest, Error: err, }, }, noop } // HTTP request was successful from fasthttp's perspective (err is nil). // The caller should check resp.StatusCode() for HTTP-level errors (4xx, 5xx). return latency, nil, noop } } // Deprecated: ConfigureRetry is now handled internally by ConfigureDialer. // This function is kept for backward compatibility but is no longer needed. func ConfigureRetry(client *fasthttp.Client) *fasthttp.Client { client.RetryIfErr = network.StaleConnectionRetryIfErr return client } // ConfigureDialer configures the client's connection behavior: // 1. Sets up the stale-connection retry policy (see network.StaleConnectionRetryIfErr). // 2. Wraps the Dial function to enable TCP keepalive on all connections, // proactively detecting dead connections before fasthttp tries to reuse them. // // Must be called AFTER ConfigureProxy (which may set client.Dial to a proxy // dialer), so the keepalive wrapper composes on top of the proxy connection. // // Keepalive parameters: // - Idle 10s: first probe after 10s of inactivity (well under the 30s MaxIdleConnDuration) // - Interval 5s: subsequent probes every 5s // - Count 3: close after 3 failed probes // // Dead connections are detected within ~25s (10 + 5*3), before the 30s // MaxIdleConnDuration expires and the connection is reused. func ConfigureDialer(client *fasthttp.Client) *fasthttp.Client { // Configure stale-connection retry policy client.RetryIfErr = network.StaleConnectionRetryIfErr existingDial := client.Dial existingDialTimeout := client.DialTimeout keepAliveCfg := net.KeepAliveConfig{ Enable: true, Idle: 10 * time.Second, Interval: 5 * time.Second, Count: 3, } client.Dial = func(addr string) (net.Conn, error) { var conn net.Conn var err error switch { case existingDial != nil: // Proxy or custom dial function is set — use it, then enable keepalive conn, err = existingDial(addr) case existingDialTimeout != nil: // Preserve dial-timeout behavior conn, err = existingDialTimeout(addr, client.ReadTimeout) default: conn, err = (&net.Dialer{ Timeout: client.ReadTimeout, KeepAliveConfig: keepAliveCfg, }).Dial("tcp", addr) } if err != nil { return nil, err } // Enable TCP keepalive on the connection if tcpConn, ok := conn.(*net.TCPConn); ok { _ = tcpConn.SetKeepAliveConfig(keepAliveCfg) } return conn, nil } return client } // ConfigureProxy sets up a proxy for the fasthttp client based on the provided configuration. // It supports HTTP, SOCKS5, and environment-based proxy configurations. // Returns the configured client or the original client if proxy configuration is invalid. func ConfigureProxy(client *fasthttp.Client, proxyConfig *schemas.ProxyConfig, logger schemas.Logger) *fasthttp.Client { if proxyConfig == nil { return client } var dialFunc fasthttp.DialFunc // Create the appropriate proxy based on type switch proxyConfig.Type { case schemas.NoProxy: return client case schemas.HTTPProxy: if proxyConfig.URL == "" { getLogger().Warn("Warning: HTTP proxy URL is required for setting up proxy") return client } proxyURL := proxyConfig.URL if proxyConfig.Username != "" && proxyConfig.Password != "" { parsedURL, err := url.Parse(proxyConfig.URL) if err != nil { getLogger().Warn("Invalid proxy configuration: invalid HTTP proxy URL") return client } // Set user and password in the parsed URL parsedURL.User = url.UserPassword(proxyConfig.Username, proxyConfig.Password) proxyURL = parsedURL.String() } dialFunc = fasthttpproxy.FasthttpHTTPDialer(proxyURL) case schemas.Socks5Proxy: if proxyConfig.URL == "" { getLogger().Warn("Warning: SOCKS5 proxy URL is required for setting up proxy") return client } proxyURL := proxyConfig.URL // Add authentication if provided if proxyConfig.Username != "" && proxyConfig.Password != "" { parsedURL, err := url.Parse(proxyConfig.URL) if err != nil { getLogger().Warn("Invalid proxy configuration: invalid SOCKS5 proxy URL") return client } // Set user and password in the parsed URL parsedURL.User = url.UserPassword(proxyConfig.Username, proxyConfig.Password) proxyURL = parsedURL.String() } dialFunc = fasthttpproxy.FasthttpSocksDialer(proxyURL) case schemas.EnvProxy: // Use environment variables for proxy configuration dialFunc = fasthttpproxy.FasthttpProxyHTTPDialer() default: getLogger().Warn("Invalid proxy configuration: unsupported proxy type: %s", proxyConfig.Type) return client } if dialFunc != nil { client.Dial = dialFunc } // Configure custom CA certificate if provided if proxyConfig.CACertPEM != "" { tlsConfig, err := createTLSConfigWithCA(proxyConfig.CACertPEM) if err != nil { getLogger().Warn("Failed to configure custom CA certificate: %v", err) } else { client.TLSConfig = tlsConfig } } return client } // createTLSConfigWithCA creates a TLS configuration with a custom CA certificate // appended to the system root CA pool. func createTLSConfigWithCA(caCertPEM string) (*tls.Config, error) { // Get the system root CA pool rootCAs, err := x509.SystemCertPool() if err != nil { // If we can't get system certs, create a new pool rootCAs = x509.NewCertPool() } // Append the custom CA certificate if !rootCAs.AppendCertsFromPEM([]byte(caCertPEM)) { return nil, fmt.Errorf("failed to parse CA certificate PEM") } return &tls.Config{ RootCAs: rootCAs, MinVersion: tls.VersionTLS12, }, nil } // ConfigureTLS applies TLS settings from NetworkConfig to the fasthttp client. // It merges with any existing TLSConfig (e.g., from ConfigureProxy). func ConfigureTLS(client *fasthttp.Client, networkConfig schemas.NetworkConfig, logger schemas.Logger) *fasthttp.Client { if !networkConfig.InsecureSkipVerify && networkConfig.CACertPEM == "" { return client } tlsConfig := client.TLSConfig if tlsConfig == nil { tlsConfig = &tls.Config{MinVersion: tls.VersionTLS12} } else { tlsConfig = tlsConfig.Clone() } if networkConfig.InsecureSkipVerify { logger.Warn("insecure_skip_verify is enabled for provider — TLS certificate verification is disabled. Not recommended for production.") tlsConfig.InsecureSkipVerify = true } if networkConfig.CACertPEM != "" { caTLSConfig, err := createTLSConfigWithCA(networkConfig.CACertPEM) if err != nil { logger.Warn("Failed to configure custom CA certificate for provider: %v", err) } else { if tlsConfig.RootCAs != nil { tlsConfig.RootCAs = tlsConfig.RootCAs.Clone() // Merge: append network CA to existing pool (e.g. from proxy) if !tlsConfig.RootCAs.AppendCertsFromPEM([]byte(networkConfig.CACertPEM)) { logger.Warn("Failed to append CA certificate to existing TLS config") } } else { tlsConfig.RootCAs = caTLSConfig.RootCAs } } } client.TLSConfig = tlsConfig return client } // hopByHopHeaders are HTTP/1.1 headers that must not be forwarded by proxies. var hopByHopHeaders = map[string]bool{ "connection": true, "proxy-connection": true, "keep-alive": true, "proxy-authenticate": true, "proxy-authorization": true, "te": true, "trailer": true, "transfer-encoding": true, "upgrade": true, } // filterHeaders filters out hop-by-hop headers and returns only the allowed headers. func filterHeaders(headers map[string][]string) map[string][]string { filtered := make(map[string][]string, len(headers)) for k, v := range headers { if !hopByHopHeaders[strings.ToLower(k)] { filtered[k] = v } } return filtered } // providerResponseFilterHeaders are headers to exclude when forwarding provider response headers. // These are transport-level headers that don't apply when re-serving the response. var providerResponseFilterHeaders = map[string]bool{ "content-length": true, "content-encoding": true, "transfer-encoding": true, "connection": true, "keep-alive": true, "proxy-connection": true, "proxy-authenticate": true, "proxy-authorization": true, "authorization": true, "cookie": true, "set-cookie": true, "set-cookie2": true, "www-authenticate": true, "te": true, "trailer": true, "upgrade": true, "host": true, "date": true, "server": true, "alt-svc": true, "strict-transport-security": true, "content-type": true, "access-control-allow-origin": true, "access-control-allow-methods": true, "access-control-allow-headers": true, "access-control-expose-headers": true, "access-control-allow-credentials": true, "access-control-max-age": true, } // ExtractProviderResponseHeaders extracts and filters response headers from a // fasthttp response. Transport-level headers are excluded. func ExtractProviderResponseHeaders(resp *fasthttp.Response) map[string]string { if resp == nil { return nil } headers := make(map[string]string) resp.Header.VisitAll(func(key, value []byte) { k := string(key) if providerResponseFilterHeaders[strings.ToLower(k)] { return } v := string(value) if existing, ok := headers[k]; ok && existing != "" { headers[k] = existing + ", " + v } else { headers[k] = v } }) if len(headers) == 0 { return nil } return headers } // ExtractProviderResponseHeadersFromHTTP extracts and filters response headers // from a standard net/http response. Transport-level headers are excluded. // Used by providers like Bedrock that use net/http instead of fasthttp. func ExtractProviderResponseHeadersFromHTTP(resp *http.Response) map[string]string { if resp == nil { return nil } headers := make(map[string]string) for k, values := range resp.Header { if !providerResponseFilterHeaders[strings.ToLower(k)] && len(values) > 0 { headers[k] = strings.Join(values, ", ") } } if len(headers) == 0 { return nil } return headers } // SetExtraHeaders sets additional headers from NetworkConfig to the fasthttp request. // This allows users to configure custom headers for their provider requests. // Header keys are canonicalized using textproto.CanonicalMIMEHeaderKey to avoid duplicates. // It accepts a list of headers (all canonicalized) to skip for security reasons. // Headers are only set if they don't already exist on the request to avoid overwriting important headers. func SetExtraHeaders(ctx context.Context, req *fasthttp.Request, extraHeaders map[string]string, skipHeaders []string) { for key, value := range extraHeaders { canonicalKey := textproto.CanonicalMIMEHeaderKey(key) if skipHeaders != nil { if slices.Contains(skipHeaders, key) { continue } } // Only set the header if it doesn't already exist to avoid overwriting important headers if len(req.Header.Peek(canonicalKey)) == 0 { req.Header.Set(canonicalKey, value) } } // Give priority to extra headers in the context if extraHeaders, ok := (ctx).Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok { for k, values := range filterHeaders(extraHeaders) { if skipHeaders != nil && slices.Contains(skipHeaders, strings.ToLower(k)) { continue } for i, v := range values { if i == 0 { req.Header.Set(k, v) } else { req.Header.Add(k, v) } } } } } // GetPathFromContext gets the path from the context, if it exists, otherwise returns the default path. func GetPathFromContext(ctx context.Context, defaultPath string) string { if pathInContext, ok := ctx.Value(schemas.BifrostContextKeyURLPath).(string); ok { return pathInContext } return defaultPath } // GetRequestPath gets the request path from the context, if it exists, checking for path overrides in the custom provider config. // It returns the resolved value and a boolean indicating whether the value is a full absolute URL. // If the boolean is false, the returned string is a path (leading slash ensured). func GetRequestPath(ctx context.Context, defaultPath string, customProviderConfig *schemas.CustomProviderConfig, requestType schemas.RequestType) (string, bool) { // If path/url set in context, return it. if pathInContext, ok := ctx.Value(schemas.BifrostContextKeyURLPath).(string); ok { trimmed := strings.TrimSpace(pathInContext) if u, err := url.Parse(trimmed); err == nil && u != nil && u.IsAbs() && u.Host != "" { return trimmed, true } return trimmed, false } // If path override set in custom provider config, return it. if customProviderConfig != nil && customProviderConfig.RequestPathOverrides != nil { if raw, ok := customProviderConfig.RequestPathOverrides[requestType]; ok { override := strings.TrimSpace(raw) if override == "" { return defaultPath, false } // Treat absolute URLs with scheme+host as full URLs. if u, err := url.Parse(override); err == nil && u != nil && u.IsAbs() && u.Host != "" { return override, true } // Otherwise treat as a path override (ensure leading slash). if !strings.HasPrefix(override, "/") { override = "/" + override } return override, false } } // Return default path. return defaultPath, false } type RequestBodyGetter interface { GetRawRequestBody() []byte } // CheckAndGetRawRequestBody checks if the raw request body should be used, and returns it if it exists. func CheckAndGetRawRequestBody(ctx context.Context, request RequestBodyGetter) ([]byte, bool) { if rawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && rawBody { return request.GetRawRequestBody(), true } return nil, false } type RequestBodyWithExtraParams interface { GetExtraParams() map[string]interface{} } type RequestBodyConverter func() (RequestBodyWithExtraParams, error) // IsLargePayloadPassthroughEnabled returns true when large payload mode has already // prepared an upstream body reader in context. func IsLargePayloadPassthroughEnabled(ctx context.Context) bool { if ctx == nil { return false } isLargePayload, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool) if !ok || !isLargePayload { return false } reader, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadReader).(io.Reader) return ok && reader != nil } // ApplyLargePayloadRequestBody applies the request body reader from context to the // outgoing provider request. Returns true when a streaming body was applied. func ApplyLargePayloadRequestBody(ctx context.Context, req *fasthttp.Request) bool { return ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, "") } const largePayloadModelRewriteScanBytes = 256 * 1024 // ApplyLargePayloadRequestBodyWithModelNormalization applies the streaming body // reader from context and optionally rewrites prefixed model values for JSON // passthrough requests (for example "openai/gpt-5" -> "gpt-5"). // This preserves low-memory streaming while keeping large-payload behavior // aligned with the normal parsed path that strips provider prefixes. func ApplyLargePayloadRequestBodyWithModelNormalization( ctx context.Context, req *fasthttp.Request, defaultProvider schemas.ModelProvider, ) bool { if req == nil || !IsLargePayloadPassthroughEnabled(ctx) { return false } bodyReader, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadReader).(io.Reader) bodySize := -1 if contentLength, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadContentLength).(int); ok { bodySize = contentLength } if contentType, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadContentType).(string); ok && contentType != "" { ctLower := strings.ToLower(contentType) if metadata, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadMetadata).(*schemas.LargePayloadMetadata); ok && metadata != nil { if rawModel := strings.TrimSpace(metadata.Model); rawModel != "" && defaultProvider != "" { _, normalizedModel := schemas.ParseModelString(rawModel, defaultProvider) if normalizedModel != "" && normalizedModel != rawModel { if strings.Contains(ctLower, "application/json") { rewrittenReader, sizeDelta := RewriteLargePayloadModelInJSONPrefix(bodyReader, rawModel, normalizedModel) bodyReader = rewrittenReader if bodySize >= 0 { bodySize += sizeDelta } } else if strings.Contains(ctLower, "multipart/form-data") { rewrittenReader, sizeDelta := RewriteLargePayloadModelInMultipartPrefix(bodyReader, rawModel, normalizedModel) bodyReader = rewrittenReader if bodySize >= 0 { bodySize += sizeDelta } } } } } req.Header.SetContentType(contentType) } req.SetBodyStream(bodyReader, bodySize) return true } // RewriteLargePayloadModelInJSONPrefix reads the first 256KB of a streaming body, // rewrites the "model" JSON value from fromModel to toModel, and returns a // combined reader (rewritten prefix + remaining stream) with the size delta. func RewriteLargePayloadModelInJSONPrefix(reader io.Reader, fromModel, toModel string) (io.Reader, int) { if reader == nil || fromModel == "" || toModel == "" || fromModel == toModel { return reader, 0 } prefix := make([]byte, largePayloadModelRewriteScanBytes) n, err := io.ReadFull(reader, prefix) if n == 0 && err != nil { return reader, 0 } prefix = prefix[:n] rewrittenPrefix, changed := rewriteJSONModelValue(prefix, fromModel, toModel) if !changed { return io.MultiReader(bytes.NewReader(prefix), reader), 0 } return io.MultiReader(bytes.NewReader(rewrittenPrefix), reader), len(rewrittenPrefix) - len(prefix) } func rewriteJSONModelValue(data []byte, fromModel, toModel string) ([]byte, bool) { if len(data) == 0 || fromModel == "" || toModel == "" || fromModel == toModel { return data, false } pattern := []byte(`"model"`) searchFrom := 0 for { match := bytes.Index(data[searchFrom:], pattern) if match < 0 { return data, false } idx := searchFrom + match + len(pattern) for idx < len(data) && (data[idx] == ' ' || data[idx] == '\t' || data[idx] == '\r' || data[idx] == '\n') { idx++ } if idx >= len(data) || data[idx] != ':' { searchFrom += match + len(pattern) continue } idx++ for idx < len(data) && (data[idx] == ' ' || data[idx] == '\t' || data[idx] == '\r' || data[idx] == '\n') { idx++ } if idx >= len(data) || data[idx] != '"' { searchFrom += match + len(pattern) continue } valueStart := idx + 1 valueEnd := valueStart escaped := false for valueEnd < len(data) { ch := data[valueEnd] if escaped { escaped = false valueEnd++ continue } if ch == '\\' { escaped = true valueEnd++ continue } if ch == '"' { break } valueEnd++ } if valueEnd >= len(data) { return data, false } if string(data[valueStart:valueEnd]) != fromModel { searchFrom = valueEnd + 1 continue } rewritten := make([]byte, 0, len(data)-len(fromModel)+len(toModel)) rewritten = append(rewritten, data[:valueStart]...) rewritten = append(rewritten, toModel...) rewritten = append(rewritten, data[valueEnd:]...) return rewritten, true } } // RewriteLargePayloadModelInMultipartPrefix reads the first 256KB of a streaming // multipart body, finds the model form field value, and rewrites it from fromModel // to toModel. The model field appears early in multipart bodies (typically the first // form field), so scanning the prefix is sufficient. func RewriteLargePayloadModelInMultipartPrefix(reader io.Reader, fromModel, toModel string) (io.Reader, int) { if reader == nil || fromModel == "" || toModel == "" || fromModel == toModel { return reader, 0 } prefix := make([]byte, largePayloadModelRewriteScanBytes) n, err := io.ReadFull(reader, prefix) if n == 0 && err != nil { return reader, 0 } prefix = prefix[:n] // In multipart, the model value appears as: // ...name="model"\r\n\r\nopenai/whisper-1\r\n--boundary... // A direct byte replacement of fromModel→toModel in the prefix is safe because // the model string (e.g. "openai/whisper-1") is unique within the form metadata. from := []byte(fromModel) to := []byte(toModel) if idx := bytes.Index(prefix, from); idx >= 0 { rewritten := make([]byte, 0, len(prefix)-len(from)+len(to)) rewritten = append(rewritten, prefix[:idx]...) rewritten = append(rewritten, to...) rewritten = append(rewritten, prefix[idx+len(from):]...) return io.MultiReader(bytes.NewReader(rewritten), reader), len(rewritten) - len(prefix) } return io.MultiReader(bytes.NewReader(prefix), reader), 0 } // DrainLargePayloadRemainder drains any unread bytes from the large payload reader. // This is useful for request types that may receive an upstream response before the // incoming client upload is fully consumed (for example, lightweight preflight APIs). // Example failure this prevents: fronting proxy returns 502/broken-pipe when backend // responds early while client is still uploading a large body. func DrainLargePayloadRemainder(ctx context.Context) { if !IsLargePayloadPassthroughEnabled(ctx) { return } bodyReader, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadReader).(io.Reader) if bodyReader == nil { return } _, _ = io.Copy(io.Discard, bodyReader) } // CloneFastHTTPClientConfig creates a fresh fasthttp.Client by copying only // config fields from base. // Never copy fasthttp.Client by value: it contains internal pools and locks. // Example failure this prevents: parallel load regressions with unexpected buffering // behavior after `cloned := *base` copies of active clients. func CloneFastHTTPClientConfig(base *fasthttp.Client) *fasthttp.Client { if base == nil { return &fasthttp.Client{} } return &fasthttp.Client{ Transport: base.Transport, DialTimeout: base.DialTimeout, Dial: base.Dial, TLSConfig: base.TLSConfig, RetryIfErr: base.RetryIfErr, ConfigureClient: base.ConfigureClient, Name: base.Name, MaxConnsPerHost: base.MaxConnsPerHost, MaxIdleConnDuration: base.MaxIdleConnDuration, MaxConnDuration: base.MaxConnDuration, MaxIdemponentCallAttempts: base.MaxIdemponentCallAttempts, ReadBufferSize: base.ReadBufferSize, WriteBufferSize: base.WriteBufferSize, ReadTimeout: base.ReadTimeout, WriteTimeout: base.WriteTimeout, MaxResponseBodySize: base.MaxResponseBodySize, MaxConnWaitTimeout: base.MaxConnWaitTimeout, ConnPoolStrategy: base.ConnPoolStrategy, NoDefaultUserAgentHeader: base.NoDefaultUserAgentHeader, DialDualStack: base.DialDualStack, DisableHeaderNamesNormalizing: base.DisableHeaderNamesNormalizing, DisablePathNormalizing: base.DisablePathNormalizing, StreamResponseBody: base.StreamResponseBody, } } // BuildStreamingClient returns a fasthttp.Client suitable for long-lived SSE // or EventStream responses. It clones base's dialer/proxy/TLS/pool settings, // then clears Read/Write timeouts and MaxConnDuration so fasthttp does not // pre-empt a healthy stream. StreamResponseBody is forced on. // // Per-chunk idle detection is enforced at the application layer via // NewIdleTimeoutReader (see GetStreamIdleTimeout / StreamIdleTimeoutInSeconds). // The initial TCP/TLS dial still honors the base client's ReadTimeout because // the Dial closure installed by ConfigureDialer reads client.ReadTimeout from // the base client pointer captured at ConfigureDialer call time — cloning copies // that closure verbatim, so zeroing the clone's ReadTimeout does not affect dial. func BuildStreamingClient(base *fasthttp.Client) *fasthttp.Client { c := CloneFastHTTPClientConfig(base) c.ReadTimeout = 0 c.WriteTimeout = 0 c.MaxConnDuration = 0 c.StreamResponseBody = true return c } // BuildStreamingHTTPClient returns an *http.Client for long-lived streaming // responses over net/http (e.g. Bedrock EventStream). It reuses the base's // Transport (safe for concurrent use by multiple clients) and sets Timeout=0 // so Client.Timeout does not cap the entire request lifecycle including body // reads. The transport's ResponseHeaderTimeout still bounds the initial // response-headers wait; per-chunk idle is enforced by NewIdleTimeoutReader. func BuildStreamingHTTPClient(base *http.Client) *http.Client { if base == nil { return &http.Client{} } return &http.Client{ Transport: base.Transport, CheckRedirect: base.CheckRedirect, Jar: base.Jar, } } // decompressBodyStreamIfGzip checks Content-Encoding for gzip and wraps the stream // with on-the-fly decompression using a pooled gzip.Reader. Clears Content-Encoding // header so downstream consumers don't double-decompress. Returns original reader // unchanged if not gzip-encoded or if gzip reader creation fails. func decompressBodyStreamIfGzip(resp *fasthttp.Response, stream io.Reader) (*gzip.Reader, io.Reader, bool) { ce := strings.ToLower(strings.TrimSpace(string(resp.Header.Peek("Content-Encoding")))) if !strings.Contains(ce, "gzip") { return nil, stream, false } gz, err := AcquireGzipReader(stream) if err != nil { ReleaseGzipReader(gz) return nil, stream, false } resp.Header.Del("Content-Encoding") return gz, gz, true } // DecompressStreamBody returns a reader for consuming the response body, with // on-the-fly gzip decompression when Content-Encoding indicates gzip. The response // object is NOT modified (no SetBodyStream call), so the original requestStream // remains live for proper cleanup by ReleaseStreamingResponse. Clears the // Content-Encoding header to prevent double-decompression. // // Returns: // - io.Reader: the reader to use for scanning (gzip reader if gzip-encoded, // original body stream otherwise). // - func(): cleanup function that releases the gzip reader back to the pool. // Must be called (typically via defer) after streaming is complete. func DecompressStreamBody(resp *fasthttp.Response) (io.Reader, func()) { bodyStream := resp.BodyStream() if bodyStream == nil { // Return an empty reader instead of nil to prevent panics in callers // that pass the reader to bufio.NewScanner without nil checks. return bytes.NewReader(nil), func() {} } gz, decompressed, wasGzip := decompressBodyStreamIfGzip(resp, bodyStream) if !wasGzip { return bodyStream, func() {} } return decompressed, func() { ReleaseGzipReader(gz) } } // DrainNonSSEStreamResponse checks if the upstream response is a Server-Sent Events stream. // If not SSE, drains the body to io.Discard to prevent bufio.Scanner buffer bloat on // non-line-delimited data. Returns true if body was drained (caller should skip scanner). // We intentionally do not touch valid SSE bodies here: callers must continue reading from // the reader returned by DecompressStreamBody, and draining SSE in this helper would consume // the stream before the scanner/manual event loop starts. func DrainNonSSEStreamResponse(resp *fasthttp.Response) bool { ct := strings.ToLower(string(resp.Header.ContentType())) if strings.Contains(ct, "text/event-stream") { return false } if bodyStream := resp.BodyStream(); bodyStream != nil { _, _ = io.Copy(io.Discard, bodyStream) } return true } // MergeExtraParams merges extraParams into jsonMap, handling nested maps recursively. func MergeExtraParams(jsonMap map[string]interface{}, extraParams map[string]interface{}) { for k, v := range extraParams { if existingVal, exists := jsonMap[k]; exists { if existingMap, ok := existingVal.(map[string]interface{}); ok { if newMap, ok := v.(map[string]interface{}); ok { MergeExtraParams(existingMap, newMap) continue } } } jsonMap[k] = v } } // MergeExtraParamsIntoJSON merges extra params into serialized JSON while preserving // the original key ordering. This avoids the order-destroying roundtrip through // map[string]interface{} that would lose key ordering in tool schemas and other // order-sensitive JSON structures. func MergeExtraParamsIntoJSON(jsonBody []byte, extraParams map[string]interface{}) ([]byte, error) { trimmed := bytes.TrimSpace(jsonBody) if len(trimmed) < 2 || trimmed[0] != '{' { return jsonBody, nil // not a JSON object, return as-is } // Parse existing JSON into ordered key-value pairs using encoding/json // (not sonic) to preserve document key order via token-by-token parsing. dec := json.NewDecoder(bytes.NewReader(trimmed)) dec.UseNumber() if _, err := dec.Token(); err != nil { // '{' return jsonBody, nil } type kvPair struct { key string val json.RawMessage } var pairs []kvPair seen := make(map[string]int) for dec.More() { tok, err := dec.Token() if err != nil { return jsonBody, nil } key, ok := tok.(string) if !ok { return jsonBody, nil } var val json.RawMessage if err := dec.Decode(&val); err != nil { return jsonBody, nil } seen[key] = len(pairs) pairs = append(pairs, kvPair{key, val}) } // Add/merge extra params (deterministic order for new keys) extraKeys := make([]string, 0, len(extraParams)) for k := range extraParams { extraKeys = append(extraKeys, k) } sort.Strings(extraKeys) for _, k := range extraKeys { v := extraParams[k] newValBytes, err := MarshalSorted(v) if err != nil { continue } if idx, exists := seen[k]; exists { // If both existing and new are JSON objects, merge recursively existingTrimmed := bytes.TrimSpace(pairs[idx].val) newTrimmed := bytes.TrimSpace(newValBytes) if len(existingTrimmed) > 0 && existingTrimmed[0] == '{' && len(newTrimmed) > 0 && newTrimmed[0] == '{' { var existingMap, newMap map[string]interface{} existingDec := json.NewDecoder(bytes.NewReader(existingTrimmed)) existingDec.UseNumber() newDec := json.NewDecoder(bytes.NewReader(newTrimmed)) newDec.UseNumber() if existingDec.Decode(&existingMap) == nil { if newDec.Decode(&newMap) == nil { MergeExtraParams(existingMap, newMap) if merged, err := MarshalSorted(existingMap); err == nil { pairs[idx].val = merged continue } } } } // Non-object or merge failed: overwrite in place (preserving position) pairs[idx].val = newValBytes } else { // New key: append at end pairs = append(pairs, kvPair{k, newValBytes}) } } // Rebuild compact JSON, then indent for consistent formatting var compact bytes.Buffer compact.WriteByte('{') for i, kv := range pairs { if i > 0 { compact.WriteByte(',') } keyBytes, err := sonic.Marshal(kv.key) if err != nil { return jsonBody, err } compact.Write(keyBytes) compact.WriteByte(':') // Use trimmed value to remove any existing indentation compact.Write(bytes.TrimSpace(kv.val)) } compact.WriteByte('}') // Re-indent to match the expected formatting var indented bytes.Buffer if err := json.Indent(&indented, compact.Bytes(), "", " "); err != nil { return compact.Bytes(), nil } return indented.Bytes(), nil } // CheckContextAndGetRequestBody checks if the raw request body should be used, and returns it if it exists. func CheckContextAndGetRequestBody(ctx context.Context, request RequestBodyGetter, requestConverter RequestBodyConverter) ([]byte, *schemas.BifrostError) { if IsLargePayloadPassthroughEnabled(ctx) { return nil, nil } rawBody, ok := CheckAndGetRawRequestBody(ctx, request) if !ok { convertedBody, err := requestConverter() if err != nil { return nil, NewBifrostOperationError(schemas.ErrRequestBodyConversion, err) } if convertedBody == nil { return nil, NewBifrostOperationError("request body is not provided", nil) } jsonBody, err := MarshalSortedIndent(convertedBody, "", " ") if err != nil { return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } // Merge ExtraParams into the JSON if passthrough is enabled if ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) != nil && ctx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true { extraParams := convertedBody.GetExtraParams() if len(extraParams) > 0 { // Use order-preserving merge to avoid destroying key ordering in // tool schemas and other order-sensitive JSON structures. jsonBody, err = MergeExtraParamsIntoJSON(jsonBody, extraParams) if err != nil { return nil, NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err) } } } return jsonBody, nil } else { return rawBody, nil } } // SetExtraHeadersHTTP sets additional headers from NetworkConfig to the standard HTTP request. // This allows users to configure custom headers for their provider requests. // Header keys are canonicalized using textproto.CanonicalMIMEHeaderKey to avoid duplicates. // It accepts a list of headers (all canonicalized) to skip for security reasons. // Headers are only set if they don't already exist on the request to avoid overwriting important headers. func SetExtraHeadersHTTP(ctx context.Context, req *http.Request, extraHeaders map[string]string, skipHeaders []string) { for key, value := range extraHeaders { canonicalKey := textproto.CanonicalMIMEHeaderKey(key) if skipHeaders != nil { if slices.Contains(skipHeaders, key) { continue } } // Only set the header if it doesn't already exist to avoid overwriting important headers if req.Header.Get(canonicalKey) == "" { req.Header.Set(canonicalKey, value) } } // Give priority to extra headers in the context if extraHeaders, ok := (ctx).Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok { for k, values := range filterHeaders(extraHeaders) { if skipHeaders != nil && slices.Contains(skipHeaders, strings.ToLower(k)) { continue } for i, v := range values { if i == 0 { req.Header.Set(k, v) } else { req.Header.Add(k, v) } } } } } // HandleProviderAPIError processes error responses from provider APIs. // It attempts to unmarshal the error response and returns a BifrostError // with the appropriate status code and error information. // HTML detection only runs if JSON parsing fails to avoid expensive regex operations // on responses that are almost certainly valid JSON. errorResp must be a pointer to // the target struct for unmarshaling. func HandleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.BifrostError { statusCode := resp.StatusCode() // Decode body decodedBody, err := CheckAndDecodeBody(resp) if err != nil { // Decode failed - still capture raw body for RawResponse rawBody := resp.Body() var rawErrorResponse interface{} if len(rawBody) > 0 { // Try to unmarshal, but if that fails, store as string if unmarshalErr := sonic.Unmarshal(rawBody, &rawErrorResponse); unmarshalErr != nil { rawErrorResponse = string(rawBody) } } return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, Error: &schemas.ErrorField{ Message: err.Error(), }, ExtraFields: schemas.BifrostErrorExtraFields{ RawResponse: rawErrorResponse, }, } } // Try to unmarshal decoded body for RawResponse var rawErrorResponse interface{} if err := sonic.Unmarshal(decodedBody, &rawErrorResponse); err != nil { // Store raw body as string for RawResponse when JSON parsing fails // Continue to HTML detection and proper error handling below rawErrorResponse = string(decodedBody) } // Check for empty response trimmed := strings.TrimSpace(string(decodedBody)) if len(trimmed) == 0 { // Provide a more descriptive error message based on HTTP status code var errorMessage string switch statusCode { case 401: errorMessage = "authentication failed: unauthorized (401) - check your API key" case 403: errorMessage = "access forbidden (403) - your API key may not have permission for this operation" case 404: errorMessage = "resource not found (404)" case 429: errorMessage = "rate limit exceeded (429)" case 500, 502, 503, 504: errorMessage = fmt.Sprintf("provider server error (%d)", statusCode) default: errorMessage = fmt.Sprintf("%s (HTTP %d)", schemas.ErrProviderResponseEmpty, statusCode) } return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, Error: &schemas.ErrorField{ Message: errorMessage, }, ExtraFields: schemas.BifrostErrorExtraFields{ RawResponse: rawErrorResponse, }, } } // Try JSON parsing first if err := sonic.Unmarshal(decodedBody, errorResp); err == nil { // JSON parsing succeeded, return success return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, Error: &schemas.ErrorField{}, ExtraFields: schemas.BifrostErrorExtraFields{ RawResponse: rawErrorResponse, }, } } // JSON parsing failed - now check if it's an HTML response (expensive operation) if IsHTMLResponse(resp, decodedBody) { return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, Error: &schemas.ErrorField{ Message: schemas.ErrProviderResponseHTML, Error: errors.New(string(decodedBody)), }, ExtraFields: schemas.BifrostErrorExtraFields{ RawResponse: rawErrorResponse, }, } } // Not HTML either - return raw response as error message message := fmt.Sprintf("provider API error: %s", string(decodedBody)) return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, Error: &schemas.ErrorField{ Message: message, }, ExtraFields: schemas.BifrostErrorExtraFields{ RawResponse: rawErrorResponse, }, } } // EnrichError attaches the raw request and response to a BifrostError. // Returns the request and response from provider embedded in BifrostError.ExtraFields. func EnrichError( ctx *schemas.BifrostContext, bifrostErr *schemas.BifrostError, requestBody []byte, responseBody []byte, sendBackRawRequest bool, sendBackRawResponse bool, ) *schemas.BifrostError { if bifrostErr == nil { return bifrostErr } if ShouldSendBackRawRequest(ctx, sendBackRawRequest) && len(requestBody) > 0 { // Store as json.RawMessage to preserve exact JSON bytes (including key ordering). // Compact to remove insignificant whitespace that would break SSE framing. bifrostErr.ExtraFields.RawRequest = compactRawJSON(requestBody) } else { bifrostErr.ExtraFields.RawRequest = nil } if ShouldSendBackRawResponse(ctx, sendBackRawResponse) { if len(responseBody) > 0 { bifrostErr.ExtraFields.RawResponse = compactRawJSON(responseBody) } } else { bifrostErr.ExtraFields.RawResponse = nil } return bifrostErr } // HandleProviderResponse handles common response parsing logic for provider responses. // It attempts to parse the response body into the provided response type // and returns either the parsed response or a BifrostError if parsing fails. // If sendBackRawResponse is true, it returns the raw response interface, otherwise nil. // HTML detection only runs if JSON parsing fails to avoid expensive regex operations // on responses that are almost certainly valid JSON. func HandleProviderResponse[T any](responseBody []byte, response *T, requestBody []byte, sendBackRawRequest bool, sendBackRawResponse bool) (rawRequest interface{}, rawResponse interface{}, bifrostErr *schemas.BifrostError) { // Check for empty response trimmed := strings.TrimSpace(string(responseBody)) if len(trimmed) == 0 { return nil, nil, &schemas.BifrostError{ IsBifrostError: true, Error: &schemas.ErrorField{ Message: schemas.ErrProviderResponseEmpty, }, } } // Skip raw request capture if requestBody is nil (e.g., for GET requests) shouldCaptureRawRequest := sendBackRawRequest && requestBody != nil if shouldCaptureRawRequest { // Store as json.RawMessage to preserve the exact JSON bytes (including key ordering). // Previously this used sonic.Unmarshal into interface{} which created map[string]interface{} // and destroyed key ordering in tool schemas and other order-sensitive structures. // Compact to remove insignificant whitespace that would break SSE framing. rawRequest = compactRawJSON(requestBody) } if sendBackRawResponse { rawResponse = compactRawJSON(responseBody) } // Unmarshal the structured response structuredErr := sonic.Unmarshal(responseBody, response) if structuredErr != nil { // JSON parsing failed - check if it's an HTML response (expensive operation) if IsHTMLResponse(nil, responseBody) { return nil, nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: schemas.ErrProviderResponseHTML, Error: errors.New(string(responseBody)), }, } } return nil, nil, &schemas.BifrostError{ IsBifrostError: true, Error: &schemas.ErrorField{ Message: schemas.ErrProviderResponseUnmarshal, Error: structuredErr, }, } } if shouldCaptureRawRequest || sendBackRawResponse { return rawRequest, rawResponse, nil } return nil, nil, nil } // compactRawJSON removes insignificant whitespace from JSON bytes, returning a // json.RawMessage safe for SSE streaming (no literal newlines). Falls back to // the original bytes if compaction fails (e.g., invalid JSON). func compactRawJSON(data []byte) json.RawMessage { var buf bytes.Buffer if err := schemas.Compact(&buf, data); err == nil { return json.RawMessage(buf.Bytes()) } return json.RawMessage(data) } // ParseAndSetRawRequest stores the raw request body in the extra fields. // Uses json.RawMessage to preserve the exact JSON bytes (including key ordering). // The body is compacted to remove insignificant whitespace, which prevents // literal newlines from breaking SSE data-line framing during streaming. func ParseAndSetRawRequest(extraFields *schemas.BifrostResponseExtraFields, jsonBody []byte) { if len(jsonBody) > 0 { extraFields.RawRequest = compactRawJSON(jsonBody) } } // ParseAndSetRawRequestIfJSON parses the request body if it's JSON and sets the raw request in the extra fields. func ParseAndSetRawRequestIfJSON(fasthttpReq *fasthttp.Request, extraFields *schemas.BifrostResponseExtraFields) { extraFields.RawRequest = nil contentType := strings.ToLower(string(fasthttpReq.Header.ContentType())) if strings.Contains(contentType, "application/json") { body := append([]byte(nil), fasthttpReq.Body()...) ParseAndSetRawRequest(extraFields, body) } } // NewUnsupportedOperationError creates a standardized error for unsupported operations. // This helper reduces code duplication across providers that don't support certain operations. func NewUnsupportedOperationError(requestType schemas.RequestType, providerName schemas.ModelProvider) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: fmt.Sprintf("%s is not supported by %s provider", requestType, providerName), Code: schemas.Ptr("unsupported_operation"), }, } } // CheckOperationAllowed enforces per-op gating using schemas.Operation. // Behavior: // - If no gating is configured (config == nil or AllowedRequests == nil), the operation is allowed. // - If gating is configured, returns an error when the operation is not explicitly allowed. func CheckOperationAllowed(defaultProvider schemas.ModelProvider, config *schemas.CustomProviderConfig, operation schemas.RequestType) *schemas.BifrostError { // No gating configured => allowed if config == nil || config.AllowedRequests == nil { return nil } // Explicitly allowed? if config.IsOperationAllowed(operation) { return nil } // Gated and not allowed resolved := GetProviderName(defaultProvider, config) return NewUnsupportedOperationError(operation, resolved) } // CheckAndDecodeBody checks the content encoding and decodes the body accordingly. // It returns a copy of the body to avoid race conditions when the response is released // back to fasthttp's buffer pool. Uses pooled gzip readers to reduce GC pressure. func CheckAndDecodeBody(resp *fasthttp.Response) ([]byte, error) { contentEncoding := strings.ToLower(strings.TrimSpace(string(resp.Header.Peek("Content-Encoding")))) if strings.Contains(contentEncoding, "gzip") { body := resp.Body() if len(body) == 0 { return nil, nil } reader := bytes.NewReader(body) gz, err := AcquireGzipReader(reader) if err != nil { return nil, err } defer ReleaseGzipReader(gz) decompressed, err := io.ReadAll(gz) if err != nil { return nil, err } return decompressed, nil } // Copy the body to avoid race conditions when response is released back to pool body := resp.Body() result := make([]byte, len(body)) copy(result, body) return result, nil } // IsHTMLResponse checks if the response is HTML by examining the Content-Type header // and/or the response body for HTML indicators. func IsHTMLResponse(resp *fasthttp.Response, body []byte) bool { // Check Content-Type header first (most reliable indicator) if resp != nil { contentType := strings.ToLower(string(resp.Header.Peek("Content-Type"))) if strings.Contains(contentType, "text/html") { return true } } // If body is small, it's unlikely to be HTML if len(body) < 20 { return false } // Check for HTML indicators in body bodyLower := strings.ToLower(string(body)) // Look for common HTML tags or DOCTYPE htmlIndicators := []string{ "", "

", "

", "

", "

", " maxBodySize { body = body[:maxBodySize] } bodyStr := string(body) bodyLower := strings.ToLower(bodyStr) // Try to extract title first if idx := strings.Index(bodyLower, ""); idx != -1 { endIdx := strings.Index(bodyLower[idx:], "") if endIdx != -1 { title := strings.TrimSpace(bodyStr[idx+7 : idx+endIdx]) if title != "" && title != "Error" { return title } } } // Try to extract from h1, h2, h3 tags (common for error pages) for _, tag := range []string{"h1", "h2", "h3"} { pattern := fmt.Sprintf("<%s[^>]*>([^<]+)", tag, tag) re := regexp.MustCompile("(?i)" + pattern) if matches := re.FindStringSubmatch(bodyStr); len(matches) > 1 { msg := strings.TrimSpace(matches[1]) if msg != "" { return msg } } } // Try to extract from meta description pattern := ` 1 { msg := strings.TrimSpace(matches[1]) if msg != "" { return msg } } // Extract visible text: remove script and style tags, then extract text // Remove script and style tags and their content re = regexp.MustCompile(`(?i)]*>.*?|]*>.*?`) cleaned := re.ReplaceAllString(bodyStr, "") // Remove HTML tags re = regexp.MustCompile(`<[^>]+>`) cleaned = re.ReplaceAllString(cleaned, " ") // Clean up whitespace and get first meaningful sentence sentences := strings.FieldsFunc(cleaned, func(r rune) bool { return r == '\n' || r == '\r' }) for _, sentence := range sentences { trimmed := strings.TrimSpace(sentence) if len(trimmed) > 10 && len(trimmed) < 500 { // Limit to first 200 chars to avoid very long messages if len(trimmed) > 200 { trimmed = trimmed[:200] + "..." } return trimmed } } // If all else fails, return a generic message with status code context return "HTML error response received from provider" } // JSONLParseResult holds parsed items and any line-level errors encountered during parsing. type JSONLParseResult struct { Errors []schemas.BatchError } // ParseJSONL parses JSONL data line by line, calling the provided callback for each line. // It collects parse errors with line numbers rather than silently skipping failed lines. // The callback receives the line bytes and returns an error if parsing fails. // This function operates directly on byte slices to avoid unnecessary string conversions. func ParseJSONL(data []byte, parseLine func(line []byte) error) JSONLParseResult { result := JSONLParseResult{} lineNum := 0 start := 0 for i := 0; i <= len(data); i++ { // Check for newline or end of data if i == len(data) || data[i] == '\n' { lineNum++ // Extract the line (excluding the newline character) end := i if end > start { line := data[start:end] // Trim trailing carriage return for Windows-style line endings if len(line) > 0 && line[len(line)-1] == '\r' { line = line[:len(line)-1] } // Skip empty lines if len(line) > 0 { if err := parseLine(line); err != nil { lineNumCopy := lineNum result.Errors = append(result.Errors, schemas.BatchError{ Code: "parse_error", Message: err.Error(), Line: &lineNumCopy, }) } } } start = i + 1 } } return result } // NewConfigurationError creates a standardized error for configuration errors. // This helper reduces code duplication across providers that have configuration errors. func NewConfigurationError(message string) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: message, }, } } // NewBifrostOperationError creates a standardized error for bifrost operation errors. // This helper reduces code duplication across providers that have bifrost operation errors. func NewBifrostOperationError(message string, err error) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: true, Error: &schemas.ErrorField{ Message: message, Error: err, }, } } // NewBifrostTimeoutError creates a standardized error for provider request timeout errors. // Sets StatusCode to 504 (Gateway Timeout) and Error.Type to RequestTimedOut, // consistent with HandleStreamTimeout for streaming requests. func NewBifrostTimeoutError(message string, err error) *schemas.BifrostError { statusCode := 504 errorType := schemas.RequestTimedOut return &schemas.BifrostError{ IsBifrostError: true, StatusCode: &statusCode, Error: &schemas.ErrorField{ Message: message, Type: &errorType, Error: err, }, } } // NewProviderAPIError creates a standardized error for provider API errors. // This helper reduces code duplication across providers that have provider API errors. func NewProviderAPIError(message string, err error, statusCode int, errorType *string, eventID *string) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, Type: errorType, EventID: eventID, Error: &schemas.ErrorField{ Message: message, Error: err, Type: errorType, }, } } // ShouldSendBackRawRequest checks if raw request bytes should be captured. // bifrost.go always writes BifrostContextKeyCaptureRawRequest before provider dispatch, // combining provider config, per-request overrides, and store_raw_request_response. // The default parameter is a fallback for callers outside the normal bifrost dispatch path. func ShouldSendBackRawRequest(ctx context.Context, defaultSendBackRawRequest bool) bool { if capture, ok := ctx.Value(schemas.BifrostContextKeyCaptureRawRequest).(bool); ok { return capture } return defaultSendBackRawRequest } // ShouldSendBackRawResponse checks if raw response bytes should be captured. // bifrost.go always writes BifrostContextKeyCaptureRawResponse before provider dispatch, // combining provider config, per-request overrides, and store_raw_request_response. // The default parameter is a fallback for callers outside the normal bifrost dispatch path. func ShouldSendBackRawResponse(ctx context.Context, defaultSendBackRawResponse bool) bool { if capture, ok := ctx.Value(schemas.BifrostContextKeyCaptureRawResponse).(bool); ok { return capture } return defaultSendBackRawResponse } // SendCreatedEventResponsesChunk sends a ResponsesStreamResponseTypeCreated event. func SendCreatedEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk, postHookSpanFinalizer func(context.Context)) { firstChunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeCreated, SequenceNumber: 0, Response: &schemas.BifrostResponsesResponse{}, ExtraFields: schemas.BifrostResponseExtraFields{ ChunkIndex: 0, Latency: time.Since(startTime).Milliseconds(), }, } // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{ ResponsesStreamResponse: firstChunk, } ProcessAndSendResponse(ctx, postHookRunner, bifrostResponse, responseChan, postHookSpanFinalizer) } // SendInProgressEventResponsesChunk sends a ResponsesStreamResponseTypeInProgress event func SendInProgressEventResponsesChunk(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, startTime time.Time, responseChan chan *schemas.BifrostStreamChunk, postHookSpanFinalizer func(context.Context)) { chunk := &schemas.BifrostResponsesStreamResponse{ Type: schemas.ResponsesStreamResponseTypeInProgress, SequenceNumber: 1, Response: &schemas.BifrostResponsesResponse{}, ExtraFields: schemas.BifrostResponseExtraFields{ ChunkIndex: 1, Latency: time.Since(startTime).Milliseconds(), }, } // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{ ResponsesStreamResponse: chunk, } ProcessAndSendResponse(ctx, postHookRunner, bifrostResponse, responseChan, postHookSpanFinalizer) } // BuildClientStreamChunk constructs a BifrostStreamChunk from post-hook results. // It never mutates the shared processedResponse or processedError objects — when raw fields // need to be stripped (captured for storage but not for send-back), it shallow-copies each // inner response struct and nils only the appropriate per-side field on those copies. // This is safe for concurrent PostLLMHook goroutines that still hold references to the originals. func BuildClientStreamChunk(ctx context.Context, processedResponse *schemas.BifrostResponse, processedError *schemas.BifrostError) *schemas.BifrostStreamChunk { dropReq, _ := ctx.Value(schemas.BifrostContextKeyDropRawRequestFromClient).(bool) dropResp, _ := ctx.Value(schemas.BifrostContextKeyDropRawResponseFromClient).(bool) drop := dropReq || dropResp streamResponse := &schemas.BifrostStreamChunk{} if processedResponse != nil { streamResponse.BifrostTextCompletionResponse = processedResponse.TextCompletionResponse streamResponse.BifrostChatResponse = processedResponse.ChatResponse streamResponse.BifrostResponsesStreamResponse = processedResponse.ResponsesStreamResponse streamResponse.BifrostSpeechStreamResponse = processedResponse.SpeechStreamResponse streamResponse.BifrostTranscriptionStreamResponse = processedResponse.TranscriptionStreamResponse streamResponse.BifrostImageGenerationStreamResponse = processedResponse.ImageGenerationStreamResponse // Strip raw fields from client-facing copies without mutating the shared objects // that PostLLMHook goroutines may still be reading. if drop { if streamResponse.BifrostTextCompletionResponse != nil { cp := *streamResponse.BifrostTextCompletionResponse if dropReq { cp.ExtraFields.RawRequest = nil } if dropResp { cp.ExtraFields.RawResponse = nil } streamResponse.BifrostTextCompletionResponse = &cp } if streamResponse.BifrostChatResponse != nil { cp := *streamResponse.BifrostChatResponse if dropReq { cp.ExtraFields.RawRequest = nil } if dropResp { cp.ExtraFields.RawResponse = nil } streamResponse.BifrostChatResponse = &cp } if streamResponse.BifrostResponsesStreamResponse != nil { cp := *streamResponse.BifrostResponsesStreamResponse if dropReq { cp.ExtraFields.RawRequest = nil } if dropResp { cp.ExtraFields.RawResponse = nil } streamResponse.BifrostResponsesStreamResponse = &cp } if streamResponse.BifrostSpeechStreamResponse != nil { cp := *streamResponse.BifrostSpeechStreamResponse if dropReq { cp.ExtraFields.RawRequest = nil } if dropResp { cp.ExtraFields.RawResponse = nil } streamResponse.BifrostSpeechStreamResponse = &cp } if streamResponse.BifrostTranscriptionStreamResponse != nil { cp := *streamResponse.BifrostTranscriptionStreamResponse if dropReq { cp.ExtraFields.RawRequest = nil } if dropResp { cp.ExtraFields.RawResponse = nil } streamResponse.BifrostTranscriptionStreamResponse = &cp } if streamResponse.BifrostImageGenerationStreamResponse != nil { cp := *streamResponse.BifrostImageGenerationStreamResponse if dropReq { cp.ExtraFields.RawRequest = nil } if dropResp { cp.ExtraFields.RawResponse = nil } streamResponse.BifrostImageGenerationStreamResponse = &cp } } } if processedError != nil { if drop { // Strip raw fields from a client-facing copy without mutating the shared error object. errCopy := *processedError if dropReq { errCopy.ExtraFields.RawRequest = nil } if dropResp { errCopy.ExtraFields.RawResponse = nil } streamResponse.BifrostError = &errCopy } else { streamResponse.BifrostError = processedError } } return streamResponse } // ProcessAndSendResponse handles post-hook processing and sends the response to the channel. // This utility reduces code duplication across streaming implementations by encapsulating // the common pattern of running post hooks, handling errors, and sending responses with // proper context cancellation handling. // It also completes the deferred LLM span when the final chunk is sent (StreamEndIndicator is true). func ProcessAndSendResponse( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, response *schemas.BifrostResponse, responseChan chan *schemas.BifrostStreamChunk, postHookSpanFinalizer func(context.Context), ) { // Accumulate chunk for tracing (common for all providers) if tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer); ok && tracer != nil { if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { tracer.AddStreamingChunk(traceID, response) } } // Run post hooks on the response (note: accumulated chunks above contain pre-hook data) processedResponse, processedError := postHookRunner(ctx, response, nil) if HandleStreamControlSkip(processedError) { // Even if skipping, complete the deferred span if this is the final chunk if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { if final, ok := isFinalChunk.(bool); ok && final { completeDeferredSpan(ctx, processedResponse, processedError, postHookSpanFinalizer) } } return } streamResponse := BuildClientStreamChunk(ctx, processedResponse, processedError) select { case responseChan <- streamResponse: case <-ctx.Done(): return } // Check if this is the final chunk and complete deferred span with post-processed data if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { if final, ok := isFinalChunk.(bool); ok && final { completeDeferredSpan(ctx, processedResponse, processedError, postHookSpanFinalizer) } } } // ProcessAndSendBifrostError handles post-hook processing and sends the bifrost error to the channel. // This utility reduces code duplication across streaming implementations by encapsulating // the common pattern of running post hooks, handling errors, and sending responses with // proper context cancellation handling. // It also completes the deferred LLM span when the final chunk is sent (StreamEndIndicator is true). func ProcessAndSendBifrostError( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, bifrostErr *schemas.BifrostError, responseChan chan *schemas.BifrostStreamChunk, logger schemas.Logger, postHookSpanFinalizer func(context.Context), ) { // Run post hooks first so span reflects post-processed data processedResponse, processedError := postHookRunner(ctx, nil, bifrostErr) if HandleStreamControlSkip(processedError) { // Even if skipping, complete the deferred span if this is the final chunk if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { if final, ok := isFinalChunk.(bool); ok && final { completeDeferredSpan(ctx, processedResponse, processedError, postHookSpanFinalizer) } } return } streamResponse := BuildClientStreamChunk(ctx, processedResponse, processedError) select { case responseChan <- streamResponse: case <-ctx.Done(): } // Check if this is the final chunk and complete deferred span with post-processed data if isFinalChunk := ctx.Value(schemas.BifrostContextKeyStreamEndIndicator); isFinalChunk != nil { if final, ok := isFinalChunk.(bool); ok && final { completeDeferredSpan(ctx, processedResponse, processedError, postHookSpanFinalizer) } } } // EnsureStreamFinalizerCalled invokes the post-hook span finalizer registered // on ctx, if any. Designed to be deferred as the last line of defence in a // provider's streaming goroutine (next to SetupStreamCancellation's cleanup): // // defer providerUtils.EnsureStreamFinalizerCalled(ctx) // // On a normal stream end the finalizer is already invoked when the final chunk // is processed (via completeDeferredSpan). The registration wraps the closure // in sync.Once, so this safety-net call is a noop in that case. It only does // real work when the streaming goroutine exits without reaching the final-chunk // path — e.g. a panic mid-stream — which would otherwise leak the plugin // pipeline back-reference held by the finalizer closure. // // Panics inside the finalizer are recovered and logged so they never mask an // in-flight panic that triggered the defer. func EnsureStreamFinalizerCalled(ctx context.Context, finalizer func(context.Context)) { defer func() { if r := recover(); r != nil { getLogger().Debug("recovered panic in deferred stream finalizer: %v", r) } }() if finalizer == nil { return } finalizer(ctx) } // SetupStreamCancellation spawns a goroutine that closes the body stream when // the context is cancelled or deadline exceeded, unblocking any blocked Read/Scan operations. // Returns a cleanup function that MUST be called when streaming is done to // prevent the goroutine from closing the stream during normal operation. // Works with both fasthttp's BodyStream() (io.Reader) and net/http's resp.Body (io.ReadCloser). func SetupStreamCancellation(ctx context.Context, bodyStream io.Reader, logger schemas.Logger) (cleanup func()) { done := make(chan struct{}) closed := make(chan struct{}) go func() { defer close(closed) select { case <-ctx.Done(): // Context cancelled or deadline exceeded - close the body stream to unblock reads if closer, ok := bodyStream.(io.Closer); ok { if err := closer.Close(); err != nil { getLogger().Debug(fmt.Sprintf("Error closing body stream on context done: %v", err)) } } case <-done: // If context was also cancelled (race between done and ctx.Done), // still close the body stream to unblock the drain in ReleaseStreamingResponse. if ctx.Err() != nil { if closer, ok := bodyStream.(io.Closer); ok { if err := closer.Close(); err != nil { getLogger().Debug(fmt.Sprintf("Error closing body stream on done with cancelled context: %v", err)) } } } } }() return func() { close(done) <-closed // Wait for goroutine to finish closing the stream before ReleaseStreamingResponse drains } } // DefaultStreamIdleTimeout is how long a stream read can block with zero data // before bifrost considers the connection stalled and closes it. This protects // against providers that stop sending data but keep the TCP connection open // (e.g., Azure TPM throttling). const DefaultStreamIdleTimeout = 60 * time.Second // SetStreamIdleTimeoutIfEmpty sets the stream idle timeout on the context from // the provider's network config, but only if no valid timeout is already present. // This allows upstream layers (transport, headers) to set the timeout first, // with the provider config acting as a fallback. func SetStreamIdleTimeoutIfEmpty(ctx *schemas.BifrostContext, configSeconds int) { if existing, ok := ctx.Value(schemas.BifrostContextKeyStreamIdleTimeout).(time.Duration); ok && existing > 0 { return // already set from upstream (transport/header), respect it } if configSeconds > 0 { ctx.SetValue(schemas.BifrostContextKeyStreamIdleTimeout, time.Duration(configSeconds)*time.Second) } } // GetStreamIdleTimeout reads the per-chunk idle timeout from context, // falling back to DefaultStreamIdleTimeout if not set. func GetStreamIdleTimeout(ctx *schemas.BifrostContext) time.Duration { if timeout, ok := ctx.Value(schemas.BifrostContextKeyStreamIdleTimeout).(time.Duration); ok && timeout > 0 { return timeout } return DefaultStreamIdleTimeout } // idleTimeoutReader wraps an io.Reader and closes the underlying body stream // if no data arrives within the configured timeout. This unblocks any pending // Read() call on the wrapped reader. type idleTimeoutReader struct { reader io.Reader bodyStream io.Reader // closed via type assertion to io.Closer on timeout timeout time.Duration timer *time.Timer once sync.Once } // NewIdleTimeoutReader wraps reader with idle detection. If reader.Read() returns // no data for the given timeout duration, bodyStream is closed to unblock the read. // bodyStream must implement io.Closer for the timeout to take effect; if it does not, // the wrapper still functions but cannot force-close the stream. // Returns the wrapped reader and a cleanup function that MUST be called (via defer) // when streaming is complete, to stop the timer and prevent premature closure. func NewIdleTimeoutReader(reader io.Reader, bodyStream io.Reader, timeout time.Duration) (io.Reader, func()) { if timeout <= 0 { timeout = DefaultStreamIdleTimeout } r := &idleTimeoutReader{ reader: reader, bodyStream: bodyStream, timeout: timeout, } r.timer = time.AfterFunc(timeout, func() { r.once.Do(func() { if closer, ok := r.bodyStream.(io.Closer); ok { closer.Close() } }) }) return r, func() { r.timer.Stop() } } func (r *idleTimeoutReader) Read(p []byte) (int, error) { n, err := r.reader.Read(p) if n > 0 { r.timer.Reset(r.timeout) } return n, err } // HandleStreamCancellation should be called when a streaming goroutine exits // due to context cancellation. It ensures proper cleanup by: // 1. Checking if StreamEndIndicator was already set (to avoid duplicate handling) // 2. Setting StreamEndIndicator to true // 3. Sending a cancellation error through PostHook chain // // This is critical for the logging plugin to update log status from "processing" to "error" // when a client disconnects mid-stream. func HandleStreamCancellation( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, responseChan chan *schemas.BifrostStreamChunk, logger schemas.Logger, postHookSpanFinalizer func(context.Context), ) { // Check if already handled (StreamEndIndicator already set) if indicator := ctx.GetAndSetValue(schemas.BifrostContextKeyStreamEndIndicator, true); indicator != nil { if set, ok := indicator.(bool); ok && set { return // Already handled } } // Create cancellation error cancelErr := &schemas.BifrostError{ StatusCode: schemas.Ptr(499), // Client Closed Request Error: &schemas.ErrorField{ Message: "Request cancelled: client disconnected", Type: schemas.Ptr(schemas.RequestCancelled), }, } // Send through PostHook chain - this updates the log to "error" status ProcessAndSendBifrostError(ctx, postHookRunner, cancelErr, responseChan, logger, postHookSpanFinalizer) } // HandleStreamTimeout should be called when a streaming goroutine exits // due to context deadline exceeded. It ensures proper cleanup by: // 1. Checking if StreamEndIndicator was already set (to avoid duplicate handling) // 2. Setting StreamEndIndicator to true // 3. Sending a timeout error through PostHook chain // // This is critical for the logging plugin to update log status from "processing" to "error" // when a request times out mid-stream. func HandleStreamTimeout( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, responseChan chan *schemas.BifrostStreamChunk, logger schemas.Logger, postHookSpanFinalizer func(context.Context), ) { // Check if already handled (StreamEndIndicator already set) if indicator := ctx.GetAndSetValue(schemas.BifrostContextKeyStreamEndIndicator, true); indicator != nil { if set, ok := indicator.(bool); ok && set { return // Already handled } } // Create timeout error timeoutErr := &schemas.BifrostError{ StatusCode: schemas.Ptr(504), // Gateway Timeout Error: &schemas.ErrorField{ Message: "Request timed out: deadline exceeded", Type: schemas.Ptr(schemas.RequestTimedOut), }, } // Send through PostHook chain - this updates the log to "error" status ProcessAndSendBifrostError(ctx, postHookRunner, timeoutErr, responseChan, logger, postHookSpanFinalizer) } // ProcessAndSendError handles post-hook processing and sends the error to the channel. // This utility reduces code duplication across streaming implementations by encapsulating // the common pattern of running post hooks, handling errors, and sending responses with // proper context cancellation handling. func ProcessAndSendError( ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, err error, responseChan chan *schemas.BifrostStreamChunk, logger schemas.Logger, postHookSpanFinalizer func(context.Context), ) { // Send scanner error through channel bifrostError := &schemas.BifrostError{ IsBifrostError: true, Error: &schemas.ErrorField{ Message: fmt.Sprintf("Error reading stream: %v", err), Error: err, }, } processedResponse, processedError := postHookRunner(ctx, nil, bifrostError) if HandleStreamControlSkip(processedError) { return } streamResponse := &schemas.BifrostStreamChunk{} if processedResponse != nil { streamResponse.BifrostTextCompletionResponse = processedResponse.TextCompletionResponse streamResponse.BifrostChatResponse = processedResponse.ChatResponse streamResponse.BifrostResponsesStreamResponse = processedResponse.ResponsesStreamResponse streamResponse.BifrostSpeechStreamResponse = processedResponse.SpeechStreamResponse streamResponse.BifrostTranscriptionStreamResponse = processedResponse.TranscriptionStreamResponse } if processedError != nil { streamResponse.BifrostError = processedError } select { case responseChan <- streamResponse: case <-ctx.Done(): } } // CreateBifrostTextCompletionChunkResponse creates a bifrost text completion chunk response. func CreateBifrostTextCompletionChunkResponse( id string, usage *schemas.BifrostLLMUsage, finishReason *string, currentChunkIndex int, requestType schemas.RequestType, ) *schemas.BifrostTextCompletionResponse { response := &schemas.BifrostTextCompletionResponse{ ID: id, Object: "text_completion", Usage: usage, Choices: []schemas.BifrostResponseChoice{ { FinishReason: finishReason, TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{}, // empty delta }, }, ExtraFields: schemas.BifrostResponseExtraFields{ ChunkIndex: currentChunkIndex + 1, }, } return response } // CreateBifrostChatCompletionChunkResponse creates a bifrost chat completion chunk response. func CreateBifrostChatCompletionChunkResponse( id string, usage *schemas.BifrostLLMUsage, finishReason *string, currentChunkIndex int, model string, created int, ) *schemas.BifrostChatResponse { response := &schemas.BifrostChatResponse{ ID: id, Model: model, Created: created, Object: "chat.completion.chunk", Usage: usage, Choices: []schemas.BifrostResponseChoice{ { FinishReason: finishReason, ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{ Delta: &schemas.ChatStreamResponseChoiceDelta{}, // empty delta }, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ ChunkIndex: currentChunkIndex + 1, }, } return response } // HandleStreamControlSkip checks if the stream control should be skipped. func HandleStreamControlSkip(bifrostErr *schemas.BifrostError) bool { if bifrostErr == nil || bifrostErr.StreamControl == nil { return false } if bifrostErr.StreamControl.SkipStream != nil && *bifrostErr.StreamControl.SkipStream { if bifrostErr.StreamControl.LogError != nil && *bifrostErr.StreamControl.LogError { getLogger().Warn("Error in stream: " + bifrostErr.Error.Message) } return true } return false } // GetProviderName extracts the provider name from custom provider configuration. // If a custom provider key is specified, it returns that; otherwise, it returns the default provider. // Note: CustomProviderKey is internally set by Bifrost and should always match the provider name. func GetProviderName(defaultProvider schemas.ModelProvider, customConfig *schemas.CustomProviderConfig) schemas.ModelProvider { if customConfig != nil { if key := strings.TrimSpace(customConfig.CustomProviderKey); key != "" { return schemas.ModelProvider(key) } } return defaultProvider } // ProviderSendsDoneMarker returns true if the provider sends the [DONE] marker in streaming responses. // Some OpenAI-compatible providers (like Cerebras) don't send [DONE] and instead end the stream // after sending the finish_reason. This function helps determine the correct stream termination logic. func ProviderSendsDoneMarker(providerName schemas.ModelProvider) bool { switch providerName { case schemas.Cerebras, schemas.Perplexity, schemas.HuggingFace: // Cerebras, Perplexity, and HuggingFace don't send [DONE] marker, ends stream after finish_reason return false default: // Default to expecting [DONE] marker for safety return true } } func ProviderIsResponsesAPINative(providerName schemas.ModelProvider) bool { switch providerName { case schemas.OpenAI, schemas.OpenRouter, schemas.Azure: return true default: return false } } // ReleaseStreamingResponse releases a streaming response by draining the body stream and releasing the response. func ReleaseStreamingResponse(resp *fasthttp.Response) { defer func() { if r := recover(); r != nil { getLogger().Error("recovered panic in ReleaseStreamingResponse: %v", r) } // Always release the response to prevent leaks, even after a panic fasthttp.ReleaseResponse(resp) }() // Drain any remaining data from the body stream before releasing. // This prevents "whitespace in header" errors when the connection is reused // (see: https://github.com/valyala/fasthttp/issues/1743). if bodyStream := resp.BodyStream(); bodyStream != nil { if _, err := io.Copy(io.Discard, bodyStream); err != nil { getLogger().Warn("failed to drain streaming response body before release (may cause stale connection reuse): %v", err) } if closer, ok := bodyStream.(io.Closer); ok { if err := closer.Close(); err != nil { getLogger().Warn("failed to close streaming response body: %v", err) } } } } // GetBifrostResponseForStreamResponse converts the provided responses to a bifrost response. func GetBifrostResponseForStreamResponse( textCompletionResponse *schemas.BifrostTextCompletionResponse, chatResponse *schemas.BifrostChatResponse, responsesStreamResponse *schemas.BifrostResponsesStreamResponse, speechStreamResponse *schemas.BifrostSpeechStreamResponse, transcriptionStreamResponse *schemas.BifrostTranscriptionStreamResponse, imageGenerationStreamResponse *schemas.BifrostImageGenerationStreamResponse, ) *schemas.BifrostResponse { // TODO add bifrost response pooling here bifrostResponse := &schemas.BifrostResponse{} switch { case textCompletionResponse != nil: bifrostResponse.TextCompletionResponse = textCompletionResponse return bifrostResponse case chatResponse != nil: bifrostResponse.ChatResponse = chatResponse return bifrostResponse case responsesStreamResponse != nil: bifrostResponse.ResponsesStreamResponse = responsesStreamResponse return bifrostResponse case speechStreamResponse != nil: bifrostResponse.SpeechStreamResponse = speechStreamResponse return bifrostResponse case transcriptionStreamResponse != nil: bifrostResponse.TranscriptionStreamResponse = transcriptionStreamResponse return bifrostResponse case imageGenerationStreamResponse != nil: bifrostResponse.ImageGenerationStreamResponse = imageGenerationStreamResponse return bifrostResponse } return nil } // aggregateListModelsResponses merges multiple BifrostListModelsResponse objects into a single response. // It concatenates all model arrays, deduplicates based on model ID, sums up latencies across all responses, // and concatenates raw responses into an array. // When duplicate IDs are found, the first occurrence is kept to maintain the original ordering. func aggregateListModelsResponses(responses []*schemas.BifrostListModelsResponse) *schemas.BifrostListModelsResponse { if len(responses) == 0 { return &schemas.BifrostListModelsResponse{ Data: []schemas.Model{}, } } // Always apply deduplication, even for single responses // Use a map to track unique model IDs for efficient deduplication seenIDs := make(map[string]struct{}) aggregated := &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), } // Aggregate all models with deduplication, and collect raw responses var rawResponses []interface{} for _, response := range responses { if response == nil { continue } // Add models, skipping duplicates based on ID for _, model := range response.Data { if _, exists := seenIDs[model.ID]; !exists { seenIDs[model.ID] = struct{}{} aggregated.Data = append(aggregated.Data, model) } } // Collect raw response if present if response.ExtraFields.RawResponse != nil { rawResponses = append(rawResponses, response.ExtraFields.RawResponse) } } // Sort models alphabetically by ID sort.Slice(aggregated.Data, func(i, j int) bool { return aggregated.Data[i].ID < aggregated.Data[j].ID }) if len(rawResponses) > 0 { aggregated.ExtraFields.RawResponse = rawResponses } return aggregated } // extractSuccessfulListModelsResponses extracts successful responses from a results channel // and tracks per-key status information. This utility reduces code duplication across providers // for handling multi-key ListModels requests. func extractSuccessfulListModelsResponses(results chan schemas.ListModelsByKeyResult, provider schemas.ModelProvider) ([]*schemas.BifrostListModelsResponse, []schemas.KeyStatus, *schemas.BifrostError) { var successfulResponses []*schemas.BifrostListModelsResponse var keyStatuses []schemas.KeyStatus var lastError *schemas.BifrostError for result := range results { if result.Err != nil { errMsg := "unknown error" if errorField := result.Err.Error; errorField != nil { if errorField.Message != "" { errMsg = errorField.Message } else if errorField.Error != nil { errMsg = errorField.Error.Error() } } getLogger().Warn(fmt.Sprintf("failed to list models with key %s: %s", result.KeyID, errMsg)) keyStatuses = append(keyStatuses, schemas.KeyStatus{ KeyID: result.KeyID, Provider: provider, Status: schemas.KeyStatusListModelsFailed, Error: result.Err, }) lastError = result.Err continue } keyStatuses = append(keyStatuses, schemas.KeyStatus{ KeyID: result.KeyID, Provider: provider, Status: schemas.KeyStatusSuccess, }) successfulResponses = append(successfulResponses, result.Response) } if len(successfulResponses) == 0 { if lastError != nil { return nil, keyStatuses, lastError } return nil, keyStatuses, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "all keys failed to list models", }, } } return successfulResponses, keyStatuses, nil } // HandleKeylessListModelsRequest wraps a list models request for keyless providers // and automatically populates the KeyStatuses field with provider-level status tracking. // This centralizes the status tracking logic for keyless providers. func HandleKeylessListModelsRequest( provider schemas.ModelProvider, listFunc func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError), ) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { resp, bifrostErr := listFunc() keyStatus := schemas.KeyStatus{ KeyID: "", // Empty for keyless providers Provider: provider, } // If request failed, attach status to error if bifrostErr != nil { keyStatus.Status = schemas.KeyStatusListModelsFailed keyStatus.Error = bifrostErr bifrostErr.ExtraFields.KeyStatuses = []schemas.KeyStatus{keyStatus} return nil, bifrostErr } // Success case if resp != nil { keyStatus.Status = schemas.KeyStatusSuccess resp.KeyStatuses = []schemas.KeyStatus{keyStatus} return resp, nil } return resp, bifrostErr } // HandleMultipleListModelsRequests handles multiple list models requests concurrently for different keys. // It launches concurrent requests for all keys and waits for all goroutines to complete. // It returns the aggregated response with per-key status information or an error if the request fails. func HandleMultipleListModelsRequests( ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest, listModelsByKey func(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError), ) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { startTime := time.Now() results := make(chan schemas.ListModelsByKeyResult, len(keys)) var wg sync.WaitGroup // Launch concurrent requests for all keys for _, key := range keys { wg.Add(1) go func(k schemas.Key) { defer wg.Done() // Should never panic, but if it does, we need to handle it gracefully defer func() { if r := recover(); r != nil { getLogger().Error("panic in listModelsByKey for key %s (%s): %v", k.Name, k.ID, r) results <- schemas.ListModelsByKeyResult{ Err: &schemas.BifrostError{ IsBifrostError: true, Error: &schemas.ErrorField{ Message: "internal error while listing models for key", }, }, KeyID: k.ID, } } }() resp, bifrostErr := listModelsByKey(ctx, k, request) results <- schemas.ListModelsByKeyResult{Response: resp, Err: bifrostErr, KeyID: k.ID} }(key) } // Wait for all goroutines to complete wg.Wait() close(results) successfulResponses, keyStatuses, err := extractSuccessfulListModelsResponses(results, request.Provider) if err != nil { // Attach key statuses to error's ExtraFields err.ExtraFields.KeyStatuses = keyStatuses return nil, err } // Aggregate all successful responses response := aggregateListModelsResponses(successfulResponses) response = response.ApplyPagination(request.PageSize, request.PageToken) // Attach key statuses to response response.KeyStatuses = keyStatuses // Set ExtraFields latency := time.Since(startTime) response.ExtraFields.Latency = latency.Milliseconds() return response, nil } // GetRandomString generates a random alphanumeric string of the given length. func GetRandomString(length int) string { if length <= 0 { return "" } randomSource := rand.New(rand.NewSource(time.Now().UnixNano())) letters := []rune("abcdef0123456789") b := make([]rune, length) for i := range b { b[i] = letters[randomSource.Intn(len(letters))] } return string(b) } // GetReasoningEffortFromBudgetTokens maps a reasoning token budget to OpenAI reasoning effort. // Valid values: none, low, medium, high func GetReasoningEffortFromBudgetTokens( budgetTokens int, minBudgetTokens int, maxTokens int, ) string { if budgetTokens <= 0 { return "none" } // Defensive defaults if maxTokens <= 0 { return "medium" } // Normalize budget if budgetTokens < minBudgetTokens { budgetTokens = minBudgetTokens } if budgetTokens > maxTokens { budgetTokens = maxTokens } // Avoid division by zero if maxTokens <= minBudgetTokens { return "high" } ratio := float64(budgetTokens-minBudgetTokens) / float64(maxTokens-minBudgetTokens) switch { case ratio <= 0.25: return "low" case ratio <= 0.60: return "medium" default: return "high" } } // GetBudgetTokensFromReasoningEffort converts reasoning effort into a reasoning token budget. // effort ∈ {"none", "minimal", "low", "medium", "high", "xhigh", "max"} func GetBudgetTokensFromReasoningEffort( effort string, minBudgetTokens int, maxTokens int, ) (int, error) { if effort == "none" { return 0, nil } if minBudgetTokens > maxTokens { return 0, fmt.Errorf("max_tokens must be greater than %d for reasoning", minBudgetTokens) } // Defensive defaults if maxTokens <= minBudgetTokens { return minBudgetTokens, nil } var ratio float64 switch effort { case "minimal": ratio = 0.025 case "low": ratio = 0.15 case "medium": ratio = 0.425 case "high": ratio = 0.80 case "xhigh": ratio = 0.92 case "max": ratio = 1.0 default: // Unknown effort → safe default ratio = 0.425 } budget := minBudgetTokens + int(ratio*float64(maxTokens-minBudgetTokens)) return budget, nil } // completeDeferredSpan completes the deferred LLM span for streaming requests. // This is called when the final chunk is processed (when StreamEndIndicator is true). // It retrieves the deferred span handle from TraceStore using the trace ID from context, // populates response attributes from accumulated chunks, and ends the span. func completeDeferredSpan(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError, postHookSpanFinalizer func(context.Context)) { if ctx == nil { return } // Get the trace ID from context (this IS available in the provider's goroutine) traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string) if !ok || traceID == "" { return } // Get the tracer from context tracerVal := ctx.Value(schemas.BifrostContextKeyTracer) if tracerVal == nil { return } tracer, ok := tracerVal.(schemas.Tracer) if !ok || tracer == nil { return } // Get the deferred span handle from TraceStore using trace ID handle := tracer.GetDeferredSpanHandle(traceID) if handle == nil { return } // Set total latency from the final chunk if result != nil { extraFields := result.GetExtraFields() if extraFields.Latency > 0 { tracer.SetAttribute(handle, "gen_ai.response.total_latency_ms", extraFields.Latency) } } // Get accumulated response with full data (content, tool calls, reasoning, etc.) // This builds a complete BifrostResponse from all the streaming chunks accumulatedResp, ttftNs, chunkCount := tracer.GetAccumulatedChunks(traceID) // Set TTFT and chunk count attributes regardless of accumulated response availability // (GetAccumulatedChunks may return nil response while still providing valid metrics) if ttftNs > 0 { tracer.SetAttribute(handle, schemas.AttrTimeToFirstToken, ttftNs) } if chunkCount > 0 { tracer.SetAttribute(handle, schemas.AttrTotalChunks, chunkCount) } if accumulatedResp != nil { // Use accumulated response for attributes (includes full content, tool calls, etc.) tracer.PopulateLLMResponseAttributes(ctx, handle, accumulatedResp, err) } else if result != nil { // Fall back to final chunk if no accumulated data (shouldn't happen normally) tracer.PopulateLLMResponseAttributes(ctx, handle, result, err) } // Finalize aggregated post-hook spans before ending the LLM span // This creates one span per plugin with average execution time // We need to set the llm.call span ID in context so post-hook spans become its children if postHookSpanFinalizer != nil { // Get the deferred span ID (the llm.call span) to set as parent for post-hook spans spanID := tracer.GetDeferredSpanID(traceID) if spanID != "" { finalizerCtx := context.WithValue(ctx, schemas.BifrostContextKeySpanID, spanID) postHookSpanFinalizer(finalizerCtx) } else { postHookSpanFinalizer(ctx) } } // End span with appropriate status if err != nil { if err.Error != nil { tracer.SetAttribute(handle, "error", err.Error.Message) } if err.StatusCode != nil { tracer.SetAttribute(handle, "status_code", *err.StatusCode) } tracer.EndSpan(handle, schemas.SpanStatusError, "streaming request failed") } else { tracer.EndSpan(handle, schemas.SpanStatusOk, "") } // Clear the deferred span from TraceStore tracer.ClearDeferredSpan(traceID) } // CheckAndSetDefaultProvider checks if the default provider should be used based on the context. // It returns the default provider if it should be used, otherwise it returns an empty string. // Checks if the direct key is set in the context, or if key selection is skipped. // Or if the available providers are set in the context and the default provider is in the list. func CheckAndSetDefaultProvider(ctx *schemas.BifrostContext, defaultProvider schemas.ModelProvider) schemas.ModelProvider { if ctx != nil { if ctx.Value(schemas.BifrostContextKeyDirectKey) != nil || ctx.Value(schemas.BifrostContextKeySkipKeySelection) != nil { return defaultProvider } if ctx.Value(schemas.BifrostContextKeyAvailableProviders) != nil { availableProviders, ok := ctx.Value(schemas.BifrostContextKeyAvailableProviders).([]schemas.ModelProvider) if !ok || len(availableProviders) == 0 { return "" } getLogger().Debug("[Provider] Available providers: %v, checking %s", availableProviders, defaultProvider) if slices.Contains(availableProviders, defaultProvider) { return defaultProvider } return "" } return defaultProvider } return defaultProvider } // ModelMatchesDenylist reports whether any of the candidate model IDs matches // an entry in denylist, using both exact and base-model (SameBaseModel) matching. // Empty candidates are skipped. Returns false immediately if denylist is empty. func ModelMatchesDenylist(denylist []string, candidates ...string) bool { if len(denylist) == 0 { return false } for _, c := range candidates { if c == "" { continue } if slices.Contains(denylist, c) { return true } for _, d := range denylist { if schemas.SameBaseModel(d, c) { return true } } } return false }