// Package bifrost provides the core implementation of the Bifrost system. // Bifrost is a unified interface for interacting with various AI model providers, // managing concurrent requests, and handling provider-specific configurations. package bifrost import ( "context" "errors" "fmt" "slices" "sort" "strings" "sync" "sync/atomic" "time" "github.com/bytedance/sonic" "github.com/google/uuid" "github.com/maximhq/bifrost/core/keyselectors" "github.com/maximhq/bifrost/core/mcp" "github.com/maximhq/bifrost/core/mcp/codemode/starlark" "github.com/maximhq/bifrost/core/providers/anthropic" "github.com/maximhq/bifrost/core/providers/azure" "github.com/maximhq/bifrost/core/providers/bedrock" "github.com/maximhq/bifrost/core/providers/cerebras" "github.com/maximhq/bifrost/core/providers/cohere" "github.com/maximhq/bifrost/core/providers/elevenlabs" "github.com/maximhq/bifrost/core/providers/fireworks" "github.com/maximhq/bifrost/core/providers/gemini" "github.com/maximhq/bifrost/core/providers/groq" "github.com/maximhq/bifrost/core/providers/huggingface" "github.com/maximhq/bifrost/core/providers/mistral" "github.com/maximhq/bifrost/core/providers/nebius" "github.com/maximhq/bifrost/core/providers/ollama" "github.com/maximhq/bifrost/core/providers/openai" "github.com/maximhq/bifrost/core/providers/openrouter" "github.com/maximhq/bifrost/core/providers/parasail" "github.com/maximhq/bifrost/core/providers/perplexity" "github.com/maximhq/bifrost/core/providers/replicate" "github.com/maximhq/bifrost/core/providers/runway" "github.com/maximhq/bifrost/core/providers/sgl" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/providers/vertex" "github.com/maximhq/bifrost/core/providers/vllm" "github.com/maximhq/bifrost/core/providers/xai" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) // ChannelMessage represents a message passed through the request channel. // It contains the request, response and error channels, and the request type. type ChannelMessage struct { schemas.BifrostRequest Context *schemas.BifrostContext Response chan *schemas.BifrostResponse ResponseStream chan chan *schemas.BifrostStreamChunk Err chan schemas.BifrostError } // Bifrost manages providers and maintains specified open channels for concurrent processing. // It handles request routing, provider management, and response processing. type Bifrost struct { ctx *schemas.BifrostContext cancel context.CancelFunc account schemas.Account // account interface llmPlugins atomic.Pointer[[]schemas.LLMPlugin] // list of llm plugins mcpPlugins atomic.Pointer[[]schemas.MCPPlugin] // list of mcp plugins providers atomic.Pointer[[]schemas.Provider] // list of providers requestQueues sync.Map // provider request queues (thread-safe), stores *ProviderQueue waitGroups sync.Map // wait groups for each provider (thread-safe) providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe) channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init pluginPipelinePool sync.Pool // Pool for PluginPipeline objects bifrostRequestPool sync.Pool // Pool for BifrostRequest objects mcpRequestPool sync.Pool // Pool for BifrostMCPRequest objects oauth2Provider schemas.OAuth2Provider // OAuth provider instance logger schemas.Logger // logger instance, default logger is used if not provided tracer atomic.Value // tracer for distributed tracing (stores schemas.Tracer, NoOpTracer if not configured) MCPManager mcp.MCPManagerInterface // MCP integration manager (nil if MCP not configured) mcpInitOnce sync.Once // Ensures MCP manager is initialized only once dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. keySelector schemas.KeySelector // Custom key selector function kvStore schemas.KVStore // optional KV store for session stickiness (nil = disabled) } // ProviderQueue wraps a provider's request channel with lifecycle management // to prevent "send on closed channel" panics during provider removal/update. // Producers must check the closing flag or select on the done channel before sending. // // Why pq.queue is NEVER closed: // // Closing a channel in Go causes any concurrent send to that channel to panic // ("send on closed channel"). There is always a TOCTOU window between a // producer's isClosing() check and its select { case pq.queue <- msg: ... }: // the producer could pass isClosing() while the queue is open, get preempted, // and resume only after the queue is closed. Go's selectgo evaluates select // cases in a random order, so even having case <-pq.done: in the same select // does not protect against this — if selectgo evaluates the send case first on // a closed channel it panics immediately via goto sclose, before reaching done. // // To close pq.queue safely you would need a sender-side WaitGroup so that // signalClosing could wait for every in-flight producer to finish. That adds // non-trivial overhead on the hot request path. // // Instead, pq.done is the sole shutdown signal. Receiving from a closed channel // is always safe (returns the zero value immediately), so: // - Workers exit via case <-pq.done: — safe // - Producers bail via case <-pq.done: — safe // - drainQueueWithErrors handles any messages that slip through the TOCTOU window // // pq.queue is garbage collected automatically: // - RemoveProvider calls requestQueues.Delete, dropping the map's reference. // - UpdateProvider calls requestQueues.Store with a new queue, dropping the // map's reference to oldPq. Shutdown does not Delete at all — the whole // Bifrost instance is torn down. // In all cases, once no producer goroutine holds a reference to the // ProviderQueue, both the struct and pq.queue are eligible for GC. // No explicit close is needed. type ProviderQueue struct { queue chan *ChannelMessage // the actual request queue channel — never closed, see above done chan struct{} // closed by signalClosing() to signal shutdown; never written to otherwise closing uint32 // atomic: 0 = open, 1 = closing signalOnce sync.Once } func isLargePayloadPassthrough(ctx *schemas.BifrostContext) bool { if ctx == nil { return false } // Large payload mode intentionally skips JSON->Bifrost input materialization. // Example: a 400MB multipart/audio upload sets Input=nil by design; strict // non-nil validation here would reject valid passthrough requests. isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool) if !isLargePayload { return false } // Verify reader is present (flag and reader are always set together by middleware) reader := ctx.Value(schemas.BifrostContextKeyLargePayloadReader) return reader != nil } // signalClosing signals the closing of the provider queue. // This is lock-free: uses atomic store and sync.Once to safely signal shutdown. func (pq *ProviderQueue) signalClosing() { pq.signalOnce.Do(func() { atomic.StoreUint32(&pq.closing, 1) close(pq.done) }) } // isClosing returns true if the provider queue is closing. // Uses atomic load for lock-free checking. func (pq *ProviderQueue) isClosing() bool { return atomic.LoadUint32(&pq.closing) == 1 } // PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation. type PluginPipeline struct { llmPlugins []schemas.LLMPlugin mcpPlugins []schemas.MCPPlugin logger schemas.Logger tracer schemas.Tracer // Number of PreHooks that were executed (used to determine which PostHooks to run in reverse order) executedPreHooks int // Errors from PreHooks and PostHooks preHookErrors []error postHookErrors []error // streamingMu guards the streaming post-hook accumulators below. Per-chunk // writes (accumulatePluginTiming) run in the provider goroutine while the // end-of-stream finalizer (FinalizeStreamingPostHookSpans) and // resetPluginPipeline can run in a different goroutine, so unsynchronised // access triggers "concurrent map read and map write" panics. streamingMu sync.Mutex postHookTimings map[string]*pluginTimingAccumulator // keyed by plugin name postHookPluginOrder []string // order in which post-hooks ran (for nested span creation) chunkCount int // Plugin logging: cached scoped contexts for streaming post-hooks (reused across chunks) streamScopedCtxs map[string]*schemas.BifrostContext } // pluginTimingAccumulator accumulates timing information for a plugin across streaming chunks type pluginTimingAccumulator struct { totalDuration time.Duration invocations int errors int } // tracerWrapper wraps a Tracer to ensure atomic.Value stores consistent types. // This is necessary because atomic.Value.Store() panics if called with values // of different concrete types, even if they implement the same interface. type tracerWrapper struct { tracer schemas.Tracer } // INITIALIZATION // Init initializes a new Bifrost instance with the given configuration. // It sets up the account, plugins, object pools, and initializes providers. // Returns an error if initialization fails. // Initial Memory Allocations happens here as per the initial pool size. func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { if config.Account == nil { return nil, fmt.Errorf("account is required to initialize Bifrost") } if config.Logger == nil { config.Logger = NewDefaultLogger(schemas.LogLevelInfo) } providerUtils.SetLogger(config.Logger) // Initialize tracer (use NoOpTracer if not provided) tracer := config.Tracer if tracer == nil { tracer = schemas.DefaultTracer() } bifrostCtx, cancel := schemas.NewBifrostContextWithCancel(ctx) bifrost := &Bifrost{ ctx: bifrostCtx, cancel: cancel, account: config.Account, llmPlugins: atomic.Pointer[[]schemas.LLMPlugin]{}, mcpPlugins: atomic.Pointer[[]schemas.MCPPlugin]{}, requestQueues: sync.Map{}, waitGroups: sync.Map{}, keySelector: config.KeySelector, oauth2Provider: config.OAuth2Provider, logger: config.Logger, kvStore: config.KVStore, } bifrost.tracer.Store(&tracerWrapper{tracer: tracer}) if config.LLMPlugins == nil { config.LLMPlugins = make([]schemas.LLMPlugin, 0) } if config.MCPPlugins == nil { config.MCPPlugins = make([]schemas.MCPPlugin, 0) } bifrost.llmPlugins.Store(&config.LLMPlugins) bifrost.mcpPlugins.Store(&config.MCPPlugins) // Initialize providers slice bifrost.providers.Store(&[]schemas.Provider{}) bifrost.dropExcessRequests.Store(config.DropExcessRequests) if bifrost.keySelector == nil { bifrost.keySelector = keyselectors.WeightedRandom } // Initialize object pools bifrost.channelMessagePool = sync.Pool{ New: func() interface{} { return &ChannelMessage{} }, } bifrost.responseChannelPool = sync.Pool{ New: func() interface{} { return make(chan *schemas.BifrostResponse, 1) }, } bifrost.errorChannelPool = sync.Pool{ New: func() interface{} { return make(chan schemas.BifrostError, 1) }, } bifrost.responseStreamPool = sync.Pool{ New: func() interface{} { return make(chan chan *schemas.BifrostStreamChunk, 1) }, } bifrost.pluginPipelinePool = sync.Pool{ New: func() interface{} { return &PluginPipeline{ preHookErrors: make([]error, 0), postHookErrors: make([]error, 0), } }, } bifrost.bifrostRequestPool = sync.Pool{ New: func() interface{} { return &schemas.BifrostRequest{} }, } bifrost.mcpRequestPool = sync.Pool{ New: func() interface{} { return &schemas.BifrostMCPRequest{} }, } // Prewarm pools with multiple objects for range config.InitialPoolSize { // Create and put new objects directly into pools bifrost.channelMessagePool.Put(&ChannelMessage{}) bifrost.responseChannelPool.Put(make(chan *schemas.BifrostResponse, 1)) bifrost.errorChannelPool.Put(make(chan schemas.BifrostError, 1)) bifrost.responseStreamPool.Put(make(chan chan *schemas.BifrostStreamChunk, 1)) bifrost.pluginPipelinePool.Put(&PluginPipeline{ preHookErrors: make([]error, 0), postHookErrors: make([]error, 0), }) bifrost.bifrostRequestPool.Put(&schemas.BifrostRequest{}) bifrost.mcpRequestPool.Put(&schemas.BifrostMCPRequest{}) } providerKeys, err := bifrost.account.GetConfiguredProviders() if err != nil { return nil, err } // Initialize MCP manager if configured if config.MCPConfig != nil { bifrost.mcpInitOnce.Do(func() { // Set up plugin pipeline provider functions for executeCode tool hooks mcpConfig := *config.MCPConfig mcpConfig.PluginPipelineProvider = func() interface{} { return bifrost.getPluginPipeline() } mcpConfig.ReleasePluginPipeline = func(pipeline interface{}) { if pp, ok := pipeline.(*PluginPipeline); ok { bifrost.releasePluginPipeline(pp) } } // Create Starlark CodeMode for code execution var codeModeConfig *mcp.CodeModeConfig if mcpConfig.ToolManagerConfig != nil { codeModeConfig = &mcp.CodeModeConfig{ BindingLevel: mcpConfig.ToolManagerConfig.CodeModeBindingLevel, ToolExecutionTimeout: mcpConfig.ToolManagerConfig.ToolExecutionTimeout, } } codeMode := starlark.NewStarlarkCodeMode(codeModeConfig, bifrost.logger) bifrost.MCPManager = mcp.NewMCPManager(bifrostCtx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode) bifrost.logger.Info("MCP integration initialized successfully") }) } // Create buffered channels for each provider and start workers for _, providerKey := range providerKeys { if strings.TrimSpace(string(providerKey)) == "" { bifrost.logger.Warn("provider key is empty, skipping init") continue } config, err := bifrost.account.GetConfigForProvider(providerKey) if err != nil { bifrost.logger.Warn("failed to get config for provider, skipping init: %v", err) continue } if config == nil { bifrost.logger.Warn("config is nil for provider %s, skipping init", providerKey) continue } // Lock the provider mutex during initialization providerMutex := bifrost.getProviderMutex(providerKey) providerMutex.Lock() err = bifrost.prepareProvider(providerKey, config) providerMutex.Unlock() if err != nil { bifrost.logger.Warn("failed to prepare provider %s: %v", providerKey, err) } } return bifrost, nil } // SetTracer sets the tracer for the Bifrost instance. func (bifrost *Bifrost) SetTracer(tracer schemas.Tracer) { if tracer == nil { // Fall back to no-op tracer if not provided tracer = schemas.DefaultTracer() } bifrost.tracer.Store(&tracerWrapper{tracer: tracer}) } // getTracer returns the tracer from atomic storage with type assertion. func (bifrost *Bifrost) getTracer() schemas.Tracer { return bifrost.tracer.Load().(*tracerWrapper).tracer } // ReloadConfig reloads the config from DB // Currently we update account, drop excess requests, and plugin lists // We will keep on adding other aspects as required func (bifrost *Bifrost) ReloadConfig(config schemas.BifrostConfig) error { bifrost.dropExcessRequests.Store(config.DropExcessRequests) return nil } // PUBLIC API METHODS // ListModelsRequest sends a list models request to the specified provider. func (bifrost *Bifrost) ListModelsRequest(ctx *schemas.BifrostContext, req *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "list models request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ListModelsRequest, }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for list models request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ListModelsRequest, }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ListModelsRequest bifrostReq.ListModelsRequest = req resp, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return resp.ListModelsResponse, nil } // ListAllModels lists all models from all configured providers. // It accumulates responses from all providers with a limit of 1000 per provider to get all results. func (bifrost *Bifrost) ListAllModels(ctx *schemas.BifrostContext, req *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if req == nil { req = &schemas.BifrostListModelsRequest{} } if ctx == nil { ctx = bifrost.ctx } providerKeys, err := bifrost.GetConfiguredProviders() if err != nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: err.Error(), Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ListModelsRequest, }, } } startTime := time.Now() // Result structure for collecting provider responses type providerResult struct { provider schemas.ModelProvider models []schemas.Model keyStatuses []schemas.KeyStatus err *schemas.BifrostError } results := make(chan providerResult, len(providerKeys)) var wg sync.WaitGroup // Launch concurrent requests for all providers for _, providerKey := range providerKeys { if strings.TrimSpace(string(providerKey)) == "" { continue } wg.Add(1) go func(providerKey schemas.ModelProvider) { defer wg.Done() providerCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) providerCtx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String()) providerModels := make([]schemas.Model, 0) var providerKeyStatuses []schemas.KeyStatus var providerErr *schemas.BifrostError // Create request for this provider with limit of 1000 providerRequest := &schemas.BifrostListModelsRequest{ Provider: providerKey, PageSize: schemas.DefaultPageSize, Unfiltered: req.Unfiltered, } iterations := 0 for { // check for context cancellation select { case <-ctx.Done(): bifrost.logger.Warn("context cancelled for provider %s", providerKey) return default: } iterations++ if iterations > schemas.MaxPaginationRequests { bifrost.logger.Warn("reached maximum pagination requests (%d) for provider %s, please increase the page size", schemas.MaxPaginationRequests, providerKey) break } response, bifrostErr := bifrost.ListModelsRequest(providerCtx, providerRequest) if bifrostErr != nil { // Skip logging "no keys found" and "not supported" errors as they are expected when a provider is not configured if !strings.Contains(bifrostErr.Error.Message, "no keys found") && !strings.Contains(bifrostErr.Error.Message, "not supported") { providerErr = bifrostErr bifrost.logger.Warn("failed to list models for provider %s: %s", providerKey, GetErrorMessage(bifrostErr)) } // Collect key statuses from error (failure case) if len(bifrostErr.ExtraFields.KeyStatuses) > 0 { providerKeyStatuses = append(providerKeyStatuses, bifrostErr.ExtraFields.KeyStatuses...) } break } if response == nil || len(response.Data) == 0 { break } providerModels = append(providerModels, response.Data...) if len(response.KeyStatuses) > 0 { providerKeyStatuses = append(providerKeyStatuses, response.KeyStatuses...) } // Check if there are more pages if response.NextPageToken == "" { break } // Set the page token for the next request providerRequest.PageToken = response.NextPageToken } results <- providerResult{ provider: providerKey, models: providerModels, keyStatuses: providerKeyStatuses, err: providerErr, } }(providerKey) } // Wait for all goroutines to complete wg.Wait() close(results) // Accumulate all models and key statuses from all providers allModels := make([]schemas.Model, 0) allKeyStatuses := make([]schemas.KeyStatus, 0) var firstError *schemas.BifrostError for result := range results { if len(result.models) > 0 { allModels = append(allModels, result.models...) } if len(result.keyStatuses) > 0 { allKeyStatuses = append(allKeyStatuses, result.keyStatuses...) } if result.err != nil && firstError == nil { firstError = result.err } } // If we couldn't get any models from any provider, return the first error if len(allModels) == 0 && firstError != nil { // Attach all key statuses to the error firstError.ExtraFields.KeyStatuses = allKeyStatuses return nil, firstError } // Sort models alphabetically by ID sort.Slice(allModels, func(i, j int) bool { return allModels[i].ID < allModels[j].ID }) // Return aggregated response with accumulated latency and key statuses response := &schemas.BifrostListModelsResponse{ Data: allModels, KeyStatuses: allKeyStatuses, ExtraFields: schemas.BifrostResponseExtraFields{ RequestType: schemas.ListModelsRequest, Latency: time.Since(startTime).Milliseconds(), }, } response = response.ApplyPagination(req.PageSize, req.PageToken) return response, nil } // TextCompletionRequest sends a text completion request to the specified provider. func (bifrost *Bifrost) TextCompletionRequest(ctx *schemas.BifrostContext, req *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "text completion request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.TextCompletionRequest, }, } } if (req.Input == nil || (req.Input.PromptStr == nil && req.Input.PromptArray == nil)) && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "prompt not provided for text completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.TextCompletionRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } // Preparing request bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.TextCompletionRequest bifrostReq.TextCompletionRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } // TODO: Release the response return response.TextCompletionResponse, nil } // TextCompletionStreamRequest sends a streaming text completion request to the specified provider. func (bifrost *Bifrost) TextCompletionStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "text completion stream request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.TextCompletionStreamRequest, }, } } if (req.Input == nil || (req.Input.PromptStr == nil && req.Input.PromptArray == nil)) && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "text not provided for text completion stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.TextCompletionStreamRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.TextCompletionStreamRequest bifrostReq.TextCompletionRequest = req return bifrost.handleStreamRequest(ctx, bifrostReq) } func (bifrost *Bifrost) makeChatCompletionRequest(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "chat completion request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ChatCompletionRequest, }, } } if req.Input == nil && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "chats not provided for chat completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ChatCompletionRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ChatCompletionRequest bifrostReq.ChatRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.ChatResponse, nil } // ChatCompletionRequest sends a chat completion request to the specified provider. func (bifrost *Bifrost) ChatCompletionRequest(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { // If ctx is nil, use the bifrost context (defensive check for mcp agent mode) if ctx == nil { ctx = bifrost.ctx } response, err := bifrost.makeChatCompletionRequest(ctx, req) if err != nil { return nil, err } // Check if we should enter agent mode if bifrost.MCPManager != nil { return bifrost.MCPManager.CheckAndExecuteAgentForChatRequest( ctx, req, response, bifrost.makeChatCompletionRequest, bifrost.executeMCPToolWithHooks, ) } return response, nil } // ChatCompletionStreamRequest sends a chat completion stream request to the specified provider. func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "chat completion stream request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ChatCompletionStreamRequest, }, } } if req.Input == nil && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "chats not provided for chat completion request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ChatCompletionStreamRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ChatCompletionStreamRequest bifrostReq.ChatRequest = req return bifrost.handleStreamRequest(ctx, bifrostReq) } func (bifrost *Bifrost) makeResponsesRequest(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "responses request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ResponsesRequest, }, } } // In large payload mode, Input is intentionally nil — body streams directly to upstream if req.Input == nil { isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool) if !isLargePayload { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "responses not provided for responses request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ResponsesRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ResponsesRequest bifrostReq.ResponsesRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.ResponsesResponse, nil } // ResponsesRequest sends a responses request to the specified provider. func (bifrost *Bifrost) ResponsesRequest(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) { // If ctx is nil, use the bifrost context (defensive check for mcp agent mode) if ctx == nil { ctx = bifrost.ctx } response, err := bifrost.makeResponsesRequest(ctx, req) if err != nil { return nil, err } // Check if we should enter agent mode if bifrost.MCPManager != nil { return bifrost.MCPManager.CheckAndExecuteAgentForResponsesRequest( ctx, req, response, bifrost.makeResponsesRequest, bifrost.executeMCPToolWithHooks, ) } return response, nil } // ResponsesStreamRequest sends a responses stream request to the specified provider. func (bifrost *Bifrost) ResponsesStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "responses stream request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ResponsesStreamRequest, }, } } // In large payload mode, Input is intentionally nil — body streams directly to upstream if req.Input == nil { isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool) if !isLargePayload { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "responses not provided for responses stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ResponsesStreamRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ResponsesStreamRequest bifrostReq.ResponsesRequest = req return bifrost.handleStreamRequest(ctx, bifrostReq) } // CountTokensRequest sends a count tokens request to the specified provider. func (bifrost *Bifrost) CountTokensRequest(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "count tokens request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.CountTokensRequest, }, } } if req.Input == nil && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "input not provided for count tokens request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.CountTokensRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.CountTokensRequest bifrostReq.CountTokensRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.CountTokensResponse, nil } // EmbeddingRequest sends an embedding request to the specified provider. func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "embedding request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.EmbeddingRequest, }, } } hasExtraInputs := req.Params != nil && req.Params.ExtraParams != nil && (req.Params.ExtraParams["inputs"] != nil || req.Params.ExtraParams["images"] != nil) if (req.Input == nil || (req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil)) && !hasExtraInputs && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "embedding input not provided for embedding request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.EmbeddingRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.EmbeddingRequest bifrostReq.EmbeddingRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } // TODO: Release the response return response.EmbeddingResponse, nil } // RerankRequest sends a rerank request to the specified provider. func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "rerank request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.RerankRequest, }, } } if strings.TrimSpace(req.Query) == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "query not provided for rerank request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.RerankRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } if len(req.Documents) == 0 { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "documents not provided for rerank request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.RerankRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } for i, doc := range req.Documents { if strings.TrimSpace(doc.Text) == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: fmt.Sprintf("document text is empty at index %d", i), }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.RerankRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.RerankRequest bifrostReq.RerankRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.RerankResponse, nil } // OCRRequest sends an OCR request to the specified provider. func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "ocr request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.OCRRequest, }, } } if strings.TrimSpace(string(req.Document.Type)) == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "document type not provided for ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.OCRRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } if req.Document.Type == schemas.OCRDocumentTypeDocumentURL && (req.Document.DocumentURL == nil || strings.TrimSpace(*req.Document.DocumentURL) == "") { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "document_url not provided for document_url type ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.OCRRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } if req.Document.Type == schemas.OCRDocumentTypeImageURL && (req.Document.ImageURL == nil || strings.TrimSpace(*req.Document.ImageURL) == "") { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "image_url not provided for image_url type ocr request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.OCRRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.OCRRequest bifrostReq.OCRRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.OCRResponse, nil } // SpeechRequest sends a speech request to the specified provider. func (bifrost *Bifrost) SpeechRequest(ctx *schemas.BifrostContext, req *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "speech request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.SpeechRequest, }, } } if (req.Input == nil || req.Input.Input == "") && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "speech input not provided for speech request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.SpeechRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.SpeechRequest bifrostReq.SpeechRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } // TODO: Release the response return response.SpeechResponse, nil } // SpeechStreamRequest sends a speech stream request to the specified provider. func (bifrost *Bifrost) SpeechStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "speech stream request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.SpeechStreamRequest, }, } } if (req.Input == nil || req.Input.Input == "") && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "speech input not provided for speech stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.SpeechStreamRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.SpeechStreamRequest bifrostReq.SpeechRequest = req return bifrost.handleStreamRequest(ctx, bifrostReq) } // TranscriptionRequest sends a transcription request to the specified provider. func (bifrost *Bifrost) TranscriptionRequest(ctx *schemas.BifrostContext, req *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "transcription request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.TranscriptionRequest, }, } } if (req.Input == nil || req.Input.File == nil) && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "transcription input not provided for transcription request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.TranscriptionRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.TranscriptionRequest bifrostReq.TranscriptionRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } // TODO: Release the response return response.TranscriptionResponse, nil } // TranscriptionStreamRequest sends a transcription stream request to the specified provider. func (bifrost *Bifrost) TranscriptionStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "transcription stream request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.TranscriptionStreamRequest, }, } } if (req.Input == nil || req.Input.File == nil) && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "transcription input not provided for transcription stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.TranscriptionStreamRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.TranscriptionStreamRequest bifrostReq.TranscriptionRequest = req return bifrost.handleStreamRequest(ctx, bifrostReq) } // ImageGenerationRequest sends an image generation request to the specified provider. func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext, req *schemas.BifrostImageGenerationRequest, ) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "image generation request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageGenerationRequest, }, } } if (req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "prompt not provided for image generation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageGenerationRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ImageGenerationRequest bifrostReq.ImageGenerationRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } if response == nil || response.ImageGenerationResponse == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageGenerationRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } return response.ImageGenerationResponse, nil } // ImageGenerationStreamRequest sends an image generation stream request to the specified provider. func (bifrost *Bifrost) ImageGenerationStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostImageGenerationRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "image generation stream request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageGenerationStreamRequest, }, } } if (req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "prompt not provided for image generation stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageGenerationStreamRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ImageGenerationStreamRequest bifrostReq.ImageGenerationRequest = req return bifrost.handleStreamRequest(ctx, bifrostReq) } // ImageEditRequest sends an image edit request to the specified provider. func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "image edit request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageEditRequest, }, } } if (req.Input == nil || req.Input.Images == nil || len(req.Input.Images) == 0) && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "images not provided for image edit request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageEditRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } // Prompt is not required for certain operation types that work without a text prompt var imageEditParamsType *string if req.Params != nil { imageEditParamsType = req.Params.Type } if !isPromptOptionalImageEditType(imageEditParamsType) && (req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "prompt not provided for image edit request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageEditRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ImageEditRequest bifrostReq.ImageEditRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } if response == nil || response.ImageGenerationResponse == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageEditRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } return response.ImageGenerationResponse, nil } // ImageEditStreamRequest sends an image edit stream request to the specified provider. func (bifrost *Bifrost) ImageEditStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "image edit stream request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageEditStreamRequest, }, } } if (req.Input == nil || req.Input.Images == nil || len(req.Input.Images) == 0) && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "images not provided for image edit stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageEditStreamRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } // Prompt is not required for certain operation types that work without a text prompt var imageEditStreamParamsType *string if req.Params != nil { imageEditStreamParamsType = req.Params.Type } if !isPromptOptionalImageEditType(imageEditStreamParamsType) && (req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "prompt not provided for image edit stream request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageEditStreamRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ImageEditStreamRequest bifrostReq.ImageEditRequest = req return bifrost.handleStreamRequest(ctx, bifrostReq) } // ImageVariationRequest sends an image variation request to the specified provider. func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "image variation request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageVariationRequest, }, } } if (req.Input == nil || req.Input.Image.Image == nil || len(req.Input.Image.Image) == 0) && !isLargePayloadPassthrough(ctx) { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "image not provided for image variation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageVariationRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ImageVariationRequest bifrostReq.ImageVariationRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } if response == nil || response.ImageGenerationResponse == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ImageVariationRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } return response.ImageGenerationResponse, nil } // VideoGenerationRequest sends a video generation request to the specified provider. func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext, req *schemas.BifrostVideoGenerationRequest, ) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "video generation request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoGenerationRequest, }, } } if req.Input == nil || req.Input.Prompt == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "prompt not provided for video generation request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoGenerationRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.VideoGenerationRequest bifrostReq.VideoGenerationRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } if response == nil || response.VideoGenerationResponse == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoGenerationRequest, Provider: req.Provider, OriginalModelRequested: req.Model, ResolvedModelUsed: req.Model, }, } } return response.VideoGenerationResponse, nil } func (bifrost *Bifrost) VideoRetrieveRequest(ctx *schemas.BifrostContext, req *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "video retrieve request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoRetrieveRequest, }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for video retrieve request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoRetrieveRequest, }, } } if req.ID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "video_id is required for video retrieve request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoRetrieveRequest, Provider: req.Provider, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.VideoRetrieveRequest bifrostReq.VideoRetrieveRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } if response == nil || response.VideoGenerationResponse == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoRetrieveRequest, Provider: req.Provider, }, } } return response.VideoGenerationResponse, nil } // VideoDownloadRequest downloads video content from the provider. func (bifrost *Bifrost) VideoDownloadRequest(ctx *schemas.BifrostContext, req *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "video download request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoDownloadRequest, }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for video download request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoDownloadRequest, }, } } if req.ID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "video_id is required for video download request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoDownloadRequest, Provider: req.Provider, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.VideoDownloadRequest bifrostReq.VideoDownloadRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.VideoDownloadResponse, nil } func (bifrost *Bifrost) VideoRemixRequest(ctx *schemas.BifrostContext, req *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "video remix request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoRemixRequest, }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for video remix request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoRemixRequest, }, } } if req.ID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "video_id is required for video remix request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoRemixRequest, Provider: req.Provider, }, } } if req.Input == nil || req.Input.Prompt == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "prompt is required for video remix request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoRemixRequest, Provider: req.Provider, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.VideoRemixRequest bifrostReq.VideoRemixRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } if response == nil || response.VideoGenerationResponse == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "received nil response from provider", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoRemixRequest, Provider: req.Provider, }, } } return response.VideoGenerationResponse, nil } func (bifrost *Bifrost) VideoListRequest(ctx *schemas.BifrostContext, req *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "video list request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoListRequest, }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for video list request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoListRequest, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.VideoListRequest bifrostReq.VideoListRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.VideoListResponse, nil } func (bifrost *Bifrost) VideoDeleteRequest(ctx *schemas.BifrostContext, req *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "video delete request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoDeleteRequest, }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for video delete request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoDeleteRequest, }, } } if req.ID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "video_id is required for video delete request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.VideoDeleteRequest, Provider: req.Provider, }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.VideoDeleteRequest bifrostReq.VideoDeleteRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.VideoDeleteResponse, nil } // BatchCreateRequest creates a new batch job for asynchronous processing. func (bifrost *Bifrost) BatchCreateRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "batch create request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for batch create request", }, } } if req.InputFileID == "" && len(req.Requests) == 0 { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "either input_file_id or requests is required for batch create request", }, } } if ctx == nil { ctx = bifrost.ctx } provider := bifrost.getProviderByKey(req.Provider) if provider == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider not found for batch create request", }, } } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.BatchCreateRequest bifrostReq.BatchCreateRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.BatchCreateResponse, nil } // BatchListRequest lists batch jobs for the specified provider. func (bifrost *Bifrost) BatchListRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "batch list request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for batch list request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.BatchListRequest bifrostReq.BatchListRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.BatchListResponse, nil } // BatchRetrieveRequest retrieves a specific batch job. func (bifrost *Bifrost) BatchRetrieveRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "batch retrieve request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for batch retrieve request", }, } } if req.BatchID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "batch_id is required for batch retrieve request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.BatchRetrieveRequest bifrostReq.BatchRetrieveRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.BatchRetrieveResponse, nil } // BatchCancelRequest cancels a batch job. func (bifrost *Bifrost) BatchCancelRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "batch cancel request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for batch cancel request", }, } } if req.BatchID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "batch_id is required for batch cancel request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.BatchCancelRequest bifrostReq.BatchCancelRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.BatchCancelResponse, nil } // BatchDeleteRequest deletes a batch job. func (bifrost *Bifrost) BatchDeleteRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "batch delete request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for batch delete request", }, } } if req.BatchID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "batch_id is required for batch delete request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.BatchDeleteRequest bifrostReq.BatchDeleteRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.BatchDeleteResponse, nil } // BatchResultsRequest retrieves results from a completed batch job. func (bifrost *Bifrost) BatchResultsRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "batch results request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.BatchResultsRequest, }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for batch results request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.BatchResultsRequest, }, } } if req.BatchID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "batch_id is required for batch results request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.BatchResultsRequest, Provider: req.Provider, }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.BatchResultsRequest bifrostReq.BatchResultsRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.BatchResultsResponse, nil } // FileUploadRequest uploads a file to the specified provider. func (bifrost *Bifrost) FileUploadRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "file upload request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.FileUploadRequest, }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for file upload request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.FileUploadRequest, }, } } if len(req.File) == 0 { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "file content is required for file upload request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.FileUploadRequest, Provider: req.Provider, }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.FileUploadRequest bifrostReq.FileUploadRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.FileUploadResponse, nil } // FileListRequest lists files from the specified provider. func (bifrost *Bifrost) FileListRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "file list request is nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.FileListRequest, }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for file list request", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.FileListRequest, }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.FileListRequest bifrostReq.FileListRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.FileListResponse, nil } // FileRetrieveRequest retrieves file metadata from the specified provider. func (bifrost *Bifrost) FileRetrieveRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "file retrieve request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for file retrieve request", }, } } if req.FileID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "file_id is required for file retrieve request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.FileRetrieveRequest bifrostReq.FileRetrieveRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.FileRetrieveResponse, nil } // FileDeleteRequest deletes a file from the specified provider. func (bifrost *Bifrost) FileDeleteRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "file delete request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for file delete request", }, } } if req.FileID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "file_id is required for file delete request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.FileDeleteRequest bifrostReq.FileDeleteRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.FileDeleteResponse, nil } // FileContentRequest downloads file content from the specified provider. func (bifrost *Bifrost) FileContentRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "file content request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for file content request", }, } } if req.FileID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "file_id is required for file content request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.FileContentRequest bifrostReq.FileContentRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.FileContentResponse, nil } func (bifrost *Bifrost) Passthrough( ctx *schemas.BifrostContext, provider schemas.ModelProvider, req *schemas.BifrostPassthroughRequest, ) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { if req == nil { sc := fasthttp.StatusBadRequest return nil, &schemas.BifrostError{ IsBifrostError: false, StatusCode: &sc, Error: &schemas.ErrorField{Message: "passthrough request is nil"}, } } req.Provider = provider bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.PassthroughRequest bifrostReq.PassthroughRequest = req resp, bifrostErr := bifrost.handleRequest(ctx, bifrostReq) if bifrostErr != nil { return nil, bifrostErr } if resp == nil || resp.PassthroughResponse == nil { sc := fasthttp.StatusBadGateway return nil, &schemas.BifrostError{ IsBifrostError: false, StatusCode: &sc, Error: &schemas.ErrorField{Message: "provider returned nil passthrough response"}, } } return resp.PassthroughResponse, nil } func (bifrost *Bifrost) PassthroughStream( ctx *schemas.BifrostContext, provider schemas.ModelProvider, req *schemas.BifrostPassthroughRequest, ) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { if req == nil { sc := fasthttp.StatusBadRequest return nil, &schemas.BifrostError{ IsBifrostError: false, StatusCode: &sc, Error: &schemas.ErrorField{Message: "passthrough request is nil"}, } } req.Provider = provider bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.PassthroughStreamRequest bifrostReq.PassthroughRequest = req return bifrost.handleStreamRequest(ctx, bifrostReq) } // ExecuteChatMCPTool executes an MCP tool call and returns the result as a chat message. // This is the main public API for manual MCP tool execution in Chat format. // // Parameters: // - ctx: Execution context // - toolCall: The tool call to execute (from assistant message) // // Returns: // - *schemas.ChatMessage: Tool message with execution result // - *schemas.BifrostError: Any execution error func (bifrost *Bifrost) ExecuteChatMCPTool(ctx *schemas.BifrostContext, toolCall *schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) { // Handle nil context early to prevent issues downstream if ctx == nil { ctx = bifrost.ctx } // Validate toolCall is not nil if toolCall == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "toolCall cannot be nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ChatCompletionRequest, }, } } // Get MCP request from pool and populate mcpRequest := bifrost.getMCPRequest() mcpRequest.RequestType = schemas.MCPRequestTypeChatToolCall mcpRequest.ChatAssistantMessageToolCall = toolCall defer bifrost.releaseMCPRequest(mcpRequest) // Execute with common handler result, err := bifrost.handleMCPToolExecution(ctx, mcpRequest, schemas.ChatCompletionRequest) if err != nil { return nil, err } // Validate and extract chat message from result if result == nil || result.ChatMessage == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "MCP tool execution returned nil chat message", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ChatCompletionRequest, }, } } return result.ChatMessage, nil } // ExecuteResponsesMCPTool executes an MCP tool call and returns the result as a responses message. // This is the main public API for manual MCP tool execution in Responses format. // // Parameters: // - ctx: Execution context // - toolCall: The tool call to execute (from assistant message) // // Returns: // - *schemas.ResponsesMessage: Tool message with execution result // - *schemas.BifrostError: Any execution error func (bifrost *Bifrost) ExecuteResponsesMCPTool(ctx *schemas.BifrostContext, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError) { // Handle nil context early to prevent issues downstream if ctx == nil { ctx = bifrost.ctx } // Validate toolCall is not nil if toolCall == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "toolCall cannot be nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ResponsesRequest, }, } } // Get MCP request from pool and populate mcpRequest := bifrost.getMCPRequest() mcpRequest.RequestType = schemas.MCPRequestTypeResponsesToolCall mcpRequest.ResponsesToolMessage = toolCall defer bifrost.releaseMCPRequest(mcpRequest) // Execute with common handler result, err := bifrost.handleMCPToolExecution(ctx, mcpRequest, schemas.ResponsesRequest) if err != nil { return nil, err } // Validate and extract responses message from result if result == nil || result.ResponsesMessage == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "MCP tool execution returned nil responses message", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: schemas.ResponsesRequest, }, } } return result.ResponsesMessage, nil } // ContainerCreateRequest creates a new container. func (bifrost *Bifrost) ContainerCreateRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container create request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for container create request", }, } } if req.Name == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "name is required for container create request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ContainerCreateRequest bifrostReq.ContainerCreateRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.ContainerCreateResponse, nil } // ContainerListRequest lists containers. func (bifrost *Bifrost) ContainerListRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container list request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for container list request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ContainerListRequest bifrostReq.ContainerListRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.ContainerListResponse, nil } // ContainerRetrieveRequest retrieves a specific container. func (bifrost *Bifrost) ContainerRetrieveRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container retrieve request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for container retrieve request", }, } } if req.ContainerID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container_id is required for container retrieve request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ContainerRetrieveRequest bifrostReq.ContainerRetrieveRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.ContainerRetrieveResponse, nil } // ContainerDeleteRequest deletes a container. func (bifrost *Bifrost) ContainerDeleteRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container delete request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for container delete request", }, } } if req.ContainerID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container_id is required for container delete request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ContainerDeleteRequest bifrostReq.ContainerDeleteRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.ContainerDeleteResponse, nil } // ContainerFileCreateRequest creates a file in a container. func (bifrost *Bifrost) ContainerFileCreateRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container file create request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for container file create request", }, } } if req.ContainerID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container_id is required for container file create request", }, } } if len(req.File) == 0 && (req.FileID == nil || strings.TrimSpace(*req.FileID) == "") && (req.Path == nil || strings.TrimSpace(*req.Path) == "") { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "one of file, file_id, or path is required for container file create request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ContainerFileCreateRequest bifrostReq.ContainerFileCreateRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.ContainerFileCreateResponse, nil } // ContainerFileListRequest lists files in a container. func (bifrost *Bifrost) ContainerFileListRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container file list request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for container file list request", }, } } if req.ContainerID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container_id is required for container file list request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ContainerFileListRequest bifrostReq.ContainerFileListRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.ContainerFileListResponse, nil } // ContainerFileRetrieveRequest retrieves a file from a container. func (bifrost *Bifrost) ContainerFileRetrieveRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container file retrieve request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for container file retrieve request", }, } } if req.ContainerID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container_id is required for container file retrieve request", }, } } if req.FileID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "file_id is required for container file retrieve request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ContainerFileRetrieveRequest bifrostReq.ContainerFileRetrieveRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.ContainerFileRetrieveResponse, nil } // ContainerFileContentRequest retrieves the content of a file from a container. func (bifrost *Bifrost) ContainerFileContentRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container file content request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for container file content request", }, } } if req.ContainerID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container_id is required for container file content request", }, } } if req.FileID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "file_id is required for container file content request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ContainerFileContentRequest bifrostReq.ContainerFileContentRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.ContainerFileContentResponse, nil } // ContainerFileDeleteRequest deletes a file from a container. func (bifrost *Bifrost) ContainerFileDeleteRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) { if req == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container file delete request is nil", }, } } if req.Provider == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is required for container file delete request", }, } } if req.ContainerID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "container_id is required for container file delete request", }, } } if req.FileID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "file_id is required for container file delete request", }, } } if ctx == nil { ctx = bifrost.ctx } bifrostReq := bifrost.getBifrostRequest() bifrostReq.RequestType = schemas.ContainerFileDeleteRequest bifrostReq.ContainerFileDeleteRequest = req response, err := bifrost.handleRequest(ctx, bifrostReq) if err != nil { return nil, err } return response.ContainerFileDeleteResponse, nil } // RemovePlugin removes a plugin from the server. func (bifrost *Bifrost) RemovePlugin(name string, pluginTypes []schemas.PluginType) error { for _, pluginType := range pluginTypes { switch pluginType { case schemas.PluginTypeLLM: err := bifrost.removeLLMPlugin(name) if err != nil { return err } case schemas.PluginTypeMCP: err := bifrost.removeMCPPlugin(name) if err != nil { return err } } } return nil } // removeLLMPlugin removes an LLM plugin from the server. func (bifrost *Bifrost) removeLLMPlugin(name string) error { for { oldPlugins := bifrost.llmPlugins.Load() if oldPlugins == nil { return nil } var pluginToCleanup schemas.LLMPlugin found := false // Create new slice without the plugin to remove newPlugins := make([]schemas.LLMPlugin, 0, len(*oldPlugins)) for _, p := range *oldPlugins { if p.GetName() == name { pluginToCleanup = p bifrost.logger.Debug("removing LLM plugin %s", name) found = true } else { newPlugins = append(newPlugins, p) } } if !found { return nil } // Atomic compare-and-swap if bifrost.llmPlugins.CompareAndSwap(oldPlugins, &newPlugins) { // Cleanup the old plugin err := pluginToCleanup.Cleanup() if err != nil { bifrost.logger.Warn("failed to cleanup old LLM plugin %s: %v", pluginToCleanup.GetName(), err) } return nil } // Retrying as swapping did not work } } // removeMCPPlugin removes an MCP plugin from the server. func (bifrost *Bifrost) removeMCPPlugin(name string) error { for { oldPlugins := bifrost.mcpPlugins.Load() if oldPlugins == nil { return nil } var pluginToCleanup schemas.MCPPlugin found := false // Create new slice without the plugin to remove newPlugins := make([]schemas.MCPPlugin, 0, len(*oldPlugins)) for _, p := range *oldPlugins { if p.GetName() == name { pluginToCleanup = p bifrost.logger.Debug("removing MCP plugin %s", name) found = true } else { newPlugins = append(newPlugins, p) } } if !found { return nil } // Atomic compare-and-swap if bifrost.mcpPlugins.CompareAndSwap(oldPlugins, &newPlugins) { // Cleanup the old plugin err := pluginToCleanup.Cleanup() if err != nil { bifrost.logger.Warn("failed to cleanup old MCP plugin %s: %v", pluginToCleanup.GetName(), err) } return nil } // Retrying as swapping did not work } } // ReloadPlugin reloads a plugin with new instance // During the reload - it's stop the world phase where we take a global lock on the plugin mutex func (bifrost *Bifrost) ReloadPlugin(plugin schemas.BasePlugin, pluginTypes []schemas.PluginType) error { for _, pluginType := range pluginTypes { switch pluginType { case schemas.PluginTypeLLM: llmPlugin, ok := plugin.(schemas.LLMPlugin) if !ok { return fmt.Errorf("plugin %s is not an LLMPlugin", plugin.GetName()) } err := bifrost.reloadLLMPlugin(llmPlugin) if err != nil { return err } case schemas.PluginTypeMCP: mcpPlugin, ok := plugin.(schemas.MCPPlugin) if !ok { return fmt.Errorf("plugin %s is not an MCPPlugin", plugin.GetName()) } err := bifrost.reloadMCPPlugin(mcpPlugin) if err != nil { return err } } } return nil } // reloadLLMPlugin reloads an LLM plugin with new instance func (bifrost *Bifrost) reloadLLMPlugin(plugin schemas.LLMPlugin) error { for { var pluginToCleanup schemas.LLMPlugin found := false oldPlugins := bifrost.llmPlugins.Load() // Create new slice with replaced plugin or initialize empty slice var newPlugins []schemas.LLMPlugin if oldPlugins == nil { // Initialize new empty slice for the first plugin newPlugins = make([]schemas.LLMPlugin, 0) } else { newPlugins = make([]schemas.LLMPlugin, len(*oldPlugins)) copy(newPlugins, *oldPlugins) } for i, p := range newPlugins { if p.GetName() == plugin.GetName() { // Cleaning up old plugin before replacing it pluginToCleanup = p bifrost.logger.Debug("replacing LLM plugin %s with new instance", plugin.GetName()) newPlugins[i] = plugin found = true break } } if !found { // This means that user is adding a new plugin bifrost.logger.Debug("adding new LLM plugin %s", plugin.GetName()) newPlugins = append(newPlugins, plugin) } // Atomic compare-and-swap if bifrost.llmPlugins.CompareAndSwap(oldPlugins, &newPlugins) { // Cleanup the old plugin if found && pluginToCleanup != nil { err := pluginToCleanup.Cleanup() if err != nil { bifrost.logger.Warn("failed to cleanup old LLM plugin %s: %v", pluginToCleanup.GetName(), err) } } return nil } // Retrying as swapping did not work } } // reloadMCPPlugin reloads an MCP plugin with new instance func (bifrost *Bifrost) reloadMCPPlugin(plugin schemas.MCPPlugin) error { for { var pluginToCleanup schemas.MCPPlugin found := false oldPlugins := bifrost.mcpPlugins.Load() if oldPlugins == nil { return nil } // Create new slice with replaced plugin newPlugins := make([]schemas.MCPPlugin, len(*oldPlugins)) copy(newPlugins, *oldPlugins) for i, p := range newPlugins { if p.GetName() == plugin.GetName() { // Cleaning up old plugin before replacing it pluginToCleanup = p bifrost.logger.Debug("replacing MCP plugin %s with new instance", plugin.GetName()) newPlugins[i] = plugin found = true break } } if !found { // This means that user is adding a new plugin bifrost.logger.Debug("adding new MCP plugin %s", plugin.GetName()) newPlugins = append(newPlugins, plugin) } // Atomic compare-and-swap if bifrost.mcpPlugins.CompareAndSwap(oldPlugins, &newPlugins) { // Cleanup the old plugin if found && pluginToCleanup != nil { err := pluginToCleanup.Cleanup() if err != nil { bifrost.logger.Warn("failed to cleanup old MCP plugin %s: %v", pluginToCleanup.GetName(), err) } } return nil } // Retrying as swapping did not work } } // ReorderPlugins reorders all plugin slices (LLM, MCP) to match the given // base plugin name ordering. This should be called after SortAndRebuildPlugins // on the config layer to sync the core's execution order. // Plugins not in the ordering are appended at the end (defensive). func (bifrost *Bifrost) ReorderPlugins(orderedNames []string) { pos := make(map[string]int, len(orderedNames)) for i, name := range orderedNames { pos[name] = i } reorderAtomicSlice(&bifrost.llmPlugins, pos) reorderAtomicSlice(&bifrost.mcpPlugins, pos) } // pluginWithName is satisfied by both LLMPlugin and MCPPlugin. type pluginWithName interface { GetName() string } // reorderAtomicSlice atomically reorders the plugin slice stored behind ptr // so that plugins appear in the order given by pos (name → position). // Uses CAS retry for lock-free safety. func reorderAtomicSlice[T pluginWithName](ptr *atomic.Pointer[[]T], pos map[string]int) { for { old := ptr.Load() if old == nil || len(*old) == 0 { return } reordered := make([]T, len(*old)) copy(reordered, *old) sort.SliceStable(reordered, func(i, j int) bool { iPos, iOk := pos[reordered[i].GetName()] jPos, jOk := pos[reordered[j].GetName()] if !iOk && !jOk { return false } if !iOk { return false } if !jOk { return true } return iPos < jPos }) if ptr.CompareAndSwap(old, &reordered) { return } } } // GetConfiguredProviders returns the configured providers. // // Returns: // - []schemas.ModelProvider: List of configured providers // - error: Any error that occurred during the retrieval process // // Example: // // providers, err := bifrost.GetConfiguredProviders() // if err != nil { // return nil, err // } // fmt.Println(providers) func (bifrost *Bifrost) GetConfiguredProviders() ([]schemas.ModelProvider, error) { providers := bifrost.providers.Load() if providers == nil { return nil, fmt.Errorf("no providers configured") } modelProviders := make([]schemas.ModelProvider, len(*providers)) for i, provider := range *providers { modelProviders[i] = provider.GetProviderKey() } return modelProviders, nil } // RemoveProvider removes a provider from the server. // This method gracefully stops all workers for the provider, // closes the request queue, and removes the provider from the providers slice. // // Parameters: // - providerKey: The provider to remove // // Returns: // - error: Any error that occurred during the removal process func (bifrost *Bifrost) RemoveProvider(providerKey schemas.ModelProvider) error { bifrost.logger.Info("Removing provider %s", providerKey) providerMutex := bifrost.getProviderMutex(providerKey) providerMutex.Lock() defer providerMutex.Unlock() // Step 1: Load the ProviderQueue and verify provider exists pqValue, exists := bifrost.requestQueues.Load(providerKey) if !exists { return fmt.Errorf("provider %s not found in request queues", providerKey) } pq := pqValue.(*ProviderQueue) // Step 2: Signal closing. Blocks new producers (isClosing() returns true) and // causes idle workers to drain remaining buffered requests with errors then exit. pq.signalClosing() bifrost.logger.Debug("signaled closing for provider %s", providerKey) // Step 3: Wait for all workers to finish in-flight requests and exit. waitGroup, exists := bifrost.waitGroups.Load(providerKey) if exists { waitGroup.(*sync.WaitGroup).Wait() bifrost.logger.Debug("all workers for provider %s have stopped", providerKey) } // Step 3b: Final drain sweep — see drainQueueWithErrors for full explanation. bifrost.drainQueueWithErrors(pq) // Step 4: Remove the provider from the request queues. bifrost.requestQueues.Delete(providerKey) // Step 5: Remove the provider from the wait groups. bifrost.waitGroups.Delete(providerKey) // Step 6: Remove the provider from the providers slice. if err := bifrost.removeProviderFromSlice(providerKey); err != nil { bifrost.logger.Error( "provider %s was removed from queues but could not be removed from the providers slice — "+ "bifrost.providers is now inconsistent. "+ "To recover: retry RemoveProvider(%s), or restart Bifrost if that fails.", providerKey, providerKey, ) return err } bifrost.logger.Info("successfully removed provider %s", providerKey) schemas.UnregisterKnownProvider(providerKey) return nil } // UpdateProvider dynamically updates a provider with new configuration. // This method gracefully recreates the provider instance with updated settings, // stops existing workers, creates a new queue with updated settings, // and starts new workers with the updated provider and concurrency configuration. // // Parameters: // - providerKey: The provider to update // // Returns: // - error: Any error that occurred during the update process // // Note: This operation will temporarily pause request processing for the specified provider // while the transition occurs. In-flight requests will complete before workers are stopped. // Buffered requests in the old queue will be transferred to the new queue to prevent loss. // // Concurrency safety — no-worker window: // UpdateProvider holds a per-provider write lock (providerMutex.Lock) for its entire // duration. All producer paths (tryRequest, tryStreamRequest) acquire the corresponding // read lock inside getProviderQueue before they can look up or enqueue into any queue. // This means no producer can observe or enqueue into newPq until UpdateProvider returns // and releases the write lock — at which point new workers are already running and // consuming newPq. There is therefore no window where newPq is visible to producers // but has zero workers. func (bifrost *Bifrost) UpdateProvider(providerKey schemas.ModelProvider) error { bifrost.logger.Info(fmt.Sprintf("Updating provider configuration for provider %s", providerKey)) // Get the updated configuration from the account providerConfig, err := bifrost.account.GetConfigForProvider(providerKey) if err != nil { return fmt.Errorf("failed to get updated config for provider %s: %v", providerKey, err) } if providerConfig == nil { return fmt.Errorf("config is nil for provider %s", providerKey) } // Lock the provider to prevent concurrent access during update providerMutex := bifrost.getProviderMutex(providerKey) providerMutex.Lock() defer providerMutex.Unlock() // Check if provider currently exists oldPqValue, exists := bifrost.requestQueues.Load(providerKey) if !exists { bifrost.logger.Debug("provider %s not currently active, initializing with new configuration", providerKey) // If provider doesn't exist, just prepare it with new configuration return bifrost.prepareProvider(providerKey, providerConfig) } oldPq := oldPqValue.(*ProviderQueue) bifrost.logger.Debug("gracefully stopping existing workers for provider %s", providerKey) // Step 1: Create new ProviderQueue with updated buffer size newPq := &ProviderQueue{ queue: make(chan *ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize), done: make(chan struct{}), signalOnce: sync.Once{}, } // Step 2: Atomically replace the queue so new producers immediately use newPq. bifrost.requestQueues.Store(providerKey, newPq) bifrost.logger.Debug("stored new queue for provider %s, new producers will use it", providerKey) // Step 3: Transfer buffered requests from the old queue to the new queue BEFORE // signalling workers to stop. This ensures buffered requests are processed by the // new workers rather than being drained with errors. // Old workers are still running and may consume some items concurrently — that is // fine, they process them normally. // If newPq is full during transfer, all remaining buffered requests are cancelled // immediately rather than blocking — this avoids the deadlock where transfer goroutines // wait for space that only opens once new workers start (which can't happen until // the transfer completes). transferredCount := 0 cancelledCount := 0 for { select { case msg := <-oldPq.queue: select { case newPq.queue <- msg: transferredCount++ default: // newPq is full — cancel this message and all remaining in oldPq. cancelMsg := func(r *ChannelMessage) { prov, mod, _ := r.BifrostRequest.GetRequestFields() select { case r.Err <- schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{Message: "request failed during provider concurrency update: queue full"}, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: r.RequestType, Provider: prov, OriginalModelRequested: mod, }, }: case <-r.Context.Done(): } } cancelMsg(msg) cancelledCount++ for { select { case r := <-oldPq.queue: cancelMsg(r) cancelledCount++ default: goto transferComplete } } } default: // No more buffered messages goto transferComplete } } transferComplete: if transferredCount > 0 { bifrost.logger.Info("transferred %d buffered requests to new queue for provider %s", transferredCount, providerKey) } if cancelledCount > 0 { bifrost.logger.Warn("cancelled %d buffered requests during transfer for provider %s: new queue was full", cancelledCount, providerKey) } // Step 4: Signal the old queue is closing. Producers that still hold a reference to // oldPq will detect this via isClosing() and transparently re-route to newPq. // This happens after the transfer so the new queue is already populated before // stale producers attempt their re-route. oldPq.signalClosing() bifrost.logger.Debug("signaled closing for old queue of provider %s", providerKey) // Step 5: Wait for all existing workers to finish processing in-flight requests. // Workers exit via oldPq.done (signalled above). waitGroup, exists := bifrost.waitGroups.Load(providerKey) if exists { waitGroup.(*sync.WaitGroup).Wait() bifrost.logger.Debug("all workers for provider %s have stopped", providerKey) } // Step 5b: Final drain sweep — see drainQueueWithErrors for full explanation. bifrost.drainQueueWithErrors(oldPq) // Step 6: Create new wait group for the updated workers. bifrost.waitGroups.Store(providerKey, &sync.WaitGroup{}) // Step 7: Create provider instance. provider, err := bifrost.createBaseProvider(providerKey, providerConfig) if err != nil { // Roll back: signal closing, remove from map, then drain. // Order matters: Delete before drainQueueWithErrors so that producers // re-routing via requestQueues.Load find nothing and return "provider // shutting down" immediately, narrowing the TOCTOU window before the sweep. newPq.signalClosing() bifrost.requestQueues.Delete(providerKey) bifrost.waitGroups.Delete(providerKey) bifrost.drainQueueWithErrors(newPq) if sliceErr := bifrost.removeProviderFromSlice(providerKey); sliceErr != nil { bifrost.logger.Error( "UpdateProvider rollback for %s is incomplete — provider was removed from queues "+ "but could not be removed from the providers slice: %v. "+ "bifrost.providers is now inconsistent. "+ "To recover: call RemoveProvider(%s) then AddProvider to re-register it, "+ "or restart Bifrost if that fails.", providerKey, sliceErr, providerKey, ) } return fmt.Errorf("provider update for %s failed during initialization; provider has been removed — re-add or retry UpdateProvider to restore it: %v", providerKey, err) } // Step 8: Atomically replace the provider in the providers slice. // This must happen before starting new workers to prevent stale reads bifrost.logger.Debug("atomically replacing provider instance in providers slice for %s", providerKey) replacementAttempts := 0 maxReplacementAttempts := 100 // Prevent infinite loops in high-contention scenarios for { replacementAttempts++ if replacementAttempts > maxReplacementAttempts { newPq.signalClosing() bifrost.requestQueues.Delete(providerKey) bifrost.waitGroups.Delete(providerKey) bifrost.drainQueueWithErrors(newPq) if sliceErr := bifrost.removeProviderFromSlice(providerKey); sliceErr != nil { bifrost.logger.Error( "UpdateProvider rollback for %s is incomplete — provider was removed from queues "+ "but could not be removed from the providers slice: %v. "+ "bifrost.providers is now inconsistent. "+ "To recover: call RemoveProvider(%s) then AddProvider to re-register it, "+ "or restart Bifrost if that fails.", providerKey, sliceErr, providerKey, ) } return fmt.Errorf("failed to replace provider %s in providers slice after %d attempts; provider has been removed — re-add or retry UpdateProvider to restore it", providerKey, maxReplacementAttempts) } oldPtr := bifrost.providers.Load() var oldSlice []schemas.Provider if oldPtr != nil { oldSlice = *oldPtr } // Create new slice without the old provider of this key // Use exact capacity to avoid allocations newSlice := make([]schemas.Provider, 0, len(oldSlice)) oldProviderFound := false for _, existingProvider := range oldSlice { if existingProvider.GetProviderKey() != providerKey { newSlice = append(newSlice, existingProvider) } else { oldProviderFound = true } } // Add the new provider newSlice = append(newSlice, provider) if bifrost.providers.CompareAndSwap(oldPtr, &newSlice) { if oldProviderFound { bifrost.logger.Debug("successfully replaced existing provider instance for %s in providers slice", providerKey) } else { bifrost.logger.Debug("successfully added new provider instance for %s to providers slice", providerKey) } break } // Retrying as swapping did not work (likely due to concurrent modification) } // Step 9: Start new workers with updated concurrency. bifrost.logger.Debug("starting %d new workers for provider %s with buffer size %d", providerConfig.ConcurrencyAndBufferSize.Concurrency, providerKey, providerConfig.ConcurrencyAndBufferSize.BufferSize) waitGroupValue, _ := bifrost.waitGroups.Load(providerKey) currentWaitGroup := waitGroupValue.(*sync.WaitGroup) for range providerConfig.ConcurrencyAndBufferSize.Concurrency { currentWaitGroup.Add(1) go bifrost.requestWorker(provider, providerConfig, newPq) } bifrost.logger.Info("successfully updated provider configuration for provider %s", providerKey) return nil } // GetDropExcessRequests returns the current value of DropExcessRequests func (bifrost *Bifrost) GetDropExcessRequests() bool { return bifrost.dropExcessRequests.Load() } // UpdateDropExcessRequests updates the DropExcessRequests setting at runtime. // This allows for hot-reloading of this configuration value. func (bifrost *Bifrost) UpdateDropExcessRequests(value bool) { bifrost.dropExcessRequests.Store(value) bifrost.logger.Info("drop_excess_requests updated to: %v", value) } // getProviderMutex gets or creates a mutex for the given provider func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *sync.RWMutex { mutexValue, _ := bifrost.providerMutexes.LoadOrStore(providerKey, &sync.RWMutex{}) return mutexValue.(*sync.RWMutex) } // removeProviderFromSlice atomically removes the provider with the given key // from bifrost.providers using a CAS retry loop. Callers hold the per-provider // write mutex so no concurrent goroutine can re-add this key — contention is // only from other providers' CAS operations, so the loop converges in at most // a few iterations under any concurrency level. // Returns an error if the limit is hit (state will be inconsistent). func (bifrost *Bifrost) removeProviderFromSlice(providerKey schemas.ModelProvider) error { const maxAttempts = 100 for range maxAttempts { oldPtr := bifrost.providers.Load() if oldPtr == nil { return nil } oldSlice := *oldPtr newSlice := make([]schemas.Provider, 0, len(oldSlice)) for _, p := range oldSlice { if p.GetProviderKey() != providerKey { newSlice = append(newSlice, p) } } if bifrost.providers.CompareAndSwap(oldPtr, &newSlice) { return nil } } return fmt.Errorf("failed to remove provider %s from providers slice after %d attempts", providerKey, maxAttempts) } // MCP PUBLIC API // RegisterMCPTool registers a typed tool handler with the MCP integration. // This allows developers to easily add custom tools that will be available // to all LLM requests processed by this Bifrost instance. // // Parameters: // - name: Unique tool name // - description: Human-readable tool description // - handler: Function that handles tool execution // - toolSchema: Bifrost tool schema for function calling // // Returns: // - error: Any registration error // // Example: // // type EchoArgs struct { // Message string `json:"message"` // } // // err := bifrost.RegisterMCPTool("echo", "Echo a message", // func(args EchoArgs) (string, error) { // return args.Message, nil // }, toolSchema) func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.ChatTool) error { if bifrost.MCPManager == nil { return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.RegisterTool(name, description, handler, toolSchema) } // IMPORTANT: Running the MCP client management operations (GetMCPClients, AddMCPClient, RemoveMCPClient, EditMCPClientTools) // may temporarily increase latency for incoming requests while the operations are being processed. // These operations involve network I/O and connection management that require mutex locks // which can block briefly during execution. // GetMCPClients returns all MCP clients managed by the Bifrost instance. // // Returns: // - []schemas.MCPClient: List of all MCP clients // - error: Any retrieval error func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) { if bifrost.MCPManager == nil { return nil, fmt.Errorf("mcp is not configured in this bifrost instance") } clients := bifrost.MCPManager.GetClients() clientsInConfig := make([]schemas.MCPClient, 0, len(clients)) for _, client := range clients { tools := make([]schemas.ChatToolFunction, 0, len(client.ToolMap)) for _, tool := range client.ToolMap { if tool.Function != nil { // Create a deep copy (for name) of the tool function to avoid modifying the original toolFunction := schemas.ChatToolFunction{} toolFunction.Name = tool.Function.Name toolFunction.Description = tool.Function.Description toolFunction.Parameters = tool.Function.Parameters toolFunction.Strict = tool.Function.Strict // Remove the client prefix from the tool name toolFunction.Name = strings.TrimPrefix(toolFunction.Name, client.ExecutionConfig.Name+"-") tools = append(tools, toolFunction) } } sort.Slice(tools, func(i, j int) bool { return tools[i].Name < tools[j].Name }) clientsInConfig = append(clientsInConfig, schemas.MCPClient{ Config: client.ExecutionConfig, Tools: tools, State: client.State, }) } return clientsInConfig, nil } // GetAvailableTools returns the available tools for the given context. // // Returns: // - []schemas.ChatTool: List of available tools func (bifrost *Bifrost) GetAvailableMCPTools(ctx *schemas.BifrostContext) []schemas.ChatTool { if bifrost.MCPManager == nil { return nil } return bifrost.MCPManager.GetAvailableTools(ctx) } // AddMCPClient adds a new MCP client to the Bifrost instance. // This allows for dynamic MCP client management at runtime. // // Parameters: // - config: MCP client configuration // // Returns: // - error: Any registration error // // Example: // // err := bifrost.AddMCPClient(schemas.MCPClientConfig{ // Name: "my-mcp-client", // ConnectionType: schemas.MCPConnectionTypeHTTP, // ConnectionString: &url, // }) func (bifrost *Bifrost) AddMCPClient(config *schemas.MCPClientConfig) error { if bifrost.MCPManager == nil { // Use sync.Once to ensure thread-safe initialization bifrost.mcpInitOnce.Do(func() { // Initialize with empty config - client will be added via AddClient below mcpConfig := schemas.MCPConfig{ ClientConfigs: []*schemas.MCPClientConfig{}, } // Set up plugin pipeline provider functions for executeCode tool hooks mcpConfig.PluginPipelineProvider = func() interface{} { return bifrost.getPluginPipeline() } mcpConfig.ReleasePluginPipeline = func(pipeline interface{}) { if pp, ok := pipeline.(*PluginPipeline); ok { bifrost.releasePluginPipeline(pp) } } // Create Starlark CodeMode for code execution (with default config) codeMode := starlark.NewStarlarkCodeMode(nil, bifrost.logger) bifrost.MCPManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode) }) } // Handle case where initialization succeeded elsewhere but manager is still nil if bifrost.MCPManager == nil { return fmt.Errorf("MCP manager is not initialized") } return bifrost.MCPManager.AddClient(config) } // RemoveMCPClient removes an MCP client from the Bifrost instance. // This allows for dynamic MCP client management at runtime. // // Parameters: // - id: ID of the client to remove // // Returns: // - error: Any removal error // // Example: // // err := bifrost.RemoveMCPClient("my-mcp-client-id") // if err != nil { // log.Fatalf("Failed to remove MCP client: %v", err) // } func (bifrost *Bifrost) RemoveMCPClient(id string) error { if bifrost.MCPManager == nil { return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.RemoveClient(id) } // SetMCPManager sets the MCP manager for this Bifrost instance. // This allows injecting a custom MCP manager implementation (e.g., for enterprise features). // If the provided manager is a concrete *mcp.MCPManager, Bifrost's plugin pipeline is injected // into the manager's CodeMode so that nested tool calls run through the plugin hooks. // // Parameters: // - manager: The MCP manager to set (must implement MCPManagerInterface) func (bifrost *Bifrost) SetMCPManager(manager mcp.MCPManagerInterface) { bifrost.MCPManager = manager // Inject Bifrost's plugin pipeline into the manager's CodeMode so that // nested tool calls (e.g. via Starlark executeCode) run through plugin hooks. if m, ok := manager.(*mcp.MCPManager); ok { m.SetPluginPipeline( func() mcp.PluginPipeline { pipeline := bifrost.getPluginPipeline() if pp, ok := any(pipeline).(mcp.PluginPipeline); ok { return pp } return nil }, func(pipeline mcp.PluginPipeline) { if pp, ok := pipeline.(*PluginPipeline); ok { bifrost.releasePluginPipeline(pp) } }, ) } } // UpdateMCPClient updates the MCP client. // This allows for dynamic MCP client tool management at runtime. // // Parameters: // - id: ID of the client to edit // - updatedConfig: Updated MCP client configuration // // Returns: // - error: Any edit error // // Example: // // err := bifrost.UpdateMCPClient("my-mcp-client-id", schemas.MCPClientConfig{ // Name: "my-mcp-client-name", // ToolsToExecute: []string{"tool1", "tool2"}, // }) func (bifrost *Bifrost) UpdateMCPClient(id string, updatedConfig *schemas.MCPClientConfig) error { if bifrost.MCPManager == nil { return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.UpdateClient(id, updatedConfig) } // ReconnectMCPClient attempts to reconnect an MCP client if it is disconnected. // // Parameters: // - id: ID of the client to reconnect // // Returns: // - error: Any reconnection error func (bifrost *Bifrost) ReconnectMCPClient(id string) error { if bifrost.MCPManager == nil { return fmt.Errorf("mcp is not configured in this bifrost instance") } return bifrost.MCPManager.ReconnectClient(id) } // VerifyPerUserOAuthConnection delegates to the MCP manager to verify an MCP // server using a temporary access token and discover available tools. The // connection is closed after verification. If the MCP manager is not yet // initialized, it is lazily created (same as AddMCPClient). func (bifrost *Bifrost) VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) { // Ensure MCP manager is initialized (lazy init, same pattern as AddMCPClient) if bifrost.MCPManager == nil { bifrost.mcpInitOnce.Do(func() { mcpConfig := schemas.MCPConfig{ ClientConfigs: []*schemas.MCPClientConfig{}, } mcpConfig.PluginPipelineProvider = func() interface{} { return bifrost.getPluginPipeline() } mcpConfig.ReleasePluginPipeline = func(pipeline interface{}) { if pp, ok := pipeline.(*PluginPipeline); ok { bifrost.releasePluginPipeline(pp) } } codeMode := starlark.NewStarlarkCodeMode(nil, bifrost.logger) bifrost.MCPManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode) }) } if bifrost.MCPManager == nil { return nil, nil, fmt.Errorf("MCP manager is not initialized") } return bifrost.MCPManager.VerifyPerUserOAuthConnection(ctx, config, accessToken) } // SetClientTools delegates to the MCP manager to update the tool map for an // existing MCP client. func (bifrost *Bifrost) SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) { if bifrost.MCPManager != nil { bifrost.MCPManager.SetClientTools(clientID, tools, toolNameMapping) } } // UpdateToolManagerConfig updates the tool manager config for the MCP manager. // This allows for hot-reloading of the tool manager config at runtime. // Pass the current value of disableAutoToolInject whenever only other fields // change so the flag is never silently reset to its zero value. func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string, disableAutoToolInject bool) error { if bifrost.MCPManager == nil { return fmt.Errorf("mcp is not configured in this bifrost instance") } bifrost.MCPManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{ MaxAgentDepth: maxAgentDepth, ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second, CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel), DisableAutoToolInject: disableAutoToolInject, }) return nil } // PROVIDER MANAGEMENT // createBaseProvider creates a provider based on the base provider type func (bifrost *Bifrost) createBaseProvider(providerKey schemas.ModelProvider, config *schemas.ProviderConfig) (schemas.Provider, error) { // Determine which provider type to create targetProviderKey := providerKey if config.CustomProviderConfig != nil { // Validate custom provider config if config.CustomProviderConfig.BaseProviderType == "" { return nil, fmt.Errorf("custom provider config missing base provider type") } // Validate that base provider type is supported if !IsSupportedBaseProvider(config.CustomProviderConfig.BaseProviderType) { return nil, fmt.Errorf("unsupported base provider type: %s", config.CustomProviderConfig.BaseProviderType) } // Automatically set the custom provider key to the provider name config.CustomProviderConfig.CustomProviderKey = string(providerKey) targetProviderKey = config.CustomProviderConfig.BaseProviderType } switch targetProviderKey { case schemas.OpenAI: return openai.NewOpenAIProvider(config, bifrost.logger), nil case schemas.Anthropic: return anthropic.NewAnthropicProvider(config, bifrost.logger), nil case schemas.Bedrock: return bedrock.NewBedrockProvider(config, bifrost.logger) case schemas.Cohere: return cohere.NewCohereProvider(config, bifrost.logger) case schemas.Azure: return azure.NewAzureProvider(config, bifrost.logger) case schemas.Vertex: return vertex.NewVertexProvider(config, bifrost.logger) case schemas.Mistral: return mistral.NewMistralProvider(config, bifrost.logger), nil case schemas.Ollama: return ollama.NewOllamaProvider(config, bifrost.logger) case schemas.Groq: return groq.NewGroqProvider(config, bifrost.logger) case schemas.SGL: return sgl.NewSGLProvider(config, bifrost.logger) case schemas.Parasail: return parasail.NewParasailProvider(config, bifrost.logger) case schemas.Perplexity: return perplexity.NewPerplexityProvider(config, bifrost.logger) case schemas.Cerebras: return cerebras.NewCerebrasProvider(config, bifrost.logger) case schemas.Gemini: return gemini.NewGeminiProvider(config, bifrost.logger), nil case schemas.OpenRouter: return openrouter.NewOpenRouterProvider(config, bifrost.logger), nil case schemas.Elevenlabs: return elevenlabs.NewElevenlabsProvider(config, bifrost.logger), nil case schemas.Nebius: return nebius.NewNebiusProvider(config, bifrost.logger) case schemas.HuggingFace: return huggingface.NewHuggingFaceProvider(config, bifrost.logger), nil case schemas.XAI: return xai.NewXAIProvider(config, bifrost.logger) case schemas.Replicate: return replicate.NewReplicateProvider(config, bifrost.logger) case schemas.VLLM: return vllm.NewVLLMProvider(config, bifrost.logger) case schemas.Runway: return runway.NewRunwayProvider(config, bifrost.logger) case schemas.Fireworks: return fireworks.NewFireworksProvider(config, bifrost.logger) default: return nil, fmt.Errorf("unsupported provider: %s", targetProviderKey) } } // prepareProvider sets up a provider with its configuration, keys, and worker channels. // It initializes the request queue and starts worker goroutines for processing requests. // Note: This function assumes the caller has already acquired the appropriate mutex for the provider. func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, config *schemas.ProviderConfig) error { // Create ProviderQueue with lifecycle management pq := &ProviderQueue{ queue: make(chan *ChannelMessage, config.ConcurrencyAndBufferSize.BufferSize), done: make(chan struct{}), signalOnce: sync.Once{}, } bifrost.requestQueues.Store(providerKey, pq) // Start specified number of workers bifrost.waitGroups.Store(providerKey, &sync.WaitGroup{}) provider, err := bifrost.createBaseProvider(providerKey, config) if err != nil { return fmt.Errorf("failed to create provider for the given key: %v", err) } waitGroupValue, _ := bifrost.waitGroups.Load(providerKey) currentWaitGroup := waitGroupValue.(*sync.WaitGroup) // Atomically append provider to the providers slice for { oldPtr := bifrost.providers.Load() var oldSlice []schemas.Provider if oldPtr != nil { oldSlice = *oldPtr } newSlice := make([]schemas.Provider, len(oldSlice)+1) copy(newSlice, oldSlice) newSlice[len(oldSlice)] = provider if bifrost.providers.CompareAndSwap(oldPtr, &newSlice) { break } } schemas.RegisterKnownProvider(providerKey) for range config.ConcurrencyAndBufferSize.Concurrency { currentWaitGroup.Add(1) go bifrost.requestWorker(provider, config, pq) } return nil } // getProviderQueue returns the ProviderQueue for a given provider key. // If the queue doesn't exist, it creates one at runtime and initializes the provider, // given the provider config is provided in the account interface implementation. // This function uses read locks to prevent race conditions during provider updates. // Callers must check the closing flag or select on the done channel before sending. func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (*ProviderQueue, error) { // Use read lock to allow concurrent reads but prevent concurrent updates providerMutex := bifrost.getProviderMutex(providerKey) providerMutex.RLock() if pqValue, exists := bifrost.requestQueues.Load(providerKey); exists { pq := pqValue.(*ProviderQueue) providerMutex.RUnlock() return pq, nil } // Provider doesn't exist, need to create it // Upgrade to write lock for creation providerMutex.RUnlock() providerMutex.Lock() defer providerMutex.Unlock() // Double-check after acquiring write lock (another goroutine might have created it) if pqValue, exists := bifrost.requestQueues.Load(providerKey); exists { pq := pqValue.(*ProviderQueue) return pq, nil } bifrost.logger.Debug(fmt.Sprintf("Creating new request queue for provider %s at runtime", providerKey)) config, err := bifrost.account.GetConfigForProvider(providerKey) if err != nil { return nil, fmt.Errorf("failed to get config for provider: %v", err) } if config == nil { return nil, fmt.Errorf("config is nil for provider %s", providerKey) } if err := bifrost.prepareProvider(providerKey, config); err != nil { return nil, err } pqValue, ok := bifrost.requestQueues.Load(providerKey) if !ok { return nil, fmt.Errorf("request queue not found for provider %s", providerKey) } pq := pqValue.(*ProviderQueue) return pq, nil } // GetProviderByKey returns the provider instance for the given provider key. // Returns nil if no provider with the given key exists. func (bifrost *Bifrost) GetProviderByKey(providerKey schemas.ModelProvider) schemas.Provider { return bifrost.getProviderByKey(providerKey) } // SelectKeyForProviderRequestType selects an API key for the given provider, request type, and model. // Used by WebSocket handlers that need a key for upstream connections while honoring request-specific // AllowedRequests gates such as realtime-only support. func (bifrost *Bifrost) SelectKeyForProviderRequestType(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { if ctx == nil { ctx = bifrost.ctx } baseProvider := providerKey if config, err := bifrost.account.GetConfigForProvider(providerKey); err == nil && config != nil && config.CustomProviderConfig != nil && config.CustomProviderConfig.BaseProviderType != "" { baseProvider = config.CustomProviderConfig.BaseProviderType } supportedKeys, _, err := bifrost.selectKeyFromProviderForModelWithPool(ctx, requestType, providerKey, model, baseProvider) if err != nil { return schemas.Key{}, err } if len(supportedKeys) == 0 { return schemas.Key{}, nil } if len(supportedKeys) == 1 { return supportedKeys[0], nil } return bifrost.keySelector(ctx, supportedKeys, providerKey, model) } // WSStreamHooks holds the post-hook runner and cleanup function returned by RunStreamPreHooks. // Call PostHookRunner for each streaming chunk, setting StreamEndIndicator on the final chunk. // Call Cleanup when done to release the pipeline back to the pool. // If ShortCircuitResponse is non-nil, a plugin short-circuited with a cached response — // the caller should write this response to the client and skip the upstream call. type WSStreamHooks struct { PostHookRunner schemas.PostHookRunner Cleanup func() ShortCircuitResponse *schemas.BifrostResponse } // RealtimeTurnHooks mirrors RunStreamPreHooks but is explicitly scoped to a // single realtime turn rather than one long-lived transport connection. type RealtimeTurnHooks struct { PostHookRunner schemas.PostHookRunner Cleanup func() } // RunStreamPreHooks acquires a plugin pipeline, sets up tracing context, runs PreLLMHooks, // and returns a PostHookRunner for per-chunk post-processing. // Used by WebSocket handlers that bypass the normal inference path but still need plugin hooks. func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*WSStreamHooks, *schemas.BifrostError) { if ctx == nil { ctx = bifrost.ctx } if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String()) } tracer := bifrost.getTracer() ctx.SetValue(schemas.BifrostContextKeyTracer, tracer) // Create a trace so the logging plugin can accumulate streaming chunks. // The traceID is used as the accumulator key in ProcessStreamingChunk. if _, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); !ok { traceID := tracer.CreateTrace("") if traceID != "" { ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID) } } // Mark as streaming context so RunPostLLMHooks uses accumulated timing ctx.SetValue(schemas.BifrostContextKeyStreamStartTime, time.Now()) pipeline := bifrost.getPluginPipeline() cleanup := func() { if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { tracer.CleanupStreamAccumulator(traceID) } bifrost.releasePluginPipeline(pipeline) } // Capture provider/model from the original request for early-exit paths below. // RequestType, Provider, OriginalModelRequested, and ResolvedModelUsed are always // overwritten around RunPostLLMHooks — plugin modifications to these 4 fields are // no-ops by design; proper request metadata is preserved and tampering is discouraged. reqProvider, reqModel, _ := req.GetRequestFields() preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) if preReq == nil && shortCircuit == nil { bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") bifrostErr.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel) _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel) } drainAndAttachPluginLogs(ctx) if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) } cleanup() return nil, bifrostErr } if shortCircuit != nil { if shortCircuit.Error != nil { shortCircuit.Error.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel) _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel) } drainAndAttachPluginLogs(ctx) if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) } cleanup() if bifrostErr != nil { return nil, bifrostErr } return nil, shortCircuit.Error } if shortCircuit.Response != nil { shortCircuit.Response.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel) resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel) } else if resp != nil { resp.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel) } drainAndAttachPluginLogs(ctx) if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) } cleanup() if bifrostErr != nil { return nil, bifrostErr } return &WSStreamHooks{ Cleanup: func() {}, ShortCircuitResponse: resp, }, nil } } wsProvider, wsModel, _ := preReq.GetRequestFields() postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { // Populate extra fields before RunPostLLMHooks so plugins (e.g. logging) // can read requestType/provider/model from the chunk or error. if result != nil { result.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) } if err != nil { err.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) } resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) if IsFinalChunk(ctx) { drainAndAttachPluginLogs(ctx) } if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) return nil, bifrostErr } else if resp != nil { resp.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) } return resp, nil } return &WSStreamHooks{ PostHookRunner: postHookRunner, Cleanup: cleanup, }, nil } // RunRealtimeTurnPreHooks acquires a plugin pipeline and runs LLM pre-hooks for // a single realtime turn. Unlike generic stream hooks, realtime turns do not // support short-circuit responses in v1 because the transports cannot yet emit a // fully synthetic assistant turn without an upstream generation. func (bifrost *Bifrost) RunRealtimeTurnPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*RealtimeTurnHooks, *schemas.BifrostError) { if req == nil { bifrostErr := newBifrostErrorFromMsg("realtime turn request is nil") bifrostErr.ExtraFields.RequestType = schemas.RealtimeRequest return nil, bifrostErr } if ctx == nil { ctx = bifrost.ctx } if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String()) } tracer := bifrost.getTracer() ctx.SetValue(schemas.BifrostContextKeyTracer, tracer) if _, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); !ok { traceID := tracer.CreateTrace("") if traceID != "" { ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID) } } pipeline := bifrost.getPluginPipeline() cleanup := func() { if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { tracer.CleanupStreamAccumulator(traceID) } bifrost.releasePluginPipeline(pipeline) } provider, model, _ := req.GetRequestFields() preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) if preReq == nil && shortCircuit == nil { bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) drainAndAttachPluginLogs(ctx) if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) } cleanup() return nil, bifrostErr } if shortCircuit != nil { if shortCircuit.Error != nil { shortCircuit.Error.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) drainAndAttachPluginLogs(ctx) if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) } cleanup() if bifrostErr != nil { return nil, bifrostErr } return nil, shortCircuit.Error } if shortCircuit.Response != nil { // Short-circuit responses are not supported for realtime turns (v1). // Treat this like an error turn so plugins can close pending state cleanly. bifrostErr := newBifrostErrorFromMsg("realtime turn short-circuit responses are not supported") bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) drainAndAttachPluginLogs(ctx) if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) } cleanup() return nil, bifrostErr } } provider, model, _ = preReq.GetRequestFields() return &RealtimeTurnHooks{ PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { if result != nil { result.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) } if err != nil { err.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) } resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) drainAndAttachPluginLogs(ctx) if bifrostErr != nil { bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) return nil, bifrostErr } else if resp != nil { resp.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) } return resp, nil }, Cleanup: cleanup, }, nil } // getProviderByKey retrieves a provider instance from the providers array by its provider key. // Returns the provider if found, or nil if no provider with the given key exists. func (bifrost *Bifrost) getProviderByKey(providerKey schemas.ModelProvider) schemas.Provider { providers := bifrost.providers.Load() if providers == nil { return nil } // Checking if provider is in the memory for _, provider := range *providers { if provider.GetProviderKey() == providerKey { return provider } } // Could happen when provider is not initialized yet, check if provider config exists in account and if so, initialize it config, err := bifrost.account.GetConfigForProvider(providerKey) if err != nil || config == nil { if slices.Contains(dynamicallyConfigurableProviders, providerKey) { bifrost.logger.Info(fmt.Sprintf("initializing provider %s with default config", providerKey)) // If no config found, use default config config = &schemas.ProviderConfig{ NetworkConfig: schemas.DefaultNetworkConfig, ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, } } else { return nil } } // Lock the provider mutex to avoid races providerMutex := bifrost.getProviderMutex(providerKey) providerMutex.Lock() defer providerMutex.Unlock() // Double-check after acquiring the lock providers = bifrost.providers.Load() if providers != nil { for _, p := range *providers { if p.GetProviderKey() == providerKey { return p } } } // Preparing provider if err := bifrost.prepareProvider(providerKey, config); err != nil { return nil } // Return newly prepared provider without recursion providers = bifrost.providers.Load() if providers != nil { for _, p := range *providers { if p.GetProviderKey() == providerKey { return p } } } return nil } // CORE INTERNAL LOGIC // shouldTryFallbacks handles the primary error and returns true if we should proceed with fallbacks, false if we should return immediately func (bifrost *Bifrost) shouldTryFallbacks(req *schemas.BifrostRequest, primaryErr *schemas.BifrostError) bool { // If no primary error, we succeeded if primaryErr == nil { bifrost.logger.Debug("no primary error, we should not try fallbacks") return false } // Handle request cancellation if primaryErr.Error != nil && primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { bifrost.logger.Debug("request cancelled, we should not try fallbacks") return false } // Check if this is a short-circuit error that doesn't allow fallbacks // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { bifrost.logger.Debug("allowFallbacks is false, we should not try fallbacks") return false } // If no fallbacks configured, return primary error _, _, fallbacks := req.GetRequestFields() if len(fallbacks) == 0 { bifrost.logger.Debug("no fallbacks configured, we should not try fallbacks") return false } // Should proceed with fallbacks return true } // prepareFallbackRequest creates a fallback request and validates the provider config // Returns the fallback request or nil if this fallback should be skipped func (bifrost *Bifrost) prepareFallbackRequest(req *schemas.BifrostRequest, fallback schemas.Fallback) *schemas.BifrostRequest { // Check if we have config for this fallback provider _, err := bifrost.account.GetConfigForProvider(fallback.Provider) if err != nil { bifrost.logger.Warn("config not found for provider %s, skipping fallback: %v", fallback.Provider, err) return nil } // Create a new request with the fallback provider and model fallbackReq := *req if req.TextCompletionRequest != nil { tmp := *req.TextCompletionRequest tmp.Provider = fallback.Provider tmp.Model = fallback.Model fallbackReq.TextCompletionRequest = &tmp } if req.ChatRequest != nil { tmp := *req.ChatRequest tmp.Provider = fallback.Provider tmp.Model = fallback.Model fallbackReq.ChatRequest = &tmp } if req.ResponsesRequest != nil { tmp := *req.ResponsesRequest tmp.Provider = fallback.Provider tmp.Model = fallback.Model fallbackReq.ResponsesRequest = &tmp } if req.CountTokensRequest != nil { tmp := *req.CountTokensRequest tmp.Provider = fallback.Provider tmp.Model = fallback.Model fallbackReq.CountTokensRequest = &tmp } if req.EmbeddingRequest != nil { tmp := *req.EmbeddingRequest tmp.Provider = fallback.Provider tmp.Model = fallback.Model fallbackReq.EmbeddingRequest = &tmp } if req.RerankRequest != nil { tmp := *req.RerankRequest tmp.Provider = fallback.Provider tmp.Model = fallback.Model fallbackReq.RerankRequest = &tmp } if req.OCRRequest != nil { tmp := *req.OCRRequest tmp.Provider = fallback.Provider tmp.Model = fallback.Model fallbackReq.OCRRequest = &tmp } if req.SpeechRequest != nil { tmp := *req.SpeechRequest tmp.Provider = fallback.Provider tmp.Model = fallback.Model fallbackReq.SpeechRequest = &tmp } if req.TranscriptionRequest != nil { tmp := *req.TranscriptionRequest tmp.Provider = fallback.Provider tmp.Model = fallback.Model fallbackReq.TranscriptionRequest = &tmp } if req.ImageGenerationRequest != nil { tmp := *req.ImageGenerationRequest tmp.Provider = fallback.Provider tmp.Model = fallback.Model fallbackReq.ImageGenerationRequest = &tmp } if req.VideoGenerationRequest != nil { tmp := *req.VideoGenerationRequest tmp.Provider = fallback.Provider tmp.Model = fallback.Model fallbackReq.VideoGenerationRequest = &tmp } return &fallbackReq } // shouldContinueWithFallbacks processes errors from fallback attempts // Returns true if we should continue with more fallbacks, false if we should stop func (bifrost *Bifrost) shouldContinueWithFallbacks(fallback schemas.Fallback, fallbackErr *schemas.BifrostError) bool { if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { return false } // Check if it was a short-circuit error that doesn't allow fallbacks if fallbackErr.AllowFallbacks != nil && !*fallbackErr.AllowFallbacks { return false } bifrost.logger.Debug(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message)) return true } // handleRequest handles the request to the provider based on the request type // It handles plugin hooks, request validation, response processing, and fallback providers. // If the primary provider fails, it will try each fallback provider in order until one succeeds. // It is the wrapper for all non-streaming public API methods. func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { defer bifrost.releaseBifrostRequest(req) provider, model, fallbacks := req.GetRequestFields() if err := validateRequest(req); err != nil { err.PopulateExtraFields(req.RequestType, provider, model, model) return nil, err } // Handle nil context early to prevent blocking if ctx == nil { ctx = bifrost.ctx } bifrost.logger.Debug(fmt.Sprintf("primary provider %s with model %s and %d fallbacks", provider, model, len(fallbacks))) // Try the primary provider first ctx.SetValue(schemas.BifrostContextKeyFallbackIndex, 0) // Ensure request ID is set in context before PreHooks if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { requestID := uuid.New().String() ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID) } primaryResult, primaryErr := bifrost.tryRequest(ctx, req) if primaryErr != nil { if primaryErr.Error != nil { bifrost.logger.Debug(fmt.Sprintf("primary provider %s with model %s returned error: %s", provider, model, primaryErr.Error.Message)) } else { bifrost.logger.Debug(fmt.Sprintf("primary provider %s with model %s returned error: %v", provider, model, primaryErr)) } if len(fallbacks) > 0 { bifrost.logger.Debug(fmt.Sprintf("check if we should try %d fallbacks", len(fallbacks))) } } // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) if !shouldTryFallbacks { return primaryResult, primaryErr } // Try fallbacks in order for i, fallback := range fallbacks { ctx.SetValue(schemas.BifrostContextKeyFallbackIndex, i+1) bifrost.logger.Debug(fmt.Sprintf("trying fallback provider %s with model %s", fallback.Provider, fallback.Model)) ctx.SetValue(schemas.BifrostContextKeyFallbackRequestID, uuid.New().String()) clearCtxForFallback(ctx) // Start span for fallback attempt tracer := bifrost.getTracer() spanCtx, handle := tracer.StartSpan(ctx, fmt.Sprintf("fallback.%s.%s", fallback.Provider, fallback.Model), schemas.SpanKindFallback) tracer.SetAttribute(handle, schemas.AttrProviderName, string(fallback.Provider)) tracer.SetAttribute(handle, schemas.AttrRequestModel, fallback.Model) tracer.SetAttribute(handle, "fallback.index", i+1) ctx.SetValue(schemas.BifrostContextKeySpanID, spanCtx.Value(schemas.BifrostContextKeySpanID)) fallbackReq := bifrost.prepareFallbackRequest(req, fallback) if fallbackReq == nil { bifrost.logger.Debug(fmt.Sprintf("fallback provider %s with model %s is nil", fallback.Provider, fallback.Model)) tracer.SetAttribute(handle, "error", "fallback request preparation failed") tracer.EndSpan(handle, schemas.SpanStatusError, "fallback request preparation failed") continue } // Try the fallback provider result, fallbackErr := bifrost.tryRequest(ctx, fallbackReq) if fallbackErr == nil { bifrost.logger.Debug(fmt.Sprintf("successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) tracer.EndSpan(handle, schemas.SpanStatusOk, "") return result, nil } // End span with error status if fallbackErr.Error != nil { tracer.SetAttribute(handle, "error", fallbackErr.Error.Message) } tracer.EndSpan(handle, schemas.SpanStatusError, "fallback failed") // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { return nil, fallbackErr } } // All providers failed, return the original error return nil, primaryErr } // handleStreamRequest handles the stream request to the provider based on the request type // It handles plugin hooks, request validation, response processing, and fallback providers. // If the primary provider fails, it will try each fallback provider in order until one succeeds. // It is the wrapper for all streaming public API methods. func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { defer bifrost.releaseBifrostRequest(req) provider, model, fallbacks := req.GetRequestFields() if err := validateRequest(req); err != nil { err.PopulateExtraFields(req.RequestType, provider, model, model) err.StatusCode = schemas.Ptr(fasthttp.StatusBadRequest) return nil, err } // Handle nil context early to prevent blocking if ctx == nil { ctx = bifrost.ctx } // Try the primary provider first ctx.SetValue(schemas.BifrostContextKeyFallbackIndex, 0) // Ensure request ID is set in context before PreHooks if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { requestID := uuid.New().String() ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID) } primaryResult, primaryErr := bifrost.tryStreamRequest(ctx, req) // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) if !shouldTryFallbacks { return primaryResult, primaryErr } // Try fallbacks in order for i, fallback := range fallbacks { ctx.SetValue(schemas.BifrostContextKeyFallbackIndex, i+1) ctx.SetValue(schemas.BifrostContextKeyFallbackRequestID, uuid.New().String()) clearCtxForFallback(ctx) // Start span for fallback attempt tracer := bifrost.getTracer() spanCtx, handle := tracer.StartSpan(ctx, fmt.Sprintf("fallback.%s.%s", fallback.Provider, fallback.Model), schemas.SpanKindFallback) tracer.SetAttribute(handle, schemas.AttrProviderName, string(fallback.Provider)) tracer.SetAttribute(handle, schemas.AttrRequestModel, fallback.Model) tracer.SetAttribute(handle, "fallback.index", i+1) ctx.SetValue(schemas.BifrostContextKeySpanID, spanCtx.Value(schemas.BifrostContextKeySpanID)) fallbackReq := bifrost.prepareFallbackRequest(req, fallback) if fallbackReq == nil { tracer.SetAttribute(handle, "error", "fallback request preparation failed") tracer.EndSpan(handle, schemas.SpanStatusError, "fallback request preparation failed") continue } // Try the fallback provider result, fallbackErr := bifrost.tryStreamRequest(ctx, fallbackReq) if fallbackErr == nil { bifrost.logger.Debug(fmt.Sprintf("successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) tracer.EndSpan(handle, schemas.SpanStatusOk, "") return result, nil } // End span with error status if fallbackErr.Error != nil { tracer.SetAttribute(handle, "error", fallbackErr.Error.Message) } tracer.EndSpan(handle, schemas.SpanStatusError, "fallback failed") // Check if we should continue with more fallbacks if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) { return nil, fallbackErr } } // All providers failed, return the original error return nil, primaryErr } // tryRequest is a generic function that handles common request processing logic // It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { provider, model, _ := req.GetRequestFields() pq, err := bifrost.getProviderQueue(provider) if err != nil { bifrostErr := newBifrostError(err) bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Add MCP tools to request if MCP is configured and requested if bifrost.MCPManager != nil { req = bifrost.MCPManager.AddToolsToRequest(ctx, req) } tracer := bifrost.getTracer() if tracer == nil { bifrostErr := newBifrostErrorFromMsg("tracer not found in context") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Store tracer in context BEFORE calling requestHandler, so streaming goroutines // have access to it for completing deferred spans when the stream ends. // The streaming goroutine captures the context when it starts, so these values // must be set before requestHandler() is called. ctx.SetValue(schemas.BifrostContextKeyTracer, tracer) pipeline := bifrost.getPluginPipeline() defer bifrost.releasePluginPipeline(pipeline) // RequestType, Provider, OriginalModelRequested, and ResolvedModelUsed are always // overwritten around RunPostLLMHooks — plugin modifications to these 4 fields are // no-ops by design; proper request metadata is preserved and tampering is discouraged. preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) if shortCircuit != nil { // Handle short-circuit with response (success case) if shortCircuit.Response != nil { shortCircuit.Response.PopulateExtraFields(req.RequestType, provider, model, model) resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) } else if resp != nil { resp.PopulateExtraFields(req.RequestType, provider, model, model) } drainAndAttachPluginLogs(ctx) if bifrostErr != nil { return nil, bifrostErr } return resp, nil } // Handle short-circuit with error if shortCircuit.Error != nil { shortCircuit.Error.PopulateExtraFields(req.RequestType, provider, model, model) resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) } else if resp != nil { resp.PopulateExtraFields(req.RequestType, provider, model, model) } drainAndAttachPluginLogs(ctx) if bifrostErr != nil { return nil, bifrostErr } return resp, nil } } if preReq == nil { bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } provider, model, _ = preReq.GetRequestFields() msg := bifrost.getChannelMessage(*preReq) msg.Context = ctx // If the queue is closing, check whether the provider was updated (new queue // available) or removed. On update, transparently re-route to the new queue // so in-flight producers don't get spurious errors. On removal, error out. // // Use a direct sync.Map lookup instead of getProviderQueue to avoid the // lazy-creation path: getProviderQueue can resurrect a provider that was // just removed by RemoveProvider if the account config still exists. if pq.isClosing() { var reroutedPq *ProviderQueue if val, ok := bifrost.requestQueues.Load(provider); ok { if candidate := val.(*ProviderQueue); candidate != pq && !candidate.isClosing() { reroutedPq = candidate } } if reroutedPq == nil { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ RequestType: req.RequestType, Provider: provider, OriginalModelRequested: model, ResolvedModelUsed: model, } return nil, bifrostErr } pq = reroutedPq } // Use select with done channel to detect shutdown during send select { case pq.queue <- msg: // Message was sent successfully case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr default: if bifrost.dropExcessRequests.Load() { bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") bifrostErr := newBifrostErrorFromMsg("request dropped: queue is full") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Re-check closing flag before blocking send (lock-free atomic check) if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } select { case pq.queue <- msg: // Message was sent successfully case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } } var result *schemas.BifrostResponse var resp *schemas.BifrostResponse pluginCount := len(*bifrost.llmPlugins.Load()) select { case result = <-msg.Response: resp, bifrostErr := pipeline.RunPostLLMHooks(msg.Context, result, nil, pluginCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) } else if resp != nil { resp.PopulateExtraFields(req.RequestType, provider, model, model) } drainAndAttachPluginLogs(msg.Context) if bifrostErr != nil { bifrost.releaseChannelMessage(msg) return nil, bifrostErr } bifrost.releaseChannelMessage(msg) // Strip raw fields that were captured for logging but should not reach the client. if resp != nil { dropReq, _ := ctx.Value(schemas.BifrostContextKeyDropRawRequestFromClient).(bool) dropResp, _ := ctx.Value(schemas.BifrostContextKeyDropRawResponseFromClient).(bool) if dropReq || dropResp { extraField := resp.GetExtraFields() if dropReq { extraField.RawRequest = nil } if dropResp { extraField.RawResponse = nil } } } return resp, nil case bifrostErrVal := <-msg.Err: bifrostErrPtr := &bifrostErrVal resp, bifrostErrPtr = pipeline.RunPostLLMHooks(msg.Context, nil, bifrostErrPtr, pluginCount) if bifrostErrPtr != nil { bifrostErrPtr.PopulateExtraFields(req.RequestType, provider, model, model) } else if resp != nil { resp.PopulateExtraFields(req.RequestType, provider, model, model) } drainAndAttachPluginLogs(msg.Context) bifrost.releaseChannelMessage(msg) // Strip raw fields on error path too. dropReq, _ := ctx.Value(schemas.BifrostContextKeyDropRawRequestFromClient).(bool) dropResp, _ := ctx.Value(schemas.BifrostContextKeyDropRawResponseFromClient).(bool) if dropReq || dropResp { if bifrostErrPtr != nil { if dropReq { bifrostErrPtr.ExtraFields.RawRequest = nil } if dropResp { bifrostErrPtr.ExtraFields.RawResponse = nil } } if resp != nil { extraField := resp.GetExtraFields() if dropReq { extraField.RawRequest = nil } if dropResp { extraField.RawResponse = nil } } } if bifrostErrPtr != nil { return nil, bifrostErrPtr } return resp, nil case <-ctx.Done(): // Do NOT releaseChannelMessage here. The message is already enqueued and // the worker still holds a reference to msg.Response and msg.Err. Returning // those channels to the pool now would let the next request reuse them while // the worker is still writing to them — stale data corruption. The worker // never calls releaseChannelMessage itself, so this message leaks from the // pool and is GC'd. That is intentional: a small pool leak on cancellation // is far safer than corrupting another request's channels. provider, model, _ := req.GetRequestFields() bifrostErr := newBifrostCtxDoneError(ctx, "waiting for provider response") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } } // tryStreamRequest is a generic function that handles common request processing logic // It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { provider, model, _ := req.GetRequestFields() pq, err := bifrost.getProviderQueue(provider) if err != nil { bifrostErr := newBifrostError(err) bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Add MCP tools to request if MCP is configured and requested if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.MCPManager != nil { req = bifrost.MCPManager.AddToolsToRequest(ctx, req) } tracer := bifrost.getTracer() if tracer == nil { bifrostErr := newBifrostErrorFromMsg("tracer not found in context") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Store tracer in context BEFORE calling RunLLMPreHooks, so plugins and streaming goroutines // have access to it for completing deferred spans when the stream ends. // The streaming goroutine captures the context when it starts, so these values // must be set before requestHandler() is called. ctx.SetValue(schemas.BifrostContextKeyTracer, tracer) // Ensure traceID exists so the logging plugin can create a stream accumulator // in PreLLMHook and accumulate chunks in PostLLMHook. For HTTP handler requests the // tracing middleware already sets this; for WebSocket bridge and Go SDK callers it // may be absent. if _, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); !ok { traceID := tracer.CreateTrace("") if traceID != "" { ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID) } } pipeline := bifrost.getPluginPipeline() releasePipeline := true defer func() { if releasePipeline { bifrost.releasePluginPipeline(pipeline) } }() // RequestType, Provider, OriginalModelRequested, and ResolvedModelUsed are always // overwritten around RunPostLLMHooks — plugin modifications to these 4 fields are // no-ops by design; proper request metadata is preserved and tampering is discouraged. preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) if shortCircuit != nil { // Handle short-circuit with response (success case) if shortCircuit.Response != nil { shortCircuit.Response.PopulateExtraFields(req.RequestType, provider, model, model) resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) } else if resp != nil { resp.PopulateExtraFields(req.RequestType, provider, model, model) } drainAndAttachPluginLogs(ctx) if bifrostErr != nil { return nil, bifrostErr } return newBifrostMessageChan(resp), nil } // Handle short-circuit with stream if shortCircuit.Stream != nil { outputStream := make(chan *schemas.BifrostStreamChunk) releasePipeline = false // pipeline is released inside the goroutine after stream drains // Create a post hook runner cause pipeline object is put back in the pool on defer pipelinePostHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { if result != nil { result.PopulateExtraFields(req.RequestType, provider, model, model) } if err != nil { err.PopulateExtraFields(req.RequestType, provider, model, model) } resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) if IsFinalChunk(ctx) { drainAndAttachPluginLogs(ctx) } if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } else if resp != nil { resp.PopulateExtraFields(req.RequestType, provider, model, model) } return resp, nil } go func() { defer func() { drainAndAttachPluginLogs(ctx) // ensure logs are drained even if stream closes without a final chunk pipeline.FinalizeStreamingPostHookSpans(ctx) bifrost.releasePluginPipeline(pipeline) }() defer close(outputStream) for streamMsg := range shortCircuit.Stream { if streamMsg == nil { continue } bifrostResponse := &schemas.BifrostResponse{} if streamMsg.BifrostTextCompletionResponse != nil { bifrostResponse.TextCompletionResponse = streamMsg.BifrostTextCompletionResponse } if streamMsg.BifrostChatResponse != nil { bifrostResponse.ChatResponse = streamMsg.BifrostChatResponse } if streamMsg.BifrostResponsesStreamResponse != nil { bifrostResponse.ResponsesStreamResponse = streamMsg.BifrostResponsesStreamResponse } if streamMsg.BifrostSpeechStreamResponse != nil { bifrostResponse.SpeechStreamResponse = streamMsg.BifrostSpeechStreamResponse } if streamMsg.BifrostTranscriptionStreamResponse != nil { bifrostResponse.TranscriptionStreamResponse = streamMsg.BifrostTranscriptionStreamResponse } if streamMsg.BifrostImageGenerationStreamResponse != nil { bifrostResponse.ImageGenerationStreamResponse = streamMsg.BifrostImageGenerationStreamResponse } // Run post hooks on the stream message processedResponse, processedError := pipelinePostHookRunner(ctx, bifrostResponse, streamMsg.BifrostError) // Build the client-facing chunk via the shared helper, which strips raw // request/response fields when in logging-only mode without mutating the // shared processedResponse or processedError objects. streamResponse := providerUtils.BuildClientStreamChunk(ctx, processedResponse, processedError) // Guarded send: if the consumer abandons outputStream (client // disconnect, ctx cancel), drain the upstream shortCircuit.Stream // so its producer can exit cleanly instead of blocking on its send. select { case outputStream <- streamResponse: case <-ctx.Done(): for range shortCircuit.Stream { } return } // TODO: Release the processed response immediately after use } }() return outputStream, nil } // Handle short-circuit with error if shortCircuit.Error != nil { shortCircuit.Error.PopulateExtraFields(req.RequestType, provider, model, model) resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) } else if resp != nil { resp.PopulateExtraFields(req.RequestType, provider, model, model) } drainAndAttachPluginLogs(ctx) if bifrostErr != nil { return nil, bifrostErr } return newBifrostMessageChan(resp), nil } } if preReq == nil { bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } provider, model, _ = preReq.GetRequestFields() msg := bifrost.getChannelMessage(*preReq) msg.Context = ctx // If the queue is closing, check whether the provider was updated (new queue // available) or removed. On update, transparently re-route to the new queue // so in-flight producers don't get spurious errors. On removal, error out. // // Use a direct sync.Map lookup instead of getProviderQueue to avoid the // lazy-creation path: getProviderQueue can resurrect a provider that was // just removed by RemoveProvider if the account config still exists. if pq.isClosing() { var reroutedPq *ProviderQueue if val, ok := bifrost.requestQueues.Load(provider); ok { if candidate := val.(*ProviderQueue); candidate != pq && !candidate.isClosing() { reroutedPq = candidate } } if reroutedPq == nil { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ RequestType: req.RequestType, Provider: provider, OriginalModelRequested: model, } return nil, bifrostErr } pq = reroutedPq } // Use select with done channel to detect shutdown during send select { case pq.queue <- msg: // Message was sent successfully case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr default: if bifrost.dropExcessRequests.Load() { bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") bifrostErr := newBifrostErrorFromMsg("request dropped: queue is full") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } // Re-check closing flag before blocking send (lock-free atomic check) if pq.isClosing() { bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } select { case pq.queue <- msg: // Message was sent successfully case <-pq.done: bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostErrorFromMsg("provider is shutting down") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr case <-ctx.Done(): bifrost.releaseChannelMessage(msg) bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space") bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) return nil, bifrostErr } } select { case stream := <-msg.ResponseStream: bifrost.releaseChannelMessage(msg) return stream, nil case bifrostErrVal := <-msg.Err: if bifrostErrVal.Error != nil { bifrost.logger.Debug("error while executing stream request: %s", bifrostErrVal.Error.Message) } else { bifrost.logger.Debug("error while executing stream request: %+v", bifrostErrVal) } // Marking final chunk ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) // On error we will complete post-hooks recoveredResp, recoveredErr := pipeline.RunPostLLMHooks(ctx, nil, &bifrostErrVal, len(*bifrost.llmPlugins.Load())) if recoveredErr != nil { recoveredErr.PopulateExtraFields(req.RequestType, provider, model, model) } else if recoveredResp != nil { recoveredResp.PopulateExtraFields(req.RequestType, provider, model, model) } drainAndAttachPluginLogs(ctx) bifrost.releaseChannelMessage(msg) if recoveredErr != nil { return nil, recoveredErr } if recoveredResp != nil { return newBifrostMessageChan(recoveredResp), nil } return nil, &bifrostErrVal case <-ctx.Done(): // Do NOT releaseChannelMessage here — see the identical note in tryRequest. // Worker still holds msg.ResponseStream/msg.Err; releasing now corrupts the // next request that reuses those pooled channels. return nil, newBifrostCtxDoneError(ctx, "while waiting for stream response") } } // executeRequestWithRetries is a generic function that handles common request processing logic. // It consolidates retry logic, backoff calculation, error handling, and key rotation. // It is not a bifrost method because interface methods in go cannot be generic. // // keyProvider, when non-nil, is called on the first attempt and again whenever a rate-limit error // triggers a key rotation. It receives the set of key IDs already used in the current rotation // cycle so it can exclude them; when the pool is exhausted the provider resets the set and starts // a fresh weighted round. Network errors (5xx) reuse the same key since they are transient server // issues rather than per-key capacity problems. func executeRequestWithRetries[T any]( ctx *schemas.BifrostContext, config *schemas.ProviderConfig, requestHandler func(key schemas.Key) (T, *schemas.BifrostError), keyProvider func(usedKeyIDs map[string]bool) (schemas.Key, error), requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, req *schemas.BifrostRequest, logger schemas.Logger, ) (T, *schemas.BifrostError) { var result T var bifrostError *schemas.BifrostError var attempts int var currentKey schemas.Key var usedKeyIDs map[string]bool lastWasRateLimit := false for attempts = 0; attempts <= config.NetworkConfig.MaxRetries; attempts++ { ctx.SetValue(schemas.BifrostContextKeyNumberOfRetries, attempts) // Reset the trail on the first attempt so a reused or shared context (bifrost.ctx) // doesn't carry over records from a previous request. if keyProvider != nil && attempts == 0 { ctx.SetValue(schemas.BifrostContextKeyAttemptTrail, []schemas.KeyAttemptRecord{}) } // Select / rotate key: always on attempt 0, and again when the previous failure was a // rate-limit (different key may have remaining capacity). Network errors keep the same key. if keyProvider != nil && (attempts == 0 || lastWasRateLimit) { if usedKeyIDs == nil { usedKeyIDs = make(map[string]bool) } // Wrap key selection in a dedicated span so traces show which key was chosen // (and when rotation happened). The span is opened before keyProvider is called // so selection errors are captured too. keyTracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer) var keySpanCtx context.Context var keyHandle schemas.SpanHandle if keyTracer != nil { keySpanCtx, keyHandle = keyTracer.StartSpan(ctx, "key.selection", schemas.SpanKindInternal) keyTracer.SetAttribute(keyHandle, schemas.AttrProviderName, string(providerKey)) keyTracer.SetAttribute(keyHandle, schemas.AttrRequestModel, model) if attempts > 0 { keyTracer.SetAttribute(keyHandle, "retry.count", attempts) } } selectedKey, err := keyProvider(usedKeyIDs) if keyTracer != nil { if err != nil { keyTracer.SetAttribute(keyHandle, "error", err.Error()) keyTracer.EndSpan(keyHandle, schemas.SpanStatusError, err.Error()) } else { keyTracer.SetAttribute(keyHandle, "key.id", selectedKey.ID) keyTracer.SetAttribute(keyHandle, "key.name", selectedKey.Name) keyTracer.EndSpan(keyHandle, schemas.SpanStatusOk, "") // Propagate the span context so subsequent spans (llm.call / retry.attempt.N) // are correctly linked in the trace hierarchy. ctx.SetValue(schemas.BifrostContextKeySpanID, keySpanCtx.Value(schemas.BifrostContextKeySpanID)) } } if err != nil { var zero T return zero, newBifrostErrorFromMsg(err.Error()) } currentKey = selectedKey ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, currentKey.ID) ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, currentKey.Name) } // Append a trail record for every attempt (key rotation and same-key retries alike). // Skipped when keyProvider is nil (keyless providers have no key to track). // FailReason is populated below once the attempt outcome is known. if keyProvider != nil { schemas.AppendToContextList(ctx, schemas.BifrostContextKeyAttemptTrail, schemas.KeyAttemptRecord{ Attempt: attempts, KeyID: currentKey.ID, KeyName: currentKey.Name, }) } if attempts > 0 { // Log retry attempt var retryMsg string if bifrostError != nil && bifrostError.Error != nil { retryMsg = bifrostError.Error.Message } else if bifrostError != nil && bifrostError.StatusCode != nil { retryMsg = fmt.Sprintf("status=%d", *bifrostError.StatusCode) if bifrostError.Type != nil { retryMsg += ", type=" + *bifrostError.Type } } logger.Debug("retrying request (attempt %d/%d) for model %s: %s", attempts, config.NetworkConfig.MaxRetries, model, retryMsg) // Calculate and apply backoff backoff := calculateBackoff(attempts-1, config) logger.Debug("sleeping for %s before retry", backoff) time.Sleep(backoff) } logger.Debug("attempting %s request for provider %s", requestType, providerKey) // Start span for LLM call (or retry attempt) tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer) if !ok || tracer == nil { logger.Error("tracer not found in context of executeRequestWithRetries") return result, newBifrostErrorFromMsg("tracer not found in context") } var spanName string var spanKind schemas.SpanKind if attempts > 0 { spanName = fmt.Sprintf("retry.attempt.%d", attempts) spanKind = schemas.SpanKindRetry } else { spanName = "llm.call" spanKind = schemas.SpanKindLLMCall } spanCtx, handle := tracer.StartSpan(ctx, spanName, spanKind) tracer.SetAttribute(handle, schemas.AttrProviderName, string(providerKey)) tracer.SetAttribute(handle, schemas.AttrRequestModel, model) tracer.SetAttribute(handle, "request.type", string(requestType)) if attempts > 0 { tracer.SetAttribute(handle, "retry.count", attempts) } // Add context-related attributes (selected key, virtual key, team, customer, etc.) if selectedKeyID, ok := ctx.Value(schemas.BifrostContextKeySelectedKeyID).(string); ok && selectedKeyID != "" { tracer.SetAttribute(handle, schemas.AttrSelectedKeyID, selectedKeyID) } if selectedKeyName, ok := ctx.Value(schemas.BifrostContextKeySelectedKeyName).(string); ok && selectedKeyName != "" { tracer.SetAttribute(handle, schemas.AttrSelectedKeyName, selectedKeyName) } if virtualKeyID, ok := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string); ok && virtualKeyID != "" { tracer.SetAttribute(handle, schemas.AttrVirtualKeyID, virtualKeyID) } if virtualKeyName, ok := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyName).(string); ok && virtualKeyName != "" { tracer.SetAttribute(handle, schemas.AttrVirtualKeyName, virtualKeyName) } if teamID, ok := ctx.Value(schemas.BifrostContextKeyGovernanceTeamID).(string); ok && teamID != "" { tracer.SetAttribute(handle, schemas.AttrTeamID, teamID) } if teamName, ok := ctx.Value(schemas.BifrostContextKeyGovernanceTeamName).(string); ok && teamName != "" { tracer.SetAttribute(handle, schemas.AttrTeamName, teamName) } if customerID, ok := ctx.Value(schemas.BifrostContextKeyGovernanceCustomerID).(string); ok && customerID != "" { tracer.SetAttribute(handle, schemas.AttrCustomerID, customerID) } if customerName, ok := ctx.Value(schemas.BifrostContextKeyGovernanceCustomerName).(string); ok && customerName != "" { tracer.SetAttribute(handle, schemas.AttrCustomerName, customerName) } if fallbackIndex, ok := ctx.Value(schemas.BifrostContextKeyFallbackIndex).(int); ok { tracer.SetAttribute(handle, schemas.AttrFallbackIndex, fallbackIndex) } tracer.SetAttribute(handle, schemas.AttrNumberOfRetries, attempts) // Populate LLM request attributes (messages, parameters, etc.) if req != nil { tracer.PopulateLLMRequestAttributes(handle, req) } // Update context with span ID ctx.SetValue(schemas.BifrostContextKeySpanID, spanCtx.Value(schemas.BifrostContextKeySpanID)) // Record stream start time for TTFT calculation (only for streaming requests) // This is also used by RunPostLLMHooks to detect streaming mode if IsStreamRequestType(requestType) { streamStartTime := time.Now() ctx.SetValue(schemas.BifrostContextKeyStreamStartTime, streamStartTime) } // Attempt the request result, bifrostError = requestHandler(currentKey) // For streaming requests that returned success, check if the first chunk // is actually an error (e.g., rate limits sent as SSE events in HTTP 200). // This enables retries and fallbacks for providers that embed errors in // the SSE stream instead of returning proper HTTP error status codes. if bifrostError == nil { if streamChan, ok := any(result).(chan *schemas.BifrostStreamChunk); ok { checkedStream, drainDone, firstChunkErr := providerUtils.CheckFirstStreamChunkForError(ctx, streamChan) if firstChunkErr != nil { <-drainDone bifrostError = firstChunkErr } else { result = any(checkedStream).(T) } } } // Check if result is a streaming channel - if so, defer span completion // Only defer for successful stream setup; error paths must end the span synchronously isStreamChan := false if bifrostError == nil { if ch, ok := any(result).(chan *schemas.BifrostStreamChunk); ok && ch != nil { isStreamChan = true } } if isStreamChan { // For streaming requests, store the span handle in TraceStore keyed by trace ID // This allows the provider's streaming goroutine to retrieve it later if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { tracer.StoreDeferredSpan(traceID, handle) } // Don't end the span here - it will be ended when streaming completes } else { // Populate LLM response attributes for non-streaming responses if resp, ok := any(result).(*schemas.BifrostResponse); ok { tracer.PopulateLLMResponseAttributes(ctx, handle, resp, bifrostError) } // End span with appropriate status if bifrostError != nil { if bifrostError.Error != nil { tracer.SetAttribute(handle, "error", bifrostError.Error.Message) } if bifrostError.StatusCode != nil { tracer.SetAttribute(handle, "status_code", *bifrostError.StatusCode) } tracer.EndSpan(handle, schemas.SpanStatusError, "request failed") } else { tracer.EndSpan(handle, schemas.SpanStatusOk, "") } } logger.Debug("request %s for provider %s completed", requestType, providerKey) // Check if successful or if we should retry if bifrostError == nil || bifrostError.IsBifrostError || (bifrostError.Error != nil && bifrostError.Error.Type != nil && *bifrostError.Error.Type == schemas.RequestCancelled) { break } // Check if we should retry based on status code or error message shouldRetry := false isRateLimit := (bifrostError.StatusCode != nil && *bifrostError.StatusCode == 429) || (bifrostError.Error != nil && (IsRateLimitErrorMessage(bifrostError.Error.Message) || (bifrostError.Error.Type != nil && IsRateLimitErrorMessage(*bifrostError.Error.Type)) || (bifrostError.Error.Code != nil && IsRateLimitErrorMessage(*bifrostError.Error.Code)))) errMessage := GetErrorMessage(bifrostError) if bifrostError.Error != nil && (bifrostError.Error.Message == schemas.ErrProviderDoRequest || bifrostError.Error.Message == schemas.ErrProviderNetworkError) { shouldRetry = true logger.Debug("detected request HTTP/network error, will retry: %s", errMessage) } else if (bifrostError.StatusCode != nil && retryableStatusCodes[*bifrostError.StatusCode]) || isRateLimit { shouldRetry = true logger.Debug("encountered error that should be retried: %s", errMessage) } // Fill FailReason on any failed attempt (retryable or terminal). // Use the provider error type when present; fall back to "unknown". if trail, ok := ctx.Value(schemas.BifrostContextKeyAttemptTrail).([]schemas.KeyAttemptRecord); ok && len(trail) > 0 { reason := "unknown" if bifrostError.Error != nil && bifrostError.Error.Type != nil && *bifrostError.Error.Type != "" { reason = *bifrostError.Error.Type } else if isRateLimit { reason = "rate_limit_error" } trail[len(trail)-1].FailReason = &reason ctx.SetValue(schemas.BifrostContextKeyAttemptTrail, trail) } if !shouldRetry { break } // Mark current key as used so the next selection excludes it (rate-limit only). // Network errors keep the same key — they are transient server issues, not per-key. if isRateLimit && keyProvider != nil { if usedKeyIDs == nil { usedKeyIDs = make(map[string]bool) } usedKeyIDs[currentKey.ID] = true } lastWasRateLimit = isRateLimit } // Add retry information to error if attempts > 0 { logger.Debug("request failed after %d %s", attempts, map[bool]string{true: "attempts", false: "attempt"}[attempts > 1]) } // On final error, clear selected_key so it only reflects a key that actually served a successful response. // The attempt trail is the authoritative record of which keys were tried. if bifrostError != nil && keyProvider != nil { ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, "") ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, "") } return result, bifrostError } // requestWorker handles incoming requests from the queue for a specific provider. // It manages retries, error handling, and response processing. func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas.ProviderConfig, pq *ProviderQueue) { defer func() { if waitGroupValue, ok := bifrost.waitGroups.Load(provider.GetProviderKey()); ok { waitGroup := waitGroupValue.(*sync.WaitGroup) waitGroup.Done() } }() for { var req *ChannelMessage select { case r := <-pq.queue: req = r case <-pq.done: // Provider is shutting down. Drain any buffered requests and send // back errors so callers are not left blocked on their response channel. for { select { case r := <-pq.queue: provKey, mod, _ := r.GetRequestFields() select { case r.Err <- schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "provider is shutting down", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: r.RequestType, Provider: provKey, OriginalModelRequested: mod, }, }: case <-r.Context.Done(): } default: return } } } _, model, _ := req.BifrostRequest.GetRequestFields() var result *schemas.BifrostResponse var stream chan *schemas.BifrostStreamChunk var bifrostError *schemas.BifrostError var err error // Determine the base provider type for key requirement checks baseProvider := provider.GetProviderKey() if cfg := config.CustomProviderConfig; cfg != nil && cfg.BaseProviderType != "" { baseProvider = cfg.BaseProviderType } req.Context.SetValue(schemas.BifrostContextKeyIsCustomProvider, !IsStandardProvider(baseProvider)) // Determine whether this provider attempt should capture raw payloads. // // Effective values are computed by merging provider config with any per-request // context overrides (BifrostContextKeySendBackRawRequest/Response and // BifrostContextKeyStoreRawRequestResponse). A context value set to either true // or false fully overrides the provider config for that flag. // // Each flag is independent: // send_back_raw_request — include raw request bytes in the client response. // send_back_raw_response — include raw response bytes in the client response. // store_raw_request_response — persist raw bytes in log records (logging plugin only). // // Capture is enabled per-side whenever send-back OR store is requested for that side. // Strip flags tell the response path to remove that side's bytes before the payload // reaches the caller (used when store=true but send-back=false for that side). // // All internal signals are always written explicitly on every attempt so stale values // from a previous provider attempt (e.g. different fallback provider config) cannot // leak into the new attempt on a reused context. The user override keys // (BifrostContextKeySendBackRaw*, BifrostContextKeyStoreRawRequestResponse) are // never overwritten — they are read-only from bifrost.go's perspective. // Step 1: compute effective value for each flag (provider config ← per-request override). effectiveSendBackReq := config.SendBackRawRequest if override, ok := req.Context.Value(schemas.BifrostContextKeySendBackRawRequest).(bool); ok { effectiveSendBackReq = override } effectiveSendBackResp := config.SendBackRawResponse if override, ok := req.Context.Value(schemas.BifrostContextKeySendBackRawResponse).(bool); ok { effectiveSendBackResp = override } effectiveStore := config.StoreRawRequestResponse if override, ok := req.Context.Value(schemas.BifrostContextKeyStoreRawRequestResponse).(bool); ok { effectiveStore = override } // Step 2: derive per-side capture and strip flags. // Capture if we need to send the data back OR store it — independent per side. captureReq := effectiveSendBackReq || effectiveStore captureResp := effectiveSendBackResp || effectiveStore // Strip from client response if we captured for storage but not for send-back. dropReq := effectiveStore && !effectiveSendBackReq dropResp := effectiveStore && !effectiveSendBackResp // Step 3: write all internal signals explicitly (never touch the user override keys). req.Context.SetValue(schemas.BifrostContextKeyCaptureRawRequest, captureReq) req.Context.SetValue(schemas.BifrostContextKeyCaptureRawResponse, captureResp) req.Context.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, dropReq) req.Context.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, dropResp) // Tells the logging plugin whether to persist raw bytes in log records. req.Context.SetValue(schemas.BifrostContextKeyShouldStoreRawInLogs, effectiveStore) var keys []schemas.Key // keyProvider is passed to executeRequestWithRetries to manage key selection and rotation. // It is nil when no key is required (e.g. providerRequiresKey=false) or for multi-key // batch/file/container operations that manage their own key lists. var keyProvider func(usedKeyIDs map[string]bool) (schemas.Key, error) if providerRequiresKey(config.CustomProviderConfig) { // ListModels needs all enabled/supported keys so providers can aggregate // and report per-key statuses (KeyStatuses). if req.RequestType == schemas.ListModelsRequest { keys, err = bifrost.getAllSupportedKeys(req.Context, provider.GetProviderKey(), baseProvider) if err != nil { bifrost.logger.Debug("error getting supported keys for list models: %v", err) req.Err <- schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: err.Error(), Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ Provider: provider.GetProviderKey(), RequestType: req.RequestType, OriginalModelRequested: model, ResolvedModelUsed: model, }, } continue } } else { // Determine if this is a multi-key batch/file/container operation // BatchCreate, FileUpload, ContainerCreate, ContainerFileCreate use single key; other batch/file/container ops use multiple keys isMultiKeyBatchOp := isBatchRequestType(req.RequestType) && req.RequestType != schemas.BatchCreateRequest isMultiKeyFileOp := isFileRequestType(req.RequestType) && req.RequestType != schemas.FileUploadRequest isMultiKeyContainerOp := isContainerRequestType(req.RequestType) && req.RequestType != schemas.ContainerCreateRequest && req.RequestType != schemas.ContainerFileCreateRequest if isMultiKeyBatchOp || isMultiKeyFileOp || isMultiKeyContainerOp { var modelPtr *string if model != "" { modelPtr = &model } keys, err = bifrost.getKeysForBatchAndFileOps(req.Context, provider.GetProviderKey(), baseProvider, modelPtr, isMultiKeyBatchOp) if err != nil { bifrost.logger.Debug("error getting keys for batch/file operation: %v", err) req.Err <- schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: err.Error(), Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ Provider: provider.GetProviderKey(), RequestType: req.RequestType, OriginalModelRequested: model, ResolvedModelUsed: model, }, } continue } } else { // Build the key pool for this request. Selection and rotation are deferred to // executeRequestWithRetries via keyProvider so that each retry attempt can use // a different key (on rate-limit errors) without re-running the full filtering. supportedKeys, canRotate, keyPoolErr := bifrost.selectKeyFromProviderForModelWithPool(req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider) if keyPoolErr != nil { bifrost.logger.Debug("error building key pool for model %s: %v", model, keyPoolErr) req.Err <- schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: keyPoolErr.Error(), Error: keyPoolErr, }, ExtraFields: schemas.BifrostErrorExtraFields{ Provider: provider.GetProviderKey(), RequestType: req.RequestType, OriginalModelRequested: model, ResolvedModelUsed: model, }, } continue } if len(supportedKeys) == 0 { // SkipKeySelection path — keyProvider stays nil, zero Key is used. } else if !canRotate { // Fixed key (DirectKey, explicit ID/name, session stickiness): always // return the same key regardless of usedKeyIDs. fixedKey := supportedKeys[0] keyProvider = func(_ map[string]bool) (schemas.Key, error) { return fixedKey, nil } } else { // Rotating pool: weighted selection with per-cycle exclusion. // Captures supportedKeys, bifrost.keySelector, provider/model by value. pool := supportedKeys provKey := provider.GetProviderKey() mdl := model keyProvider = func(usedKeyIDs map[string]bool) (schemas.Key, error) { available := make([]schemas.Key, 0, len(pool)) for _, k := range pool { if !usedKeyIDs[k.ID] { available = append(available, k) } } if len(available) == 0 { // All keys exhausted — start a fresh weighted round. for id := range usedKeyIDs { delete(usedKeyIDs, id) } available = pool } return bifrost.keySelector(req.Context, available, provKey, mdl) } } } } } originalModelRequested := model // resolvedModel is set inside the handler closures below on every attempt so that each // key's own alias mapping is applied. The outer var holds the LAST attempt's value and is // read single-threaded by the worker after retries finish (e.g. the error-fallback at // line 5653). Streaming postHookRunner must NOT capture this var by reference — it // snapshots its own attemptResolvedModel inside the per-attempt closure. var resolvedModel string // lastAttemptFinalizer captures the LAST attempt's postHookSpanFinalizer for the // worker-level error fallback below. Single-threaded write (assigned by the retry // loop's per-attempt closure) and single-threaded read (after retries finish), so // no synchronization needed. Earlier attempts' finalizers fire via their provider // goroutines' defers — passed via the postHookSpanFinalizer parameter directly to // handleProviderStreamRequest, never via the shared req.Context. var lastAttemptFinalizer func(context.Context) // Execute request with retries. For streaming, the plugin pipeline, // postHookRunner, and finalizer are allocated per-attempt inside the // request handler closure. If they were request-scoped, a retry // triggered by CheckFirstStreamChunkForError could run against a // pipeline the previous attempt's provider goroutine has already // returned to the pool via its deferred finalizer. if IsStreamRequestType(req.RequestType) { stream, bifrostError = executeRequestWithRetries(req.Context, config, func(k schemas.Key) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { resolvedModel = k.Aliases.Resolve(originalModelRequested) req.SetModel(resolvedModel) // Snapshot per-attempt so postHookRunner doesn't observe a later retry's // alias while this attempt's provider goroutine is still emitting chunks. attemptResolvedModel := resolvedModel pipeline := bifrost.getPluginPipeline() postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { // Populate extra fields before RunPostLLMHooks so plugins (e.g. logging) // can read requestType/provider/model from the chunk or error. // Uses the per-attempt snapshot — capturing the outer resolvedModel by // reference would let a later retry's alias bleed into this attempt's chunks. if result != nil { result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel) } if err != nil { err.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel) } resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load())) if IsFinalChunk(ctx) { drainAndAttachPluginLogs(ctx) } if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel) return nil, bifrostErr } else if resp != nil { resp.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel) } return resp, nil } // Store a finalizer callback to create aggregated post-hook spans at stream end. // Wrapped in sync.Once so the normal end-of-stream invocation and a deferred // safety-net invocation (e.g. from a provider goroutine's panic path) cannot // double-release the pipeline. var finalizerOnce sync.Once postHookSpanFinalizer := func(ctx context.Context) { finalizerOnce.Do(func() { pipeline.FinalizeStreamingPostHookSpans(ctx) bifrost.releasePluginPipeline(pipeline) }) } lastAttemptFinalizer = postHookSpanFinalizer streamCh, streamErr := bifrost.handleProviderStreamRequest(provider, req, k, postHookRunner, postHookSpanFinalizer) // If stream setup failed before any provider goroutine started, // no deferred finalizer will run — release the pipeline directly // so a retry doesn't inherit a leaked pool entry. if streamErr != nil && streamCh == nil { finalizerOnce.Do(func() { bifrost.releasePluginPipeline(pipeline) }) } return streamCh, streamErr }, keyProvider, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) } else { result, bifrostError = executeRequestWithRetries(req.Context, config, func(k schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { resolvedModel = k.Aliases.Resolve(originalModelRequested) req.SetModel(resolvedModel) return bifrost.handleProviderRequest(provider, config, req, k, keys) }, keyProvider, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) } // For streaming with an error, route release through the LAST attempt's // finalizer (wrapped in sync.Once) so we don't double-Put into the pool // or race the provider goroutine's deferred FinalizeStreamingPostHookSpans // call. lastAttemptFinalizer is set inside the per-attempt closure on every // iteration; after retries finish, it holds the LAST attempt's finalizer. // Earlier attempts' finalizers have already fired via their provider // goroutines' defers (passed via the postHookSpanFinalizer parameter // directly to handleProviderStreamRequest). For streaming without error, // the finalizer is invoked by completeDeferredSpan / the provider // goroutine's defer. if IsStreamRequestType(req.RequestType) && bifrostError != nil { if lastAttemptFinalizer != nil { lastAttemptFinalizer(req.Context) } } if bifrostError != nil { bifrostError.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) // Send error with context awareness to prevent deadlock select { case req.Err <- *bifrostError: // Error sent successfully case <-req.Context.Done(): // Client no longer listening, log and continue bifrost.logger.Debug("Client context cancelled while sending error response") case <-time.After(5 * time.Second): // Timeout to prevent indefinite blocking bifrost.logger.Warn("Timeout while sending error response, client may have disconnected") } } else { if result != nil { result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) } if IsStreamRequestType(req.RequestType) { // Send stream with context awareness to prevent deadlock select { case req.ResponseStream <- stream: // Stream sent successfully case <-req.Context.Done(): // Client no longer listening, log and continue bifrost.logger.Debug("Client context cancelled while sending stream response") case <-time.After(5 * time.Second): // Timeout to prevent indefinite blocking bifrost.logger.Warn("Timeout while sending stream response, client may have disconnected") } } else { // Send response with context awareness to prevent deadlock select { case req.Response <- result: // Response sent successfully case <-req.Context.Done(): // Client no longer listening, log and continue bifrost.logger.Debug("Client context cancelled while sending response") case <-time.After(5 * time.Second): // Timeout to prevent indefinite blocking bifrost.logger.Warn("Timeout while sending response, client may have disconnected") } } } } // bifrost.logger.Debug("worker for provider %s exiting...", provider.GetProviderKey()) } // handleProviderRequest handles the request to the provider based on the request type // key is used for single-key operations, keys is used for batch/file operations that need multiple keys func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config *schemas.ProviderConfig, req *ChannelMessage, key schemas.Key, keys []schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { response := &schemas.BifrostResponse{} switch req.RequestType { case schemas.ListModelsRequest: listModelsResponse, bifrostError := provider.ListModels(req.Context, keys, req.BifrostRequest.ListModelsRequest) if bifrostError != nil { return nil, bifrostError } response.ListModelsResponse = listModelsResponse case schemas.TextCompletionRequest: if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest { chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() if chatRequest != nil { chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, chatRequest) if bifrostError != nil { return nil, bifrostError } response.TextCompletionResponse = chatCompletionResponse.ToBifrostTextCompletionResponse() break } } textCompletionResponse, bifrostError := provider.TextCompletion(req.Context, key, req.BifrostRequest.TextCompletionRequest) if bifrostError != nil { return nil, bifrostError } response.TextCompletionResponse = textCompletionResponse case schemas.ChatCompletionRequest: if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest { responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() if responsesRequest != nil { responsesResponse, bifrostError := provider.Responses(req.Context, key, responsesRequest) if bifrostError != nil { return nil, bifrostError } response.ChatResponse = responsesResponse.ToBifrostChatResponse() break } } chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, req.BifrostRequest.ChatRequest) if bifrostError != nil { return nil, bifrostError } chatCompletionResponse.BackfillParams(req.BifrostRequest.ChatRequest) response.ChatResponse = chatCompletionResponse case schemas.ResponsesRequest: responsesResponse, bifrostError := provider.Responses(req.Context, key, req.BifrostRequest.ResponsesRequest) if bifrostError != nil { return nil, bifrostError } responsesResponse.BackfillParams(req.BifrostRequest.ResponsesRequest) response.ResponsesResponse = responsesResponse case schemas.CountTokensRequest: countTokensResponse, bifrostError := provider.CountTokens(req.Context, key, req.BifrostRequest.CountTokensRequest) if bifrostError != nil { return nil, bifrostError } response.CountTokensResponse = countTokensResponse case schemas.EmbeddingRequest: embeddingResponse, bifrostError := provider.Embedding(req.Context, key, req.BifrostRequest.EmbeddingRequest) if bifrostError != nil { return nil, bifrostError } response.EmbeddingResponse = embeddingResponse case schemas.RerankRequest: rerankResponse, bifrostError := provider.Rerank(req.Context, key, req.BifrostRequest.RerankRequest) if bifrostError != nil { return nil, bifrostError } response.RerankResponse = rerankResponse case schemas.OCRRequest: var customProviderConfig *schemas.CustomProviderConfig if config != nil { customProviderConfig = config.CustomProviderConfig } if bifrostError := providerUtils.CheckOperationAllowed(provider.GetProviderKey(), customProviderConfig, schemas.OCRRequest); bifrostError != nil { if req.BifrostRequest.OCRRequest != nil { bifrostError.ExtraFields.OriginalModelRequested = req.BifrostRequest.OCRRequest.Model } return nil, bifrostError } ocrResponse, bifrostError := provider.OCR(req.Context, key, req.BifrostRequest.OCRRequest) if bifrostError != nil { return nil, bifrostError } response.OCRResponse = ocrResponse case schemas.SpeechRequest: speechResponse, bifrostError := provider.Speech(req.Context, key, req.BifrostRequest.SpeechRequest) if bifrostError != nil { return nil, bifrostError } speechResponse.BackfillParams(req.BifrostRequest.SpeechRequest) response.SpeechResponse = speechResponse case schemas.TranscriptionRequest: transcriptionResponse, bifrostError := provider.Transcription(req.Context, key, req.BifrostRequest.TranscriptionRequest) if bifrostError != nil { return nil, bifrostError } transcriptionResponse.BackfillParams(req.BifrostRequest.TranscriptionRequest) response.TranscriptionResponse = transcriptionResponse case schemas.ImageGenerationRequest: imageResponse, bifrostError := provider.ImageGeneration(req.Context, key, req.BifrostRequest.ImageGenerationRequest) if bifrostError != nil { return nil, bifrostError } imageResponse.BackfillParams(&req.BifrostRequest) response.ImageGenerationResponse = imageResponse case schemas.ImageEditRequest: imageEditResponse, bifrostError := provider.ImageEdit(req.Context, key, req.BifrostRequest.ImageEditRequest) if bifrostError != nil { return nil, bifrostError } imageEditResponse.BackfillParams(&req.BifrostRequest) response.ImageGenerationResponse = imageEditResponse case schemas.ImageVariationRequest: imageVariationResponse, bifrostError := provider.ImageVariation(req.Context, key, req.BifrostRequest.ImageVariationRequest) if bifrostError != nil { return nil, bifrostError } imageVariationResponse.BackfillParams(&req.BifrostRequest) response.ImageGenerationResponse = imageVariationResponse case schemas.VideoGenerationRequest: videoGenerationResponse, bifrostError := provider.VideoGeneration(req.Context, key, req.BifrostRequest.VideoGenerationRequest) if bifrostError != nil { return nil, bifrostError } videoGenerationResponse.BackfillParams(&req.BifrostRequest) response.VideoGenerationResponse = videoGenerationResponse case schemas.VideoRetrieveRequest: videoRetrieveResponse, bifrostError := provider.VideoRetrieve(req.Context, key, req.BifrostRequest.VideoRetrieveRequest) if bifrostError != nil { return nil, bifrostError } response.VideoGenerationResponse = videoRetrieveResponse case schemas.VideoDownloadRequest: videoDownloadResponse, bifrostError := provider.VideoDownload(req.Context, key, req.BifrostRequest.VideoDownloadRequest) if bifrostError != nil { return nil, bifrostError } response.VideoDownloadResponse = videoDownloadResponse case schemas.VideoListRequest: videoListResponse, bifrostError := provider.VideoList(req.Context, key, req.BifrostRequest.VideoListRequest) if bifrostError != nil { return nil, bifrostError } response.VideoListResponse = videoListResponse case schemas.VideoDeleteRequest: videoDeleteResponse, bifrostError := provider.VideoDelete(req.Context, key, req.BifrostRequest.VideoDeleteRequest) if bifrostError != nil { return nil, bifrostError } response.VideoDeleteResponse = videoDeleteResponse case schemas.VideoRemixRequest: videoRemixResponse, bifrostError := provider.VideoRemix(req.Context, key, req.BifrostRequest.VideoRemixRequest) if bifrostError != nil { return nil, bifrostError } response.VideoGenerationResponse = videoRemixResponse case schemas.FileUploadRequest: fileUploadResponse, bifrostError := provider.FileUpload(req.Context, key, req.BifrostRequest.FileUploadRequest) if bifrostError != nil { return nil, bifrostError } response.FileUploadResponse = fileUploadResponse case schemas.FileListRequest: fileListResponse, bifrostError := provider.FileList(req.Context, keys, req.BifrostRequest.FileListRequest) if bifrostError != nil { return nil, bifrostError } response.FileListResponse = fileListResponse case schemas.FileRetrieveRequest: fileRetrieveResponse, bifrostError := provider.FileRetrieve(req.Context, keys, req.BifrostRequest.FileRetrieveRequest) if bifrostError != nil { return nil, bifrostError } response.FileRetrieveResponse = fileRetrieveResponse case schemas.FileDeleteRequest: fileDeleteResponse, bifrostError := provider.FileDelete(req.Context, keys, req.BifrostRequest.FileDeleteRequest) if bifrostError != nil { return nil, bifrostError } response.FileDeleteResponse = fileDeleteResponse case schemas.FileContentRequest: fileContentResponse, bifrostError := provider.FileContent(req.Context, keys, req.BifrostRequest.FileContentRequest) if bifrostError != nil { return nil, bifrostError } response.FileContentResponse = fileContentResponse case schemas.BatchCreateRequest: batchCreateResponse, bifrostError := provider.BatchCreate(req.Context, key, req.BifrostRequest.BatchCreateRequest) if bifrostError != nil { return nil, bifrostError } response.BatchCreateResponse = batchCreateResponse case schemas.BatchListRequest: batchListResponse, bifrostError := provider.BatchList(req.Context, keys, req.BifrostRequest.BatchListRequest) if bifrostError != nil { return nil, bifrostError } response.BatchListResponse = batchListResponse case schemas.BatchRetrieveRequest: batchRetrieveResponse, bifrostError := provider.BatchRetrieve(req.Context, keys, req.BifrostRequest.BatchRetrieveRequest) if bifrostError != nil { return nil, bifrostError } response.BatchRetrieveResponse = batchRetrieveResponse case schemas.BatchCancelRequest: batchCancelResponse, bifrostError := provider.BatchCancel(req.Context, keys, req.BifrostRequest.BatchCancelRequest) if bifrostError != nil { return nil, bifrostError } response.BatchCancelResponse = batchCancelResponse case schemas.BatchDeleteRequest: batchDeleteResponse, bifrostError := provider.BatchDelete(req.Context, keys, req.BifrostRequest.BatchDeleteRequest) if bifrostError != nil { return nil, bifrostError } response.BatchDeleteResponse = batchDeleteResponse case schemas.BatchResultsRequest: batchResultsResponse, bifrostError := provider.BatchResults(req.Context, keys, req.BifrostRequest.BatchResultsRequest) if bifrostError != nil { return nil, bifrostError } response.BatchResultsResponse = batchResultsResponse case schemas.ContainerCreateRequest: containerCreateResponse, bifrostError := provider.ContainerCreate(req.Context, key, req.BifrostRequest.ContainerCreateRequest) if bifrostError != nil { return nil, bifrostError } response.ContainerCreateResponse = containerCreateResponse case schemas.ContainerListRequest: containerListResponse, bifrostError := provider.ContainerList(req.Context, keys, req.BifrostRequest.ContainerListRequest) if bifrostError != nil { return nil, bifrostError } response.ContainerListResponse = containerListResponse case schemas.ContainerRetrieveRequest: containerRetrieveResponse, bifrostError := provider.ContainerRetrieve(req.Context, keys, req.BifrostRequest.ContainerRetrieveRequest) if bifrostError != nil { return nil, bifrostError } response.ContainerRetrieveResponse = containerRetrieveResponse case schemas.ContainerDeleteRequest: containerDeleteResponse, bifrostError := provider.ContainerDelete(req.Context, keys, req.BifrostRequest.ContainerDeleteRequest) if bifrostError != nil { return nil, bifrostError } response.ContainerDeleteResponse = containerDeleteResponse case schemas.ContainerFileCreateRequest: containerFileCreateResponse, bifrostError := provider.ContainerFileCreate(req.Context, key, req.BifrostRequest.ContainerFileCreateRequest) if bifrostError != nil { return nil, bifrostError } response.ContainerFileCreateResponse = containerFileCreateResponse case schemas.ContainerFileListRequest: containerFileListResponse, bifrostError := provider.ContainerFileList(req.Context, keys, req.BifrostRequest.ContainerFileListRequest) if bifrostError != nil { return nil, bifrostError } response.ContainerFileListResponse = containerFileListResponse case schemas.ContainerFileRetrieveRequest: containerFileRetrieveResponse, bifrostError := provider.ContainerFileRetrieve(req.Context, keys, req.BifrostRequest.ContainerFileRetrieveRequest) if bifrostError != nil { return nil, bifrostError } response.ContainerFileRetrieveResponse = containerFileRetrieveResponse case schemas.ContainerFileContentRequest: containerFileContentResponse, bifrostError := provider.ContainerFileContent(req.Context, keys, req.BifrostRequest.ContainerFileContentRequest) if bifrostError != nil { return nil, bifrostError } response.ContainerFileContentResponse = containerFileContentResponse case schemas.ContainerFileDeleteRequest: containerFileDeleteResponse, bifrostError := provider.ContainerFileDelete(req.Context, keys, req.BifrostRequest.ContainerFileDeleteRequest) if bifrostError != nil { return nil, bifrostError } response.ContainerFileDeleteResponse = containerFileDeleteResponse case schemas.PassthroughRequest: passthroughResponse, bifrostError := provider.Passthrough(req.Context, key, req.BifrostRequest.PassthroughRequest) if bifrostError != nil { return nil, bifrostError } response.PassthroughResponse = passthroughResponse default: _, model, _ := req.BifrostRequest.GetRequestFields() return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: fmt.Sprintf("unsupported request type: %s", req.RequestType), }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: req.RequestType, Provider: provider.GetProviderKey(), OriginalModelRequested: model, ResolvedModelUsed: model, }, } } return response, nil } // handleProviderStreamRequest handles the stream request to the provider based on the request type func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context)) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { switch req.RequestType { case schemas.TextCompletionStreamRequest: if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest { chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest() if chatRequest != nil { return provider.ChatCompletionStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ChatCompletionRequest), postHookSpanFinalizer, key, chatRequest) } } return provider.TextCompletionStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.TextCompletionRequest) case schemas.ChatCompletionStreamRequest: if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest { responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest() if responsesRequest != nil { return provider.ResponsesStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ResponsesRequest), postHookSpanFinalizer, key, responsesRequest) } } return provider.ChatCompletionStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.ChatRequest) case schemas.ResponsesStreamRequest: return provider.ResponsesStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.ResponsesRequest) case schemas.SpeechStreamRequest: return provider.SpeechStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.SpeechRequest) case schemas.TranscriptionStreamRequest: return provider.TranscriptionStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.TranscriptionRequest) case schemas.ImageGenerationStreamRequest: return provider.ImageGenerationStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.ImageGenerationRequest) case schemas.ImageEditStreamRequest: return provider.ImageEditStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.ImageEditRequest) case schemas.PassthroughStreamRequest: return provider.PassthroughStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.PassthroughRequest) default: _, model, _ := req.BifrostRequest.GetRequestFields() return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: fmt.Sprintf("unsupported request type: %s", req.RequestType), }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: req.RequestType, Provider: provider.GetProviderKey(), OriginalModelRequested: model, ResolvedModelUsed: model, }, } } } // handleMCPToolExecution is the common handler for MCP tool execution with plugin pipeline support. // It handles pre-hooks, execution, post-hooks, and error handling for both Chat and Responses formats. // // Parameters: // - ctx: Execution context // - mcpRequest: The MCP request to execute (already populated with tool call) // - requestType: The request type for error reporting (ChatCompletionRequest or ResponsesRequest) // // Returns: // - *schemas.BifrostMCPResponse: The MCP response after all hooks // - *schemas.BifrostError: Any execution error func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpRequest *schemas.BifrostMCPRequest, requestType schemas.RequestType) (*schemas.BifrostMCPResponse, *schemas.BifrostError) { if bifrost.MCPManager == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "mcp is not configured in this bifrost instance", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: requestType, }, } } // Ensure request ID exists for hooks/tracing consistency if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String()) } // Get plugin pipeline for MCP hooks pipeline := bifrost.getPluginPipeline() defer bifrost.releasePluginPipeline(pipeline) // Run pre-hooks preReq, shortCircuit, preCount := pipeline.RunMCPPreHooks(ctx, mcpRequest) // Handle short-circuit cases if shortCircuit != nil { // Handle short-circuit with response (success case) if shortCircuit.Response != nil { finalMcpResp, bifrostErr := pipeline.RunMCPPostHooks(ctx, shortCircuit.Response, nil, preCount) drainAndAttachPluginLogs(ctx) if bifrostErr != nil { return nil, bifrostErr } return finalMcpResp, nil } // Handle short-circuit with error if shortCircuit.Error != nil { // Capture post-hook results to respect transformations or recovery finalResp, finalErr := pipeline.RunMCPPostHooks(ctx, nil, shortCircuit.Error, preCount) drainAndAttachPluginLogs(ctx) // Return post-hook error if present (post-hook may have transformed the error) if finalErr != nil { return nil, finalErr } // Return post-hook response if present (post-hook may have recovered from error) if finalResp != nil { return finalResp, nil } // Fall back to original short-circuit error if post-hooks returned nil/nil return nil, shortCircuit.Error } } if preReq == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "MCP request after plugin hooks cannot be nil", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: requestType, }, } } // Execute tool with modified request result, err := bifrost.MCPManager.ExecuteToolCall(ctx, preReq) // Prepare MCP response and error for post-hooks var mcpResp *schemas.BifrostMCPResponse var bifrostErr *schemas.BifrostError if err != nil { bifrostErr = &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: err.Error(), }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: requestType, }, } // Preserve MCPUserOAuthRequiredError for downstream detection in agent mode var oauthErr *schemas.MCPUserOAuthRequiredError if errors.As(err, &oauthErr) { bifrostErr.ExtraFields.MCPAuthRequired = oauthErr } } else if result == nil { bifrostErr = &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: "tool execution returned nil result", }, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: requestType, }, } } else { // Use the MCP response directly mcpResp = result } // Run post-hooks finalResp, finalErr := pipeline.RunMCPPostHooks(ctx, mcpResp, bifrostErr, preCount) drainAndAttachPluginLogs(ctx) if finalErr != nil { return nil, finalErr } return finalResp, nil } // executeMCPToolWithHooks is a wrapper around handleMCPToolExecution that matches the signature // expected by the agent's executeToolFunc parameter. It runs MCP plugin hooks before and after // tool execution to enable logging, telemetry, and other plugin functionality. func (bifrost *Bifrost) executeMCPToolWithHooks(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) { // Defensive check: context must be non-nil to prevent panics in plugin hooks if ctx == nil { return nil, fmt.Errorf("context cannot be nil") } if request == nil { return nil, fmt.Errorf("request cannot be nil") } // Determine request type from the MCP request - explicitly handle all known types var requestType schemas.RequestType switch request.RequestType { case schemas.MCPRequestTypeChatToolCall: requestType = schemas.ChatCompletionRequest case schemas.MCPRequestTypeResponsesToolCall: requestType = schemas.ResponsesRequest default: // Return error for unknown/unsupported request types instead of silently defaulting return nil, fmt.Errorf("unsupported MCP request type: %s", request.RequestType) } resp, bifrostErr := bifrost.handleMCPToolExecution(ctx, request, requestType) if bifrostErr != nil { if bifrostErr.ExtraFields.MCPAuthRequired != nil { return nil, bifrostErr.ExtraFields.MCPAuthRequired } return nil, fmt.Errorf("%s", GetErrorMessage(bifrostErr)) } return resp, nil } // PLUGIN MANAGEMENT // RunLLMPreHooks executes PreHooks in order, tracks how many ran, and returns the final request, any short-circuit decision, and the count. func (p *PluginPipeline) RunLLMPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, int) { // If the skip plugin pipeline flag is set, skip the plugin pipeline if skipPluginPipeline, ok := ctx.Value(schemas.BifrostContextKeySkipPluginPipeline).(bool); ok && skipPluginPipeline { return req, nil, 0 } var shortCircuit *schemas.LLMPluginShortCircuit var err error ctx.BlockRestrictedWrites() defer ctx.UnblockRestrictedWrites() for i, plugin := range p.llmPlugins { pluginName := plugin.GetName() p.logger.Debug("running pre-hook for plugin %s", pluginName) // Start span for this plugin's PreLLMHook spanCtx, handle := p.tracer.StartSpan(ctx, fmt.Sprintf("plugin.%s.prehook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin) // Update pluginCtx with span context for nested operations if spanCtx != nil { if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { ctx.SetValue(schemas.BifrostContextKeySpanID, spanID) } } pluginCtx := ctx.WithPluginScope(&pluginName) req, shortCircuit, err = plugin.PreLLMHook(pluginCtx, req) pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { p.tracer.SetAttribute(handle, "error", err.Error()) p.tracer.EndSpan(handle, schemas.SpanStatusError, err.Error()) p.preHookErrors = append(p.preHookErrors, err) p.logger.Warn("error in PreLLMHook for plugin %s: %s", pluginName, err.Error()) } else if shortCircuit != nil { p.tracer.SetAttribute(handle, "short_circuit", true) p.tracer.EndSpan(handle, schemas.SpanStatusOk, "short-circuit") } else { p.tracer.EndSpan(handle, schemas.SpanStatusOk, "") } p.executedPreHooks = i + 1 if shortCircuit != nil { return req, shortCircuit, p.executedPreHooks // short-circuit: only plugins up to and including i ran } } return req, nil, p.executedPreHooks } // RunPostLLMHooks executes PostHooks in reverse order for the plugins whose PreLLMHook ran. // Accepts the response and error, and allows plugins to transform either (e.g., recover from error, or invalidate a response). // Returns the final response and error after all hooks. If both are set, error takes precedence unless error is nil. // runFrom is the count of plugins whose PreHooks ran; PostHooks will run in reverse from index (runFrom - 1) down to 0 // For streaming requests, it accumulates timing per plugin instead of creating individual spans per chunk. func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostResponse, *schemas.BifrostError) { // If the skip plugin pipeline flag is set, skip the plugin pipeline if skipPluginPipeline, ok := ctx.Value(schemas.BifrostContextKeySkipPluginPipeline).(bool); ok && skipPluginPipeline { return resp, bifrostErr } // Defensive: ensure count is within valid bounds if runFrom < 0 { runFrom = 0 } if runFrom > len(p.llmPlugins) { runFrom = len(p.llmPlugins) } requestType, _, _, _ := GetResponseFields(resp, bifrostErr) // Realtime turns carry StreamStartTime for plugin latency/final-chunk context, // but they are finalized as one completed turn, not chunk-by-chunk stream output. isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil && requestType != schemas.RealtimeRequest ctx.BlockRestrictedWrites() defer ctx.UnblockRestrictedWrites() var err error for i := runFrom - 1; i >= 0; i-- { plugin := p.llmPlugins[i] pluginName := plugin.GetName() p.logger.Debug("running post-hook for plugin %s", pluginName) if isStreaming { // For streaming: accumulate timing, don't create individual spans per chunk // Lazily create cached scoped contexts on first chunk (reused across all chunks) if p.streamScopedCtxs == nil { p.streamScopedCtxs = make(map[string]*schemas.BifrostContext, len(p.llmPlugins)) for _, pl := range p.llmPlugins { name := pl.GetName() p.streamScopedCtxs[name] = ctx.WithPluginScope(&name) } } pluginCtx := p.streamScopedCtxs[pluginName] start := time.Now() resp, bifrostErr, err = plugin.PostLLMHook(pluginCtx, resp, bifrostErr) duration := time.Since(start) p.accumulatePluginTiming(pluginName, duration, err != nil) if err != nil { p.postHookErrors = append(p.postHookErrors, err) p.logger.Warn("error in PostLLMHook for plugin %s: %v", pluginName, err) } } else { // For non-streaming: create span per plugin (existing behavior) spanCtx, handle := p.tracer.StartSpan(ctx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin) // Update pluginCtx with span context for nested operations if spanCtx != nil { if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { ctx.SetValue(schemas.BifrostContextKeySpanID, spanID) } } pluginCtx := ctx.WithPluginScope(&pluginName) resp, bifrostErr, err = plugin.PostLLMHook(pluginCtx, resp, bifrostErr) pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { p.tracer.SetAttribute(handle, "error", err.Error()) p.tracer.EndSpan(handle, schemas.SpanStatusError, err.Error()) p.postHookErrors = append(p.postHookErrors, err) p.logger.Warn("error in PostLLMHook for plugin %s: %v", pluginName, err) } else { p.tracer.EndSpan(handle, schemas.SpanStatusOk, "") } } // If a plugin recovers from an error (sets bifrostErr to nil and sets resp), allow that // If a plugin invalidates a response (sets resp to nil and sets bifrostErr), allow that } // Increment chunk count for streaming if isStreaming { p.streamingMu.Lock() p.chunkCount++ p.streamingMu.Unlock() } // Final logic: if both are set, error takes precedence, unless error is nil if bifrostErr != nil { if resp != nil && bifrostErr.StatusCode == nil && bifrostErr.Error != nil && bifrostErr.Error.Type == nil && bifrostErr.Error.Message == "" && bifrostErr.Error.Error == nil { // Defensive: treat as recovery if error is empty return resp, nil } return resp, bifrostErr } return resp, nil } // RunMCPPreHooks executes MCP PreHooks in order for all registered MCP plugins. // Returns the modified request, any short-circuit decision, and the count of hooks that ran. // If a plugin short-circuits, only PostHooks for plugins up to and including that plugin will run. func (p *PluginPipeline) RunMCPPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, int) { // If the skip plugin pipeline flag is set, skip the plugin pipeline if skipPluginPipeline, ok := ctx.Value(schemas.BifrostContextKeySkipPluginPipeline).(bool); ok && skipPluginPipeline { return req, nil, 0 } var shortCircuit *schemas.MCPPluginShortCircuit var err error ctx.BlockRestrictedWrites() defer ctx.UnblockRestrictedWrites() for i, plugin := range p.mcpPlugins { pluginName := plugin.GetName() p.logger.Debug("running MCP pre-hook for plugin %s", pluginName) // Start span for this plugin's PreMCPHook spanCtx, handle := p.tracer.StartSpan(ctx, fmt.Sprintf("plugin.%s.mcp_prehook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin) // Update pluginCtx with span context for nested operations if spanCtx != nil { if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { ctx.SetValue(schemas.BifrostContextKeySpanID, spanID) } } pluginCtx := ctx.WithPluginScope(&pluginName) req, shortCircuit, err = plugin.PreMCPHook(pluginCtx, req) pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { p.tracer.SetAttribute(handle, "error", err.Error()) p.tracer.EndSpan(handle, schemas.SpanStatusError, err.Error()) p.preHookErrors = append(p.preHookErrors, err) p.logger.Warn("error in PreMCPHook for plugin %s: %s", pluginName, err.Error()) } else if shortCircuit != nil { p.tracer.SetAttribute(handle, "short_circuit", true) p.tracer.EndSpan(handle, schemas.SpanStatusOk, "short-circuit") } else { p.tracer.EndSpan(handle, schemas.SpanStatusOk, "") } p.executedPreHooks = i + 1 if shortCircuit != nil { return req, shortCircuit, p.executedPreHooks // short-circuit: only plugins up to and including i ran } } return req, nil, p.executedPreHooks } // RunMCPPostHooks executes MCP PostHooks in reverse order for the plugins whose PreMCPHook ran. // Accepts the MCP response and error, and allows plugins to transform either (e.g., recover from error, or invalidate a response). // Returns the final MCP response and error after all hooks. If both are set, error takes precedence unless error is nil. // runFrom is the count of plugins whose PreHooks ran; PostHooks will run in reverse from index (runFrom - 1) down to 0 func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostMCPResponse, *schemas.BifrostError) { // If the skip plugin pipeline flag is set, skip the plugin pipeline if skipPluginPipeline, ok := ctx.Value(schemas.BifrostContextKeySkipPluginPipeline).(bool); ok && skipPluginPipeline { return mcpResp, bifrostErr } // Defensive: ensure count is within valid bounds if runFrom < 0 { runFrom = 0 } if runFrom > len(p.mcpPlugins) { runFrom = len(p.mcpPlugins) } ctx.BlockRestrictedWrites() defer ctx.UnblockRestrictedWrites() var err error for i := runFrom - 1; i >= 0; i-- { plugin := p.mcpPlugins[i] pluginName := plugin.GetName() p.logger.Debug("running MCP post-hook for plugin %s", pluginName) // Create span per plugin spanCtx, handle := p.tracer.StartSpan(ctx, fmt.Sprintf("plugin.%s.mcp_posthook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin) // Update pluginCtx with span context for nested operations if spanCtx != nil { if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { ctx.SetValue(schemas.BifrostContextKeySpanID, spanID) } } pluginCtx := ctx.WithPluginScope(&pluginName) mcpResp, bifrostErr, err = plugin.PostMCPHook(pluginCtx, mcpResp, bifrostErr) pluginCtx.ReleasePluginScope() // End span with appropriate status if err != nil { p.tracer.SetAttribute(handle, "error", err.Error()) p.tracer.EndSpan(handle, schemas.SpanStatusError, err.Error()) p.postHookErrors = append(p.postHookErrors, err) p.logger.Warn("error in PostMCPHook for plugin %s: %v", pluginName, err) } else { p.tracer.EndSpan(handle, schemas.SpanStatusOk, "") } // If a plugin recovers from an error (sets bifrostErr to nil and sets mcpResp), allow that // If a plugin invalidates a response (sets mcpResp to nil and sets bifrostErr), allow that } // Final logic: if both are set, error takes precedence, unless error is nil if bifrostErr != nil { if mcpResp != nil && bifrostErr.StatusCode == nil && bifrostErr.Error != nil && bifrostErr.Error.Type == nil && bifrostErr.Error.Message == "" && bifrostErr.Error.Error == nil { // Defensive: treat as recovery if error is empty return mcpResp, nil } return mcpResp, bifrostErr } return mcpResp, nil } // resetPluginPipeline resets a PluginPipeline instance for reuse. // IMPORTANT: drainAndAttachPluginLogs must be called on the root BifrostContext // BEFORE this method, because it calls ReleasePluginScope on cached scoped contexts // which nils out their pluginLogs pointer. The drain reads from the shared store // on the root context, so it must happen while the store is still referenced. func (p *PluginPipeline) resetPluginPipeline() { // Drop cross-request references while the object sits in the pool. // getPluginPipeline rebinds all four on acquisition, so nil'ing here // only affects GC hygiene — important when plugins are hot-swapped. p.llmPlugins = nil p.mcpPlugins = nil p.executedPreHooks = 0 clear(p.preHookErrors) p.preHookErrors = p.preHookErrors[:0] clear(p.postHookErrors) p.postHookErrors = p.postHookErrors[:0] // Reset streaming timing accumulation under lock — the provider goroutine's // deferred finalizer may still be iterating these fields when the pipeline // is returned to the pool. logger/tracer are nilled here too so the write // is synchronized with the finalizer's read under the same mutex. p.streamingMu.Lock() p.logger = nil p.tracer = nil p.chunkCount = 0 if p.postHookTimings != nil { // clear() drops *pluginTimingAccumulator values (freeing them for GC) // while retaining the map's backing hash table for reuse. clear(p.postHookTimings) } // clear() zeros elements in [0, len) — scrub before [:0] so the backing // array doesn't retain live string references once the slice is truncated. clear(p.postHookPluginOrder) p.postHookPluginOrder = p.postHookPluginOrder[:0] // Release cached scoped contexts for streaming for _, scopedCtx := range p.streamScopedCtxs { scopedCtx.ReleasePluginScope() } p.streamScopedCtxs = nil p.streamingMu.Unlock() } // drainAndAttachPluginLogs drains accumulated plugin logs from the BifrostContext // and attaches them to the trace for later retrieval by observability plugins. func drainAndAttachPluginLogs(ctx *schemas.BifrostContext) { tracer, traceID, err := GetTracerFromContext(ctx) if err != nil || tracer == nil || traceID == "" { return } logs := ctx.DrainPluginLogs() if len(logs) == 0 { return } tracer.AttachPluginLogs(traceID, logs) } // accumulatePluginTiming accumulates timing for a plugin during streaming func (p *PluginPipeline) accumulatePluginTiming(pluginName string, duration time.Duration, hasError bool) { p.streamingMu.Lock() defer p.streamingMu.Unlock() if p.postHookTimings == nil { p.postHookTimings = make(map[string]*pluginTimingAccumulator) } timing, ok := p.postHookTimings[pluginName] if !ok { timing = &pluginTimingAccumulator{} p.postHookTimings[pluginName] = timing // Track order on first occurrence (first chunk) p.postHookPluginOrder = append(p.postHookPluginOrder, pluginName) } timing.totalDuration += duration timing.invocations++ if hasError { timing.errors++ } } // FinalizeStreamingPostHookSpans creates aggregated spans for each plugin after streaming completes. // This should be called once at the end of streaming to create one span per plugin with average timing. // Spans are nested to mirror the pre-hook hierarchy (each post-hook is a child of the previous one). func (p *PluginPipeline) FinalizeStreamingPostHookSpans(ctx context.Context) { // Snapshot the accumulators under lock so per-chunk writers in the // provider goroutine can't race with the finalizer. Tracer calls below // run unlocked — we don't want to stall chunk writers on span I/O. type snapshotEntry struct { pluginName string totalDuration time.Duration invocations int errors int } p.streamingMu.Lock() // Capture tracer under the same lock that guards resetPluginPipeline's // writes so the read/write pair on p.tracer is synchronized and the // unlocked tracer calls below use a stable local. tracer := p.tracer if tracer == nil || p.postHookTimings == nil || len(p.postHookPluginOrder) == 0 { p.streamingMu.Unlock() return } snapshot := make([]snapshotEntry, 0, len(p.postHookPluginOrder)) for _, pluginName := range p.postHookPluginOrder { timing, ok := p.postHookTimings[pluginName] if !ok || timing.invocations == 0 { continue } snapshot = append(snapshot, snapshotEntry{ pluginName: pluginName, totalDuration: timing.totalDuration, invocations: timing.invocations, errors: timing.errors, }) } p.streamingMu.Unlock() if len(snapshot) == 0 { return } // Collect handles and timing info to end spans in reverse order type spanInfo struct { handle schemas.SpanHandle hasErrors bool } spans := make([]spanInfo, 0, len(snapshot)) currentCtx := ctx // Start spans in execution order (nested: each is a child of the previous) for _, entry := range snapshot { // Create span as child of the previous span (nested hierarchy) newCtx, handle := tracer.StartSpan(currentCtx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(entry.pluginName)), schemas.SpanKindPlugin) if handle == nil { continue } // Calculate average duration in milliseconds avgMs := float64(entry.totalDuration.Milliseconds()) / float64(entry.invocations) // Set aggregated attributes tracer.SetAttribute(handle, schemas.AttrPluginInvocations, entry.invocations) tracer.SetAttribute(handle, schemas.AttrPluginAvgDurationMs, avgMs) tracer.SetAttribute(handle, schemas.AttrPluginTotalDurationMs, entry.totalDuration.Milliseconds()) if entry.errors > 0 { tracer.SetAttribute(handle, schemas.AttrPluginErrorCount, entry.errors) } spans = append(spans, spanInfo{handle: handle, hasErrors: entry.errors > 0}) currentCtx = newCtx } // End spans in reverse order (innermost first, like unwinding a call stack) for i := len(spans) - 1; i >= 0; i-- { if spans[i].hasErrors { tracer.EndSpan(spans[i].handle, schemas.SpanStatusError, "some invocations failed") } else { tracer.EndSpan(spans[i].handle, schemas.SpanStatusOk, "") } } } // GetChunkCount returns the number of chunks processed during streaming func (p *PluginPipeline) GetChunkCount() int { p.streamingMu.Lock() defer p.streamingMu.Unlock() return p.chunkCount } // getPluginPipeline gets a PluginPipeline from the pool and configures it func (bifrost *Bifrost) getPluginPipeline() *PluginPipeline { pipeline := bifrost.pluginPipelinePool.Get().(*PluginPipeline) pipeline.llmPlugins = *bifrost.llmPlugins.Load() pipeline.mcpPlugins = *bifrost.mcpPlugins.Load() pipeline.logger = bifrost.logger pipeline.tracer = bifrost.getTracer() return pipeline } // releasePluginPipeline returns a PluginPipeline to the pool. // Caller must ensure drainAndAttachPluginLogs has already been called on the // associated BifrostContext before calling this method. func (bifrost *Bifrost) releasePluginPipeline(pipeline *PluginPipeline) { pipeline.resetPluginPipeline() bifrost.pluginPipelinePool.Put(pipeline) } // POOL & RESOURCE MANAGEMENT // getChannelMessage gets a ChannelMessage from the pool and configures it with the request. // It also gets response and error channels from their respective pools. func (bifrost *Bifrost) getChannelMessage(req schemas.BifrostRequest) *ChannelMessage { // Get channels from pool responseChan := bifrost.responseChannelPool.Get().(chan *schemas.BifrostResponse) errorChan := bifrost.errorChannelPool.Get().(chan schemas.BifrostError) // Clear any previous values to avoid leaking between requests select { case <-responseChan: default: } select { case <-errorChan: default: } // Get message from pool and configure it msg := bifrost.channelMessagePool.Get().(*ChannelMessage) msg.BifrostRequest = req msg.Response = responseChan msg.Err = errorChan // Conditionally allocate ResponseStream for streaming requests only if IsStreamRequestType(req.RequestType) { responseStreamChan := bifrost.responseStreamPool.Get().(chan chan *schemas.BifrostStreamChunk) // Clear any previous values to avoid leaking between requests select { case <-responseStreamChan: default: } msg.ResponseStream = responseStreamChan } return msg } // drainQueueWithErrors drains all buffered messages from pq and sends each a // "provider is shutting down" error. It must be called after all workers for // the queue have exited (i.e. after wg.Wait()) to cover the TOCTOU window: // a producer that passed isClosing() just before signalClosing fired can still // win the `case pq.queue <- msg` branch in tryRequest, landing a message in // the queue after the last worker's drain loop already exited via `default:`. // Without this sweep, those callers block forever on <-msg.Response / <-msg.Err. // // Residual TOCTOU window (known limitation): this sweep runs exactly once via // a non-blocking `select { default: }`. A producer that deposits a message // after the sweep's `default:` branch exits has no worker and no sweep to drain // it — the caller will block until its own context is cancelled. Fully closing // this window requires a sender-side reference count (so the last producer can // signal "queue is fully idle"), which is intentionally not implemented because // it would add per-send atomic overhead on the hot path. func (bifrost *Bifrost) drainQueueWithErrors(pq *ProviderQueue) { for { select { case r := <-pq.queue: provKey, mod, _ := r.GetRequestFields() select { case r.Err <- schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{Message: "provider is shutting down"}, ExtraFields: schemas.BifrostErrorExtraFields{ RequestType: r.RequestType, Provider: provKey, OriginalModelRequested: mod, }, }: case <-r.Context.Done(): // No time.After needed: r.Err is a buffered channel of size 1 freshly // allocated per request, so the send always completes immediately unless // the caller already cancelled. ctx.Done() is the only valid escape. } default: return } } } // releaseChannelMessage returns a ChannelMessage and its channels to their respective pools. func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { // Put channels back in pools bifrost.responseChannelPool.Put(msg.Response) bifrost.errorChannelPool.Put(msg.Err) // Return ResponseStream to pool if it was used if msg.ResponseStream != nil { // Drain any remaining channels to prevent memory leaks select { case <-msg.ResponseStream: default: } bifrost.responseStreamPool.Put(msg.ResponseStream) } // Release of Bifrost Request is handled in handle methods as they are required for fallbacks // Clear references and return to pool msg.Response = nil msg.ResponseStream = nil msg.Err = nil bifrost.channelMessagePool.Put(msg) } // resetBifrostRequest resets a BifrostRequest instance for reuse func resetBifrostRequest(req *schemas.BifrostRequest) { req.RequestType = "" req.ListModelsRequest = nil req.TextCompletionRequest = nil req.ChatRequest = nil req.ResponsesRequest = nil req.CountTokensRequest = nil req.EmbeddingRequest = nil req.RerankRequest = nil req.OCRRequest = nil req.SpeechRequest = nil req.TranscriptionRequest = nil req.ImageGenerationRequest = nil req.ImageEditRequest = nil req.ImageVariationRequest = nil req.VideoGenerationRequest = nil req.VideoRetrieveRequest = nil req.VideoDownloadRequest = nil req.VideoListRequest = nil req.VideoRemixRequest = nil req.VideoDeleteRequest = nil req.FileUploadRequest = nil req.FileListRequest = nil req.FileRetrieveRequest = nil req.FileDeleteRequest = nil req.FileContentRequest = nil req.BatchCreateRequest = nil req.BatchListRequest = nil req.BatchRetrieveRequest = nil req.BatchCancelRequest = nil req.BatchDeleteRequest = nil req.BatchResultsRequest = nil req.ContainerCreateRequest = nil req.ContainerListRequest = nil req.ContainerRetrieveRequest = nil req.ContainerDeleteRequest = nil req.ContainerFileCreateRequest = nil req.ContainerFileListRequest = nil req.ContainerFileRetrieveRequest = nil req.ContainerFileContentRequest = nil req.ContainerFileDeleteRequest = nil req.PassthroughRequest = nil } // getBifrostRequest gets a BifrostRequest from the pool func (bifrost *Bifrost) getBifrostRequest() *schemas.BifrostRequest { req := bifrost.bifrostRequestPool.Get().(*schemas.BifrostRequest) return req } // releaseBifrostRequest returns a BifrostRequest to the pool func (bifrost *Bifrost) releaseBifrostRequest(req *schemas.BifrostRequest) { resetBifrostRequest(req) bifrost.bifrostRequestPool.Put(req) } // resetMCPRequest resets a BifrostMCPRequest instance for reuse func resetMCPRequest(req *schemas.BifrostMCPRequest) { req.RequestType = "" req.ChatAssistantMessageToolCall = nil req.ResponsesToolMessage = nil } // getMCPRequest gets a BifrostMCPRequest from the pool func (bifrost *Bifrost) getMCPRequest() *schemas.BifrostMCPRequest { req := bifrost.mcpRequestPool.Get().(*schemas.BifrostMCPRequest) return req } // releaseMCPRequest returns a BifrostMCPRequest to the pool func (bifrost *Bifrost) releaseMCPRequest(req *schemas.BifrostMCPRequest) { resetMCPRequest(req) bifrost.mcpRequestPool.Put(req) } // getAllSupportedKeys retrieves all valid keys for a ListModels request. // allowing the provider to aggregate results from multiple keys. func (bifrost *Bifrost) getAllSupportedKeys(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider, baseProviderType schemas.ModelProvider) ([]schemas.Key, error) { // Check if key has been set in the context explicitly if ctx != nil { key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) if ok { if err := validateKey(baseProviderType, &key); err != nil { return nil, fmt.Errorf("invalid direct key for provider %v: %w", baseProviderType, err) } // If a direct key is specified, return it as a single-element slice return []schemas.Key{key}, nil } } keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey) if err != nil { return nil, err } if len(keys) == 0 { return nil, fmt.Errorf("no keys found for provider: %v", providerKey) } // Filter keys for ListModels - only check if key has a value var supportedKeys []schemas.Key for _, key := range keys { // Skip disabled keys (default enabled when nil) if key.Enabled != nil && !*key.Enabled { continue } if err := validateKey(baseProviderType, &key); err != nil { bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", key.Name, key.ID, providerKey, err.Error()) continue } if strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { supportedKeys = append(supportedKeys, key) } } bifrost.logger.Debug("[Bifrost] Provider %s: %d valid keys found", providerKey, len(supportedKeys)) if len(supportedKeys) == 0 { return nil, fmt.Errorf("no valid keys found for provider: %v", providerKey) } return supportedKeys, nil } // getKeysForBatchAndFileOps retrieves keys for batch and file operations with model filtering. // For batch operations, only keys with UseForBatchAPI enabled are included. // Model filtering: if model is specified and key has model restrictions, only include if model is in list. func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider, baseProviderType schemas.ModelProvider, model *string, isBatchOp bool) ([]schemas.Key, error) { // Check if key has been set in the context explicitly if ctx != nil { key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key) if ok { if err := validateKey(baseProviderType, &key); err != nil { return nil, fmt.Errorf("invalid direct key for provider %v: %w", baseProviderType, err) } // If a direct key is specified, return it as a single-element slice return []schemas.Key{key}, nil } } keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey) if err != nil { return nil, err } if len(keys) == 0 { return nil, fmt.Errorf("no keys found for provider: %v", providerKey) } var filteredKeys []schemas.Key for _, k := range keys { // Skip disabled keys if k.Enabled != nil && !*k.Enabled { continue } // For batch operations, only include keys with UseForBatchAPI enabled if isBatchOp && (k.UseForBatchAPI == nil || !*k.UseForBatchAPI) { continue } if err := validateKey(baseProviderType, &k); err != nil { bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", k.Name, k.ID, providerKey, err.Error()) continue } // Model filtering logic: // - If model is nil or empty → include all keys (no model filter) // - If model is specified: // - If model is in key.BlacklistedModels → exclude (wins over Models allow list) // - If key.Models is ["*"] → include key (supports all non-blacklisted models) // - If key.Models is empty → exclude key (deny-by-default) // - If key.Models is non-empty → only include if model is in list // Blacklist wins over allowlist if model != nil && *model != "" { if k.BlacklistedModels.IsBlocked(*model) || !k.Models.IsAllowed(*model) { continue } } // Check key value (or if provider allows empty keys or has Azure Entra ID credentials) if strings.TrimSpace(k.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { filteredKeys = append(filteredKeys, k) } } if len(filteredKeys) == 0 { modelStr := "" if model != nil { modelStr = *model } if isBatchOp { return nil, fmt.Errorf("no batch-enabled keys found for provider: %v and model: %s", providerKey, modelStr) } return nil, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, modelStr) } // Sort keys by ID for deterministic pagination order across requests sort.Slice(filteredKeys, func(i, j int) bool { return filteredKeys[i].ID < filteredKeys[j].ID }) return filteredKeys, nil } // selectKeyFromProviderForModelWithPool returns the filtered pool of eligible keys for the given // provider/model, along with a canRotate flag indicating whether key rotation across retries is // permitted. Key selection (choosing which key to use) is deferred to executeRequestWithRetries // via the keyProvider closure built by the caller. // // canRotate=false is returned for cases where the caller must always use the same key: // - DirectKey (caller-supplied key bypasses all selection) // - SkipKeySelection (provider allows keyless requests; empty slice returned) // - Explicit BifrostContextKeyAPIKeyID / APIKeyName (user pinned a specific key) // - Session stickiness (key persisted in KV store for the session lifetime) // - Single-key pool (only one eligible key — rotation is a no-op, KV write skipped) // // canRotate=true is returned when there are two or more eligible keys and no pinning // or stickiness constraint is in effect. func (bifrost *Bifrost) selectKeyFromProviderForModelWithPool(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) ([]schemas.Key, bool, error) { // DirectKey: caller supplied a key directly — no pool, no rotation. if ctx != nil { if key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key); ok { if err := validateKey(baseProviderType, &key); err != nil { return nil, false, fmt.Errorf("invalid direct key for provider %v: %w", baseProviderType, err) } return []schemas.Key{key}, false, nil } } // SkipKeySelection: provider allows keyless requests — return empty pool, no rotation. if skipKeySelection, ok := ctx.Value(schemas.BifrostContextKeySkipKeySelection).(bool); ok && skipKeySelection && isKeySkippingAllowed(providerKey) { return []schemas.Key{}, false, nil } // Get keys for provider keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey) if err != nil { return nil, false, err } if len(keys) == 0 { return nil, false, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, model) } // For batch API operations, filter keys to only include those with UseForBatchAPI enabled if isBatchRequestType(requestType) || isFileRequestType(requestType) { var batchEnabledKeys []schemas.Key for _, k := range keys { if k.UseForBatchAPI != nil && *k.UseForBatchAPI { batchEnabledKeys = append(batchEnabledKeys, k) } } if len(batchEnabledKeys) == 0 { return nil, false, fmt.Errorf("no config found for batch apis; enable 'Use for Batch APIs' on at least one key for provider: %v", providerKey) } keys = batchEnabledKeys } // Filter out keys that don't support the model: blacklisted_models wins over models allow list; // if the key has no models list, it supports all models except those blacklisted. var supportedKeys []schemas.Key // Skip model check conditions // We can improve these conditions in the future skipModelCheck := (model == "" && (isFileRequestType(requestType) || isBatchRequestType(requestType) || isContainerRequestType(requestType) || isModellessVideoRequestType(requestType) || isPassthroughRequestType(requestType))) || requestType == schemas.ListModelsRequest if skipModelCheck { // When skipping model check: just verify keys are enabled and have values for _, key := range keys { // Skip disabled keys if key.Enabled != nil && !*key.Enabled { continue } if err := validateKey(baseProviderType, &key); err != nil { bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", key.Name, key.ID, providerKey, err.Error()) continue } if strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) { supportedKeys = append(supportedKeys, key) } } } else { // When NOT skipping model check: do full model filtering for _, key := range keys { // Skip disabled keys if key.Enabled != nil && !*key.Enabled { continue } if err := validateKey(baseProviderType, &key); err != nil { bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", key.Name, key.ID, providerKey, err.Error()) continue } hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) // ["*"] = allow all models; [] = deny all; specific list = allow only listed // NOTE: Model filtering uses the original requested model (which may be an alias). // key.Models and key.BlacklistedModels must therefore be expressed in alias keys. // The provider-specific identifier is resolved later in the handler closure via key.Aliases.Resolve(model). modelSupported := hasValue && key.Models.IsAllowed(model) && !key.BlacklistedModels.IsBlocked(model) if baseProviderType == schemas.VLLM && key.VLLMKeyConfig != nil { if key.VLLMKeyConfig.ModelName != "" { modelSupported = modelSupported && (key.VLLMKeyConfig.ModelName == model) } } if modelSupported { supportedKeys = append(supportedKeys, key) } } } if len(supportedKeys) == 0 { return nil, false, fmt.Errorf("no keys found that support model: %s", model) } // Explicit key ID takes priority over key name — pin to that key, no rotation. if ctx != nil { if keyID, ok := ctx.Value(schemas.BifrostContextKeyAPIKeyID).(string); ok { if keyID = strings.TrimSpace(keyID); keyID != "" { for _, key := range supportedKeys { if key.ID == keyID { return []schemas.Key{key}, false, nil } } return nil, false, fmt.Errorf("no supported key found with id %q for provider: %v and model: %s", keyID, providerKey, model) } } if keyName, ok := ctx.Value(schemas.BifrostContextKeyAPIKeyName).(string); ok { if keyName = strings.TrimSpace(keyName); keyName != "" { for _, key := range supportedKeys { if key.Name == keyName { return []schemas.Key{key}, false, nil } } return nil, false, fmt.Errorf("no supported key found with name %q for provider: %v and model: %s", keyName, providerKey, model) } } } // Single key: no rotation possible, skip session stickiness (no KV write needed). if len(supportedKeys) == 1 { return []schemas.Key{supportedKeys[0]}, false, nil } // Session stickiness: on the first request for a session ID, the randomly selected key is // persisted in the KV store. Subsequent requests reuse it for the session lifetime. The sticky // key is intentionally kept fixed across all retry attempts — return it as a single-element // pool with canRotate=false so rate-limit retries also stay on the same key. sessionID := "" if ctx != nil { if id, ok := ctx.Value(schemas.BifrostContextKeySessionID).(string); ok && id != "" { sessionID = id } } fallbackIndex := 0 if ctx != nil { fallbackIndex, _ = ctx.Value(schemas.BifrostContextKeyFallbackIndex).(int) } stickinessActive := sessionID != "" && bifrost.kvStore != nil && fallbackIndex == 0 if stickinessActive { kvKey := buildSessionKey(providerKey, sessionID, model) ttl, _ := ctx.Value(schemas.BifrostContextKeySessionTTL).(time.Duration) if ttl <= 0 { ttl = schemas.DefaultSessionStickyTTL } if cachedKey, found, stale := getCachedKeyFromStore(bifrost.kvStore, kvKey, supportedKeys); found { if err := bifrost.kvStore.SetWithTTL(kvKey, cachedKey.ID, ttl); err != nil { bifrost.logger.Warn("error setting session cache for provider=%s key_id=%s: %s", providerKey, cachedKey.ID, err.Error()) } return []schemas.Key{cachedKey}, false, nil } else if stale { if _, err := bifrost.kvStore.Delete(kvKey); err != nil { bifrost.logger.Warn("error deleting stale session cache for provider=%s: %s", providerKey, err.Error()) } } selectedKey, err := bifrost.keySelector(ctx, supportedKeys, providerKey, model) if err != nil { return nil, false, err } wasSet, err := bifrost.kvStore.SetNXWithTTL(kvKey, selectedKey.ID, ttl) if err != nil { bifrost.logger.Warn("error setting session cache for provider=%s key_id=%s: %s", providerKey, selectedKey.ID, err.Error()) return []schemas.Key{selectedKey}, false, nil } if wasSet { return []schemas.Key{selectedKey}, false, nil } // Another concurrent request won the race — re-read the persisted key. if currentKey, found, stale := getCachedKeyFromStore(bifrost.kvStore, kvKey, supportedKeys); found { return []schemas.Key{currentKey}, false, nil } else if stale { if _, err := bifrost.kvStore.Delete(kvKey); err != nil { bifrost.logger.Warn("error deleting stale session cache for provider=%s: %s", providerKey, err.Error()) } return []schemas.Key{selectedKey}, false, nil } return []schemas.Key{selectedKey}, false, nil } // Normal case: return the full filtered pool with rotation enabled. return supportedKeys, true, nil } // getCachedKeyFromStore retrieves a key ID from the KV store and looks it up in supportedKeys. // Returns the matching Key, found (true if key exists in supportedKeys), and stale (true if // KV contains an ID but it is not in supportedKeys—caller should delete before SetNXWithTTL). func getCachedKeyFromStore(kvStore schemas.KVStore, kvKey string, supportedKeys []schemas.Key) (schemas.Key, bool, bool) { raw, err := kvStore.Get(kvKey) if err != nil { return schemas.Key{}, false, false } var cachedKeyID string switch v := raw.(type) { case string: cachedKeyID = v case []byte: var s string if err := sonic.Unmarshal(v, &s); err == nil { cachedKeyID = s } else { cachedKeyID = string(v) } } if cachedKeyID != "" { for _, k := range supportedKeys { if k.ID == cachedKeyID { return k, true, false } } return schemas.Key{}, false, true } return schemas.Key{}, false, false } // Shutdown gracefully stops all workers when triggered. // It closes all request channels and waits for workers to exit. func (bifrost *Bifrost) Shutdown() { bifrost.logger.Info("closing all request channels...") // Cancel the context if not already done if bifrost.ctx.Err() == nil && bifrost.cancel != nil { bifrost.cancel() } // Signal all provider queues to close. Workers exit via pq.done; // we never close pq.queue to avoid "send on closed channel" panics in // producers that are concurrently in tryRequest. bifrost.requestQueues.Range(func(key, value interface{}) bool { pq := value.(*ProviderQueue) pq.signalClosing() return true }) // Wait for all workers to exit bifrost.waitGroups.Range(func(key, value interface{}) bool { waitGroup := value.(*sync.WaitGroup) waitGroup.Wait() return true }) // Final drain sweep — same reasoning as RemoveProvider's Step 3b. bifrost.requestQueues.Range(func(key, value interface{}) bool { bifrost.drainQueueWithErrors(value.(*ProviderQueue)) return true }) // Cleanup MCP manager if bifrost.MCPManager != nil { err := bifrost.MCPManager.Cleanup() if err != nil { bifrost.logger.Warn("Error cleaning up MCP manager: %s", err.Error()) } } // Stop the tracerWrapper to clean up background goroutines if tracerWrapper := bifrost.tracer.Load().(*tracerWrapper); tracerWrapper != nil && tracerWrapper.tracer != nil { tracerWrapper.tracer.Stop() } // Cleanup plugins if llmPlugins := bifrost.llmPlugins.Load(); llmPlugins != nil { for _, plugin := range *llmPlugins { err := plugin.Cleanup() if err != nil { bifrost.logger.Warn(fmt.Sprintf("Error cleaning up LLM plugin: %s", err.Error())) } } } if mcpPlugins := bifrost.mcpPlugins.Load(); mcpPlugins != nil { for _, plugin := range *mcpPlugins { err := plugin.Cleanup() if err != nil { bifrost.logger.Warn(fmt.Sprintf("Error cleaning up MCP plugin: %s", err.Error())) } } } bifrost.logger.Info("all request channels closed") }