first commit
This commit is contained in:
614
core/schemas/provider.go
Normal file
614
core/schemas/provider.go
Normal file
@@ -0,0 +1,614 @@
|
||||
// Package schemas defines the core schemas and types used by the Bifrost system.
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"maps"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMaxRetries = 0
|
||||
DefaultRetryBackoffInitial = 500 * time.Millisecond
|
||||
DefaultRetryBackoffMax = 5 * time.Second
|
||||
DefaultRequestTimeoutInSeconds = 30
|
||||
DefaultMaxConnDurationInSeconds = 300 // 5 minutes — forces connection recycling to prevent stale connections from NAT/LB silent drops
|
||||
DefaultBufferSize = 5000
|
||||
DefaultConcurrency = 1000
|
||||
DefaultStreamBufferSize = 256
|
||||
DefaultStreamIdleTimeoutInSeconds = 60 // Idle timeout per stream chunk — if no data for this many seconds, bifrost closes the connection
|
||||
DefaultMaxConnsPerHost = 5000
|
||||
MaxConnsPerHostUpperBound = 10000
|
||||
DefaultMaxIdleConnsPerHost = 40
|
||||
)
|
||||
|
||||
// Pre-defined errors for provider operations
|
||||
const (
|
||||
ErrProviderRequestTimedOut = "request timed out (default is 30 seconds). You can increase it by setting the default_request_timeout_in_seconds in the network_config or in UI - Providers > Provider Name > Network Config."
|
||||
ErrRequestCancelled = "request cancelled by caller"
|
||||
ErrRequestBodyConversion = "failed to convert bifrost request to the expected provider request body"
|
||||
ErrProviderRequestMarshal = "failed to marshal request body to JSON"
|
||||
ErrProviderCreateRequest = "failed to create HTTP request to provider API"
|
||||
ErrProviderDoRequest = "failed to execute HTTP request to provider API"
|
||||
ErrProviderNetworkError = "network error occurred while connecting to provider API (DNS lookup, connection refused, etc.)"
|
||||
ErrProviderResponseDecode = "failed to decode response body from provider API"
|
||||
ErrProviderResponseUnmarshal = "failed to unmarshal response from provider API"
|
||||
ErrProviderResponseEmpty = "empty response received from provider"
|
||||
ErrProviderResponseHTML = "HTML response received from provider"
|
||||
ErrProviderRawRequestUnmarshal = "failed to unmarshal raw request from provider API"
|
||||
ErrProviderRawResponseUnmarshal = "failed to unmarshal raw response from provider API"
|
||||
ErrProviderResponseDecompress = "failed to decompress provider's response"
|
||||
)
|
||||
|
||||
// NetworkConfig represents the network configuration for provider connections.
|
||||
// ExtraHeaders is automatically copied during provider initialization to prevent data races.
|
||||
//
|
||||
// RetryBackoffInitial and RetryBackoffMax are stored internally as time.Duration (nanoseconds),
|
||||
// but are serialized/deserialized to/from JSON as milliseconds (integers).
|
||||
// This means:
|
||||
// - In JSON: values are represented as milliseconds (e.g., 1000 means 1000ms)
|
||||
// - In Go: values are time.Duration (e.g., 1000ms = 1000000000 nanoseconds)
|
||||
// - When unmarshaling from JSON: a value of 1000 is interpreted as 1000ms, not 1000ns
|
||||
// - When marshaling to JSON: a time.Duration is converted to milliseconds
|
||||
type NetworkConfig struct {
|
||||
// BaseURL is supported for OpenAI, Anthropic, Cohere, Mistral, and Ollama providers (required for Ollama)
|
||||
BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional)
|
||||
ExtraHeaders map[string]string `json:"extra_headers,omitempty"` // Additional headers to include in requests (optional)
|
||||
DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests
|
||||
MaxRetries int `json:"max_retries"` // Maximum number of retries
|
||||
RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration (stored as nanoseconds, JSON as milliseconds)
|
||||
RetryBackoffMax time.Duration `json:"retry_backoff_max"` // Maximum backoff duration (stored as nanoseconds, JSON as milliseconds)
|
||||
InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` // Disables TLS certificate verification for provider connections
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"` // PEM-encoded CA certificate to trust for provider endpoint connections
|
||||
StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"` // Idle timeout per stream chunk (0 = use default 60s)
|
||||
MaxConnsPerHost int `json:"max_conns_per_host,omitempty"` // Max TCP connections per provider host (default: 5000)
|
||||
EnforceHTTP2 bool `json:"enforce_http2,omitempty"` // Force HTTP/2 on provider connections (relevant for net/http-based providers like Bedrock)
|
||||
BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"` // Override default beta header support per provider (keys are prefixes like "redact-thinking-")
|
||||
}
|
||||
|
||||
// UnmarshalJSON customizes JSON unmarshaling for NetworkConfig.
|
||||
// RetryBackoffInitial and RetryBackoffMax are interpreted as milliseconds in JSON,
|
||||
// but stored as time.Duration (nanoseconds) internally.
|
||||
func (nc *NetworkConfig) UnmarshalJSON(data []byte) error {
|
||||
// Use an alias type to avoid infinite recursion
|
||||
type NetworkConfigAlias struct {
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
|
||||
DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
RetryBackoffInitial int64 `json:"retry_backoff_initial"` // milliseconds in JSON
|
||||
RetryBackoffMax int64 `json:"retry_backoff_max"` // milliseconds in JSON
|
||||
InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"`
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||
StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"`
|
||||
MaxConnsPerHost int `json:"max_conns_per_host,omitempty"`
|
||||
EnforceHTTP2 bool `json:"enforce_http2,omitempty"`
|
||||
BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"`
|
||||
}
|
||||
|
||||
var alias NetworkConfigAlias
|
||||
if err := json.Unmarshal(data, &alias); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy all fields
|
||||
nc.BaseURL = alias.BaseURL
|
||||
nc.ExtraHeaders = alias.ExtraHeaders
|
||||
nc.DefaultRequestTimeoutInSeconds = alias.DefaultRequestTimeoutInSeconds
|
||||
nc.MaxRetries = alias.MaxRetries
|
||||
nc.InsecureSkipVerify = alias.InsecureSkipVerify
|
||||
nc.CACertPEM = alias.CACertPEM
|
||||
nc.StreamIdleTimeoutInSeconds = alias.StreamIdleTimeoutInSeconds
|
||||
nc.MaxConnsPerHost = alias.MaxConnsPerHost
|
||||
nc.EnforceHTTP2 = alias.EnforceHTTP2
|
||||
nc.BetaHeaderOverrides = alias.BetaHeaderOverrides
|
||||
|
||||
// Convert milliseconds to time.Duration (nanoseconds)
|
||||
// Only convert if value is greater than 0
|
||||
if alias.RetryBackoffInitial > 0 {
|
||||
nc.RetryBackoffInitial = time.Duration(alias.RetryBackoffInitial) * time.Millisecond
|
||||
}
|
||||
if alias.RetryBackoffMax > 0 {
|
||||
nc.RetryBackoffMax = time.Duration(alias.RetryBackoffMax) * time.Millisecond
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON customizes JSON marshaling for NetworkConfig.
|
||||
// RetryBackoffInitial and RetryBackoffMax are converted from time.Duration (nanoseconds)
|
||||
// to milliseconds (integers) in JSON.
|
||||
func (nc NetworkConfig) MarshalJSON() ([]byte, error) {
|
||||
// Use an alias type to avoid infinite recursion
|
||||
type NetworkConfigAlias struct {
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
|
||||
DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"`
|
||||
MaxRetries int `json:"max_retries"`
|
||||
RetryBackoffInitial int64 `json:"retry_backoff_initial"` // milliseconds in JSON
|
||||
RetryBackoffMax int64 `json:"retry_backoff_max"` // milliseconds in JSON
|
||||
InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"`
|
||||
CACertPEM string `json:"ca_cert_pem,omitempty"`
|
||||
StreamIdleTimeoutInSeconds int `json:"stream_idle_timeout_in_seconds,omitempty"`
|
||||
MaxConnsPerHost int `json:"max_conns_per_host,omitempty"`
|
||||
EnforceHTTP2 bool `json:"enforce_http2,omitempty"`
|
||||
BetaHeaderOverrides map[string]bool `json:"beta_header_overrides,omitempty"`
|
||||
}
|
||||
|
||||
alias := NetworkConfigAlias{
|
||||
BaseURL: nc.BaseURL,
|
||||
ExtraHeaders: nc.ExtraHeaders,
|
||||
DefaultRequestTimeoutInSeconds: nc.DefaultRequestTimeoutInSeconds,
|
||||
MaxRetries: nc.MaxRetries,
|
||||
// Convert time.Duration (nanoseconds) to milliseconds
|
||||
RetryBackoffInitial: int64(nc.RetryBackoffInitial / time.Millisecond),
|
||||
RetryBackoffMax: int64(nc.RetryBackoffMax / time.Millisecond),
|
||||
InsecureSkipVerify: nc.InsecureSkipVerify,
|
||||
CACertPEM: nc.CACertPEM,
|
||||
StreamIdleTimeoutInSeconds: nc.StreamIdleTimeoutInSeconds,
|
||||
MaxConnsPerHost: nc.MaxConnsPerHost,
|
||||
EnforceHTTP2: nc.EnforceHTTP2,
|
||||
BetaHeaderOverrides: nc.BetaHeaderOverrides,
|
||||
}
|
||||
|
||||
return json.Marshal(alias)
|
||||
}
|
||||
|
||||
// Redacted returns a redacted copy of the network configuration with CACertPEM masked.
|
||||
func (nc *NetworkConfig) Redacted() *NetworkConfig {
|
||||
if nc == nil {
|
||||
return nil
|
||||
}
|
||||
redacted := *nc
|
||||
if nc.CACertPEM != "" {
|
||||
redacted.CACertPEM = "<REDACTED>"
|
||||
}
|
||||
return &redacted
|
||||
}
|
||||
|
||||
// DefaultNetworkConfig is the default network configuration for provider connections.
|
||||
var DefaultNetworkConfig = NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: DefaultRequestTimeoutInSeconds,
|
||||
MaxRetries: DefaultMaxRetries,
|
||||
RetryBackoffInitial: DefaultRetryBackoffInitial,
|
||||
RetryBackoffMax: DefaultRetryBackoffMax,
|
||||
StreamIdleTimeoutInSeconds: DefaultStreamIdleTimeoutInSeconds,
|
||||
MaxConnsPerHost: DefaultMaxConnsPerHost,
|
||||
}
|
||||
|
||||
// ConcurrencyAndBufferSize represents configuration for concurrent operations and buffer sizes.
|
||||
type ConcurrencyAndBufferSize struct {
|
||||
Concurrency int `json:"concurrency"` // Number of concurrent operations. Also used as the initial pool size for the provider reponses.
|
||||
BufferSize int `json:"buffer_size"` // Size of the buffer
|
||||
}
|
||||
|
||||
// DefaultConcurrencyAndBufferSize is the default concurrency and buffer size for provider operations.
|
||||
var DefaultConcurrencyAndBufferSize = ConcurrencyAndBufferSize{
|
||||
Concurrency: DefaultConcurrency,
|
||||
BufferSize: DefaultBufferSize,
|
||||
}
|
||||
|
||||
// ProxyType defines the type of proxy to use for connections.
|
||||
type ProxyType string
|
||||
|
||||
const (
|
||||
// NoProxy indicates no proxy should be used
|
||||
NoProxy ProxyType = "none"
|
||||
// HTTPProxy indicates an HTTP proxy should be used
|
||||
HTTPProxy ProxyType = "http"
|
||||
// Socks5Proxy indicates a SOCKS5 proxy should be used
|
||||
Socks5Proxy ProxyType = "socks5"
|
||||
// EnvProxy indicates the proxy should be read from environment variables
|
||||
EnvProxy ProxyType = "environment"
|
||||
)
|
||||
|
||||
// ProxyConfig holds the configuration for proxy settings.
|
||||
type ProxyConfig struct {
|
||||
Type ProxyType `json:"type"` // Type of proxy to use
|
||||
URL string `json:"url"` // URL of the proxy server
|
||||
Username string `json:"username"` // Username for proxy authentication
|
||||
Password string `json:"password"` // Password for proxy authentication
|
||||
CACertPEM string `json:"ca_cert_pem"` // PEM-encoded CA certificate to trust for TLS connections through the proxy
|
||||
}
|
||||
|
||||
// IsRedactedValue returns true if the value is redacted.
|
||||
func (pc *ProxyConfig) IsRedactedValue(value string) bool {
|
||||
return value == "<REDACTED>" || value == "********"
|
||||
}
|
||||
|
||||
// Redacted returns a redacted copy of the proxy configuration.
|
||||
func (pc *ProxyConfig) Redacted() *ProxyConfig {
|
||||
// Create redacted config with same structure but redacted values
|
||||
redactedConfig := ProxyConfig{
|
||||
Type: pc.Type,
|
||||
URL: pc.URL,
|
||||
Username: pc.Username,
|
||||
}
|
||||
if pc.Password != "" {
|
||||
redactedConfig.Password = "<REDACTED>"
|
||||
}
|
||||
if pc.CACertPEM != "" {
|
||||
redactedConfig.CACertPEM = "<REDACTED>"
|
||||
}
|
||||
return &redactedConfig
|
||||
}
|
||||
|
||||
// AllowedRequests controls which operations are permitted.
|
||||
// A nil *AllowedRequests means "all operations allowed."
|
||||
// A non-nil value only allows fields set to true; omitted or false fields are disallowed.
|
||||
type AllowedRequests struct {
|
||||
ListModels bool `json:"list_models"`
|
||||
TextCompletion bool `json:"text_completion"`
|
||||
TextCompletionStream bool `json:"text_completion_stream"`
|
||||
ChatCompletion bool `json:"chat_completion"`
|
||||
ChatCompletionStream bool `json:"chat_completion_stream"`
|
||||
Responses bool `json:"responses"`
|
||||
ResponsesStream bool `json:"responses_stream"`
|
||||
CountTokens bool `json:"count_tokens"`
|
||||
Embedding bool `json:"embedding"`
|
||||
Rerank bool `json:"rerank"`
|
||||
OCR bool `json:"ocr"`
|
||||
Speech bool `json:"speech"`
|
||||
SpeechStream bool `json:"speech_stream"`
|
||||
Transcription bool `json:"transcription"`
|
||||
TranscriptionStream bool `json:"transcription_stream"`
|
||||
ImageGeneration bool `json:"image_generation"`
|
||||
ImageGenerationStream bool `json:"image_generation_stream"`
|
||||
ImageEdit bool `json:"image_edit"`
|
||||
ImageEditStream bool `json:"image_edit_stream"`
|
||||
ImageVariation bool `json:"image_variation"`
|
||||
VideoGeneration bool `json:"video_generation"`
|
||||
VideoRetrieve bool `json:"video_retrieve"`
|
||||
VideoDownload bool `json:"video_download"`
|
||||
VideoDelete bool `json:"video_delete"`
|
||||
VideoList bool `json:"video_list"`
|
||||
VideoRemix bool `json:"video_remix"`
|
||||
BatchCreate bool `json:"batch_create"`
|
||||
BatchList bool `json:"batch_list"`
|
||||
BatchRetrieve bool `json:"batch_retrieve"`
|
||||
BatchCancel bool `json:"batch_cancel"`
|
||||
BatchDelete bool `json:"batch_delete"`
|
||||
BatchResults bool `json:"batch_results"`
|
||||
FileUpload bool `json:"file_upload"`
|
||||
FileList bool `json:"file_list"`
|
||||
FileRetrieve bool `json:"file_retrieve"`
|
||||
FileDelete bool `json:"file_delete"`
|
||||
FileContent bool `json:"file_content"`
|
||||
ContainerCreate bool `json:"container_create"`
|
||||
ContainerList bool `json:"container_list"`
|
||||
ContainerRetrieve bool `json:"container_retrieve"`
|
||||
ContainerDelete bool `json:"container_delete"`
|
||||
ContainerFileCreate bool `json:"container_file_create"`
|
||||
ContainerFileList bool `json:"container_file_list"`
|
||||
ContainerFileRetrieve bool `json:"container_file_retrieve"`
|
||||
ContainerFileContent bool `json:"container_file_content"`
|
||||
ContainerFileDelete bool `json:"container_file_delete"`
|
||||
Passthrough bool `json:"passthrough"`
|
||||
PassthroughStream bool `json:"passthrough_stream"`
|
||||
WebSocketResponses bool `json:"websocket_responses"`
|
||||
Realtime bool `json:"realtime"`
|
||||
}
|
||||
|
||||
// IsOperationAllowed checks if a specific operation is allowed
|
||||
func (ar *AllowedRequests) IsOperationAllowed(operation RequestType) bool {
|
||||
if ar == nil {
|
||||
return true // Default to allowed if no restrictions
|
||||
}
|
||||
|
||||
switch operation {
|
||||
case ListModelsRequest:
|
||||
return ar.ListModels
|
||||
case TextCompletionRequest:
|
||||
return ar.TextCompletion
|
||||
case TextCompletionStreamRequest:
|
||||
return ar.TextCompletionStream
|
||||
case ChatCompletionRequest:
|
||||
return ar.ChatCompletion
|
||||
case ChatCompletionStreamRequest:
|
||||
return ar.ChatCompletionStream
|
||||
case ResponsesRequest:
|
||||
return ar.Responses
|
||||
case ResponsesStreamRequest:
|
||||
return ar.ResponsesStream
|
||||
case CountTokensRequest:
|
||||
return ar.CountTokens
|
||||
case EmbeddingRequest:
|
||||
return ar.Embedding
|
||||
case RerankRequest:
|
||||
return ar.Rerank
|
||||
case OCRRequest:
|
||||
return ar.OCR
|
||||
case SpeechRequest:
|
||||
return ar.Speech
|
||||
case SpeechStreamRequest:
|
||||
return ar.SpeechStream
|
||||
case TranscriptionRequest:
|
||||
return ar.Transcription
|
||||
case TranscriptionStreamRequest:
|
||||
return ar.TranscriptionStream
|
||||
case ImageGenerationRequest:
|
||||
return ar.ImageGeneration
|
||||
case ImageGenerationStreamRequest:
|
||||
return ar.ImageGenerationStream
|
||||
case ImageEditRequest:
|
||||
return ar.ImageEdit
|
||||
case ImageEditStreamRequest:
|
||||
return ar.ImageEditStream
|
||||
case ImageVariationRequest:
|
||||
return ar.ImageVariation
|
||||
case VideoGenerationRequest:
|
||||
return ar.VideoGeneration
|
||||
case VideoRetrieveRequest:
|
||||
return ar.VideoRetrieve
|
||||
case VideoDownloadRequest:
|
||||
return ar.VideoDownload
|
||||
case VideoDeleteRequest:
|
||||
return ar.VideoDelete
|
||||
case VideoListRequest:
|
||||
return ar.VideoList
|
||||
case VideoRemixRequest:
|
||||
return ar.VideoRemix
|
||||
case BatchCreateRequest:
|
||||
return ar.BatchCreate
|
||||
case BatchListRequest:
|
||||
return ar.BatchList
|
||||
case BatchRetrieveRequest:
|
||||
return ar.BatchRetrieve
|
||||
case BatchCancelRequest:
|
||||
return ar.BatchCancel
|
||||
case BatchDeleteRequest:
|
||||
return ar.BatchDelete
|
||||
case BatchResultsRequest:
|
||||
return ar.BatchResults
|
||||
case FileUploadRequest:
|
||||
return ar.FileUpload
|
||||
case FileListRequest:
|
||||
return ar.FileList
|
||||
case FileRetrieveRequest:
|
||||
return ar.FileRetrieve
|
||||
case FileDeleteRequest:
|
||||
return ar.FileDelete
|
||||
case FileContentRequest:
|
||||
return ar.FileContent
|
||||
case ContainerCreateRequest:
|
||||
return ar.ContainerCreate
|
||||
case ContainerListRequest:
|
||||
return ar.ContainerList
|
||||
case ContainerRetrieveRequest:
|
||||
return ar.ContainerRetrieve
|
||||
case ContainerDeleteRequest:
|
||||
return ar.ContainerDelete
|
||||
case ContainerFileCreateRequest:
|
||||
return ar.ContainerFileCreate
|
||||
case ContainerFileListRequest:
|
||||
return ar.ContainerFileList
|
||||
case ContainerFileRetrieveRequest:
|
||||
return ar.ContainerFileRetrieve
|
||||
case ContainerFileContentRequest:
|
||||
return ar.ContainerFileContent
|
||||
case ContainerFileDeleteRequest:
|
||||
return ar.ContainerFileDelete
|
||||
case PassthroughRequest:
|
||||
return ar.Passthrough
|
||||
case PassthroughStreamRequest:
|
||||
return ar.PassthroughStream
|
||||
case WebSocketResponsesRequest:
|
||||
return ar.WebSocketResponses
|
||||
case RealtimeRequest:
|
||||
return ar.Realtime
|
||||
default:
|
||||
return false // Default to not allowed for unknown operations
|
||||
}
|
||||
}
|
||||
|
||||
type CustomProviderConfig struct {
|
||||
CustomProviderKey string `json:"-"` // Custom provider key, internally set by Bifrost
|
||||
IsKeyLess bool `json:"is_key_less"` // Whether the custom provider requires a key (not allowed for Bedrock)
|
||||
BaseProviderType ModelProvider `json:"base_provider_type"` // Base provider type
|
||||
AllowedRequests *AllowedRequests `json:"allowed_requests,omitempty"` // Allowed requests for the custom provider
|
||||
RequestPathOverrides map[RequestType]string `json:"request_path_overrides,omitempty"` // Mapping of request type to its custom path which will override the default path of the provider (not allowed for Bedrock)
|
||||
}
|
||||
|
||||
// IsOperationAllowed checks if a specific operation is allowed for this custom provider
|
||||
func (cpc *CustomProviderConfig) IsOperationAllowed(operation RequestType) bool {
|
||||
if cpc == nil || cpc.AllowedRequests == nil {
|
||||
return true // Default to allowed if no restrictions
|
||||
}
|
||||
return cpc.AllowedRequests.IsOperationAllowed(operation)
|
||||
}
|
||||
|
||||
// ProviderConfig represents the complete configuration for a provider.
|
||||
// An array of ProviderConfig needs to be provided in GetConfigForProvider
|
||||
// in your account interface implementation.
|
||||
type ProviderConfig struct {
|
||||
NetworkConfig NetworkConfig `json:"network_config"` // Network configuration
|
||||
ConcurrencyAndBufferSize ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings
|
||||
// Logger instance, can be provided by the user or bifrost default logger is used if not provided
|
||||
Logger Logger `json:"-"`
|
||||
ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration
|
||||
SendBackRawRequest bool `json:"send_back_raw_request"` // Send raw request back in the bifrost response (default: false)
|
||||
SendBackRawResponse bool `json:"send_back_raw_response"` // Send raw response back in the bifrost response (default: false)
|
||||
StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only; strip from API responses returned to clients (default: false)
|
||||
CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"`
|
||||
OpenAIConfig *OpenAIConfig `json:"openai_config,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIConfig holds OpenAI-specific provider configuration.
|
||||
type OpenAIConfig struct {
|
||||
DisableStore bool `json:"disable_store"` // When true, forces store=false on all outgoing OpenAI requests (default: false)
|
||||
}
|
||||
|
||||
func (config *ProviderConfig) CheckAndSetDefaults() {
|
||||
if config.ConcurrencyAndBufferSize.Concurrency == 0 {
|
||||
config.ConcurrencyAndBufferSize.Concurrency = DefaultConcurrency
|
||||
}
|
||||
|
||||
if config.ConcurrencyAndBufferSize.BufferSize == 0 {
|
||||
config.ConcurrencyAndBufferSize.BufferSize = DefaultBufferSize
|
||||
}
|
||||
|
||||
if config.NetworkConfig.DefaultRequestTimeoutInSeconds <= 0 {
|
||||
config.NetworkConfig.DefaultRequestTimeoutInSeconds = DefaultRequestTimeoutInSeconds
|
||||
}
|
||||
|
||||
if config.NetworkConfig.MaxRetries == 0 {
|
||||
config.NetworkConfig.MaxRetries = DefaultMaxRetries
|
||||
}
|
||||
|
||||
if config.NetworkConfig.RetryBackoffInitial == 0 {
|
||||
config.NetworkConfig.RetryBackoffInitial = DefaultRetryBackoffInitial
|
||||
}
|
||||
|
||||
if config.NetworkConfig.RetryBackoffMax == 0 {
|
||||
config.NetworkConfig.RetryBackoffMax = DefaultRetryBackoffMax
|
||||
}
|
||||
|
||||
if config.NetworkConfig.StreamIdleTimeoutInSeconds <= 0 {
|
||||
config.NetworkConfig.StreamIdleTimeoutInSeconds = DefaultStreamIdleTimeoutInSeconds
|
||||
}
|
||||
|
||||
if config.NetworkConfig.MaxConnsPerHost <= 0 {
|
||||
config.NetworkConfig.MaxConnsPerHost = DefaultMaxConnsPerHost
|
||||
} else if config.NetworkConfig.MaxConnsPerHost > MaxConnsPerHostUpperBound {
|
||||
config.NetworkConfig.MaxConnsPerHost = MaxConnsPerHostUpperBound
|
||||
}
|
||||
|
||||
// Create a defensive copy of ExtraHeaders to prevent data races
|
||||
if config.NetworkConfig.ExtraHeaders != nil {
|
||||
headersCopy := make(map[string]string, len(config.NetworkConfig.ExtraHeaders))
|
||||
maps.Copy(headersCopy, config.NetworkConfig.ExtraHeaders)
|
||||
config.NetworkConfig.ExtraHeaders = headersCopy
|
||||
}
|
||||
|
||||
// Create a defensive copy of BetaHeaderOverrides to prevent data races
|
||||
if config.NetworkConfig.BetaHeaderOverrides != nil {
|
||||
overridesCopy := make(map[string]bool, len(config.NetworkConfig.BetaHeaderOverrides))
|
||||
maps.Copy(overridesCopy, config.NetworkConfig.BetaHeaderOverrides)
|
||||
config.NetworkConfig.BetaHeaderOverrides = overridesCopy
|
||||
}
|
||||
}
|
||||
|
||||
type PostHookRunner func(ctx *BifrostContext, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError)
|
||||
|
||||
// Provider defines the interface for AI model providers.
|
||||
type Provider interface {
|
||||
// GetProviderKey returns the provider's identifier
|
||||
GetProviderKey() ModelProvider
|
||||
// ListModels performs a list models request
|
||||
ListModels(ctx *BifrostContext, keys []Key, request *BifrostListModelsRequest) (*BifrostListModelsResponse, *BifrostError)
|
||||
// TextCompletion performs a text completion request
|
||||
TextCompletion(ctx *BifrostContext, key Key, request *BifrostTextCompletionRequest) (*BifrostTextCompletionResponse, *BifrostError)
|
||||
// TextCompletionStream performs a text completion stream request.
|
||||
// postHookSpanFinalizer is invoked by the provider's stream goroutine on stream completion
|
||||
// (or on its panic-recovery defer) to finalize aggregated post-hook spans and release the
|
||||
// per-attempt plugin pipeline. Pass nil if the caller does not need finalization.
|
||||
TextCompletionStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostTextCompletionRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// ChatCompletion performs a chat completion request
|
||||
ChatCompletion(ctx *BifrostContext, key Key, request *BifrostChatRequest) (*BifrostChatResponse, *BifrostError)
|
||||
// ChatCompletionStream performs a chat completion stream request
|
||||
ChatCompletionStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostChatRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// Responses performs a completion request using the Responses API (uses chat completion request internally for non-openai providers)
|
||||
Responses(ctx *BifrostContext, key Key, request *BifrostResponsesRequest) (*BifrostResponsesResponse, *BifrostError)
|
||||
// ResponsesStream performs a completion request using the Responses API stream (uses chat completion stream request internally for non-openai providers)
|
||||
ResponsesStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostResponsesRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// CountTokens performs a count tokens request
|
||||
CountTokens(ctx *BifrostContext, key Key, request *BifrostResponsesRequest) (*BifrostCountTokensResponse, *BifrostError)
|
||||
// Embedding performs an embedding request
|
||||
Embedding(ctx *BifrostContext, key Key, request *BifrostEmbeddingRequest) (*BifrostEmbeddingResponse, *BifrostError)
|
||||
// Rerank performs a rerank request to reorder documents by relevance to a query
|
||||
Rerank(ctx *BifrostContext, key Key, request *BifrostRerankRequest) (*BifrostRerankResponse, *BifrostError)
|
||||
// OCR performs an optical character recognition request on a document
|
||||
OCR(ctx *BifrostContext, key Key, request *BifrostOCRRequest) (*BifrostOCRResponse, *BifrostError)
|
||||
// Speech performs a text to speech request
|
||||
Speech(ctx *BifrostContext, key Key, request *BifrostSpeechRequest) (*BifrostSpeechResponse, *BifrostError)
|
||||
// SpeechStream performs a text to speech stream request
|
||||
SpeechStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostSpeechRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// Transcription performs a transcription request
|
||||
Transcription(ctx *BifrostContext, key Key, request *BifrostTranscriptionRequest) (*BifrostTranscriptionResponse, *BifrostError)
|
||||
// TranscriptionStream performs a transcription stream request
|
||||
TranscriptionStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, request *BifrostTranscriptionRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// ImageGeneration performs an image generation request
|
||||
ImageGeneration(ctx *BifrostContext, key Key, request *BifrostImageGenerationRequest) (
|
||||
*BifrostImageGenerationResponse, *BifrostError)
|
||||
// ImageGenerationStream performs an image generation stream request
|
||||
ImageGenerationStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key,
|
||||
request *BifrostImageGenerationRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// ImageEdit performs an image edit request
|
||||
ImageEdit(ctx *BifrostContext, key Key, request *BifrostImageEditRequest) (*BifrostImageGenerationResponse, *BifrostError)
|
||||
// ImageEditStream performs an image edit stream request
|
||||
ImageEditStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key,
|
||||
request *BifrostImageEditRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
// ImageVariation performs an image variation request
|
||||
ImageVariation(ctx *BifrostContext, key Key, request *BifrostImageVariationRequest) (*BifrostImageGenerationResponse, *BifrostError)
|
||||
// VideoGeneration performs a video generation request
|
||||
VideoGeneration(ctx *BifrostContext, key Key, request *BifrostVideoGenerationRequest) (*BifrostVideoGenerationResponse, *BifrostError)
|
||||
// VideoRetrieve retrieves a video from the provider
|
||||
VideoRetrieve(ctx *BifrostContext, key Key, request *BifrostVideoRetrieveRequest) (*BifrostVideoGenerationResponse, *BifrostError)
|
||||
// VideoDownload downloads a video from the provider
|
||||
VideoDownload(ctx *BifrostContext, key Key, request *BifrostVideoDownloadRequest) (*BifrostVideoDownloadResponse, *BifrostError)
|
||||
// VideoDelete deletes a video from the provider
|
||||
VideoDelete(ctx *BifrostContext, key Key, request *BifrostVideoDeleteRequest) (*BifrostVideoDeleteResponse, *BifrostError)
|
||||
// VideoList lists videos from the provider
|
||||
VideoList(ctx *BifrostContext, key Key, request *BifrostVideoListRequest) (*BifrostVideoListResponse, *BifrostError)
|
||||
// VideoRemix remixes a video from the provider
|
||||
VideoRemix(ctx *BifrostContext, key Key, request *BifrostVideoRemixRequest) (*BifrostVideoGenerationResponse, *BifrostError)
|
||||
// BatchCreate creates a new batch job for asynchronous processing
|
||||
BatchCreate(ctx *BifrostContext, key Key, request *BifrostBatchCreateRequest) (*BifrostBatchCreateResponse, *BifrostError)
|
||||
// BatchList lists batch jobs
|
||||
BatchList(ctx *BifrostContext, keys []Key, request *BifrostBatchListRequest) (*BifrostBatchListResponse, *BifrostError)
|
||||
// BatchRetrieve retrieves a specific batch job
|
||||
BatchRetrieve(ctx *BifrostContext, keys []Key, request *BifrostBatchRetrieveRequest) (*BifrostBatchRetrieveResponse, *BifrostError)
|
||||
// BatchCancel cancels a batch job
|
||||
BatchCancel(ctx *BifrostContext, keys []Key, request *BifrostBatchCancelRequest) (*BifrostBatchCancelResponse, *BifrostError)
|
||||
// BatchDelete deletes a batch job
|
||||
BatchDelete(ctx *BifrostContext, keys []Key, request *BifrostBatchDeleteRequest) (*BifrostBatchDeleteResponse, *BifrostError)
|
||||
// BatchResults retrieves results from a completed batch job
|
||||
BatchResults(ctx *BifrostContext, keys []Key, request *BifrostBatchResultsRequest) (*BifrostBatchResultsResponse, *BifrostError)
|
||||
// FileUpload uploads a file to the provider
|
||||
FileUpload(ctx *BifrostContext, key Key, request *BifrostFileUploadRequest) (*BifrostFileUploadResponse, *BifrostError)
|
||||
// FileList lists files from the provider
|
||||
FileList(ctx *BifrostContext, keys []Key, request *BifrostFileListRequest) (*BifrostFileListResponse, *BifrostError)
|
||||
// FileRetrieve retrieves file metadata from the provider
|
||||
FileRetrieve(ctx *BifrostContext, keys []Key, request *BifrostFileRetrieveRequest) (*BifrostFileRetrieveResponse, *BifrostError)
|
||||
// FileDelete deletes a file from the provider
|
||||
FileDelete(ctx *BifrostContext, keys []Key, request *BifrostFileDeleteRequest) (*BifrostFileDeleteResponse, *BifrostError)
|
||||
// FileContent downloads file content from the provider
|
||||
FileContent(ctx *BifrostContext, keys []Key, request *BifrostFileContentRequest) (*BifrostFileContentResponse, *BifrostError)
|
||||
// ContainerCreate creates a new container
|
||||
ContainerCreate(ctx *BifrostContext, key Key, request *BifrostContainerCreateRequest) (*BifrostContainerCreateResponse, *BifrostError)
|
||||
// ContainerList lists containers
|
||||
ContainerList(ctx *BifrostContext, keys []Key, request *BifrostContainerListRequest) (*BifrostContainerListResponse, *BifrostError)
|
||||
// ContainerRetrieve retrieves a specific container
|
||||
ContainerRetrieve(ctx *BifrostContext, keys []Key, request *BifrostContainerRetrieveRequest) (*BifrostContainerRetrieveResponse, *BifrostError)
|
||||
// ContainerDelete deletes a container
|
||||
ContainerDelete(ctx *BifrostContext, keys []Key, request *BifrostContainerDeleteRequest) (*BifrostContainerDeleteResponse, *BifrostError)
|
||||
// ContainerFileCreate creates a file in a container
|
||||
ContainerFileCreate(ctx *BifrostContext, key Key, request *BifrostContainerFileCreateRequest) (*BifrostContainerFileCreateResponse, *BifrostError)
|
||||
// ContainerFileList lists files in a container
|
||||
ContainerFileList(ctx *BifrostContext, keys []Key, request *BifrostContainerFileListRequest) (*BifrostContainerFileListResponse, *BifrostError)
|
||||
// ContainerFileRetrieve retrieves a file from a container
|
||||
ContainerFileRetrieve(ctx *BifrostContext, keys []Key, request *BifrostContainerFileRetrieveRequest) (*BifrostContainerFileRetrieveResponse, *BifrostError)
|
||||
// ContainerFileContent retrieves the content of a file from a container
|
||||
ContainerFileContent(ctx *BifrostContext, keys []Key, request *BifrostContainerFileContentRequest) (*BifrostContainerFileContentResponse, *BifrostError)
|
||||
// ContainerFileDelete deletes a file from a container
|
||||
ContainerFileDelete(ctx *BifrostContext, keys []Key, request *BifrostContainerFileDeleteRequest) (*BifrostContainerFileDeleteResponse, *BifrostError)
|
||||
// Passthrough executes a non-streaming passthrough; body is fully buffered.
|
||||
Passthrough(ctx *BifrostContext, key Key, req *BifrostPassthroughRequest) (*BifrostPassthroughResponse, *BifrostError)
|
||||
// PassthroughStream executes a streaming passthrough, forwarding raw response bytes as BifrostStreamChunks.
|
||||
PassthroughStream(ctx *BifrostContext, postHookRunner PostHookRunner, postHookSpanFinalizer func(context.Context), key Key, req *BifrostPassthroughRequest) (chan *BifrostStreamChunk, *BifrostError)
|
||||
}
|
||||
|
||||
// WebSocketCapableProvider is an optional interface that providers can implement
|
||||
// to indicate support for the OpenAI Responses API WebSocket Mode.
|
||||
// Checked via type assertion: provider.(WebSocketCapableProvider).
|
||||
// Providers that implement this interface will have native WS upstream connections
|
||||
// instead of the HTTP bridge fallback for Responses WS mode.
|
||||
type WebSocketCapableProvider interface {
|
||||
// SupportsWebSocketMode returns true if the provider supports the Responses API WebSocket Mode.
|
||||
SupportsWebSocketMode() bool
|
||||
// WebSocketResponsesURL returns the WebSocket URL for the Responses API.
|
||||
WebSocketResponsesURL(key Key) string
|
||||
// WebSocketHeaders returns the headers required for the upstream WebSocket connection.
|
||||
WebSocketHeaders(key Key) map[string]string
|
||||
}
|
||||
Reference in New Issue
Block a user