Files
bifrost/core/bifrost.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

7365 lines
265 KiB
Go

// Package bifrost provides the core implementation of the Bifrost system.
// Bifrost is a unified interface for interacting with various AI model providers,
// managing concurrent requests, and handling provider-specific configurations.
package bifrost
import (
"context"
"errors"
"fmt"
"slices"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/bytedance/sonic"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/keyselectors"
"github.com/maximhq/bifrost/core/mcp"
"github.com/maximhq/bifrost/core/mcp/codemode/starlark"
"github.com/maximhq/bifrost/core/providers/anthropic"
"github.com/maximhq/bifrost/core/providers/azure"
"github.com/maximhq/bifrost/core/providers/bedrock"
"github.com/maximhq/bifrost/core/providers/cerebras"
"github.com/maximhq/bifrost/core/providers/cohere"
"github.com/maximhq/bifrost/core/providers/elevenlabs"
"github.com/maximhq/bifrost/core/providers/fireworks"
"github.com/maximhq/bifrost/core/providers/gemini"
"github.com/maximhq/bifrost/core/providers/groq"
"github.com/maximhq/bifrost/core/providers/huggingface"
"github.com/maximhq/bifrost/core/providers/mistral"
"github.com/maximhq/bifrost/core/providers/nebius"
"github.com/maximhq/bifrost/core/providers/ollama"
"github.com/maximhq/bifrost/core/providers/openai"
"github.com/maximhq/bifrost/core/providers/openrouter"
"github.com/maximhq/bifrost/core/providers/parasail"
"github.com/maximhq/bifrost/core/providers/perplexity"
"github.com/maximhq/bifrost/core/providers/replicate"
"github.com/maximhq/bifrost/core/providers/runway"
"github.com/maximhq/bifrost/core/providers/sgl"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/providers/vertex"
"github.com/maximhq/bifrost/core/providers/vllm"
"github.com/maximhq/bifrost/core/providers/xai"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// ChannelMessage represents a message passed through the request channel.
// It contains the request, response and error channels, and the request type.
type ChannelMessage struct {
schemas.BifrostRequest
Context *schemas.BifrostContext
Response chan *schemas.BifrostResponse
ResponseStream chan chan *schemas.BifrostStreamChunk
Err chan schemas.BifrostError
}
// Bifrost manages providers and maintains specified open channels for concurrent processing.
// It handles request routing, provider management, and response processing.
type Bifrost struct {
ctx *schemas.BifrostContext
cancel context.CancelFunc
account schemas.Account // account interface
llmPlugins atomic.Pointer[[]schemas.LLMPlugin] // list of llm plugins
mcpPlugins atomic.Pointer[[]schemas.MCPPlugin] // list of mcp plugins
providers atomic.Pointer[[]schemas.Provider] // list of providers
requestQueues sync.Map // provider request queues (thread-safe), stores *ProviderQueue
waitGroups sync.Map // wait groups for each provider (thread-safe)
providerMutexes sync.Map // mutexes for each provider to prevent concurrent updates (thread-safe)
channelMessagePool sync.Pool // Pool for ChannelMessage objects, initial pool size is set in Init
responseChannelPool sync.Pool // Pool for response channels, initial pool size is set in Init
errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init
responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init
pluginPipelinePool sync.Pool // Pool for PluginPipeline objects
bifrostRequestPool sync.Pool // Pool for BifrostRequest objects
mcpRequestPool sync.Pool // Pool for BifrostMCPRequest objects
oauth2Provider schemas.OAuth2Provider // OAuth provider instance
logger schemas.Logger // logger instance, default logger is used if not provided
tracer atomic.Value // tracer for distributed tracing (stores schemas.Tracer, NoOpTracer if not configured)
MCPManager mcp.MCPManagerInterface // MCP integration manager (nil if MCP not configured)
mcpInitOnce sync.Once // Ensures MCP manager is initialized only once
dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead.
keySelector schemas.KeySelector // Custom key selector function
kvStore schemas.KVStore // optional KV store for session stickiness (nil = disabled)
}
// ProviderQueue wraps a provider's request channel with lifecycle management
// to prevent "send on closed channel" panics during provider removal/update.
// Producers must check the closing flag or select on the done channel before sending.
//
// Why pq.queue is NEVER closed:
//
// Closing a channel in Go causes any concurrent send to that channel to panic
// ("send on closed channel"). There is always a TOCTOU window between a
// producer's isClosing() check and its select { case pq.queue <- msg: ... }:
// the producer could pass isClosing() while the queue is open, get preempted,
// and resume only after the queue is closed. Go's selectgo evaluates select
// cases in a random order, so even having case <-pq.done: in the same select
// does not protect against this — if selectgo evaluates the send case first on
// a closed channel it panics immediately via goto sclose, before reaching done.
//
// To close pq.queue safely you would need a sender-side WaitGroup so that
// signalClosing could wait for every in-flight producer to finish. That adds
// non-trivial overhead on the hot request path.
//
// Instead, pq.done is the sole shutdown signal. Receiving from a closed channel
// is always safe (returns the zero value immediately), so:
// - Workers exit via case <-pq.done: — safe
// - Producers bail via case <-pq.done: — safe
// - drainQueueWithErrors handles any messages that slip through the TOCTOU window
//
// pq.queue is garbage collected automatically:
// - RemoveProvider calls requestQueues.Delete, dropping the map's reference.
// - UpdateProvider calls requestQueues.Store with a new queue, dropping the
// map's reference to oldPq. Shutdown does not Delete at all — the whole
// Bifrost instance is torn down.
// In all cases, once no producer goroutine holds a reference to the
// ProviderQueue, both the struct and pq.queue are eligible for GC.
// No explicit close is needed.
type ProviderQueue struct {
queue chan *ChannelMessage // the actual request queue channel — never closed, see above
done chan struct{} // closed by signalClosing() to signal shutdown; never written to otherwise
closing uint32 // atomic: 0 = open, 1 = closing
signalOnce sync.Once
}
func isLargePayloadPassthrough(ctx *schemas.BifrostContext) bool {
if ctx == nil {
return false
}
// Large payload mode intentionally skips JSON->Bifrost input materialization.
// Example: a 400MB multipart/audio upload sets Input=nil by design; strict
// non-nil validation here would reject valid passthrough requests.
isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool)
if !isLargePayload {
return false
}
// Verify reader is present (flag and reader are always set together by middleware)
reader := ctx.Value(schemas.BifrostContextKeyLargePayloadReader)
return reader != nil
}
// signalClosing signals the closing of the provider queue.
// This is lock-free: uses atomic store and sync.Once to safely signal shutdown.
func (pq *ProviderQueue) signalClosing() {
pq.signalOnce.Do(func() {
atomic.StoreUint32(&pq.closing, 1)
close(pq.done)
})
}
// isClosing returns true if the provider queue is closing.
// Uses atomic load for lock-free checking.
func (pq *ProviderQueue) isClosing() bool {
return atomic.LoadUint32(&pq.closing) == 1
}
// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation.
type PluginPipeline struct {
llmPlugins []schemas.LLMPlugin
mcpPlugins []schemas.MCPPlugin
logger schemas.Logger
tracer schemas.Tracer
// Number of PreHooks that were executed (used to determine which PostHooks to run in reverse order)
executedPreHooks int
// Errors from PreHooks and PostHooks
preHookErrors []error
postHookErrors []error
// streamingMu guards the streaming post-hook accumulators below. Per-chunk
// writes (accumulatePluginTiming) run in the provider goroutine while the
// end-of-stream finalizer (FinalizeStreamingPostHookSpans) and
// resetPluginPipeline can run in a different goroutine, so unsynchronised
// access triggers "concurrent map read and map write" panics.
streamingMu sync.Mutex
postHookTimings map[string]*pluginTimingAccumulator // keyed by plugin name
postHookPluginOrder []string // order in which post-hooks ran (for nested span creation)
chunkCount int
// Plugin logging: cached scoped contexts for streaming post-hooks (reused across chunks)
streamScopedCtxs map[string]*schemas.BifrostContext
}
// pluginTimingAccumulator accumulates timing information for a plugin across streaming chunks
type pluginTimingAccumulator struct {
totalDuration time.Duration
invocations int
errors int
}
// tracerWrapper wraps a Tracer to ensure atomic.Value stores consistent types.
// This is necessary because atomic.Value.Store() panics if called with values
// of different concrete types, even if they implement the same interface.
type tracerWrapper struct {
tracer schemas.Tracer
}
// INITIALIZATION
// Init initializes a new Bifrost instance with the given configuration.
// It sets up the account, plugins, object pools, and initializes providers.
// Returns an error if initialization fails.
// Initial Memory Allocations happens here as per the initial pool size.
func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) {
if config.Account == nil {
return nil, fmt.Errorf("account is required to initialize Bifrost")
}
if config.Logger == nil {
config.Logger = NewDefaultLogger(schemas.LogLevelInfo)
}
providerUtils.SetLogger(config.Logger)
// Initialize tracer (use NoOpTracer if not provided)
tracer := config.Tracer
if tracer == nil {
tracer = schemas.DefaultTracer()
}
bifrostCtx, cancel := schemas.NewBifrostContextWithCancel(ctx)
bifrost := &Bifrost{
ctx: bifrostCtx,
cancel: cancel,
account: config.Account,
llmPlugins: atomic.Pointer[[]schemas.LLMPlugin]{},
mcpPlugins: atomic.Pointer[[]schemas.MCPPlugin]{},
requestQueues: sync.Map{},
waitGroups: sync.Map{},
keySelector: config.KeySelector,
oauth2Provider: config.OAuth2Provider,
logger: config.Logger,
kvStore: config.KVStore,
}
bifrost.tracer.Store(&tracerWrapper{tracer: tracer})
if config.LLMPlugins == nil {
config.LLMPlugins = make([]schemas.LLMPlugin, 0)
}
if config.MCPPlugins == nil {
config.MCPPlugins = make([]schemas.MCPPlugin, 0)
}
bifrost.llmPlugins.Store(&config.LLMPlugins)
bifrost.mcpPlugins.Store(&config.MCPPlugins)
// Initialize providers slice
bifrost.providers.Store(&[]schemas.Provider{})
bifrost.dropExcessRequests.Store(config.DropExcessRequests)
if bifrost.keySelector == nil {
bifrost.keySelector = keyselectors.WeightedRandom
}
// Initialize object pools
bifrost.channelMessagePool = sync.Pool{
New: func() interface{} {
return &ChannelMessage{}
},
}
bifrost.responseChannelPool = sync.Pool{
New: func() interface{} {
return make(chan *schemas.BifrostResponse, 1)
},
}
bifrost.errorChannelPool = sync.Pool{
New: func() interface{} {
return make(chan schemas.BifrostError, 1)
},
}
bifrost.responseStreamPool = sync.Pool{
New: func() interface{} {
return make(chan chan *schemas.BifrostStreamChunk, 1)
},
}
bifrost.pluginPipelinePool = sync.Pool{
New: func() interface{} {
return &PluginPipeline{
preHookErrors: make([]error, 0),
postHookErrors: make([]error, 0),
}
},
}
bifrost.bifrostRequestPool = sync.Pool{
New: func() interface{} {
return &schemas.BifrostRequest{}
},
}
bifrost.mcpRequestPool = sync.Pool{
New: func() interface{} {
return &schemas.BifrostMCPRequest{}
},
}
// Prewarm pools with multiple objects
for range config.InitialPoolSize {
// Create and put new objects directly into pools
bifrost.channelMessagePool.Put(&ChannelMessage{})
bifrost.responseChannelPool.Put(make(chan *schemas.BifrostResponse, 1))
bifrost.errorChannelPool.Put(make(chan schemas.BifrostError, 1))
bifrost.responseStreamPool.Put(make(chan chan *schemas.BifrostStreamChunk, 1))
bifrost.pluginPipelinePool.Put(&PluginPipeline{
preHookErrors: make([]error, 0),
postHookErrors: make([]error, 0),
})
bifrost.bifrostRequestPool.Put(&schemas.BifrostRequest{})
bifrost.mcpRequestPool.Put(&schemas.BifrostMCPRequest{})
}
providerKeys, err := bifrost.account.GetConfiguredProviders()
if err != nil {
return nil, err
}
// Initialize MCP manager if configured
if config.MCPConfig != nil {
bifrost.mcpInitOnce.Do(func() {
// Set up plugin pipeline provider functions for executeCode tool hooks
mcpConfig := *config.MCPConfig
mcpConfig.PluginPipelineProvider = func() interface{} {
return bifrost.getPluginPipeline()
}
mcpConfig.ReleasePluginPipeline = func(pipeline interface{}) {
if pp, ok := pipeline.(*PluginPipeline); ok {
bifrost.releasePluginPipeline(pp)
}
}
// Create Starlark CodeMode for code execution
var codeModeConfig *mcp.CodeModeConfig
if mcpConfig.ToolManagerConfig != nil {
codeModeConfig = &mcp.CodeModeConfig{
BindingLevel: mcpConfig.ToolManagerConfig.CodeModeBindingLevel,
ToolExecutionTimeout: mcpConfig.ToolManagerConfig.ToolExecutionTimeout,
}
}
codeMode := starlark.NewStarlarkCodeMode(codeModeConfig, bifrost.logger)
bifrost.MCPManager = mcp.NewMCPManager(bifrostCtx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode)
bifrost.logger.Info("MCP integration initialized successfully")
})
}
// Create buffered channels for each provider and start workers
for _, providerKey := range providerKeys {
if strings.TrimSpace(string(providerKey)) == "" {
bifrost.logger.Warn("provider key is empty, skipping init")
continue
}
config, err := bifrost.account.GetConfigForProvider(providerKey)
if err != nil {
bifrost.logger.Warn("failed to get config for provider, skipping init: %v", err)
continue
}
if config == nil {
bifrost.logger.Warn("config is nil for provider %s, skipping init", providerKey)
continue
}
// Lock the provider mutex during initialization
providerMutex := bifrost.getProviderMutex(providerKey)
providerMutex.Lock()
err = bifrost.prepareProvider(providerKey, config)
providerMutex.Unlock()
if err != nil {
bifrost.logger.Warn("failed to prepare provider %s: %v", providerKey, err)
}
}
return bifrost, nil
}
// SetTracer sets the tracer for the Bifrost instance.
func (bifrost *Bifrost) SetTracer(tracer schemas.Tracer) {
if tracer == nil {
// Fall back to no-op tracer if not provided
tracer = schemas.DefaultTracer()
}
bifrost.tracer.Store(&tracerWrapper{tracer: tracer})
}
// getTracer returns the tracer from atomic storage with type assertion.
func (bifrost *Bifrost) getTracer() schemas.Tracer {
return bifrost.tracer.Load().(*tracerWrapper).tracer
}
// ReloadConfig reloads the config from DB
// Currently we update account, drop excess requests, and plugin lists
// We will keep on adding other aspects as required
func (bifrost *Bifrost) ReloadConfig(config schemas.BifrostConfig) error {
bifrost.dropExcessRequests.Store(config.DropExcessRequests)
return nil
}
// PUBLIC API METHODS
// ListModelsRequest sends a list models request to the specified provider.
func (bifrost *Bifrost) ListModelsRequest(ctx *schemas.BifrostContext, req *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "list models request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ListModelsRequest,
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for list models request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ListModelsRequest,
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ListModelsRequest
bifrostReq.ListModelsRequest = req
resp, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return resp.ListModelsResponse, nil
}
// ListAllModels lists all models from all configured providers.
// It accumulates responses from all providers with a limit of 1000 per provider to get all results.
func (bifrost *Bifrost) ListAllModels(ctx *schemas.BifrostContext, req *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
if req == nil {
req = &schemas.BifrostListModelsRequest{}
}
if ctx == nil {
ctx = bifrost.ctx
}
providerKeys, err := bifrost.GetConfiguredProviders()
if err != nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: err.Error(),
Error: err,
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ListModelsRequest,
},
}
}
startTime := time.Now()
// Result structure for collecting provider responses
type providerResult struct {
provider schemas.ModelProvider
models []schemas.Model
keyStatuses []schemas.KeyStatus
err *schemas.BifrostError
}
results := make(chan providerResult, len(providerKeys))
var wg sync.WaitGroup
// Launch concurrent requests for all providers
for _, providerKey := range providerKeys {
if strings.TrimSpace(string(providerKey)) == "" {
continue
}
wg.Add(1)
go func(providerKey schemas.ModelProvider) {
defer wg.Done()
providerCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
providerCtx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String())
providerModels := make([]schemas.Model, 0)
var providerKeyStatuses []schemas.KeyStatus
var providerErr *schemas.BifrostError
// Create request for this provider with limit of 1000
providerRequest := &schemas.BifrostListModelsRequest{
Provider: providerKey,
PageSize: schemas.DefaultPageSize,
Unfiltered: req.Unfiltered,
}
iterations := 0
for {
// check for context cancellation
select {
case <-ctx.Done():
bifrost.logger.Warn("context cancelled for provider %s", providerKey)
return
default:
}
iterations++
if iterations > schemas.MaxPaginationRequests {
bifrost.logger.Warn("reached maximum pagination requests (%d) for provider %s, please increase the page size", schemas.MaxPaginationRequests, providerKey)
break
}
response, bifrostErr := bifrost.ListModelsRequest(providerCtx, providerRequest)
if bifrostErr != nil {
// Skip logging "no keys found" and "not supported" errors as they are expected when a provider is not configured
if !strings.Contains(bifrostErr.Error.Message, "no keys found") &&
!strings.Contains(bifrostErr.Error.Message, "not supported") {
providerErr = bifrostErr
bifrost.logger.Warn("failed to list models for provider %s: %s", providerKey, GetErrorMessage(bifrostErr))
}
// Collect key statuses from error (failure case)
if len(bifrostErr.ExtraFields.KeyStatuses) > 0 {
providerKeyStatuses = append(providerKeyStatuses, bifrostErr.ExtraFields.KeyStatuses...)
}
break
}
if response == nil || len(response.Data) == 0 {
break
}
providerModels = append(providerModels, response.Data...)
if len(response.KeyStatuses) > 0 {
providerKeyStatuses = append(providerKeyStatuses, response.KeyStatuses...)
}
// Check if there are more pages
if response.NextPageToken == "" {
break
}
// Set the page token for the next request
providerRequest.PageToken = response.NextPageToken
}
results <- providerResult{
provider: providerKey,
models: providerModels,
keyStatuses: providerKeyStatuses,
err: providerErr,
}
}(providerKey)
}
// Wait for all goroutines to complete
wg.Wait()
close(results)
// Accumulate all models and key statuses from all providers
allModels := make([]schemas.Model, 0)
allKeyStatuses := make([]schemas.KeyStatus, 0)
var firstError *schemas.BifrostError
for result := range results {
if len(result.models) > 0 {
allModels = append(allModels, result.models...)
}
if len(result.keyStatuses) > 0 {
allKeyStatuses = append(allKeyStatuses, result.keyStatuses...)
}
if result.err != nil && firstError == nil {
firstError = result.err
}
}
// If we couldn't get any models from any provider, return the first error
if len(allModels) == 0 && firstError != nil {
// Attach all key statuses to the error
firstError.ExtraFields.KeyStatuses = allKeyStatuses
return nil, firstError
}
// Sort models alphabetically by ID
sort.Slice(allModels, func(i, j int) bool {
return allModels[i].ID < allModels[j].ID
})
// Return aggregated response with accumulated latency and key statuses
response := &schemas.BifrostListModelsResponse{
Data: allModels,
KeyStatuses: allKeyStatuses,
ExtraFields: schemas.BifrostResponseExtraFields{
RequestType: schemas.ListModelsRequest,
Latency: time.Since(startTime).Milliseconds(),
},
}
response = response.ApplyPagination(req.PageSize, req.PageToken)
return response, nil
}
// TextCompletionRequest sends a text completion request to the specified provider.
func (bifrost *Bifrost) TextCompletionRequest(ctx *schemas.BifrostContext, req *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "text completion request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.TextCompletionRequest,
},
}
}
if (req.Input == nil || (req.Input.PromptStr == nil && req.Input.PromptArray == nil)) && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "prompt not provided for text completion request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.TextCompletionRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
// Preparing request
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.TextCompletionRequest
bifrostReq.TextCompletionRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
// TODO: Release the response
return response.TextCompletionResponse, nil
}
// TextCompletionStreamRequest sends a streaming text completion request to the specified provider.
func (bifrost *Bifrost) TextCompletionStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "text completion stream request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.TextCompletionStreamRequest,
},
}
}
if (req.Input == nil || (req.Input.PromptStr == nil && req.Input.PromptArray == nil)) && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "text not provided for text completion stream request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.TextCompletionStreamRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.TextCompletionStreamRequest
bifrostReq.TextCompletionRequest = req
return bifrost.handleStreamRequest(ctx, bifrostReq)
}
func (bifrost *Bifrost) makeChatCompletionRequest(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "chat completion request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ChatCompletionRequest,
},
}
}
if req.Input == nil && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "chats not provided for chat completion request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ChatCompletionRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ChatCompletionRequest
bifrostReq.ChatRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.ChatResponse, nil
}
// ChatCompletionRequest sends a chat completion request to the specified provider.
func (bifrost *Bifrost) ChatCompletionRequest(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
// If ctx is nil, use the bifrost context (defensive check for mcp agent mode)
if ctx == nil {
ctx = bifrost.ctx
}
response, err := bifrost.makeChatCompletionRequest(ctx, req)
if err != nil {
return nil, err
}
// Check if we should enter agent mode
if bifrost.MCPManager != nil {
return bifrost.MCPManager.CheckAndExecuteAgentForChatRequest(
ctx,
req,
response,
bifrost.makeChatCompletionRequest,
bifrost.executeMCPToolWithHooks,
)
}
return response, nil
}
// ChatCompletionStreamRequest sends a chat completion stream request to the specified provider.
func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "chat completion stream request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ChatCompletionStreamRequest,
},
}
}
if req.Input == nil && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "chats not provided for chat completion request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ChatCompletionStreamRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ChatCompletionStreamRequest
bifrostReq.ChatRequest = req
return bifrost.handleStreamRequest(ctx, bifrostReq)
}
func (bifrost *Bifrost) makeResponsesRequest(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "responses request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ResponsesRequest,
},
}
}
// In large payload mode, Input is intentionally nil — body streams directly to upstream
if req.Input == nil {
isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool)
if !isLargePayload {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "responses not provided for responses request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ResponsesRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ResponsesRequest
bifrostReq.ResponsesRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.ResponsesResponse, nil
}
// ResponsesRequest sends a responses request to the specified provider.
func (bifrost *Bifrost) ResponsesRequest(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
// If ctx is nil, use the bifrost context (defensive check for mcp agent mode)
if ctx == nil {
ctx = bifrost.ctx
}
response, err := bifrost.makeResponsesRequest(ctx, req)
if err != nil {
return nil, err
}
// Check if we should enter agent mode
if bifrost.MCPManager != nil {
return bifrost.MCPManager.CheckAndExecuteAgentForResponsesRequest(
ctx,
req,
response,
bifrost.makeResponsesRequest,
bifrost.executeMCPToolWithHooks,
)
}
return response, nil
}
// ResponsesStreamRequest sends a responses stream request to the specified provider.
func (bifrost *Bifrost) ResponsesStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "responses stream request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ResponsesStreamRequest,
},
}
}
// In large payload mode, Input is intentionally nil — body streams directly to upstream
if req.Input == nil {
isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool)
if !isLargePayload {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "responses not provided for responses stream request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ResponsesStreamRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ResponsesStreamRequest
bifrostReq.ResponsesRequest = req
return bifrost.handleStreamRequest(ctx, bifrostReq)
}
// CountTokensRequest sends a count tokens request to the specified provider.
func (bifrost *Bifrost) CountTokensRequest(ctx *schemas.BifrostContext, req *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "count tokens request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.CountTokensRequest,
},
}
}
if req.Input == nil && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "input not provided for count tokens request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.CountTokensRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.CountTokensRequest
bifrostReq.CountTokensRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.CountTokensResponse, nil
}
// EmbeddingRequest sends an embedding request to the specified provider.
func (bifrost *Bifrost) EmbeddingRequest(ctx *schemas.BifrostContext, req *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "embedding request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.EmbeddingRequest,
},
}
}
hasExtraInputs := req.Params != nil && req.Params.ExtraParams != nil &&
(req.Params.ExtraParams["inputs"] != nil || req.Params.ExtraParams["images"] != nil)
if (req.Input == nil || (req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil)) && !hasExtraInputs && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "embedding input not provided for embedding request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.EmbeddingRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.EmbeddingRequest
bifrostReq.EmbeddingRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
// TODO: Release the response
return response.EmbeddingResponse, nil
}
// RerankRequest sends a rerank request to the specified provider.
func (bifrost *Bifrost) RerankRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "rerank request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.RerankRequest,
},
}
}
if strings.TrimSpace(req.Query) == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "query not provided for rerank request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.RerankRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
if len(req.Documents) == 0 {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "documents not provided for rerank request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.RerankRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
for i, doc := range req.Documents {
if strings.TrimSpace(doc.Text) == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: fmt.Sprintf("document text is empty at index %d", i),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.RerankRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.RerankRequest
bifrostReq.RerankRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.RerankResponse, nil
}
// OCRRequest sends an OCR request to the specified provider.
func (bifrost *Bifrost) OCRRequest(ctx *schemas.BifrostContext, req *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "ocr request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.OCRRequest,
},
}
}
if strings.TrimSpace(string(req.Document.Type)) == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "document type not provided for ocr request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.OCRRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
if req.Document.Type == schemas.OCRDocumentTypeDocumentURL && (req.Document.DocumentURL == nil || strings.TrimSpace(*req.Document.DocumentURL) == "") {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "document_url not provided for document_url type ocr request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.OCRRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
if req.Document.Type == schemas.OCRDocumentTypeImageURL && (req.Document.ImageURL == nil || strings.TrimSpace(*req.Document.ImageURL) == "") {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "image_url not provided for image_url type ocr request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.OCRRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.OCRRequest
bifrostReq.OCRRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.OCRResponse, nil
}
// SpeechRequest sends a speech request to the specified provider.
func (bifrost *Bifrost) SpeechRequest(ctx *schemas.BifrostContext, req *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "speech request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.SpeechRequest,
},
}
}
if (req.Input == nil || req.Input.Input == "") && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "speech input not provided for speech request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.SpeechRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.SpeechRequest
bifrostReq.SpeechRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
// TODO: Release the response
return response.SpeechResponse, nil
}
// SpeechStreamRequest sends a speech stream request to the specified provider.
func (bifrost *Bifrost) SpeechStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "speech stream request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.SpeechStreamRequest,
},
}
}
if (req.Input == nil || req.Input.Input == "") && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "speech input not provided for speech stream request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.SpeechStreamRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.SpeechStreamRequest
bifrostReq.SpeechRequest = req
return bifrost.handleStreamRequest(ctx, bifrostReq)
}
// TranscriptionRequest sends a transcription request to the specified provider.
func (bifrost *Bifrost) TranscriptionRequest(ctx *schemas.BifrostContext, req *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "transcription request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.TranscriptionRequest,
},
}
}
if (req.Input == nil || req.Input.File == nil) && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "transcription input not provided for transcription request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.TranscriptionRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.TranscriptionRequest
bifrostReq.TranscriptionRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
// TODO: Release the response
return response.TranscriptionResponse, nil
}
// TranscriptionStreamRequest sends a transcription stream request to the specified provider.
func (bifrost *Bifrost) TranscriptionStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "transcription stream request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.TranscriptionStreamRequest,
},
}
}
if (req.Input == nil || req.Input.File == nil) && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "transcription input not provided for transcription stream request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.TranscriptionStreamRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.TranscriptionStreamRequest
bifrostReq.TranscriptionRequest = req
return bifrost.handleStreamRequest(ctx, bifrostReq)
}
// ImageGenerationRequest sends an image generation request to the specified provider.
func (bifrost *Bifrost) ImageGenerationRequest(ctx *schemas.BifrostContext,
req *schemas.BifrostImageGenerationRequest,
) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "image generation request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageGenerationRequest,
},
}
}
if (req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "prompt not provided for image generation request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageGenerationRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ImageGenerationRequest
bifrostReq.ImageGenerationRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
if response == nil || response.ImageGenerationResponse == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "received nil response from provider",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageGenerationRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
return response.ImageGenerationResponse, nil
}
// ImageGenerationStreamRequest sends an image generation stream request to the specified provider.
func (bifrost *Bifrost) ImageGenerationStreamRequest(ctx *schemas.BifrostContext,
req *schemas.BifrostImageGenerationRequest,
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "image generation stream request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageGenerationStreamRequest,
},
}
}
if (req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "prompt not provided for image generation stream request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageGenerationStreamRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ImageGenerationStreamRequest
bifrostReq.ImageGenerationRequest = req
return bifrost.handleStreamRequest(ctx, bifrostReq)
}
// ImageEditRequest sends an image edit request to the specified provider.
func (bifrost *Bifrost) ImageEditRequest(ctx *schemas.BifrostContext, req *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "image edit request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageEditRequest,
},
}
}
if (req.Input == nil || req.Input.Images == nil || len(req.Input.Images) == 0) && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "images not provided for image edit request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageEditRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
// Prompt is not required for certain operation types that work without a text prompt
var imageEditParamsType *string
if req.Params != nil {
imageEditParamsType = req.Params.Type
}
if !isPromptOptionalImageEditType(imageEditParamsType) &&
(req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "prompt not provided for image edit request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageEditRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ImageEditRequest
bifrostReq.ImageEditRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
if response == nil || response.ImageGenerationResponse == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "received nil response from provider",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageEditRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
return response.ImageGenerationResponse, nil
}
// ImageEditStreamRequest sends an image edit stream request to the specified provider.
func (bifrost *Bifrost) ImageEditStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "image edit stream request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageEditStreamRequest,
},
}
}
if (req.Input == nil || req.Input.Images == nil || len(req.Input.Images) == 0) && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "images not provided for image edit stream request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageEditStreamRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
// Prompt is not required for certain operation types that work without a text prompt
var imageEditStreamParamsType *string
if req.Params != nil {
imageEditStreamParamsType = req.Params.Type
}
if !isPromptOptionalImageEditType(imageEditStreamParamsType) &&
(req.Input == nil || req.Input.Prompt == "") && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "prompt not provided for image edit stream request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageEditStreamRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ImageEditStreamRequest
bifrostReq.ImageEditRequest = req
return bifrost.handleStreamRequest(ctx, bifrostReq)
}
// ImageVariationRequest sends an image variation request to the specified provider.
func (bifrost *Bifrost) ImageVariationRequest(ctx *schemas.BifrostContext, req *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "image variation request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageVariationRequest,
},
}
}
if (req.Input == nil || req.Input.Image.Image == nil || len(req.Input.Image.Image) == 0) && !isLargePayloadPassthrough(ctx) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "image not provided for image variation request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageVariationRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ImageVariationRequest
bifrostReq.ImageVariationRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
if response == nil || response.ImageGenerationResponse == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "received nil response from provider",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ImageVariationRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
return response.ImageGenerationResponse, nil
}
// VideoGenerationRequest sends a video generation request to the specified provider.
func (bifrost *Bifrost) VideoGenerationRequest(ctx *schemas.BifrostContext,
req *schemas.BifrostVideoGenerationRequest,
) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "video generation request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoGenerationRequest,
},
}
}
if req.Input == nil || req.Input.Prompt == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "prompt not provided for video generation request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoGenerationRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.VideoGenerationRequest
bifrostReq.VideoGenerationRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
if response == nil || response.VideoGenerationResponse == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "received nil response from provider",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoGenerationRequest,
Provider: req.Provider,
OriginalModelRequested: req.Model,
ResolvedModelUsed: req.Model,
},
}
}
return response.VideoGenerationResponse, nil
}
func (bifrost *Bifrost) VideoRetrieveRequest(ctx *schemas.BifrostContext, req *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "video retrieve request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoRetrieveRequest,
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for video retrieve request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoRetrieveRequest,
},
}
}
if req.ID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "video_id is required for video retrieve request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoRetrieveRequest,
Provider: req.Provider,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.VideoRetrieveRequest
bifrostReq.VideoRetrieveRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
if response == nil || response.VideoGenerationResponse == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "received nil response from provider",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoRetrieveRequest,
Provider: req.Provider,
},
}
}
return response.VideoGenerationResponse, nil
}
// VideoDownloadRequest downloads video content from the provider.
func (bifrost *Bifrost) VideoDownloadRequest(ctx *schemas.BifrostContext, req *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "video download request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoDownloadRequest,
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for video download request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoDownloadRequest,
},
}
}
if req.ID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "video_id is required for video download request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoDownloadRequest,
Provider: req.Provider,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.VideoDownloadRequest
bifrostReq.VideoDownloadRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.VideoDownloadResponse, nil
}
func (bifrost *Bifrost) VideoRemixRequest(ctx *schemas.BifrostContext, req *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "video remix request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoRemixRequest,
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for video remix request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoRemixRequest,
},
}
}
if req.ID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "video_id is required for video remix request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoRemixRequest,
Provider: req.Provider,
},
}
}
if req.Input == nil || req.Input.Prompt == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "prompt is required for video remix request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoRemixRequest,
Provider: req.Provider,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.VideoRemixRequest
bifrostReq.VideoRemixRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
if response == nil || response.VideoGenerationResponse == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "received nil response from provider",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoRemixRequest,
Provider: req.Provider,
},
}
}
return response.VideoGenerationResponse, nil
}
func (bifrost *Bifrost) VideoListRequest(ctx *schemas.BifrostContext, req *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "video list request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoListRequest,
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for video list request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoListRequest,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.VideoListRequest
bifrostReq.VideoListRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.VideoListResponse, nil
}
func (bifrost *Bifrost) VideoDeleteRequest(ctx *schemas.BifrostContext, req *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "video delete request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoDeleteRequest,
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for video delete request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoDeleteRequest,
},
}
}
if req.ID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "video_id is required for video delete request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.VideoDeleteRequest,
Provider: req.Provider,
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.VideoDeleteRequest
bifrostReq.VideoDeleteRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.VideoDeleteResponse, nil
}
// BatchCreateRequest creates a new batch job for asynchronous processing.
func (bifrost *Bifrost) BatchCreateRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "batch create request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for batch create request",
},
}
}
if req.InputFileID == "" && len(req.Requests) == 0 {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "either input_file_id or requests is required for batch create request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
provider := bifrost.getProviderByKey(req.Provider)
if provider == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider not found for batch create request",
},
}
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.BatchCreateRequest
bifrostReq.BatchCreateRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.BatchCreateResponse, nil
}
// BatchListRequest lists batch jobs for the specified provider.
func (bifrost *Bifrost) BatchListRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "batch list request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for batch list request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.BatchListRequest
bifrostReq.BatchListRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.BatchListResponse, nil
}
// BatchRetrieveRequest retrieves a specific batch job.
func (bifrost *Bifrost) BatchRetrieveRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "batch retrieve request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for batch retrieve request",
},
}
}
if req.BatchID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "batch_id is required for batch retrieve request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.BatchRetrieveRequest
bifrostReq.BatchRetrieveRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.BatchRetrieveResponse, nil
}
// BatchCancelRequest cancels a batch job.
func (bifrost *Bifrost) BatchCancelRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "batch cancel request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for batch cancel request",
},
}
}
if req.BatchID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "batch_id is required for batch cancel request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.BatchCancelRequest
bifrostReq.BatchCancelRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.BatchCancelResponse, nil
}
// BatchDeleteRequest deletes a batch job.
func (bifrost *Bifrost) BatchDeleteRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "batch delete request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for batch delete request",
},
}
}
if req.BatchID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "batch_id is required for batch delete request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.BatchDeleteRequest
bifrostReq.BatchDeleteRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.BatchDeleteResponse, nil
}
// BatchResultsRequest retrieves results from a completed batch job.
func (bifrost *Bifrost) BatchResultsRequest(ctx *schemas.BifrostContext, req *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "batch results request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.BatchResultsRequest,
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for batch results request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.BatchResultsRequest,
},
}
}
if req.BatchID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "batch_id is required for batch results request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.BatchResultsRequest,
Provider: req.Provider,
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.BatchResultsRequest
bifrostReq.BatchResultsRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.BatchResultsResponse, nil
}
// FileUploadRequest uploads a file to the specified provider.
func (bifrost *Bifrost) FileUploadRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "file upload request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.FileUploadRequest,
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for file upload request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.FileUploadRequest,
},
}
}
if len(req.File) == 0 {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "file content is required for file upload request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.FileUploadRequest,
Provider: req.Provider,
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.FileUploadRequest
bifrostReq.FileUploadRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.FileUploadResponse, nil
}
// FileListRequest lists files from the specified provider.
func (bifrost *Bifrost) FileListRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "file list request is nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.FileListRequest,
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for file list request",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.FileListRequest,
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.FileListRequest
bifrostReq.FileListRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.FileListResponse, nil
}
// FileRetrieveRequest retrieves file metadata from the specified provider.
func (bifrost *Bifrost) FileRetrieveRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "file retrieve request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for file retrieve request",
},
}
}
if req.FileID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "file_id is required for file retrieve request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.FileRetrieveRequest
bifrostReq.FileRetrieveRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.FileRetrieveResponse, nil
}
// FileDeleteRequest deletes a file from the specified provider.
func (bifrost *Bifrost) FileDeleteRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "file delete request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for file delete request",
},
}
}
if req.FileID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "file_id is required for file delete request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.FileDeleteRequest
bifrostReq.FileDeleteRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.FileDeleteResponse, nil
}
// FileContentRequest downloads file content from the specified provider.
func (bifrost *Bifrost) FileContentRequest(ctx *schemas.BifrostContext, req *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "file content request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for file content request",
},
}
}
if req.FileID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "file_id is required for file content request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.FileContentRequest
bifrostReq.FileContentRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.FileContentResponse, nil
}
func (bifrost *Bifrost) Passthrough(
ctx *schemas.BifrostContext,
provider schemas.ModelProvider,
req *schemas.BifrostPassthroughRequest,
) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
if req == nil {
sc := fasthttp.StatusBadRequest
return nil, &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &sc,
Error: &schemas.ErrorField{Message: "passthrough request is nil"},
}
}
req.Provider = provider
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.PassthroughRequest
bifrostReq.PassthroughRequest = req
resp, bifrostErr := bifrost.handleRequest(ctx, bifrostReq)
if bifrostErr != nil {
return nil, bifrostErr
}
if resp == nil || resp.PassthroughResponse == nil {
sc := fasthttp.StatusBadGateway
return nil, &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &sc,
Error: &schemas.ErrorField{Message: "provider returned nil passthrough response"},
}
}
return resp.PassthroughResponse, nil
}
func (bifrost *Bifrost) PassthroughStream(
ctx *schemas.BifrostContext,
provider schemas.ModelProvider,
req *schemas.BifrostPassthroughRequest,
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
if req == nil {
sc := fasthttp.StatusBadRequest
return nil, &schemas.BifrostError{
IsBifrostError: false,
StatusCode: &sc,
Error: &schemas.ErrorField{Message: "passthrough request is nil"},
}
}
req.Provider = provider
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.PassthroughStreamRequest
bifrostReq.PassthroughRequest = req
return bifrost.handleStreamRequest(ctx, bifrostReq)
}
// ExecuteChatMCPTool executes an MCP tool call and returns the result as a chat message.
// This is the main public API for manual MCP tool execution in Chat format.
//
// Parameters:
// - ctx: Execution context
// - toolCall: The tool call to execute (from assistant message)
//
// Returns:
// - *schemas.ChatMessage: Tool message with execution result
// - *schemas.BifrostError: Any execution error
func (bifrost *Bifrost) ExecuteChatMCPTool(ctx *schemas.BifrostContext, toolCall *schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) {
// Handle nil context early to prevent issues downstream
if ctx == nil {
ctx = bifrost.ctx
}
// Validate toolCall is not nil
if toolCall == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "toolCall cannot be nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ChatCompletionRequest,
},
}
}
// Get MCP request from pool and populate
mcpRequest := bifrost.getMCPRequest()
mcpRequest.RequestType = schemas.MCPRequestTypeChatToolCall
mcpRequest.ChatAssistantMessageToolCall = toolCall
defer bifrost.releaseMCPRequest(mcpRequest)
// Execute with common handler
result, err := bifrost.handleMCPToolExecution(ctx, mcpRequest, schemas.ChatCompletionRequest)
if err != nil {
return nil, err
}
// Validate and extract chat message from result
if result == nil || result.ChatMessage == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "MCP tool execution returned nil chat message",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ChatCompletionRequest,
},
}
}
return result.ChatMessage, nil
}
// ExecuteResponsesMCPTool executes an MCP tool call and returns the result as a responses message.
// This is the main public API for manual MCP tool execution in Responses format.
//
// Parameters:
// - ctx: Execution context
// - toolCall: The tool call to execute (from assistant message)
//
// Returns:
// - *schemas.ResponsesMessage: Tool message with execution result
// - *schemas.BifrostError: Any execution error
func (bifrost *Bifrost) ExecuteResponsesMCPTool(ctx *schemas.BifrostContext, toolCall *schemas.ResponsesToolMessage) (*schemas.ResponsesMessage, *schemas.BifrostError) {
// Handle nil context early to prevent issues downstream
if ctx == nil {
ctx = bifrost.ctx
}
// Validate toolCall is not nil
if toolCall == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "toolCall cannot be nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ResponsesRequest,
},
}
}
// Get MCP request from pool and populate
mcpRequest := bifrost.getMCPRequest()
mcpRequest.RequestType = schemas.MCPRequestTypeResponsesToolCall
mcpRequest.ResponsesToolMessage = toolCall
defer bifrost.releaseMCPRequest(mcpRequest)
// Execute with common handler
result, err := bifrost.handleMCPToolExecution(ctx, mcpRequest, schemas.ResponsesRequest)
if err != nil {
return nil, err
}
// Validate and extract responses message from result
if result == nil || result.ResponsesMessage == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "MCP tool execution returned nil responses message",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ResponsesRequest,
},
}
}
return result.ResponsesMessage, nil
}
// ContainerCreateRequest creates a new container.
func (bifrost *Bifrost) ContainerCreateRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container create request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for container create request",
},
}
}
if req.Name == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "name is required for container create request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ContainerCreateRequest
bifrostReq.ContainerCreateRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.ContainerCreateResponse, nil
}
// ContainerListRequest lists containers.
func (bifrost *Bifrost) ContainerListRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container list request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for container list request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ContainerListRequest
bifrostReq.ContainerListRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.ContainerListResponse, nil
}
// ContainerRetrieveRequest retrieves a specific container.
func (bifrost *Bifrost) ContainerRetrieveRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container retrieve request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for container retrieve request",
},
}
}
if req.ContainerID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container_id is required for container retrieve request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ContainerRetrieveRequest
bifrostReq.ContainerRetrieveRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.ContainerRetrieveResponse, nil
}
// ContainerDeleteRequest deletes a container.
func (bifrost *Bifrost) ContainerDeleteRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container delete request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for container delete request",
},
}
}
if req.ContainerID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container_id is required for container delete request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ContainerDeleteRequest
bifrostReq.ContainerDeleteRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.ContainerDeleteResponse, nil
}
// ContainerFileCreateRequest creates a file in a container.
func (bifrost *Bifrost) ContainerFileCreateRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container file create request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for container file create request",
},
}
}
if req.ContainerID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container_id is required for container file create request",
},
}
}
if len(req.File) == 0 && (req.FileID == nil || strings.TrimSpace(*req.FileID) == "") && (req.Path == nil || strings.TrimSpace(*req.Path) == "") {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "one of file, file_id, or path is required for container file create request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ContainerFileCreateRequest
bifrostReq.ContainerFileCreateRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.ContainerFileCreateResponse, nil
}
// ContainerFileListRequest lists files in a container.
func (bifrost *Bifrost) ContainerFileListRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container file list request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for container file list request",
},
}
}
if req.ContainerID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container_id is required for container file list request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ContainerFileListRequest
bifrostReq.ContainerFileListRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.ContainerFileListResponse, nil
}
// ContainerFileRetrieveRequest retrieves a file from a container.
func (bifrost *Bifrost) ContainerFileRetrieveRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container file retrieve request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for container file retrieve request",
},
}
}
if req.ContainerID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container_id is required for container file retrieve request",
},
}
}
if req.FileID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "file_id is required for container file retrieve request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ContainerFileRetrieveRequest
bifrostReq.ContainerFileRetrieveRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.ContainerFileRetrieveResponse, nil
}
// ContainerFileContentRequest retrieves the content of a file from a container.
func (bifrost *Bifrost) ContainerFileContentRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container file content request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for container file content request",
},
}
}
if req.ContainerID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container_id is required for container file content request",
},
}
}
if req.FileID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "file_id is required for container file content request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ContainerFileContentRequest
bifrostReq.ContainerFileContentRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.ContainerFileContentResponse, nil
}
// ContainerFileDeleteRequest deletes a file from a container.
func (bifrost *Bifrost) ContainerFileDeleteRequest(ctx *schemas.BifrostContext, req *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
if req == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container file delete request is nil",
},
}
}
if req.Provider == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is required for container file delete request",
},
}
}
if req.ContainerID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "container_id is required for container file delete request",
},
}
}
if req.FileID == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "file_id is required for container file delete request",
},
}
}
if ctx == nil {
ctx = bifrost.ctx
}
bifrostReq := bifrost.getBifrostRequest()
bifrostReq.RequestType = schemas.ContainerFileDeleteRequest
bifrostReq.ContainerFileDeleteRequest = req
response, err := bifrost.handleRequest(ctx, bifrostReq)
if err != nil {
return nil, err
}
return response.ContainerFileDeleteResponse, nil
}
// RemovePlugin removes a plugin from the server.
func (bifrost *Bifrost) RemovePlugin(name string, pluginTypes []schemas.PluginType) error {
for _, pluginType := range pluginTypes {
switch pluginType {
case schemas.PluginTypeLLM:
err := bifrost.removeLLMPlugin(name)
if err != nil {
return err
}
case schemas.PluginTypeMCP:
err := bifrost.removeMCPPlugin(name)
if err != nil {
return err
}
}
}
return nil
}
// removeLLMPlugin removes an LLM plugin from the server.
func (bifrost *Bifrost) removeLLMPlugin(name string) error {
for {
oldPlugins := bifrost.llmPlugins.Load()
if oldPlugins == nil {
return nil
}
var pluginToCleanup schemas.LLMPlugin
found := false
// Create new slice without the plugin to remove
newPlugins := make([]schemas.LLMPlugin, 0, len(*oldPlugins))
for _, p := range *oldPlugins {
if p.GetName() == name {
pluginToCleanup = p
bifrost.logger.Debug("removing LLM plugin %s", name)
found = true
} else {
newPlugins = append(newPlugins, p)
}
}
if !found {
return nil
}
// Atomic compare-and-swap
if bifrost.llmPlugins.CompareAndSwap(oldPlugins, &newPlugins) {
// Cleanup the old plugin
err := pluginToCleanup.Cleanup()
if err != nil {
bifrost.logger.Warn("failed to cleanup old LLM plugin %s: %v", pluginToCleanup.GetName(), err)
}
return nil
}
// Retrying as swapping did not work
}
}
// removeMCPPlugin removes an MCP plugin from the server.
func (bifrost *Bifrost) removeMCPPlugin(name string) error {
for {
oldPlugins := bifrost.mcpPlugins.Load()
if oldPlugins == nil {
return nil
}
var pluginToCleanup schemas.MCPPlugin
found := false
// Create new slice without the plugin to remove
newPlugins := make([]schemas.MCPPlugin, 0, len(*oldPlugins))
for _, p := range *oldPlugins {
if p.GetName() == name {
pluginToCleanup = p
bifrost.logger.Debug("removing MCP plugin %s", name)
found = true
} else {
newPlugins = append(newPlugins, p)
}
}
if !found {
return nil
}
// Atomic compare-and-swap
if bifrost.mcpPlugins.CompareAndSwap(oldPlugins, &newPlugins) {
// Cleanup the old plugin
err := pluginToCleanup.Cleanup()
if err != nil {
bifrost.logger.Warn("failed to cleanup old MCP plugin %s: %v", pluginToCleanup.GetName(), err)
}
return nil
}
// Retrying as swapping did not work
}
}
// ReloadPlugin reloads a plugin with new instance
// During the reload - it's stop the world phase where we take a global lock on the plugin mutex
func (bifrost *Bifrost) ReloadPlugin(plugin schemas.BasePlugin, pluginTypes []schemas.PluginType) error {
for _, pluginType := range pluginTypes {
switch pluginType {
case schemas.PluginTypeLLM:
llmPlugin, ok := plugin.(schemas.LLMPlugin)
if !ok {
return fmt.Errorf("plugin %s is not an LLMPlugin", plugin.GetName())
}
err := bifrost.reloadLLMPlugin(llmPlugin)
if err != nil {
return err
}
case schemas.PluginTypeMCP:
mcpPlugin, ok := plugin.(schemas.MCPPlugin)
if !ok {
return fmt.Errorf("plugin %s is not an MCPPlugin", plugin.GetName())
}
err := bifrost.reloadMCPPlugin(mcpPlugin)
if err != nil {
return err
}
}
}
return nil
}
// reloadLLMPlugin reloads an LLM plugin with new instance
func (bifrost *Bifrost) reloadLLMPlugin(plugin schemas.LLMPlugin) error {
for {
var pluginToCleanup schemas.LLMPlugin
found := false
oldPlugins := bifrost.llmPlugins.Load()
// Create new slice with replaced plugin or initialize empty slice
var newPlugins []schemas.LLMPlugin
if oldPlugins == nil {
// Initialize new empty slice for the first plugin
newPlugins = make([]schemas.LLMPlugin, 0)
} else {
newPlugins = make([]schemas.LLMPlugin, len(*oldPlugins))
copy(newPlugins, *oldPlugins)
}
for i, p := range newPlugins {
if p.GetName() == plugin.GetName() {
// Cleaning up old plugin before replacing it
pluginToCleanup = p
bifrost.logger.Debug("replacing LLM plugin %s with new instance", plugin.GetName())
newPlugins[i] = plugin
found = true
break
}
}
if !found {
// This means that user is adding a new plugin
bifrost.logger.Debug("adding new LLM plugin %s", plugin.GetName())
newPlugins = append(newPlugins, plugin)
}
// Atomic compare-and-swap
if bifrost.llmPlugins.CompareAndSwap(oldPlugins, &newPlugins) {
// Cleanup the old plugin
if found && pluginToCleanup != nil {
err := pluginToCleanup.Cleanup()
if err != nil {
bifrost.logger.Warn("failed to cleanup old LLM plugin %s: %v", pluginToCleanup.GetName(), err)
}
}
return nil
}
// Retrying as swapping did not work
}
}
// reloadMCPPlugin reloads an MCP plugin with new instance
func (bifrost *Bifrost) reloadMCPPlugin(plugin schemas.MCPPlugin) error {
for {
var pluginToCleanup schemas.MCPPlugin
found := false
oldPlugins := bifrost.mcpPlugins.Load()
if oldPlugins == nil {
return nil
}
// Create new slice with replaced plugin
newPlugins := make([]schemas.MCPPlugin, len(*oldPlugins))
copy(newPlugins, *oldPlugins)
for i, p := range newPlugins {
if p.GetName() == plugin.GetName() {
// Cleaning up old plugin before replacing it
pluginToCleanup = p
bifrost.logger.Debug("replacing MCP plugin %s with new instance", plugin.GetName())
newPlugins[i] = plugin
found = true
break
}
}
if !found {
// This means that user is adding a new plugin
bifrost.logger.Debug("adding new MCP plugin %s", plugin.GetName())
newPlugins = append(newPlugins, plugin)
}
// Atomic compare-and-swap
if bifrost.mcpPlugins.CompareAndSwap(oldPlugins, &newPlugins) {
// Cleanup the old plugin
if found && pluginToCleanup != nil {
err := pluginToCleanup.Cleanup()
if err != nil {
bifrost.logger.Warn("failed to cleanup old MCP plugin %s: %v", pluginToCleanup.GetName(), err)
}
}
return nil
}
// Retrying as swapping did not work
}
}
// ReorderPlugins reorders all plugin slices (LLM, MCP) to match the given
// base plugin name ordering. This should be called after SortAndRebuildPlugins
// on the config layer to sync the core's execution order.
// Plugins not in the ordering are appended at the end (defensive).
func (bifrost *Bifrost) ReorderPlugins(orderedNames []string) {
pos := make(map[string]int, len(orderedNames))
for i, name := range orderedNames {
pos[name] = i
}
reorderAtomicSlice(&bifrost.llmPlugins, pos)
reorderAtomicSlice(&bifrost.mcpPlugins, pos)
}
// pluginWithName is satisfied by both LLMPlugin and MCPPlugin.
type pluginWithName interface {
GetName() string
}
// reorderAtomicSlice atomically reorders the plugin slice stored behind ptr
// so that plugins appear in the order given by pos (name → position).
// Uses CAS retry for lock-free safety.
func reorderAtomicSlice[T pluginWithName](ptr *atomic.Pointer[[]T], pos map[string]int) {
for {
old := ptr.Load()
if old == nil || len(*old) == 0 {
return
}
reordered := make([]T, len(*old))
copy(reordered, *old)
sort.SliceStable(reordered, func(i, j int) bool {
iPos, iOk := pos[reordered[i].GetName()]
jPos, jOk := pos[reordered[j].GetName()]
if !iOk && !jOk {
return false
}
if !iOk {
return false
}
if !jOk {
return true
}
return iPos < jPos
})
if ptr.CompareAndSwap(old, &reordered) {
return
}
}
}
// GetConfiguredProviders returns the configured providers.
//
// Returns:
// - []schemas.ModelProvider: List of configured providers
// - error: Any error that occurred during the retrieval process
//
// Example:
//
// providers, err := bifrost.GetConfiguredProviders()
// if err != nil {
// return nil, err
// }
// fmt.Println(providers)
func (bifrost *Bifrost) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
providers := bifrost.providers.Load()
if providers == nil {
return nil, fmt.Errorf("no providers configured")
}
modelProviders := make([]schemas.ModelProvider, len(*providers))
for i, provider := range *providers {
modelProviders[i] = provider.GetProviderKey()
}
return modelProviders, nil
}
// RemoveProvider removes a provider from the server.
// This method gracefully stops all workers for the provider,
// closes the request queue, and removes the provider from the providers slice.
//
// Parameters:
// - providerKey: The provider to remove
//
// Returns:
// - error: Any error that occurred during the removal process
func (bifrost *Bifrost) RemoveProvider(providerKey schemas.ModelProvider) error {
bifrost.logger.Info("Removing provider %s", providerKey)
providerMutex := bifrost.getProviderMutex(providerKey)
providerMutex.Lock()
defer providerMutex.Unlock()
// Step 1: Load the ProviderQueue and verify provider exists
pqValue, exists := bifrost.requestQueues.Load(providerKey)
if !exists {
return fmt.Errorf("provider %s not found in request queues", providerKey)
}
pq := pqValue.(*ProviderQueue)
// Step 2: Signal closing. Blocks new producers (isClosing() returns true) and
// causes idle workers to drain remaining buffered requests with errors then exit.
pq.signalClosing()
bifrost.logger.Debug("signaled closing for provider %s", providerKey)
// Step 3: Wait for all workers to finish in-flight requests and exit.
waitGroup, exists := bifrost.waitGroups.Load(providerKey)
if exists {
waitGroup.(*sync.WaitGroup).Wait()
bifrost.logger.Debug("all workers for provider %s have stopped", providerKey)
}
// Step 3b: Final drain sweep — see drainQueueWithErrors for full explanation.
bifrost.drainQueueWithErrors(pq)
// Step 4: Remove the provider from the request queues.
bifrost.requestQueues.Delete(providerKey)
// Step 5: Remove the provider from the wait groups.
bifrost.waitGroups.Delete(providerKey)
// Step 6: Remove the provider from the providers slice.
if err := bifrost.removeProviderFromSlice(providerKey); err != nil {
bifrost.logger.Error(
"provider %s was removed from queues but could not be removed from the providers slice — "+
"bifrost.providers is now inconsistent. "+
"To recover: retry RemoveProvider(%s), or restart Bifrost if that fails.",
providerKey, providerKey,
)
return err
}
bifrost.logger.Info("successfully removed provider %s", providerKey)
schemas.UnregisterKnownProvider(providerKey)
return nil
}
// UpdateProvider dynamically updates a provider with new configuration.
// This method gracefully recreates the provider instance with updated settings,
// stops existing workers, creates a new queue with updated settings,
// and starts new workers with the updated provider and concurrency configuration.
//
// Parameters:
// - providerKey: The provider to update
//
// Returns:
// - error: Any error that occurred during the update process
//
// Note: This operation will temporarily pause request processing for the specified provider
// while the transition occurs. In-flight requests will complete before workers are stopped.
// Buffered requests in the old queue will be transferred to the new queue to prevent loss.
//
// Concurrency safety — no-worker window:
// UpdateProvider holds a per-provider write lock (providerMutex.Lock) for its entire
// duration. All producer paths (tryRequest, tryStreamRequest) acquire the corresponding
// read lock inside getProviderQueue before they can look up or enqueue into any queue.
// This means no producer can observe or enqueue into newPq until UpdateProvider returns
// and releases the write lock — at which point new workers are already running and
// consuming newPq. There is therefore no window where newPq is visible to producers
// but has zero workers.
func (bifrost *Bifrost) UpdateProvider(providerKey schemas.ModelProvider) error {
bifrost.logger.Info(fmt.Sprintf("Updating provider configuration for provider %s", providerKey))
// Get the updated configuration from the account
providerConfig, err := bifrost.account.GetConfigForProvider(providerKey)
if err != nil {
return fmt.Errorf("failed to get updated config for provider %s: %v", providerKey, err)
}
if providerConfig == nil {
return fmt.Errorf("config is nil for provider %s", providerKey)
}
// Lock the provider to prevent concurrent access during update
providerMutex := bifrost.getProviderMutex(providerKey)
providerMutex.Lock()
defer providerMutex.Unlock()
// Check if provider currently exists
oldPqValue, exists := bifrost.requestQueues.Load(providerKey)
if !exists {
bifrost.logger.Debug("provider %s not currently active, initializing with new configuration", providerKey)
// If provider doesn't exist, just prepare it with new configuration
return bifrost.prepareProvider(providerKey, providerConfig)
}
oldPq := oldPqValue.(*ProviderQueue)
bifrost.logger.Debug("gracefully stopping existing workers for provider %s", providerKey)
// Step 1: Create new ProviderQueue with updated buffer size
newPq := &ProviderQueue{
queue: make(chan *ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize),
done: make(chan struct{}),
signalOnce: sync.Once{},
}
// Step 2: Atomically replace the queue so new producers immediately use newPq.
bifrost.requestQueues.Store(providerKey, newPq)
bifrost.logger.Debug("stored new queue for provider %s, new producers will use it", providerKey)
// Step 3: Transfer buffered requests from the old queue to the new queue BEFORE
// signalling workers to stop. This ensures buffered requests are processed by the
// new workers rather than being drained with errors.
// Old workers are still running and may consume some items concurrently — that is
// fine, they process them normally.
// If newPq is full during transfer, all remaining buffered requests are cancelled
// immediately rather than blocking — this avoids the deadlock where transfer goroutines
// wait for space that only opens once new workers start (which can't happen until
// the transfer completes).
transferredCount := 0
cancelledCount := 0
for {
select {
case msg := <-oldPq.queue:
select {
case newPq.queue <- msg:
transferredCount++
default:
// newPq is full — cancel this message and all remaining in oldPq.
cancelMsg := func(r *ChannelMessage) {
prov, mod, _ := r.BifrostRequest.GetRequestFields()
select {
case r.Err <- schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{Message: "request failed during provider concurrency update: queue full"},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: r.RequestType,
Provider: prov,
OriginalModelRequested: mod,
},
}:
case <-r.Context.Done():
}
}
cancelMsg(msg)
cancelledCount++
for {
select {
case r := <-oldPq.queue:
cancelMsg(r)
cancelledCount++
default:
goto transferComplete
}
}
}
default:
// No more buffered messages
goto transferComplete
}
}
transferComplete:
if transferredCount > 0 {
bifrost.logger.Info("transferred %d buffered requests to new queue for provider %s", transferredCount, providerKey)
}
if cancelledCount > 0 {
bifrost.logger.Warn("cancelled %d buffered requests during transfer for provider %s: new queue was full", cancelledCount, providerKey)
}
// Step 4: Signal the old queue is closing. Producers that still hold a reference to
// oldPq will detect this via isClosing() and transparently re-route to newPq.
// This happens after the transfer so the new queue is already populated before
// stale producers attempt their re-route.
oldPq.signalClosing()
bifrost.logger.Debug("signaled closing for old queue of provider %s", providerKey)
// Step 5: Wait for all existing workers to finish processing in-flight requests.
// Workers exit via oldPq.done (signalled above).
waitGroup, exists := bifrost.waitGroups.Load(providerKey)
if exists {
waitGroup.(*sync.WaitGroup).Wait()
bifrost.logger.Debug("all workers for provider %s have stopped", providerKey)
}
// Step 5b: Final drain sweep — see drainQueueWithErrors for full explanation.
bifrost.drainQueueWithErrors(oldPq)
// Step 6: Create new wait group for the updated workers.
bifrost.waitGroups.Store(providerKey, &sync.WaitGroup{})
// Step 7: Create provider instance.
provider, err := bifrost.createBaseProvider(providerKey, providerConfig)
if err != nil {
// Roll back: signal closing, remove from map, then drain.
// Order matters: Delete before drainQueueWithErrors so that producers
// re-routing via requestQueues.Load find nothing and return "provider
// shutting down" immediately, narrowing the TOCTOU window before the sweep.
newPq.signalClosing()
bifrost.requestQueues.Delete(providerKey)
bifrost.waitGroups.Delete(providerKey)
bifrost.drainQueueWithErrors(newPq)
if sliceErr := bifrost.removeProviderFromSlice(providerKey); sliceErr != nil {
bifrost.logger.Error(
"UpdateProvider rollback for %s is incomplete — provider was removed from queues "+
"but could not be removed from the providers slice: %v. "+
"bifrost.providers is now inconsistent. "+
"To recover: call RemoveProvider(%s) then AddProvider to re-register it, "+
"or restart Bifrost if that fails.",
providerKey, sliceErr, providerKey,
)
}
return fmt.Errorf("provider update for %s failed during initialization; provider has been removed — re-add or retry UpdateProvider to restore it: %v", providerKey, err)
}
// Step 8: Atomically replace the provider in the providers slice.
// This must happen before starting new workers to prevent stale reads
bifrost.logger.Debug("atomically replacing provider instance in providers slice for %s", providerKey)
replacementAttempts := 0
maxReplacementAttempts := 100 // Prevent infinite loops in high-contention scenarios
for {
replacementAttempts++
if replacementAttempts > maxReplacementAttempts {
newPq.signalClosing()
bifrost.requestQueues.Delete(providerKey)
bifrost.waitGroups.Delete(providerKey)
bifrost.drainQueueWithErrors(newPq)
if sliceErr := bifrost.removeProviderFromSlice(providerKey); sliceErr != nil {
bifrost.logger.Error(
"UpdateProvider rollback for %s is incomplete — provider was removed from queues "+
"but could not be removed from the providers slice: %v. "+
"bifrost.providers is now inconsistent. "+
"To recover: call RemoveProvider(%s) then AddProvider to re-register it, "+
"or restart Bifrost if that fails.",
providerKey, sliceErr, providerKey,
)
}
return fmt.Errorf("failed to replace provider %s in providers slice after %d attempts; provider has been removed — re-add or retry UpdateProvider to restore it", providerKey, maxReplacementAttempts)
}
oldPtr := bifrost.providers.Load()
var oldSlice []schemas.Provider
if oldPtr != nil {
oldSlice = *oldPtr
}
// Create new slice without the old provider of this key
// Use exact capacity to avoid allocations
newSlice := make([]schemas.Provider, 0, len(oldSlice))
oldProviderFound := false
for _, existingProvider := range oldSlice {
if existingProvider.GetProviderKey() != providerKey {
newSlice = append(newSlice, existingProvider)
} else {
oldProviderFound = true
}
}
// Add the new provider
newSlice = append(newSlice, provider)
if bifrost.providers.CompareAndSwap(oldPtr, &newSlice) {
if oldProviderFound {
bifrost.logger.Debug("successfully replaced existing provider instance for %s in providers slice", providerKey)
} else {
bifrost.logger.Debug("successfully added new provider instance for %s to providers slice", providerKey)
}
break
}
// Retrying as swapping did not work (likely due to concurrent modification)
}
// Step 9: Start new workers with updated concurrency.
bifrost.logger.Debug("starting %d new workers for provider %s with buffer size %d",
providerConfig.ConcurrencyAndBufferSize.Concurrency,
providerKey,
providerConfig.ConcurrencyAndBufferSize.BufferSize)
waitGroupValue, _ := bifrost.waitGroups.Load(providerKey)
currentWaitGroup := waitGroupValue.(*sync.WaitGroup)
for range providerConfig.ConcurrencyAndBufferSize.Concurrency {
currentWaitGroup.Add(1)
go bifrost.requestWorker(provider, providerConfig, newPq)
}
bifrost.logger.Info("successfully updated provider configuration for provider %s", providerKey)
return nil
}
// GetDropExcessRequests returns the current value of DropExcessRequests
func (bifrost *Bifrost) GetDropExcessRequests() bool {
return bifrost.dropExcessRequests.Load()
}
// UpdateDropExcessRequests updates the DropExcessRequests setting at runtime.
// This allows for hot-reloading of this configuration value.
func (bifrost *Bifrost) UpdateDropExcessRequests(value bool) {
bifrost.dropExcessRequests.Store(value)
bifrost.logger.Info("drop_excess_requests updated to: %v", value)
}
// getProviderMutex gets or creates a mutex for the given provider
func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *sync.RWMutex {
mutexValue, _ := bifrost.providerMutexes.LoadOrStore(providerKey, &sync.RWMutex{})
return mutexValue.(*sync.RWMutex)
}
// removeProviderFromSlice atomically removes the provider with the given key
// from bifrost.providers using a CAS retry loop. Callers hold the per-provider
// write mutex so no concurrent goroutine can re-add this key — contention is
// only from other providers' CAS operations, so the loop converges in at most
// a few iterations under any concurrency level.
// Returns an error if the limit is hit (state will be inconsistent).
func (bifrost *Bifrost) removeProviderFromSlice(providerKey schemas.ModelProvider) error {
const maxAttempts = 100
for range maxAttempts {
oldPtr := bifrost.providers.Load()
if oldPtr == nil {
return nil
}
oldSlice := *oldPtr
newSlice := make([]schemas.Provider, 0, len(oldSlice))
for _, p := range oldSlice {
if p.GetProviderKey() != providerKey {
newSlice = append(newSlice, p)
}
}
if bifrost.providers.CompareAndSwap(oldPtr, &newSlice) {
return nil
}
}
return fmt.Errorf("failed to remove provider %s from providers slice after %d attempts", providerKey, maxAttempts)
}
// MCP PUBLIC API
// RegisterMCPTool registers a typed tool handler with the MCP integration.
// This allows developers to easily add custom tools that will be available
// to all LLM requests processed by this Bifrost instance.
//
// Parameters:
// - name: Unique tool name
// - description: Human-readable tool description
// - handler: Function that handles tool execution
// - toolSchema: Bifrost tool schema for function calling
//
// Returns:
// - error: Any registration error
//
// Example:
//
// type EchoArgs struct {
// Message string `json:"message"`
// }
//
// err := bifrost.RegisterMCPTool("echo", "Echo a message",
// func(args EchoArgs) (string, error) {
// return args.Message, nil
// }, toolSchema)
func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.ChatTool) error {
if bifrost.MCPManager == nil {
return fmt.Errorf("mcp is not configured in this bifrost instance")
}
return bifrost.MCPManager.RegisterTool(name, description, handler, toolSchema)
}
// IMPORTANT: Running the MCP client management operations (GetMCPClients, AddMCPClient, RemoveMCPClient, EditMCPClientTools)
// may temporarily increase latency for incoming requests while the operations are being processed.
// These operations involve network I/O and connection management that require mutex locks
// which can block briefly during execution.
// GetMCPClients returns all MCP clients managed by the Bifrost instance.
//
// Returns:
// - []schemas.MCPClient: List of all MCP clients
// - error: Any retrieval error
func (bifrost *Bifrost) GetMCPClients() ([]schemas.MCPClient, error) {
if bifrost.MCPManager == nil {
return nil, fmt.Errorf("mcp is not configured in this bifrost instance")
}
clients := bifrost.MCPManager.GetClients()
clientsInConfig := make([]schemas.MCPClient, 0, len(clients))
for _, client := range clients {
tools := make([]schemas.ChatToolFunction, 0, len(client.ToolMap))
for _, tool := range client.ToolMap {
if tool.Function != nil {
// Create a deep copy (for name) of the tool function to avoid modifying the original
toolFunction := schemas.ChatToolFunction{}
toolFunction.Name = tool.Function.Name
toolFunction.Description = tool.Function.Description
toolFunction.Parameters = tool.Function.Parameters
toolFunction.Strict = tool.Function.Strict
// Remove the client prefix from the tool name
toolFunction.Name = strings.TrimPrefix(toolFunction.Name, client.ExecutionConfig.Name+"-")
tools = append(tools, toolFunction)
}
}
sort.Slice(tools, func(i, j int) bool {
return tools[i].Name < tools[j].Name
})
clientsInConfig = append(clientsInConfig, schemas.MCPClient{
Config: client.ExecutionConfig,
Tools: tools,
State: client.State,
})
}
return clientsInConfig, nil
}
// GetAvailableTools returns the available tools for the given context.
//
// Returns:
// - []schemas.ChatTool: List of available tools
func (bifrost *Bifrost) GetAvailableMCPTools(ctx *schemas.BifrostContext) []schemas.ChatTool {
if bifrost.MCPManager == nil {
return nil
}
return bifrost.MCPManager.GetAvailableTools(ctx)
}
// AddMCPClient adds a new MCP client to the Bifrost instance.
// This allows for dynamic MCP client management at runtime.
//
// Parameters:
// - config: MCP client configuration
//
// Returns:
// - error: Any registration error
//
// Example:
//
// err := bifrost.AddMCPClient(schemas.MCPClientConfig{
// Name: "my-mcp-client",
// ConnectionType: schemas.MCPConnectionTypeHTTP,
// ConnectionString: &url,
// })
func (bifrost *Bifrost) AddMCPClient(config *schemas.MCPClientConfig) error {
if bifrost.MCPManager == nil {
// Use sync.Once to ensure thread-safe initialization
bifrost.mcpInitOnce.Do(func() {
// Initialize with empty config - client will be added via AddClient below
mcpConfig := schemas.MCPConfig{
ClientConfigs: []*schemas.MCPClientConfig{},
}
// Set up plugin pipeline provider functions for executeCode tool hooks
mcpConfig.PluginPipelineProvider = func() interface{} {
return bifrost.getPluginPipeline()
}
mcpConfig.ReleasePluginPipeline = func(pipeline interface{}) {
if pp, ok := pipeline.(*PluginPipeline); ok {
bifrost.releasePluginPipeline(pp)
}
}
// Create Starlark CodeMode for code execution (with default config)
codeMode := starlark.NewStarlarkCodeMode(nil, bifrost.logger)
bifrost.MCPManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode)
})
}
// Handle case where initialization succeeded elsewhere but manager is still nil
if bifrost.MCPManager == nil {
return fmt.Errorf("MCP manager is not initialized")
}
return bifrost.MCPManager.AddClient(config)
}
// RemoveMCPClient removes an MCP client from the Bifrost instance.
// This allows for dynamic MCP client management at runtime.
//
// Parameters:
// - id: ID of the client to remove
//
// Returns:
// - error: Any removal error
//
// Example:
//
// err := bifrost.RemoveMCPClient("my-mcp-client-id")
// if err != nil {
// log.Fatalf("Failed to remove MCP client: %v", err)
// }
func (bifrost *Bifrost) RemoveMCPClient(id string) error {
if bifrost.MCPManager == nil {
return fmt.Errorf("mcp is not configured in this bifrost instance")
}
return bifrost.MCPManager.RemoveClient(id)
}
// SetMCPManager sets the MCP manager for this Bifrost instance.
// This allows injecting a custom MCP manager implementation (e.g., for enterprise features).
// If the provided manager is a concrete *mcp.MCPManager, Bifrost's plugin pipeline is injected
// into the manager's CodeMode so that nested tool calls run through the plugin hooks.
//
// Parameters:
// - manager: The MCP manager to set (must implement MCPManagerInterface)
func (bifrost *Bifrost) SetMCPManager(manager mcp.MCPManagerInterface) {
bifrost.MCPManager = manager
// Inject Bifrost's plugin pipeline into the manager's CodeMode so that
// nested tool calls (e.g. via Starlark executeCode) run through plugin hooks.
if m, ok := manager.(*mcp.MCPManager); ok {
m.SetPluginPipeline(
func() mcp.PluginPipeline {
pipeline := bifrost.getPluginPipeline()
if pp, ok := any(pipeline).(mcp.PluginPipeline); ok {
return pp
}
return nil
},
func(pipeline mcp.PluginPipeline) {
if pp, ok := pipeline.(*PluginPipeline); ok {
bifrost.releasePluginPipeline(pp)
}
},
)
}
}
// UpdateMCPClient updates the MCP client.
// This allows for dynamic MCP client tool management at runtime.
//
// Parameters:
// - id: ID of the client to edit
// - updatedConfig: Updated MCP client configuration
//
// Returns:
// - error: Any edit error
//
// Example:
//
// err := bifrost.UpdateMCPClient("my-mcp-client-id", schemas.MCPClientConfig{
// Name: "my-mcp-client-name",
// ToolsToExecute: []string{"tool1", "tool2"},
// })
func (bifrost *Bifrost) UpdateMCPClient(id string, updatedConfig *schemas.MCPClientConfig) error {
if bifrost.MCPManager == nil {
return fmt.Errorf("mcp is not configured in this bifrost instance")
}
return bifrost.MCPManager.UpdateClient(id, updatedConfig)
}
// ReconnectMCPClient attempts to reconnect an MCP client if it is disconnected.
//
// Parameters:
// - id: ID of the client to reconnect
//
// Returns:
// - error: Any reconnection error
func (bifrost *Bifrost) ReconnectMCPClient(id string) error {
if bifrost.MCPManager == nil {
return fmt.Errorf("mcp is not configured in this bifrost instance")
}
return bifrost.MCPManager.ReconnectClient(id)
}
// VerifyPerUserOAuthConnection delegates to the MCP manager to verify an MCP
// server using a temporary access token and discover available tools. The
// connection is closed after verification. If the MCP manager is not yet
// initialized, it is lazily created (same as AddMCPClient).
func (bifrost *Bifrost) VerifyPerUserOAuthConnection(ctx context.Context, config *schemas.MCPClientConfig, accessToken string) (map[string]schemas.ChatTool, map[string]string, error) {
// Ensure MCP manager is initialized (lazy init, same pattern as AddMCPClient)
if bifrost.MCPManager == nil {
bifrost.mcpInitOnce.Do(func() {
mcpConfig := schemas.MCPConfig{
ClientConfigs: []*schemas.MCPClientConfig{},
}
mcpConfig.PluginPipelineProvider = func() interface{} {
return bifrost.getPluginPipeline()
}
mcpConfig.ReleasePluginPipeline = func(pipeline interface{}) {
if pp, ok := pipeline.(*PluginPipeline); ok {
bifrost.releasePluginPipeline(pp)
}
}
codeMode := starlark.NewStarlarkCodeMode(nil, bifrost.logger)
bifrost.MCPManager = mcp.NewMCPManager(bifrost.ctx, mcpConfig, bifrost.oauth2Provider, bifrost.logger, codeMode)
})
}
if bifrost.MCPManager == nil {
return nil, nil, fmt.Errorf("MCP manager is not initialized")
}
return bifrost.MCPManager.VerifyPerUserOAuthConnection(ctx, config, accessToken)
}
// SetClientTools delegates to the MCP manager to update the tool map for an
// existing MCP client.
func (bifrost *Bifrost) SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) {
if bifrost.MCPManager != nil {
bifrost.MCPManager.SetClientTools(clientID, tools, toolNameMapping)
}
}
// UpdateToolManagerConfig updates the tool manager config for the MCP manager.
// This allows for hot-reloading of the tool manager config at runtime.
// Pass the current value of disableAutoToolInject whenever only other fields
// change so the flag is never silently reset to its zero value.
func (bifrost *Bifrost) UpdateToolManagerConfig(maxAgentDepth int, toolExecutionTimeoutInSeconds int, codeModeBindingLevel string, disableAutoToolInject bool) error {
if bifrost.MCPManager == nil {
return fmt.Errorf("mcp is not configured in this bifrost instance")
}
bifrost.MCPManager.UpdateToolManagerConfig(&schemas.MCPToolManagerConfig{
MaxAgentDepth: maxAgentDepth,
ToolExecutionTimeout: time.Duration(toolExecutionTimeoutInSeconds) * time.Second,
CodeModeBindingLevel: schemas.CodeModeBindingLevel(codeModeBindingLevel),
DisableAutoToolInject: disableAutoToolInject,
})
return nil
}
// PROVIDER MANAGEMENT
// createBaseProvider creates a provider based on the base provider type
func (bifrost *Bifrost) createBaseProvider(providerKey schemas.ModelProvider, config *schemas.ProviderConfig) (schemas.Provider, error) {
// Determine which provider type to create
targetProviderKey := providerKey
if config.CustomProviderConfig != nil {
// Validate custom provider config
if config.CustomProviderConfig.BaseProviderType == "" {
return nil, fmt.Errorf("custom provider config missing base provider type")
}
// Validate that base provider type is supported
if !IsSupportedBaseProvider(config.CustomProviderConfig.BaseProviderType) {
return nil, fmt.Errorf("unsupported base provider type: %s", config.CustomProviderConfig.BaseProviderType)
}
// Automatically set the custom provider key to the provider name
config.CustomProviderConfig.CustomProviderKey = string(providerKey)
targetProviderKey = config.CustomProviderConfig.BaseProviderType
}
switch targetProviderKey {
case schemas.OpenAI:
return openai.NewOpenAIProvider(config, bifrost.logger), nil
case schemas.Anthropic:
return anthropic.NewAnthropicProvider(config, bifrost.logger), nil
case schemas.Bedrock:
return bedrock.NewBedrockProvider(config, bifrost.logger)
case schemas.Cohere:
return cohere.NewCohereProvider(config, bifrost.logger)
case schemas.Azure:
return azure.NewAzureProvider(config, bifrost.logger)
case schemas.Vertex:
return vertex.NewVertexProvider(config, bifrost.logger)
case schemas.Mistral:
return mistral.NewMistralProvider(config, bifrost.logger), nil
case schemas.Ollama:
return ollama.NewOllamaProvider(config, bifrost.logger)
case schemas.Groq:
return groq.NewGroqProvider(config, bifrost.logger)
case schemas.SGL:
return sgl.NewSGLProvider(config, bifrost.logger)
case schemas.Parasail:
return parasail.NewParasailProvider(config, bifrost.logger)
case schemas.Perplexity:
return perplexity.NewPerplexityProvider(config, bifrost.logger)
case schemas.Cerebras:
return cerebras.NewCerebrasProvider(config, bifrost.logger)
case schemas.Gemini:
return gemini.NewGeminiProvider(config, bifrost.logger), nil
case schemas.OpenRouter:
return openrouter.NewOpenRouterProvider(config, bifrost.logger), nil
case schemas.Elevenlabs:
return elevenlabs.NewElevenlabsProvider(config, bifrost.logger), nil
case schemas.Nebius:
return nebius.NewNebiusProvider(config, bifrost.logger)
case schemas.HuggingFace:
return huggingface.NewHuggingFaceProvider(config, bifrost.logger), nil
case schemas.XAI:
return xai.NewXAIProvider(config, bifrost.logger)
case schemas.Replicate:
return replicate.NewReplicateProvider(config, bifrost.logger)
case schemas.VLLM:
return vllm.NewVLLMProvider(config, bifrost.logger)
case schemas.Runway:
return runway.NewRunwayProvider(config, bifrost.logger)
case schemas.Fireworks:
return fireworks.NewFireworksProvider(config, bifrost.logger)
default:
return nil, fmt.Errorf("unsupported provider: %s", targetProviderKey)
}
}
// prepareProvider sets up a provider with its configuration, keys, and worker channels.
// It initializes the request queue and starts worker goroutines for processing requests.
// Note: This function assumes the caller has already acquired the appropriate mutex for the provider.
func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, config *schemas.ProviderConfig) error {
// Create ProviderQueue with lifecycle management
pq := &ProviderQueue{
queue: make(chan *ChannelMessage, config.ConcurrencyAndBufferSize.BufferSize),
done: make(chan struct{}),
signalOnce: sync.Once{},
}
bifrost.requestQueues.Store(providerKey, pq)
// Start specified number of workers
bifrost.waitGroups.Store(providerKey, &sync.WaitGroup{})
provider, err := bifrost.createBaseProvider(providerKey, config)
if err != nil {
return fmt.Errorf("failed to create provider for the given key: %v", err)
}
waitGroupValue, _ := bifrost.waitGroups.Load(providerKey)
currentWaitGroup := waitGroupValue.(*sync.WaitGroup)
// Atomically append provider to the providers slice
for {
oldPtr := bifrost.providers.Load()
var oldSlice []schemas.Provider
if oldPtr != nil {
oldSlice = *oldPtr
}
newSlice := make([]schemas.Provider, len(oldSlice)+1)
copy(newSlice, oldSlice)
newSlice[len(oldSlice)] = provider
if bifrost.providers.CompareAndSwap(oldPtr, &newSlice) {
break
}
}
schemas.RegisterKnownProvider(providerKey)
for range config.ConcurrencyAndBufferSize.Concurrency {
currentWaitGroup.Add(1)
go bifrost.requestWorker(provider, config, pq)
}
return nil
}
// getProviderQueue returns the ProviderQueue for a given provider key.
// If the queue doesn't exist, it creates one at runtime and initializes the provider,
// given the provider config is provided in the account interface implementation.
// This function uses read locks to prevent race conditions during provider updates.
// Callers must check the closing flag or select on the done channel before sending.
func (bifrost *Bifrost) getProviderQueue(providerKey schemas.ModelProvider) (*ProviderQueue, error) {
// Use read lock to allow concurrent reads but prevent concurrent updates
providerMutex := bifrost.getProviderMutex(providerKey)
providerMutex.RLock()
if pqValue, exists := bifrost.requestQueues.Load(providerKey); exists {
pq := pqValue.(*ProviderQueue)
providerMutex.RUnlock()
return pq, nil
}
// Provider doesn't exist, need to create it
// Upgrade to write lock for creation
providerMutex.RUnlock()
providerMutex.Lock()
defer providerMutex.Unlock()
// Double-check after acquiring write lock (another goroutine might have created it)
if pqValue, exists := bifrost.requestQueues.Load(providerKey); exists {
pq := pqValue.(*ProviderQueue)
return pq, nil
}
bifrost.logger.Debug(fmt.Sprintf("Creating new request queue for provider %s at runtime", providerKey))
config, err := bifrost.account.GetConfigForProvider(providerKey)
if err != nil {
return nil, fmt.Errorf("failed to get config for provider: %v", err)
}
if config == nil {
return nil, fmt.Errorf("config is nil for provider %s", providerKey)
}
if err := bifrost.prepareProvider(providerKey, config); err != nil {
return nil, err
}
pqValue, ok := bifrost.requestQueues.Load(providerKey)
if !ok {
return nil, fmt.Errorf("request queue not found for provider %s", providerKey)
}
pq := pqValue.(*ProviderQueue)
return pq, nil
}
// GetProviderByKey returns the provider instance for the given provider key.
// Returns nil if no provider with the given key exists.
func (bifrost *Bifrost) GetProviderByKey(providerKey schemas.ModelProvider) schemas.Provider {
return bifrost.getProviderByKey(providerKey)
}
// SelectKeyForProviderRequestType selects an API key for the given provider, request type, and model.
// Used by WebSocket handlers that need a key for upstream connections while honoring request-specific
// AllowedRequests gates such as realtime-only support.
func (bifrost *Bifrost) SelectKeyForProviderRequestType(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string) (schemas.Key, error) {
if ctx == nil {
ctx = bifrost.ctx
}
baseProvider := providerKey
if config, err := bifrost.account.GetConfigForProvider(providerKey); err == nil && config != nil &&
config.CustomProviderConfig != nil && config.CustomProviderConfig.BaseProviderType != "" {
baseProvider = config.CustomProviderConfig.BaseProviderType
}
supportedKeys, _, err := bifrost.selectKeyFromProviderForModelWithPool(ctx, requestType, providerKey, model, baseProvider)
if err != nil {
return schemas.Key{}, err
}
if len(supportedKeys) == 0 {
return schemas.Key{}, nil
}
if len(supportedKeys) == 1 {
return supportedKeys[0], nil
}
return bifrost.keySelector(ctx, supportedKeys, providerKey, model)
}
// WSStreamHooks holds the post-hook runner and cleanup function returned by RunStreamPreHooks.
// Call PostHookRunner for each streaming chunk, setting StreamEndIndicator on the final chunk.
// Call Cleanup when done to release the pipeline back to the pool.
// If ShortCircuitResponse is non-nil, a plugin short-circuited with a cached response —
// the caller should write this response to the client and skip the upstream call.
type WSStreamHooks struct {
PostHookRunner schemas.PostHookRunner
Cleanup func()
ShortCircuitResponse *schemas.BifrostResponse
}
// RealtimeTurnHooks mirrors RunStreamPreHooks but is explicitly scoped to a
// single realtime turn rather than one long-lived transport connection.
type RealtimeTurnHooks struct {
PostHookRunner schemas.PostHookRunner
Cleanup func()
}
// RunStreamPreHooks acquires a plugin pipeline, sets up tracing context, runs PreLLMHooks,
// and returns a PostHookRunner for per-chunk post-processing.
// Used by WebSocket handlers that bypass the normal inference path but still need plugin hooks.
func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*WSStreamHooks, *schemas.BifrostError) {
if ctx == nil {
ctx = bifrost.ctx
}
if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok {
ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String())
}
tracer := bifrost.getTracer()
ctx.SetValue(schemas.BifrostContextKeyTracer, tracer)
// Create a trace so the logging plugin can accumulate streaming chunks.
// The traceID is used as the accumulator key in ProcessStreamingChunk.
if _, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); !ok {
traceID := tracer.CreateTrace("")
if traceID != "" {
ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID)
}
}
// Mark as streaming context so RunPostLLMHooks uses accumulated timing
ctx.SetValue(schemas.BifrostContextKeyStreamStartTime, time.Now())
pipeline := bifrost.getPluginPipeline()
cleanup := func() {
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" {
tracer.CleanupStreamAccumulator(traceID)
}
bifrost.releasePluginPipeline(pipeline)
}
// Capture provider/model from the original request for early-exit paths below.
// RequestType, Provider, OriginalModelRequested, and ResolvedModelUsed are always
// overwritten around RunPostLLMHooks — plugin modifications to these 4 fields are
// no-ops by design; proper request metadata is preserved and tampering is discouraged.
reqProvider, reqModel, _ := req.GetRequestFields()
preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req)
if preReq == nil && shortCircuit == nil {
bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
bifrostErr.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel)
_, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount)
if bifrostErr != nil {
bifrostErr.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel)
}
drainAndAttachPluginLogs(ctx)
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" {
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
cleanup()
return nil, bifrostErr
}
if shortCircuit != nil {
if shortCircuit.Error != nil {
shortCircuit.Error.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel)
_, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount)
if bifrostErr != nil {
bifrostErr.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel)
}
drainAndAttachPluginLogs(ctx)
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" {
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
cleanup()
if bifrostErr != nil {
return nil, bifrostErr
}
return nil, shortCircuit.Error
}
if shortCircuit.Response != nil {
shortCircuit.Response.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel)
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount)
if bifrostErr != nil {
bifrostErr.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel)
} else if resp != nil {
resp.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel)
}
drainAndAttachPluginLogs(ctx)
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" {
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
cleanup()
if bifrostErr != nil {
return nil, bifrostErr
}
return &WSStreamHooks{
Cleanup: func() {},
ShortCircuitResponse: resp,
}, nil
}
}
wsProvider, wsModel, _ := preReq.GetRequestFields()
postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
// Populate extra fields before RunPostLLMHooks so plugins (e.g. logging)
// can read requestType/provider/model from the chunk or error.
if result != nil {
result.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel)
}
if err != nil {
err.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel)
}
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount)
if IsFinalChunk(ctx) {
drainAndAttachPluginLogs(ctx)
}
if bifrostErr != nil {
bifrostErr.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel)
return nil, bifrostErr
} else if resp != nil {
resp.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel)
}
return resp, nil
}
return &WSStreamHooks{
PostHookRunner: postHookRunner,
Cleanup: cleanup,
}, nil
}
// RunRealtimeTurnPreHooks acquires a plugin pipeline and runs LLM pre-hooks for
// a single realtime turn. Unlike generic stream hooks, realtime turns do not
// support short-circuit responses in v1 because the transports cannot yet emit a
// fully synthetic assistant turn without an upstream generation.
func (bifrost *Bifrost) RunRealtimeTurnPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*RealtimeTurnHooks, *schemas.BifrostError) {
if req == nil {
bifrostErr := newBifrostErrorFromMsg("realtime turn request is nil")
bifrostErr.ExtraFields.RequestType = schemas.RealtimeRequest
return nil, bifrostErr
}
if ctx == nil {
ctx = bifrost.ctx
}
if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok {
ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String())
}
tracer := bifrost.getTracer()
ctx.SetValue(schemas.BifrostContextKeyTracer, tracer)
if _, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); !ok {
traceID := tracer.CreateTrace("")
if traceID != "" {
ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID)
}
}
pipeline := bifrost.getPluginPipeline()
cleanup := func() {
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" {
tracer.CleanupStreamAccumulator(traceID)
}
bifrost.releasePluginPipeline(pipeline)
}
provider, model, _ := req.GetRequestFields()
preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req)
if preReq == nil && shortCircuit == nil {
bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model)
_, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount)
drainAndAttachPluginLogs(ctx)
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" {
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
cleanup()
return nil, bifrostErr
}
if shortCircuit != nil {
if shortCircuit.Error != nil {
shortCircuit.Error.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model)
_, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount)
drainAndAttachPluginLogs(ctx)
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" {
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
cleanup()
if bifrostErr != nil {
return nil, bifrostErr
}
return nil, shortCircuit.Error
}
if shortCircuit.Response != nil {
// Short-circuit responses are not supported for realtime turns (v1).
// Treat this like an error turn so plugins can close pending state cleanly.
bifrostErr := newBifrostErrorFromMsg("realtime turn short-circuit responses are not supported")
bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model)
_, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount)
drainAndAttachPluginLogs(ctx)
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" {
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
cleanup()
return nil, bifrostErr
}
}
provider, model, _ = preReq.GetRequestFields()
return &RealtimeTurnHooks{
PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
if result != nil {
result.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model)
}
if err != nil {
err.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model)
}
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount)
drainAndAttachPluginLogs(ctx)
if bifrostErr != nil {
bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model)
return nil, bifrostErr
} else if resp != nil {
resp.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model)
}
return resp, nil
},
Cleanup: cleanup,
}, nil
}
// getProviderByKey retrieves a provider instance from the providers array by its provider key.
// Returns the provider if found, or nil if no provider with the given key exists.
func (bifrost *Bifrost) getProviderByKey(providerKey schemas.ModelProvider) schemas.Provider {
providers := bifrost.providers.Load()
if providers == nil {
return nil
}
// Checking if provider is in the memory
for _, provider := range *providers {
if provider.GetProviderKey() == providerKey {
return provider
}
}
// Could happen when provider is not initialized yet, check if provider config exists in account and if so, initialize it
config, err := bifrost.account.GetConfigForProvider(providerKey)
if err != nil || config == nil {
if slices.Contains(dynamicallyConfigurableProviders, providerKey) {
bifrost.logger.Info(fmt.Sprintf("initializing provider %s with default config", providerKey))
// If no config found, use default config
config = &schemas.ProviderConfig{
NetworkConfig: schemas.DefaultNetworkConfig,
ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize,
}
} else {
return nil
}
}
// Lock the provider mutex to avoid races
providerMutex := bifrost.getProviderMutex(providerKey)
providerMutex.Lock()
defer providerMutex.Unlock()
// Double-check after acquiring the lock
providers = bifrost.providers.Load()
if providers != nil {
for _, p := range *providers {
if p.GetProviderKey() == providerKey {
return p
}
}
}
// Preparing provider
if err := bifrost.prepareProvider(providerKey, config); err != nil {
return nil
}
// Return newly prepared provider without recursion
providers = bifrost.providers.Load()
if providers != nil {
for _, p := range *providers {
if p.GetProviderKey() == providerKey {
return p
}
}
}
return nil
}
// CORE INTERNAL LOGIC
// shouldTryFallbacks handles the primary error and returns true if we should proceed with fallbacks, false if we should return immediately
func (bifrost *Bifrost) shouldTryFallbacks(req *schemas.BifrostRequest, primaryErr *schemas.BifrostError) bool {
// If no primary error, we succeeded
if primaryErr == nil {
bifrost.logger.Debug("no primary error, we should not try fallbacks")
return false
}
// Handle request cancellation
if primaryErr.Error != nil && primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled {
bifrost.logger.Debug("request cancelled, we should not try fallbacks")
return false
}
// Check if this is a short-circuit error that doesn't allow fallbacks
// Note: AllowFallbacks = nil is treated as true (allow fallbacks by default)
if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks {
bifrost.logger.Debug("allowFallbacks is false, we should not try fallbacks")
return false
}
// If no fallbacks configured, return primary error
_, _, fallbacks := req.GetRequestFields()
if len(fallbacks) == 0 {
bifrost.logger.Debug("no fallbacks configured, we should not try fallbacks")
return false
}
// Should proceed with fallbacks
return true
}
// prepareFallbackRequest creates a fallback request and validates the provider config
// Returns the fallback request or nil if this fallback should be skipped
func (bifrost *Bifrost) prepareFallbackRequest(req *schemas.BifrostRequest, fallback schemas.Fallback) *schemas.BifrostRequest {
// Check if we have config for this fallback provider
_, err := bifrost.account.GetConfigForProvider(fallback.Provider)
if err != nil {
bifrost.logger.Warn("config not found for provider %s, skipping fallback: %v", fallback.Provider, err)
return nil
}
// Create a new request with the fallback provider and model
fallbackReq := *req
if req.TextCompletionRequest != nil {
tmp := *req.TextCompletionRequest
tmp.Provider = fallback.Provider
tmp.Model = fallback.Model
fallbackReq.TextCompletionRequest = &tmp
}
if req.ChatRequest != nil {
tmp := *req.ChatRequest
tmp.Provider = fallback.Provider
tmp.Model = fallback.Model
fallbackReq.ChatRequest = &tmp
}
if req.ResponsesRequest != nil {
tmp := *req.ResponsesRequest
tmp.Provider = fallback.Provider
tmp.Model = fallback.Model
fallbackReq.ResponsesRequest = &tmp
}
if req.CountTokensRequest != nil {
tmp := *req.CountTokensRequest
tmp.Provider = fallback.Provider
tmp.Model = fallback.Model
fallbackReq.CountTokensRequest = &tmp
}
if req.EmbeddingRequest != nil {
tmp := *req.EmbeddingRequest
tmp.Provider = fallback.Provider
tmp.Model = fallback.Model
fallbackReq.EmbeddingRequest = &tmp
}
if req.RerankRequest != nil {
tmp := *req.RerankRequest
tmp.Provider = fallback.Provider
tmp.Model = fallback.Model
fallbackReq.RerankRequest = &tmp
}
if req.OCRRequest != nil {
tmp := *req.OCRRequest
tmp.Provider = fallback.Provider
tmp.Model = fallback.Model
fallbackReq.OCRRequest = &tmp
}
if req.SpeechRequest != nil {
tmp := *req.SpeechRequest
tmp.Provider = fallback.Provider
tmp.Model = fallback.Model
fallbackReq.SpeechRequest = &tmp
}
if req.TranscriptionRequest != nil {
tmp := *req.TranscriptionRequest
tmp.Provider = fallback.Provider
tmp.Model = fallback.Model
fallbackReq.TranscriptionRequest = &tmp
}
if req.ImageGenerationRequest != nil {
tmp := *req.ImageGenerationRequest
tmp.Provider = fallback.Provider
tmp.Model = fallback.Model
fallbackReq.ImageGenerationRequest = &tmp
}
if req.VideoGenerationRequest != nil {
tmp := *req.VideoGenerationRequest
tmp.Provider = fallback.Provider
tmp.Model = fallback.Model
fallbackReq.VideoGenerationRequest = &tmp
}
return &fallbackReq
}
// shouldContinueWithFallbacks processes errors from fallback attempts
// Returns true if we should continue with more fallbacks, false if we should stop
func (bifrost *Bifrost) shouldContinueWithFallbacks(fallback schemas.Fallback, fallbackErr *schemas.BifrostError) bool {
if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled {
return false
}
// Check if it was a short-circuit error that doesn't allow fallbacks
if fallbackErr.AllowFallbacks != nil && !*fallbackErr.AllowFallbacks {
return false
}
bifrost.logger.Debug(fmt.Sprintf("Fallback provider %s failed: %s", fallback.Provider, fallbackErr.Error.Message))
return true
}
// handleRequest handles the request to the provider based on the request type
// It handles plugin hooks, request validation, response processing, and fallback providers.
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
// It is the wrapper for all non-streaming public API methods.
func (bifrost *Bifrost) handleRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
defer bifrost.releaseBifrostRequest(req)
provider, model, fallbacks := req.GetRequestFields()
if err := validateRequest(req); err != nil {
err.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, err
}
// Handle nil context early to prevent blocking
if ctx == nil {
ctx = bifrost.ctx
}
bifrost.logger.Debug(fmt.Sprintf("primary provider %s with model %s and %d fallbacks", provider, model, len(fallbacks)))
// Try the primary provider first
ctx.SetValue(schemas.BifrostContextKeyFallbackIndex, 0)
// Ensure request ID is set in context before PreHooks
if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok {
requestID := uuid.New().String()
ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID)
}
primaryResult, primaryErr := bifrost.tryRequest(ctx, req)
if primaryErr != nil {
if primaryErr.Error != nil {
bifrost.logger.Debug(fmt.Sprintf("primary provider %s with model %s returned error: %s", provider, model, primaryErr.Error.Message))
} else {
bifrost.logger.Debug(fmt.Sprintf("primary provider %s with model %s returned error: %v", provider, model, primaryErr))
}
if len(fallbacks) > 0 {
bifrost.logger.Debug(fmt.Sprintf("check if we should try %d fallbacks", len(fallbacks)))
}
}
// Check if we should proceed with fallbacks
shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr)
if !shouldTryFallbacks {
return primaryResult, primaryErr
}
// Try fallbacks in order
for i, fallback := range fallbacks {
ctx.SetValue(schemas.BifrostContextKeyFallbackIndex, i+1)
bifrost.logger.Debug(fmt.Sprintf("trying fallback provider %s with model %s", fallback.Provider, fallback.Model))
ctx.SetValue(schemas.BifrostContextKeyFallbackRequestID, uuid.New().String())
clearCtxForFallback(ctx)
// Start span for fallback attempt
tracer := bifrost.getTracer()
spanCtx, handle := tracer.StartSpan(ctx, fmt.Sprintf("fallback.%s.%s", fallback.Provider, fallback.Model), schemas.SpanKindFallback)
tracer.SetAttribute(handle, schemas.AttrProviderName, string(fallback.Provider))
tracer.SetAttribute(handle, schemas.AttrRequestModel, fallback.Model)
tracer.SetAttribute(handle, "fallback.index", i+1)
ctx.SetValue(schemas.BifrostContextKeySpanID, spanCtx.Value(schemas.BifrostContextKeySpanID))
fallbackReq := bifrost.prepareFallbackRequest(req, fallback)
if fallbackReq == nil {
bifrost.logger.Debug(fmt.Sprintf("fallback provider %s with model %s is nil", fallback.Provider, fallback.Model))
tracer.SetAttribute(handle, "error", "fallback request preparation failed")
tracer.EndSpan(handle, schemas.SpanStatusError, "fallback request preparation failed")
continue
}
// Try the fallback provider
result, fallbackErr := bifrost.tryRequest(ctx, fallbackReq)
if fallbackErr == nil {
bifrost.logger.Debug(fmt.Sprintf("successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
tracer.EndSpan(handle, schemas.SpanStatusOk, "")
return result, nil
}
// End span with error status
if fallbackErr.Error != nil {
tracer.SetAttribute(handle, "error", fallbackErr.Error.Message)
}
tracer.EndSpan(handle, schemas.SpanStatusError, "fallback failed")
// Check if we should continue with more fallbacks
if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) {
return nil, fallbackErr
}
}
// All providers failed, return the original error
return nil, primaryErr
}
// handleStreamRequest handles the stream request to the provider based on the request type
// It handles plugin hooks, request validation, response processing, and fallback providers.
// If the primary provider fails, it will try each fallback provider in order until one succeeds.
// It is the wrapper for all streaming public API methods.
func (bifrost *Bifrost) handleStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
defer bifrost.releaseBifrostRequest(req)
provider, model, fallbacks := req.GetRequestFields()
if err := validateRequest(req); err != nil {
err.PopulateExtraFields(req.RequestType, provider, model, model)
err.StatusCode = schemas.Ptr(fasthttp.StatusBadRequest)
return nil, err
}
// Handle nil context early to prevent blocking
if ctx == nil {
ctx = bifrost.ctx
}
// Try the primary provider first
ctx.SetValue(schemas.BifrostContextKeyFallbackIndex, 0)
// Ensure request ID is set in context before PreHooks
if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok {
requestID := uuid.New().String()
ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID)
}
primaryResult, primaryErr := bifrost.tryStreamRequest(ctx, req)
// Check if we should proceed with fallbacks
shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr)
if !shouldTryFallbacks {
return primaryResult, primaryErr
}
// Try fallbacks in order
for i, fallback := range fallbacks {
ctx.SetValue(schemas.BifrostContextKeyFallbackIndex, i+1)
ctx.SetValue(schemas.BifrostContextKeyFallbackRequestID, uuid.New().String())
clearCtxForFallback(ctx)
// Start span for fallback attempt
tracer := bifrost.getTracer()
spanCtx, handle := tracer.StartSpan(ctx, fmt.Sprintf("fallback.%s.%s", fallback.Provider, fallback.Model), schemas.SpanKindFallback)
tracer.SetAttribute(handle, schemas.AttrProviderName, string(fallback.Provider))
tracer.SetAttribute(handle, schemas.AttrRequestModel, fallback.Model)
tracer.SetAttribute(handle, "fallback.index", i+1)
ctx.SetValue(schemas.BifrostContextKeySpanID, spanCtx.Value(schemas.BifrostContextKeySpanID))
fallbackReq := bifrost.prepareFallbackRequest(req, fallback)
if fallbackReq == nil {
tracer.SetAttribute(handle, "error", "fallback request preparation failed")
tracer.EndSpan(handle, schemas.SpanStatusError, "fallback request preparation failed")
continue
}
// Try the fallback provider
result, fallbackErr := bifrost.tryStreamRequest(ctx, fallbackReq)
if fallbackErr == nil {
bifrost.logger.Debug(fmt.Sprintf("successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model))
tracer.EndSpan(handle, schemas.SpanStatusOk, "")
return result, nil
}
// End span with error status
if fallbackErr.Error != nil {
tracer.SetAttribute(handle, "error", fallbackErr.Error.Message)
}
tracer.EndSpan(handle, schemas.SpanStatusError, "fallback failed")
// Check if we should continue with more fallbacks
if !bifrost.shouldContinueWithFallbacks(fallback, fallbackErr) {
return nil, fallbackErr
}
}
// All providers failed, return the original error
return nil, primaryErr
}
// tryRequest is a generic function that handles common request processing logic
// It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling
func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) {
provider, model, _ := req.GetRequestFields()
pq, err := bifrost.getProviderQueue(provider)
if err != nil {
bifrostErr := newBifrostError(err)
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
// Add MCP tools to request if MCP is configured and requested
if bifrost.MCPManager != nil {
req = bifrost.MCPManager.AddToolsToRequest(ctx, req)
}
tracer := bifrost.getTracer()
if tracer == nil {
bifrostErr := newBifrostErrorFromMsg("tracer not found in context")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
// Store tracer in context BEFORE calling requestHandler, so streaming goroutines
// have access to it for completing deferred spans when the stream ends.
// The streaming goroutine captures the context when it starts, so these values
// must be set before requestHandler() is called.
ctx.SetValue(schemas.BifrostContextKeyTracer, tracer)
pipeline := bifrost.getPluginPipeline()
defer bifrost.releasePluginPipeline(pipeline)
// RequestType, Provider, OriginalModelRequested, and ResolvedModelUsed are always
// overwritten around RunPostLLMHooks — plugin modifications to these 4 fields are
// no-ops by design; proper request metadata is preserved and tampering is discouraged.
preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req)
if shortCircuit != nil {
// Handle short-circuit with response (success case)
if shortCircuit.Response != nil {
shortCircuit.Response.PopulateExtraFields(req.RequestType, provider, model, model)
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount)
if bifrostErr != nil {
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
} else if resp != nil {
resp.PopulateExtraFields(req.RequestType, provider, model, model)
}
drainAndAttachPluginLogs(ctx)
if bifrostErr != nil {
return nil, bifrostErr
}
return resp, nil
}
// Handle short-circuit with error
if shortCircuit.Error != nil {
shortCircuit.Error.PopulateExtraFields(req.RequestType, provider, model, model)
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount)
if bifrostErr != nil {
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
} else if resp != nil {
resp.PopulateExtraFields(req.RequestType, provider, model, model)
}
drainAndAttachPluginLogs(ctx)
if bifrostErr != nil {
return nil, bifrostErr
}
return resp, nil
}
}
if preReq == nil {
bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
provider, model, _ = preReq.GetRequestFields()
msg := bifrost.getChannelMessage(*preReq)
msg.Context = ctx
// If the queue is closing, check whether the provider was updated (new queue
// available) or removed. On update, transparently re-route to the new queue
// so in-flight producers don't get spurious errors. On removal, error out.
//
// Use a direct sync.Map lookup instead of getProviderQueue to avoid the
// lazy-creation path: getProviderQueue can resurrect a provider that was
// just removed by RemoveProvider if the account config still exists.
if pq.isClosing() {
var reroutedPq *ProviderQueue
if val, ok := bifrost.requestQueues.Load(provider); ok {
if candidate := val.(*ProviderQueue); candidate != pq && !candidate.isClosing() {
reroutedPq = candidate
}
}
if reroutedPq == nil {
bifrost.releaseChannelMessage(msg)
bifrostErr := newBifrostErrorFromMsg("provider is shutting down")
bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{
RequestType: req.RequestType,
Provider: provider,
OriginalModelRequested: model,
ResolvedModelUsed: model,
}
return nil, bifrostErr
}
pq = reroutedPq
}
// Use select with done channel to detect shutdown during send
select {
case pq.queue <- msg:
// Message was sent successfully
case <-pq.done:
bifrost.releaseChannelMessage(msg)
bifrostErr := newBifrostErrorFromMsg("provider is shutting down")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
case <-ctx.Done():
bifrost.releaseChannelMessage(msg)
bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
default:
if bifrost.dropExcessRequests.Load() {
bifrost.releaseChannelMessage(msg)
bifrost.logger.Warn("request dropped: queue is full, please increase the queue size or set dropExcessRequests to false")
bifrostErr := newBifrostErrorFromMsg("request dropped: queue is full")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
// Re-check closing flag before blocking send (lock-free atomic check)
if pq.isClosing() {
bifrost.releaseChannelMessage(msg)
bifrostErr := newBifrostErrorFromMsg("provider is shutting down")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
select {
case pq.queue <- msg:
// Message was sent successfully
case <-pq.done:
bifrost.releaseChannelMessage(msg)
bifrostErr := newBifrostErrorFromMsg("provider is shutting down")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
case <-ctx.Done():
bifrost.releaseChannelMessage(msg)
bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
}
var result *schemas.BifrostResponse
var resp *schemas.BifrostResponse
pluginCount := len(*bifrost.llmPlugins.Load())
select {
case result = <-msg.Response:
resp, bifrostErr := pipeline.RunPostLLMHooks(msg.Context, result, nil, pluginCount)
if bifrostErr != nil {
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
} else if resp != nil {
resp.PopulateExtraFields(req.RequestType, provider, model, model)
}
drainAndAttachPluginLogs(msg.Context)
if bifrostErr != nil {
bifrost.releaseChannelMessage(msg)
return nil, bifrostErr
}
bifrost.releaseChannelMessage(msg)
// Strip raw fields that were captured for logging but should not reach the client.
if resp != nil {
dropReq, _ := ctx.Value(schemas.BifrostContextKeyDropRawRequestFromClient).(bool)
dropResp, _ := ctx.Value(schemas.BifrostContextKeyDropRawResponseFromClient).(bool)
if dropReq || dropResp {
extraField := resp.GetExtraFields()
if dropReq {
extraField.RawRequest = nil
}
if dropResp {
extraField.RawResponse = nil
}
}
}
return resp, nil
case bifrostErrVal := <-msg.Err:
bifrostErrPtr := &bifrostErrVal
resp, bifrostErrPtr = pipeline.RunPostLLMHooks(msg.Context, nil, bifrostErrPtr, pluginCount)
if bifrostErrPtr != nil {
bifrostErrPtr.PopulateExtraFields(req.RequestType, provider, model, model)
} else if resp != nil {
resp.PopulateExtraFields(req.RequestType, provider, model, model)
}
drainAndAttachPluginLogs(msg.Context)
bifrost.releaseChannelMessage(msg)
// Strip raw fields on error path too.
dropReq, _ := ctx.Value(schemas.BifrostContextKeyDropRawRequestFromClient).(bool)
dropResp, _ := ctx.Value(schemas.BifrostContextKeyDropRawResponseFromClient).(bool)
if dropReq || dropResp {
if bifrostErrPtr != nil {
if dropReq {
bifrostErrPtr.ExtraFields.RawRequest = nil
}
if dropResp {
bifrostErrPtr.ExtraFields.RawResponse = nil
}
}
if resp != nil {
extraField := resp.GetExtraFields()
if dropReq {
extraField.RawRequest = nil
}
if dropResp {
extraField.RawResponse = nil
}
}
}
if bifrostErrPtr != nil {
return nil, bifrostErrPtr
}
return resp, nil
case <-ctx.Done():
// Do NOT releaseChannelMessage here. The message is already enqueued and
// the worker still holds a reference to msg.Response and msg.Err. Returning
// those channels to the pool now would let the next request reuse them while
// the worker is still writing to them — stale data corruption. The worker
// never calls releaseChannelMessage itself, so this message leaks from the
// pool and is GC'd. That is intentional: a small pool leak on cancellation
// is far safer than corrupting another request's channels.
provider, model, _ := req.GetRequestFields()
bifrostErr := newBifrostCtxDoneError(ctx, "waiting for provider response")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
}
// tryStreamRequest is a generic function that handles common request processing logic
// It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling
func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
provider, model, _ := req.GetRequestFields()
pq, err := bifrost.getProviderQueue(provider)
if err != nil {
bifrostErr := newBifrostError(err)
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
// Add MCP tools to request if MCP is configured and requested
if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.MCPManager != nil {
req = bifrost.MCPManager.AddToolsToRequest(ctx, req)
}
tracer := bifrost.getTracer()
if tracer == nil {
bifrostErr := newBifrostErrorFromMsg("tracer not found in context")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
// Store tracer in context BEFORE calling RunLLMPreHooks, so plugins and streaming goroutines
// have access to it for completing deferred spans when the stream ends.
// The streaming goroutine captures the context when it starts, so these values
// must be set before requestHandler() is called.
ctx.SetValue(schemas.BifrostContextKeyTracer, tracer)
// Ensure traceID exists so the logging plugin can create a stream accumulator
// in PreLLMHook and accumulate chunks in PostLLMHook. For HTTP handler requests the
// tracing middleware already sets this; for WebSocket bridge and Go SDK callers it
// may be absent.
if _, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); !ok {
traceID := tracer.CreateTrace("")
if traceID != "" {
ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID)
}
}
pipeline := bifrost.getPluginPipeline()
releasePipeline := true
defer func() {
if releasePipeline {
bifrost.releasePluginPipeline(pipeline)
}
}()
// RequestType, Provider, OriginalModelRequested, and ResolvedModelUsed are always
// overwritten around RunPostLLMHooks — plugin modifications to these 4 fields are
// no-ops by design; proper request metadata is preserved and tampering is discouraged.
preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req)
if shortCircuit != nil {
// Handle short-circuit with response (success case)
if shortCircuit.Response != nil {
shortCircuit.Response.PopulateExtraFields(req.RequestType, provider, model, model)
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount)
if bifrostErr != nil {
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
} else if resp != nil {
resp.PopulateExtraFields(req.RequestType, provider, model, model)
}
drainAndAttachPluginLogs(ctx)
if bifrostErr != nil {
return nil, bifrostErr
}
return newBifrostMessageChan(resp), nil
}
// Handle short-circuit with stream
if shortCircuit.Stream != nil {
outputStream := make(chan *schemas.BifrostStreamChunk)
releasePipeline = false // pipeline is released inside the goroutine after stream drains
// Create a post hook runner cause pipeline object is put back in the pool on defer
pipelinePostHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
if result != nil {
result.PopulateExtraFields(req.RequestType, provider, model, model)
}
if err != nil {
err.PopulateExtraFields(req.RequestType, provider, model, model)
}
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount)
if IsFinalChunk(ctx) {
drainAndAttachPluginLogs(ctx)
}
if bifrostErr != nil {
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
} else if resp != nil {
resp.PopulateExtraFields(req.RequestType, provider, model, model)
}
return resp, nil
}
go func() {
defer func() {
drainAndAttachPluginLogs(ctx) // ensure logs are drained even if stream closes without a final chunk
pipeline.FinalizeStreamingPostHookSpans(ctx)
bifrost.releasePluginPipeline(pipeline)
}()
defer close(outputStream)
for streamMsg := range shortCircuit.Stream {
if streamMsg == nil {
continue
}
bifrostResponse := &schemas.BifrostResponse{}
if streamMsg.BifrostTextCompletionResponse != nil {
bifrostResponse.TextCompletionResponse = streamMsg.BifrostTextCompletionResponse
}
if streamMsg.BifrostChatResponse != nil {
bifrostResponse.ChatResponse = streamMsg.BifrostChatResponse
}
if streamMsg.BifrostResponsesStreamResponse != nil {
bifrostResponse.ResponsesStreamResponse = streamMsg.BifrostResponsesStreamResponse
}
if streamMsg.BifrostSpeechStreamResponse != nil {
bifrostResponse.SpeechStreamResponse = streamMsg.BifrostSpeechStreamResponse
}
if streamMsg.BifrostTranscriptionStreamResponse != nil {
bifrostResponse.TranscriptionStreamResponse = streamMsg.BifrostTranscriptionStreamResponse
}
if streamMsg.BifrostImageGenerationStreamResponse != nil {
bifrostResponse.ImageGenerationStreamResponse = streamMsg.BifrostImageGenerationStreamResponse
}
// Run post hooks on the stream message
processedResponse, processedError := pipelinePostHookRunner(ctx, bifrostResponse, streamMsg.BifrostError)
// Build the client-facing chunk via the shared helper, which strips raw
// request/response fields when in logging-only mode without mutating the
// shared processedResponse or processedError objects.
streamResponse := providerUtils.BuildClientStreamChunk(ctx, processedResponse, processedError)
// Guarded send: if the consumer abandons outputStream (client
// disconnect, ctx cancel), drain the upstream shortCircuit.Stream
// so its producer can exit cleanly instead of blocking on its send.
select {
case outputStream <- streamResponse:
case <-ctx.Done():
for range shortCircuit.Stream {
}
return
}
// TODO: Release the processed response immediately after use
}
}()
return outputStream, nil
}
// Handle short-circuit with error
if shortCircuit.Error != nil {
shortCircuit.Error.PopulateExtraFields(req.RequestType, provider, model, model)
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount)
if bifrostErr != nil {
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
} else if resp != nil {
resp.PopulateExtraFields(req.RequestType, provider, model, model)
}
drainAndAttachPluginLogs(ctx)
if bifrostErr != nil {
return nil, bifrostErr
}
return newBifrostMessageChan(resp), nil
}
}
if preReq == nil {
bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
provider, model, _ = preReq.GetRequestFields()
msg := bifrost.getChannelMessage(*preReq)
msg.Context = ctx
// If the queue is closing, check whether the provider was updated (new queue
// available) or removed. On update, transparently re-route to the new queue
// so in-flight producers don't get spurious errors. On removal, error out.
//
// Use a direct sync.Map lookup instead of getProviderQueue to avoid the
// lazy-creation path: getProviderQueue can resurrect a provider that was
// just removed by RemoveProvider if the account config still exists.
if pq.isClosing() {
var reroutedPq *ProviderQueue
if val, ok := bifrost.requestQueues.Load(provider); ok {
if candidate := val.(*ProviderQueue); candidate != pq && !candidate.isClosing() {
reroutedPq = candidate
}
}
if reroutedPq == nil {
bifrost.releaseChannelMessage(msg)
bifrostErr := newBifrostErrorFromMsg("provider is shutting down")
bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{
RequestType: req.RequestType,
Provider: provider,
OriginalModelRequested: model,
}
return nil, bifrostErr
}
pq = reroutedPq
}
// Use select with done channel to detect shutdown during send
select {
case pq.queue <- msg:
// Message was sent successfully
case <-pq.done:
bifrost.releaseChannelMessage(msg)
bifrostErr := newBifrostErrorFromMsg("provider is shutting down")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
case <-ctx.Done():
bifrost.releaseChannelMessage(msg)
bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
default:
if bifrost.dropExcessRequests.Load() {
bifrost.releaseChannelMessage(msg)
bifrost.logger.Warn("request dropped: queue is full, please increase the queue size or set dropExcessRequests to false")
bifrostErr := newBifrostErrorFromMsg("request dropped: queue is full")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
// Re-check closing flag before blocking send (lock-free atomic check)
if pq.isClosing() {
bifrost.releaseChannelMessage(msg)
bifrostErr := newBifrostErrorFromMsg("provider is shutting down")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
select {
case pq.queue <- msg:
// Message was sent successfully
case <-pq.done:
bifrost.releaseChannelMessage(msg)
bifrostErr := newBifrostErrorFromMsg("provider is shutting down")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
case <-ctx.Done():
bifrost.releaseChannelMessage(msg)
bifrostErr := newBifrostCtxDoneError(ctx, "while waiting for queue space")
bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model)
return nil, bifrostErr
}
}
select {
case stream := <-msg.ResponseStream:
bifrost.releaseChannelMessage(msg)
return stream, nil
case bifrostErrVal := <-msg.Err:
if bifrostErrVal.Error != nil {
bifrost.logger.Debug("error while executing stream request: %s", bifrostErrVal.Error.Message)
} else {
bifrost.logger.Debug("error while executing stream request: %+v", bifrostErrVal)
}
// Marking final chunk
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
// On error we will complete post-hooks
recoveredResp, recoveredErr := pipeline.RunPostLLMHooks(ctx, nil, &bifrostErrVal, len(*bifrost.llmPlugins.Load()))
if recoveredErr != nil {
recoveredErr.PopulateExtraFields(req.RequestType, provider, model, model)
} else if recoveredResp != nil {
recoveredResp.PopulateExtraFields(req.RequestType, provider, model, model)
}
drainAndAttachPluginLogs(ctx)
bifrost.releaseChannelMessage(msg)
if recoveredErr != nil {
return nil, recoveredErr
}
if recoveredResp != nil {
return newBifrostMessageChan(recoveredResp), nil
}
return nil, &bifrostErrVal
case <-ctx.Done():
// Do NOT releaseChannelMessage here — see the identical note in tryRequest.
// Worker still holds msg.ResponseStream/msg.Err; releasing now corrupts the
// next request that reuses those pooled channels.
return nil, newBifrostCtxDoneError(ctx, "while waiting for stream response")
}
}
// executeRequestWithRetries is a generic function that handles common request processing logic.
// It consolidates retry logic, backoff calculation, error handling, and key rotation.
// It is not a bifrost method because interface methods in go cannot be generic.
//
// keyProvider, when non-nil, is called on the first attempt and again whenever a rate-limit error
// triggers a key rotation. It receives the set of key IDs already used in the current rotation
// cycle so it can exclude them; when the pool is exhausted the provider resets the set and starts
// a fresh weighted round. Network errors (5xx) reuse the same key since they are transient server
// issues rather than per-key capacity problems.
func executeRequestWithRetries[T any](
ctx *schemas.BifrostContext,
config *schemas.ProviderConfig,
requestHandler func(key schemas.Key) (T, *schemas.BifrostError),
keyProvider func(usedKeyIDs map[string]bool) (schemas.Key, error),
requestType schemas.RequestType,
providerKey schemas.ModelProvider,
model string,
req *schemas.BifrostRequest,
logger schemas.Logger,
) (T, *schemas.BifrostError) {
var result T
var bifrostError *schemas.BifrostError
var attempts int
var currentKey schemas.Key
var usedKeyIDs map[string]bool
lastWasRateLimit := false
for attempts = 0; attempts <= config.NetworkConfig.MaxRetries; attempts++ {
ctx.SetValue(schemas.BifrostContextKeyNumberOfRetries, attempts)
// Reset the trail on the first attempt so a reused or shared context (bifrost.ctx)
// doesn't carry over records from a previous request.
if keyProvider != nil && attempts == 0 {
ctx.SetValue(schemas.BifrostContextKeyAttemptTrail, []schemas.KeyAttemptRecord{})
}
// Select / rotate key: always on attempt 0, and again when the previous failure was a
// rate-limit (different key may have remaining capacity). Network errors keep the same key.
if keyProvider != nil && (attempts == 0 || lastWasRateLimit) {
if usedKeyIDs == nil {
usedKeyIDs = make(map[string]bool)
}
// Wrap key selection in a dedicated span so traces show which key was chosen
// (and when rotation happened). The span is opened before keyProvider is called
// so selection errors are captured too.
keyTracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
var keySpanCtx context.Context
var keyHandle schemas.SpanHandle
if keyTracer != nil {
keySpanCtx, keyHandle = keyTracer.StartSpan(ctx, "key.selection", schemas.SpanKindInternal)
keyTracer.SetAttribute(keyHandle, schemas.AttrProviderName, string(providerKey))
keyTracer.SetAttribute(keyHandle, schemas.AttrRequestModel, model)
if attempts > 0 {
keyTracer.SetAttribute(keyHandle, "retry.count", attempts)
}
}
selectedKey, err := keyProvider(usedKeyIDs)
if keyTracer != nil {
if err != nil {
keyTracer.SetAttribute(keyHandle, "error", err.Error())
keyTracer.EndSpan(keyHandle, schemas.SpanStatusError, err.Error())
} else {
keyTracer.SetAttribute(keyHandle, "key.id", selectedKey.ID)
keyTracer.SetAttribute(keyHandle, "key.name", selectedKey.Name)
keyTracer.EndSpan(keyHandle, schemas.SpanStatusOk, "")
// Propagate the span context so subsequent spans (llm.call / retry.attempt.N)
// are correctly linked in the trace hierarchy.
ctx.SetValue(schemas.BifrostContextKeySpanID, keySpanCtx.Value(schemas.BifrostContextKeySpanID))
}
}
if err != nil {
var zero T
return zero, newBifrostErrorFromMsg(err.Error())
}
currentKey = selectedKey
ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, currentKey.ID)
ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, currentKey.Name)
}
// Append a trail record for every attempt (key rotation and same-key retries alike).
// Skipped when keyProvider is nil (keyless providers have no key to track).
// FailReason is populated below once the attempt outcome is known.
if keyProvider != nil {
schemas.AppendToContextList(ctx, schemas.BifrostContextKeyAttemptTrail, schemas.KeyAttemptRecord{
Attempt: attempts,
KeyID: currentKey.ID,
KeyName: currentKey.Name,
})
}
if attempts > 0 {
// Log retry attempt
var retryMsg string
if bifrostError != nil && bifrostError.Error != nil {
retryMsg = bifrostError.Error.Message
} else if bifrostError != nil && bifrostError.StatusCode != nil {
retryMsg = fmt.Sprintf("status=%d", *bifrostError.StatusCode)
if bifrostError.Type != nil {
retryMsg += ", type=" + *bifrostError.Type
}
}
logger.Debug("retrying request (attempt %d/%d) for model %s: %s", attempts, config.NetworkConfig.MaxRetries, model, retryMsg)
// Calculate and apply backoff
backoff := calculateBackoff(attempts-1, config)
logger.Debug("sleeping for %s before retry", backoff)
time.Sleep(backoff)
}
logger.Debug("attempting %s request for provider %s", requestType, providerKey)
// Start span for LLM call (or retry attempt)
tracer, ok := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer)
if !ok || tracer == nil {
logger.Error("tracer not found in context of executeRequestWithRetries")
return result, newBifrostErrorFromMsg("tracer not found in context")
}
var spanName string
var spanKind schemas.SpanKind
if attempts > 0 {
spanName = fmt.Sprintf("retry.attempt.%d", attempts)
spanKind = schemas.SpanKindRetry
} else {
spanName = "llm.call"
spanKind = schemas.SpanKindLLMCall
}
spanCtx, handle := tracer.StartSpan(ctx, spanName, spanKind)
tracer.SetAttribute(handle, schemas.AttrProviderName, string(providerKey))
tracer.SetAttribute(handle, schemas.AttrRequestModel, model)
tracer.SetAttribute(handle, "request.type", string(requestType))
if attempts > 0 {
tracer.SetAttribute(handle, "retry.count", attempts)
}
// Add context-related attributes (selected key, virtual key, team, customer, etc.)
if selectedKeyID, ok := ctx.Value(schemas.BifrostContextKeySelectedKeyID).(string); ok && selectedKeyID != "" {
tracer.SetAttribute(handle, schemas.AttrSelectedKeyID, selectedKeyID)
}
if selectedKeyName, ok := ctx.Value(schemas.BifrostContextKeySelectedKeyName).(string); ok && selectedKeyName != "" {
tracer.SetAttribute(handle, schemas.AttrSelectedKeyName, selectedKeyName)
}
if virtualKeyID, ok := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string); ok && virtualKeyID != "" {
tracer.SetAttribute(handle, schemas.AttrVirtualKeyID, virtualKeyID)
}
if virtualKeyName, ok := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyName).(string); ok && virtualKeyName != "" {
tracer.SetAttribute(handle, schemas.AttrVirtualKeyName, virtualKeyName)
}
if teamID, ok := ctx.Value(schemas.BifrostContextKeyGovernanceTeamID).(string); ok && teamID != "" {
tracer.SetAttribute(handle, schemas.AttrTeamID, teamID)
}
if teamName, ok := ctx.Value(schemas.BifrostContextKeyGovernanceTeamName).(string); ok && teamName != "" {
tracer.SetAttribute(handle, schemas.AttrTeamName, teamName)
}
if customerID, ok := ctx.Value(schemas.BifrostContextKeyGovernanceCustomerID).(string); ok && customerID != "" {
tracer.SetAttribute(handle, schemas.AttrCustomerID, customerID)
}
if customerName, ok := ctx.Value(schemas.BifrostContextKeyGovernanceCustomerName).(string); ok && customerName != "" {
tracer.SetAttribute(handle, schemas.AttrCustomerName, customerName)
}
if fallbackIndex, ok := ctx.Value(schemas.BifrostContextKeyFallbackIndex).(int); ok {
tracer.SetAttribute(handle, schemas.AttrFallbackIndex, fallbackIndex)
}
tracer.SetAttribute(handle, schemas.AttrNumberOfRetries, attempts)
// Populate LLM request attributes (messages, parameters, etc.)
if req != nil {
tracer.PopulateLLMRequestAttributes(handle, req)
}
// Update context with span ID
ctx.SetValue(schemas.BifrostContextKeySpanID, spanCtx.Value(schemas.BifrostContextKeySpanID))
// Record stream start time for TTFT calculation (only for streaming requests)
// This is also used by RunPostLLMHooks to detect streaming mode
if IsStreamRequestType(requestType) {
streamStartTime := time.Now()
ctx.SetValue(schemas.BifrostContextKeyStreamStartTime, streamStartTime)
}
// Attempt the request
result, bifrostError = requestHandler(currentKey)
// For streaming requests that returned success, check if the first chunk
// is actually an error (e.g., rate limits sent as SSE events in HTTP 200).
// This enables retries and fallbacks for providers that embed errors in
// the SSE stream instead of returning proper HTTP error status codes.
if bifrostError == nil {
if streamChan, ok := any(result).(chan *schemas.BifrostStreamChunk); ok {
checkedStream, drainDone, firstChunkErr := providerUtils.CheckFirstStreamChunkForError(ctx, streamChan)
if firstChunkErr != nil {
<-drainDone
bifrostError = firstChunkErr
} else {
result = any(checkedStream).(T)
}
}
}
// Check if result is a streaming channel - if so, defer span completion
// Only defer for successful stream setup; error paths must end the span synchronously
isStreamChan := false
if bifrostError == nil {
if ch, ok := any(result).(chan *schemas.BifrostStreamChunk); ok && ch != nil {
isStreamChan = true
}
}
if isStreamChan {
// For streaming requests, store the span handle in TraceStore keyed by trace ID
// This allows the provider's streaming goroutine to retrieve it later
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" {
tracer.StoreDeferredSpan(traceID, handle)
}
// Don't end the span here - it will be ended when streaming completes
} else {
// Populate LLM response attributes for non-streaming responses
if resp, ok := any(result).(*schemas.BifrostResponse); ok {
tracer.PopulateLLMResponseAttributes(ctx, handle, resp, bifrostError)
}
// End span with appropriate status
if bifrostError != nil {
if bifrostError.Error != nil {
tracer.SetAttribute(handle, "error", bifrostError.Error.Message)
}
if bifrostError.StatusCode != nil {
tracer.SetAttribute(handle, "status_code", *bifrostError.StatusCode)
}
tracer.EndSpan(handle, schemas.SpanStatusError, "request failed")
} else {
tracer.EndSpan(handle, schemas.SpanStatusOk, "")
}
}
logger.Debug("request %s for provider %s completed", requestType, providerKey)
// Check if successful or if we should retry
if bifrostError == nil ||
bifrostError.IsBifrostError ||
(bifrostError.Error != nil && bifrostError.Error.Type != nil && *bifrostError.Error.Type == schemas.RequestCancelled) {
break
}
// Check if we should retry based on status code or error message
shouldRetry := false
isRateLimit := (bifrostError.StatusCode != nil && *bifrostError.StatusCode == 429) ||
(bifrostError.Error != nil &&
(IsRateLimitErrorMessage(bifrostError.Error.Message) ||
(bifrostError.Error.Type != nil && IsRateLimitErrorMessage(*bifrostError.Error.Type)) ||
(bifrostError.Error.Code != nil && IsRateLimitErrorMessage(*bifrostError.Error.Code))))
errMessage := GetErrorMessage(bifrostError)
if bifrostError.Error != nil &&
(bifrostError.Error.Message == schemas.ErrProviderDoRequest ||
bifrostError.Error.Message == schemas.ErrProviderNetworkError) {
shouldRetry = true
logger.Debug("detected request HTTP/network error, will retry: %s", errMessage)
} else if (bifrostError.StatusCode != nil && retryableStatusCodes[*bifrostError.StatusCode]) || isRateLimit {
shouldRetry = true
logger.Debug("encountered error that should be retried: %s", errMessage)
}
// Fill FailReason on any failed attempt (retryable or terminal).
// Use the provider error type when present; fall back to "unknown".
if trail, ok := ctx.Value(schemas.BifrostContextKeyAttemptTrail).([]schemas.KeyAttemptRecord); ok && len(trail) > 0 {
reason := "unknown"
if bifrostError.Error != nil && bifrostError.Error.Type != nil && *bifrostError.Error.Type != "" {
reason = *bifrostError.Error.Type
} else if isRateLimit {
reason = "rate_limit_error"
}
trail[len(trail)-1].FailReason = &reason
ctx.SetValue(schemas.BifrostContextKeyAttemptTrail, trail)
}
if !shouldRetry {
break
}
// Mark current key as used so the next selection excludes it (rate-limit only).
// Network errors keep the same key — they are transient server issues, not per-key.
if isRateLimit && keyProvider != nil {
if usedKeyIDs == nil {
usedKeyIDs = make(map[string]bool)
}
usedKeyIDs[currentKey.ID] = true
}
lastWasRateLimit = isRateLimit
}
// Add retry information to error
if attempts > 0 {
logger.Debug("request failed after %d %s", attempts, map[bool]string{true: "attempts", false: "attempt"}[attempts > 1])
}
// On final error, clear selected_key so it only reflects a key that actually served a successful response.
// The attempt trail is the authoritative record of which keys were tried.
if bifrostError != nil && keyProvider != nil {
ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, "")
ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, "")
}
return result, bifrostError
}
// requestWorker handles incoming requests from the queue for a specific provider.
// It manages retries, error handling, and response processing.
func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas.ProviderConfig, pq *ProviderQueue) {
defer func() {
if waitGroupValue, ok := bifrost.waitGroups.Load(provider.GetProviderKey()); ok {
waitGroup := waitGroupValue.(*sync.WaitGroup)
waitGroup.Done()
}
}()
for {
var req *ChannelMessage
select {
case r := <-pq.queue:
req = r
case <-pq.done:
// Provider is shutting down. Drain any buffered requests and send
// back errors so callers are not left blocked on their response channel.
for {
select {
case r := <-pq.queue:
provKey, mod, _ := r.GetRequestFields()
select {
case r.Err <- schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "provider is shutting down",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: r.RequestType,
Provider: provKey,
OriginalModelRequested: mod,
},
}:
case <-r.Context.Done():
}
default:
return
}
}
}
_, model, _ := req.BifrostRequest.GetRequestFields()
var result *schemas.BifrostResponse
var stream chan *schemas.BifrostStreamChunk
var bifrostError *schemas.BifrostError
var err error
// Determine the base provider type for key requirement checks
baseProvider := provider.GetProviderKey()
if cfg := config.CustomProviderConfig; cfg != nil && cfg.BaseProviderType != "" {
baseProvider = cfg.BaseProviderType
}
req.Context.SetValue(schemas.BifrostContextKeyIsCustomProvider, !IsStandardProvider(baseProvider))
// Determine whether this provider attempt should capture raw payloads.
//
// Effective values are computed by merging provider config with any per-request
// context overrides (BifrostContextKeySendBackRawRequest/Response and
// BifrostContextKeyStoreRawRequestResponse). A context value set to either true
// or false fully overrides the provider config for that flag.
//
// Each flag is independent:
// send_back_raw_request — include raw request bytes in the client response.
// send_back_raw_response — include raw response bytes in the client response.
// store_raw_request_response — persist raw bytes in log records (logging plugin only).
//
// Capture is enabled per-side whenever send-back OR store is requested for that side.
// Strip flags tell the response path to remove that side's bytes before the payload
// reaches the caller (used when store=true but send-back=false for that side).
//
// All internal signals are always written explicitly on every attempt so stale values
// from a previous provider attempt (e.g. different fallback provider config) cannot
// leak into the new attempt on a reused context. The user override keys
// (BifrostContextKeySendBackRaw*, BifrostContextKeyStoreRawRequestResponse) are
// never overwritten — they are read-only from bifrost.go's perspective.
// Step 1: compute effective value for each flag (provider config ← per-request override).
effectiveSendBackReq := config.SendBackRawRequest
if override, ok := req.Context.Value(schemas.BifrostContextKeySendBackRawRequest).(bool); ok {
effectiveSendBackReq = override
}
effectiveSendBackResp := config.SendBackRawResponse
if override, ok := req.Context.Value(schemas.BifrostContextKeySendBackRawResponse).(bool); ok {
effectiveSendBackResp = override
}
effectiveStore := config.StoreRawRequestResponse
if override, ok := req.Context.Value(schemas.BifrostContextKeyStoreRawRequestResponse).(bool); ok {
effectiveStore = override
}
// Step 2: derive per-side capture and strip flags.
// Capture if we need to send the data back OR store it — independent per side.
captureReq := effectiveSendBackReq || effectiveStore
captureResp := effectiveSendBackResp || effectiveStore
// Strip from client response if we captured for storage but not for send-back.
dropReq := effectiveStore && !effectiveSendBackReq
dropResp := effectiveStore && !effectiveSendBackResp
// Step 3: write all internal signals explicitly (never touch the user override keys).
req.Context.SetValue(schemas.BifrostContextKeyCaptureRawRequest, captureReq)
req.Context.SetValue(schemas.BifrostContextKeyCaptureRawResponse, captureResp)
req.Context.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, dropReq)
req.Context.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, dropResp)
// Tells the logging plugin whether to persist raw bytes in log records.
req.Context.SetValue(schemas.BifrostContextKeyShouldStoreRawInLogs, effectiveStore)
var keys []schemas.Key
// keyProvider is passed to executeRequestWithRetries to manage key selection and rotation.
// It is nil when no key is required (e.g. providerRequiresKey=false) or for multi-key
// batch/file/container operations that manage their own key lists.
var keyProvider func(usedKeyIDs map[string]bool) (schemas.Key, error)
if providerRequiresKey(config.CustomProviderConfig) {
// ListModels needs all enabled/supported keys so providers can aggregate
// and report per-key statuses (KeyStatuses).
if req.RequestType == schemas.ListModelsRequest {
keys, err = bifrost.getAllSupportedKeys(req.Context, provider.GetProviderKey(), baseProvider)
if err != nil {
bifrost.logger.Debug("error getting supported keys for list models: %v", err)
req.Err <- schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: err.Error(),
Error: err,
},
ExtraFields: schemas.BifrostErrorExtraFields{
Provider: provider.GetProviderKey(),
RequestType: req.RequestType,
OriginalModelRequested: model,
ResolvedModelUsed: model,
},
}
continue
}
} else {
// Determine if this is a multi-key batch/file/container operation
// BatchCreate, FileUpload, ContainerCreate, ContainerFileCreate use single key; other batch/file/container ops use multiple keys
isMultiKeyBatchOp := isBatchRequestType(req.RequestType) && req.RequestType != schemas.BatchCreateRequest
isMultiKeyFileOp := isFileRequestType(req.RequestType) && req.RequestType != schemas.FileUploadRequest
isMultiKeyContainerOp := isContainerRequestType(req.RequestType) && req.RequestType != schemas.ContainerCreateRequest && req.RequestType != schemas.ContainerFileCreateRequest
if isMultiKeyBatchOp || isMultiKeyFileOp || isMultiKeyContainerOp {
var modelPtr *string
if model != "" {
modelPtr = &model
}
keys, err = bifrost.getKeysForBatchAndFileOps(req.Context, provider.GetProviderKey(), baseProvider, modelPtr, isMultiKeyBatchOp)
if err != nil {
bifrost.logger.Debug("error getting keys for batch/file operation: %v", err)
req.Err <- schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: err.Error(),
Error: err,
},
ExtraFields: schemas.BifrostErrorExtraFields{
Provider: provider.GetProviderKey(),
RequestType: req.RequestType,
OriginalModelRequested: model,
ResolvedModelUsed: model,
},
}
continue
}
} else {
// Build the key pool for this request. Selection and rotation are deferred to
// executeRequestWithRetries via keyProvider so that each retry attempt can use
// a different key (on rate-limit errors) without re-running the full filtering.
supportedKeys, canRotate, keyPoolErr := bifrost.selectKeyFromProviderForModelWithPool(req.Context, req.RequestType, provider.GetProviderKey(), model, baseProvider)
if keyPoolErr != nil {
bifrost.logger.Debug("error building key pool for model %s: %v", model, keyPoolErr)
req.Err <- schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: keyPoolErr.Error(),
Error: keyPoolErr,
},
ExtraFields: schemas.BifrostErrorExtraFields{
Provider: provider.GetProviderKey(),
RequestType: req.RequestType,
OriginalModelRequested: model,
ResolvedModelUsed: model,
},
}
continue
}
if len(supportedKeys) == 0 {
// SkipKeySelection path — keyProvider stays nil, zero Key is used.
} else if !canRotate {
// Fixed key (DirectKey, explicit ID/name, session stickiness): always
// return the same key regardless of usedKeyIDs.
fixedKey := supportedKeys[0]
keyProvider = func(_ map[string]bool) (schemas.Key, error) {
return fixedKey, nil
}
} else {
// Rotating pool: weighted selection with per-cycle exclusion.
// Captures supportedKeys, bifrost.keySelector, provider/model by value.
pool := supportedKeys
provKey := provider.GetProviderKey()
mdl := model
keyProvider = func(usedKeyIDs map[string]bool) (schemas.Key, error) {
available := make([]schemas.Key, 0, len(pool))
for _, k := range pool {
if !usedKeyIDs[k.ID] {
available = append(available, k)
}
}
if len(available) == 0 {
// All keys exhausted — start a fresh weighted round.
for id := range usedKeyIDs {
delete(usedKeyIDs, id)
}
available = pool
}
return bifrost.keySelector(req.Context, available, provKey, mdl)
}
}
}
}
}
originalModelRequested := model
// resolvedModel is set inside the handler closures below on every attempt so that each
// key's own alias mapping is applied. The outer var holds the LAST attempt's value and is
// read single-threaded by the worker after retries finish (e.g. the error-fallback at
// line 5653). Streaming postHookRunner must NOT capture this var by reference — it
// snapshots its own attemptResolvedModel inside the per-attempt closure.
var resolvedModel string
// lastAttemptFinalizer captures the LAST attempt's postHookSpanFinalizer for the
// worker-level error fallback below. Single-threaded write (assigned by the retry
// loop's per-attempt closure) and single-threaded read (after retries finish), so
// no synchronization needed. Earlier attempts' finalizers fire via their provider
// goroutines' defers — passed via the postHookSpanFinalizer parameter directly to
// handleProviderStreamRequest, never via the shared req.Context.
var lastAttemptFinalizer func(context.Context)
// Execute request with retries. For streaming, the plugin pipeline,
// postHookRunner, and finalizer are allocated per-attempt inside the
// request handler closure. If they were request-scoped, a retry
// triggered by CheckFirstStreamChunkForError could run against a
// pipeline the previous attempt's provider goroutine has already
// returned to the pool via its deferred finalizer.
if IsStreamRequestType(req.RequestType) {
stream, bifrostError = executeRequestWithRetries(req.Context, config, func(k schemas.Key) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
resolvedModel = k.Aliases.Resolve(originalModelRequested)
req.SetModel(resolvedModel)
// Snapshot per-attempt so postHookRunner doesn't observe a later retry's
// alias while this attempt's provider goroutine is still emitting chunks.
attemptResolvedModel := resolvedModel
pipeline := bifrost.getPluginPipeline()
postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
// Populate extra fields before RunPostLLMHooks so plugins (e.g. logging)
// can read requestType/provider/model from the chunk or error.
// Uses the per-attempt snapshot — capturing the outer resolvedModel by
// reference would let a later retry's alias bleed into this attempt's chunks.
if result != nil {
result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel)
}
if err != nil {
err.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel)
}
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load()))
if IsFinalChunk(ctx) {
drainAndAttachPluginLogs(ctx)
}
if bifrostErr != nil {
bifrostErr.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel)
return nil, bifrostErr
} else if resp != nil {
resp.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel)
}
return resp, nil
}
// Store a finalizer callback to create aggregated post-hook spans at stream end.
// Wrapped in sync.Once so the normal end-of-stream invocation and a deferred
// safety-net invocation (e.g. from a provider goroutine's panic path) cannot
// double-release the pipeline.
var finalizerOnce sync.Once
postHookSpanFinalizer := func(ctx context.Context) {
finalizerOnce.Do(func() {
pipeline.FinalizeStreamingPostHookSpans(ctx)
bifrost.releasePluginPipeline(pipeline)
})
}
lastAttemptFinalizer = postHookSpanFinalizer
streamCh, streamErr := bifrost.handleProviderStreamRequest(provider, req, k, postHookRunner, postHookSpanFinalizer)
// If stream setup failed before any provider goroutine started,
// no deferred finalizer will run — release the pipeline directly
// so a retry doesn't inherit a leaked pool entry.
if streamErr != nil && streamCh == nil {
finalizerOnce.Do(func() {
bifrost.releasePluginPipeline(pipeline)
})
}
return streamCh, streamErr
}, keyProvider, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger)
} else {
result, bifrostError = executeRequestWithRetries(req.Context, config, func(k schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) {
resolvedModel = k.Aliases.Resolve(originalModelRequested)
req.SetModel(resolvedModel)
return bifrost.handleProviderRequest(provider, config, req, k, keys)
}, keyProvider, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger)
}
// For streaming with an error, route release through the LAST attempt's
// finalizer (wrapped in sync.Once) so we don't double-Put into the pool
// or race the provider goroutine's deferred FinalizeStreamingPostHookSpans
// call. lastAttemptFinalizer is set inside the per-attempt closure on every
// iteration; after retries finish, it holds the LAST attempt's finalizer.
// Earlier attempts' finalizers have already fired via their provider
// goroutines' defers (passed via the postHookSpanFinalizer parameter
// directly to handleProviderStreamRequest). For streaming without error,
// the finalizer is invoked by completeDeferredSpan / the provider
// goroutine's defer.
if IsStreamRequestType(req.RequestType) && bifrostError != nil {
if lastAttemptFinalizer != nil {
lastAttemptFinalizer(req.Context)
}
}
if bifrostError != nil {
bifrostError.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel)
// Send error with context awareness to prevent deadlock
select {
case req.Err <- *bifrostError:
// Error sent successfully
case <-req.Context.Done():
// Client no longer listening, log and continue
bifrost.logger.Debug("Client context cancelled while sending error response")
case <-time.After(5 * time.Second):
// Timeout to prevent indefinite blocking
bifrost.logger.Warn("Timeout while sending error response, client may have disconnected")
}
} else {
if result != nil {
result.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel)
}
if IsStreamRequestType(req.RequestType) {
// Send stream with context awareness to prevent deadlock
select {
case req.ResponseStream <- stream:
// Stream sent successfully
case <-req.Context.Done():
// Client no longer listening, log and continue
bifrost.logger.Debug("Client context cancelled while sending stream response")
case <-time.After(5 * time.Second):
// Timeout to prevent indefinite blocking
bifrost.logger.Warn("Timeout while sending stream response, client may have disconnected")
}
} else {
// Send response with context awareness to prevent deadlock
select {
case req.Response <- result:
// Response sent successfully
case <-req.Context.Done():
// Client no longer listening, log and continue
bifrost.logger.Debug("Client context cancelled while sending response")
case <-time.After(5 * time.Second):
// Timeout to prevent indefinite blocking
bifrost.logger.Warn("Timeout while sending response, client may have disconnected")
}
}
}
}
// bifrost.logger.Debug("worker for provider %s exiting...", provider.GetProviderKey())
}
// handleProviderRequest handles the request to the provider based on the request type
// key is used for single-key operations, keys is used for batch/file operations that need multiple keys
func (bifrost *Bifrost) handleProviderRequest(provider schemas.Provider, config *schemas.ProviderConfig, req *ChannelMessage, key schemas.Key, keys []schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) {
response := &schemas.BifrostResponse{}
switch req.RequestType {
case schemas.ListModelsRequest:
listModelsResponse, bifrostError := provider.ListModels(req.Context, keys, req.BifrostRequest.ListModelsRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.ListModelsResponse = listModelsResponse
case schemas.TextCompletionRequest:
if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest {
chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest()
if chatRequest != nil {
chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, chatRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.TextCompletionResponse = chatCompletionResponse.ToBifrostTextCompletionResponse()
break
}
}
textCompletionResponse, bifrostError := provider.TextCompletion(req.Context, key, req.BifrostRequest.TextCompletionRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.TextCompletionResponse = textCompletionResponse
case schemas.ChatCompletionRequest:
if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest {
responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest()
if responsesRequest != nil {
responsesResponse, bifrostError := provider.Responses(req.Context, key, responsesRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.ChatResponse = responsesResponse.ToBifrostChatResponse()
break
}
}
chatCompletionResponse, bifrostError := provider.ChatCompletion(req.Context, key, req.BifrostRequest.ChatRequest)
if bifrostError != nil {
return nil, bifrostError
}
chatCompletionResponse.BackfillParams(req.BifrostRequest.ChatRequest)
response.ChatResponse = chatCompletionResponse
case schemas.ResponsesRequest:
responsesResponse, bifrostError := provider.Responses(req.Context, key, req.BifrostRequest.ResponsesRequest)
if bifrostError != nil {
return nil, bifrostError
}
responsesResponse.BackfillParams(req.BifrostRequest.ResponsesRequest)
response.ResponsesResponse = responsesResponse
case schemas.CountTokensRequest:
countTokensResponse, bifrostError := provider.CountTokens(req.Context, key, req.BifrostRequest.CountTokensRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.CountTokensResponse = countTokensResponse
case schemas.EmbeddingRequest:
embeddingResponse, bifrostError := provider.Embedding(req.Context, key, req.BifrostRequest.EmbeddingRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.EmbeddingResponse = embeddingResponse
case schemas.RerankRequest:
rerankResponse, bifrostError := provider.Rerank(req.Context, key, req.BifrostRequest.RerankRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.RerankResponse = rerankResponse
case schemas.OCRRequest:
var customProviderConfig *schemas.CustomProviderConfig
if config != nil {
customProviderConfig = config.CustomProviderConfig
}
if bifrostError := providerUtils.CheckOperationAllowed(provider.GetProviderKey(), customProviderConfig, schemas.OCRRequest); bifrostError != nil {
if req.BifrostRequest.OCRRequest != nil {
bifrostError.ExtraFields.OriginalModelRequested = req.BifrostRequest.OCRRequest.Model
}
return nil, bifrostError
}
ocrResponse, bifrostError := provider.OCR(req.Context, key, req.BifrostRequest.OCRRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.OCRResponse = ocrResponse
case schemas.SpeechRequest:
speechResponse, bifrostError := provider.Speech(req.Context, key, req.BifrostRequest.SpeechRequest)
if bifrostError != nil {
return nil, bifrostError
}
speechResponse.BackfillParams(req.BifrostRequest.SpeechRequest)
response.SpeechResponse = speechResponse
case schemas.TranscriptionRequest:
transcriptionResponse, bifrostError := provider.Transcription(req.Context, key, req.BifrostRequest.TranscriptionRequest)
if bifrostError != nil {
return nil, bifrostError
}
transcriptionResponse.BackfillParams(req.BifrostRequest.TranscriptionRequest)
response.TranscriptionResponse = transcriptionResponse
case schemas.ImageGenerationRequest:
imageResponse, bifrostError := provider.ImageGeneration(req.Context, key, req.BifrostRequest.ImageGenerationRequest)
if bifrostError != nil {
return nil, bifrostError
}
imageResponse.BackfillParams(&req.BifrostRequest)
response.ImageGenerationResponse = imageResponse
case schemas.ImageEditRequest:
imageEditResponse, bifrostError := provider.ImageEdit(req.Context, key, req.BifrostRequest.ImageEditRequest)
if bifrostError != nil {
return nil, bifrostError
}
imageEditResponse.BackfillParams(&req.BifrostRequest)
response.ImageGenerationResponse = imageEditResponse
case schemas.ImageVariationRequest:
imageVariationResponse, bifrostError := provider.ImageVariation(req.Context, key, req.BifrostRequest.ImageVariationRequest)
if bifrostError != nil {
return nil, bifrostError
}
imageVariationResponse.BackfillParams(&req.BifrostRequest)
response.ImageGenerationResponse = imageVariationResponse
case schemas.VideoGenerationRequest:
videoGenerationResponse, bifrostError := provider.VideoGeneration(req.Context, key, req.BifrostRequest.VideoGenerationRequest)
if bifrostError != nil {
return nil, bifrostError
}
videoGenerationResponse.BackfillParams(&req.BifrostRequest)
response.VideoGenerationResponse = videoGenerationResponse
case schemas.VideoRetrieveRequest:
videoRetrieveResponse, bifrostError := provider.VideoRetrieve(req.Context, key, req.BifrostRequest.VideoRetrieveRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.VideoGenerationResponse = videoRetrieveResponse
case schemas.VideoDownloadRequest:
videoDownloadResponse, bifrostError := provider.VideoDownload(req.Context, key, req.BifrostRequest.VideoDownloadRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.VideoDownloadResponse = videoDownloadResponse
case schemas.VideoListRequest:
videoListResponse, bifrostError := provider.VideoList(req.Context, key, req.BifrostRequest.VideoListRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.VideoListResponse = videoListResponse
case schemas.VideoDeleteRequest:
videoDeleteResponse, bifrostError := provider.VideoDelete(req.Context, key, req.BifrostRequest.VideoDeleteRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.VideoDeleteResponse = videoDeleteResponse
case schemas.VideoRemixRequest:
videoRemixResponse, bifrostError := provider.VideoRemix(req.Context, key, req.BifrostRequest.VideoRemixRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.VideoGenerationResponse = videoRemixResponse
case schemas.FileUploadRequest:
fileUploadResponse, bifrostError := provider.FileUpload(req.Context, key, req.BifrostRequest.FileUploadRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.FileUploadResponse = fileUploadResponse
case schemas.FileListRequest:
fileListResponse, bifrostError := provider.FileList(req.Context, keys, req.BifrostRequest.FileListRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.FileListResponse = fileListResponse
case schemas.FileRetrieveRequest:
fileRetrieveResponse, bifrostError := provider.FileRetrieve(req.Context, keys, req.BifrostRequest.FileRetrieveRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.FileRetrieveResponse = fileRetrieveResponse
case schemas.FileDeleteRequest:
fileDeleteResponse, bifrostError := provider.FileDelete(req.Context, keys, req.BifrostRequest.FileDeleteRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.FileDeleteResponse = fileDeleteResponse
case schemas.FileContentRequest:
fileContentResponse, bifrostError := provider.FileContent(req.Context, keys, req.BifrostRequest.FileContentRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.FileContentResponse = fileContentResponse
case schemas.BatchCreateRequest:
batchCreateResponse, bifrostError := provider.BatchCreate(req.Context, key, req.BifrostRequest.BatchCreateRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.BatchCreateResponse = batchCreateResponse
case schemas.BatchListRequest:
batchListResponse, bifrostError := provider.BatchList(req.Context, keys, req.BifrostRequest.BatchListRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.BatchListResponse = batchListResponse
case schemas.BatchRetrieveRequest:
batchRetrieveResponse, bifrostError := provider.BatchRetrieve(req.Context, keys, req.BifrostRequest.BatchRetrieveRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.BatchRetrieveResponse = batchRetrieveResponse
case schemas.BatchCancelRequest:
batchCancelResponse, bifrostError := provider.BatchCancel(req.Context, keys, req.BifrostRequest.BatchCancelRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.BatchCancelResponse = batchCancelResponse
case schemas.BatchDeleteRequest:
batchDeleteResponse, bifrostError := provider.BatchDelete(req.Context, keys, req.BifrostRequest.BatchDeleteRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.BatchDeleteResponse = batchDeleteResponse
case schemas.BatchResultsRequest:
batchResultsResponse, bifrostError := provider.BatchResults(req.Context, keys, req.BifrostRequest.BatchResultsRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.BatchResultsResponse = batchResultsResponse
case schemas.ContainerCreateRequest:
containerCreateResponse, bifrostError := provider.ContainerCreate(req.Context, key, req.BifrostRequest.ContainerCreateRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.ContainerCreateResponse = containerCreateResponse
case schemas.ContainerListRequest:
containerListResponse, bifrostError := provider.ContainerList(req.Context, keys, req.BifrostRequest.ContainerListRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.ContainerListResponse = containerListResponse
case schemas.ContainerRetrieveRequest:
containerRetrieveResponse, bifrostError := provider.ContainerRetrieve(req.Context, keys, req.BifrostRequest.ContainerRetrieveRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.ContainerRetrieveResponse = containerRetrieveResponse
case schemas.ContainerDeleteRequest:
containerDeleteResponse, bifrostError := provider.ContainerDelete(req.Context, keys, req.BifrostRequest.ContainerDeleteRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.ContainerDeleteResponse = containerDeleteResponse
case schemas.ContainerFileCreateRequest:
containerFileCreateResponse, bifrostError := provider.ContainerFileCreate(req.Context, key, req.BifrostRequest.ContainerFileCreateRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.ContainerFileCreateResponse = containerFileCreateResponse
case schemas.ContainerFileListRequest:
containerFileListResponse, bifrostError := provider.ContainerFileList(req.Context, keys, req.BifrostRequest.ContainerFileListRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.ContainerFileListResponse = containerFileListResponse
case schemas.ContainerFileRetrieveRequest:
containerFileRetrieveResponse, bifrostError := provider.ContainerFileRetrieve(req.Context, keys, req.BifrostRequest.ContainerFileRetrieveRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.ContainerFileRetrieveResponse = containerFileRetrieveResponse
case schemas.ContainerFileContentRequest:
containerFileContentResponse, bifrostError := provider.ContainerFileContent(req.Context, keys, req.BifrostRequest.ContainerFileContentRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.ContainerFileContentResponse = containerFileContentResponse
case schemas.ContainerFileDeleteRequest:
containerFileDeleteResponse, bifrostError := provider.ContainerFileDelete(req.Context, keys, req.BifrostRequest.ContainerFileDeleteRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.ContainerFileDeleteResponse = containerFileDeleteResponse
case schemas.PassthroughRequest:
passthroughResponse, bifrostError := provider.Passthrough(req.Context, key, req.BifrostRequest.PassthroughRequest)
if bifrostError != nil {
return nil, bifrostError
}
response.PassthroughResponse = passthroughResponse
default:
_, model, _ := req.BifrostRequest.GetRequestFields()
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: fmt.Sprintf("unsupported request type: %s", req.RequestType),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: req.RequestType,
Provider: provider.GetProviderKey(),
OriginalModelRequested: model,
ResolvedModelUsed: model,
},
}
}
return response, nil
}
// handleProviderStreamRequest handles the stream request to the provider based on the request type
func (bifrost *Bifrost) handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context)) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
switch req.RequestType {
case schemas.TextCompletionStreamRequest:
if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ChatCompletionRequest {
chatRequest := req.BifrostRequest.TextCompletionRequest.ToBifrostChatRequest()
if chatRequest != nil {
return provider.ChatCompletionStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ChatCompletionRequest), postHookSpanFinalizer, key, chatRequest)
}
}
return provider.TextCompletionStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.TextCompletionRequest)
case schemas.ChatCompletionStreamRequest:
if changeType, ok := req.Context.Value(schemas.BifrostContextKeyChangeRequestType).(schemas.RequestType); ok && changeType == schemas.ResponsesRequest {
responsesRequest := req.BifrostRequest.ChatRequest.ToResponsesRequest()
if responsesRequest != nil {
return provider.ResponsesStream(req.Context, wrapConvertedStreamPostHookRunner(postHookRunner, schemas.ResponsesRequest), postHookSpanFinalizer, key, responsesRequest)
}
}
return provider.ChatCompletionStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.ChatRequest)
case schemas.ResponsesStreamRequest:
return provider.ResponsesStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.ResponsesRequest)
case schemas.SpeechStreamRequest:
return provider.SpeechStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.SpeechRequest)
case schemas.TranscriptionStreamRequest:
return provider.TranscriptionStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.TranscriptionRequest)
case schemas.ImageGenerationStreamRequest:
return provider.ImageGenerationStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.ImageGenerationRequest)
case schemas.ImageEditStreamRequest:
return provider.ImageEditStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.ImageEditRequest)
case schemas.PassthroughStreamRequest:
return provider.PassthroughStream(req.Context, postHookRunner, postHookSpanFinalizer, key, req.BifrostRequest.PassthroughRequest)
default:
_, model, _ := req.BifrostRequest.GetRequestFields()
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: fmt.Sprintf("unsupported request type: %s", req.RequestType),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: req.RequestType,
Provider: provider.GetProviderKey(),
OriginalModelRequested: model,
ResolvedModelUsed: model,
},
}
}
}
// handleMCPToolExecution is the common handler for MCP tool execution with plugin pipeline support.
// It handles pre-hooks, execution, post-hooks, and error handling for both Chat and Responses formats.
//
// Parameters:
// - ctx: Execution context
// - mcpRequest: The MCP request to execute (already populated with tool call)
// - requestType: The request type for error reporting (ChatCompletionRequest or ResponsesRequest)
//
// Returns:
// - *schemas.BifrostMCPResponse: The MCP response after all hooks
// - *schemas.BifrostError: Any execution error
func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpRequest *schemas.BifrostMCPRequest, requestType schemas.RequestType) (*schemas.BifrostMCPResponse, *schemas.BifrostError) {
if bifrost.MCPManager == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "mcp is not configured in this bifrost instance",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: requestType,
},
}
}
// Ensure request ID exists for hooks/tracing consistency
if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok {
ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String())
}
// Get plugin pipeline for MCP hooks
pipeline := bifrost.getPluginPipeline()
defer bifrost.releasePluginPipeline(pipeline)
// Run pre-hooks
preReq, shortCircuit, preCount := pipeline.RunMCPPreHooks(ctx, mcpRequest)
// Handle short-circuit cases
if shortCircuit != nil {
// Handle short-circuit with response (success case)
if shortCircuit.Response != nil {
finalMcpResp, bifrostErr := pipeline.RunMCPPostHooks(ctx, shortCircuit.Response, nil, preCount)
drainAndAttachPluginLogs(ctx)
if bifrostErr != nil {
return nil, bifrostErr
}
return finalMcpResp, nil
}
// Handle short-circuit with error
if shortCircuit.Error != nil {
// Capture post-hook results to respect transformations or recovery
finalResp, finalErr := pipeline.RunMCPPostHooks(ctx, nil, shortCircuit.Error, preCount)
drainAndAttachPluginLogs(ctx)
// Return post-hook error if present (post-hook may have transformed the error)
if finalErr != nil {
return nil, finalErr
}
// Return post-hook response if present (post-hook may have recovered from error)
if finalResp != nil {
return finalResp, nil
}
// Fall back to original short-circuit error if post-hooks returned nil/nil
return nil, shortCircuit.Error
}
}
if preReq == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "MCP request after plugin hooks cannot be nil",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: requestType,
},
}
}
// Execute tool with modified request
result, err := bifrost.MCPManager.ExecuteToolCall(ctx, preReq)
// Prepare MCP response and error for post-hooks
var mcpResp *schemas.BifrostMCPResponse
var bifrostErr *schemas.BifrostError
if err != nil {
bifrostErr = &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: err.Error(),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: requestType,
},
}
// Preserve MCPUserOAuthRequiredError for downstream detection in agent mode
var oauthErr *schemas.MCPUserOAuthRequiredError
if errors.As(err, &oauthErr) {
bifrostErr.ExtraFields.MCPAuthRequired = oauthErr
}
} else if result == nil {
bifrostErr = &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: "tool execution returned nil result",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: requestType,
},
}
} else {
// Use the MCP response directly
mcpResp = result
}
// Run post-hooks
finalResp, finalErr := pipeline.RunMCPPostHooks(ctx, mcpResp, bifrostErr, preCount)
drainAndAttachPluginLogs(ctx)
if finalErr != nil {
return nil, finalErr
}
return finalResp, nil
}
// executeMCPToolWithHooks is a wrapper around handleMCPToolExecution that matches the signature
// expected by the agent's executeToolFunc parameter. It runs MCP plugin hooks before and after
// tool execution to enable logging, telemetry, and other plugin functionality.
func (bifrost *Bifrost) executeMCPToolWithHooks(ctx *schemas.BifrostContext, request *schemas.BifrostMCPRequest) (*schemas.BifrostMCPResponse, error) {
// Defensive check: context must be non-nil to prevent panics in plugin hooks
if ctx == nil {
return nil, fmt.Errorf("context cannot be nil")
}
if request == nil {
return nil, fmt.Errorf("request cannot be nil")
}
// Determine request type from the MCP request - explicitly handle all known types
var requestType schemas.RequestType
switch request.RequestType {
case schemas.MCPRequestTypeChatToolCall:
requestType = schemas.ChatCompletionRequest
case schemas.MCPRequestTypeResponsesToolCall:
requestType = schemas.ResponsesRequest
default:
// Return error for unknown/unsupported request types instead of silently defaulting
return nil, fmt.Errorf("unsupported MCP request type: %s", request.RequestType)
}
resp, bifrostErr := bifrost.handleMCPToolExecution(ctx, request, requestType)
if bifrostErr != nil {
if bifrostErr.ExtraFields.MCPAuthRequired != nil {
return nil, bifrostErr.ExtraFields.MCPAuthRequired
}
return nil, fmt.Errorf("%s", GetErrorMessage(bifrostErr))
}
return resp, nil
}
// PLUGIN MANAGEMENT
// RunLLMPreHooks executes PreHooks in order, tracks how many ran, and returns the final request, any short-circuit decision, and the count.
func (p *PluginPipeline) RunLLMPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, int) {
// If the skip plugin pipeline flag is set, skip the plugin pipeline
if skipPluginPipeline, ok := ctx.Value(schemas.BifrostContextKeySkipPluginPipeline).(bool); ok && skipPluginPipeline {
return req, nil, 0
}
var shortCircuit *schemas.LLMPluginShortCircuit
var err error
ctx.BlockRestrictedWrites()
defer ctx.UnblockRestrictedWrites()
for i, plugin := range p.llmPlugins {
pluginName := plugin.GetName()
p.logger.Debug("running pre-hook for plugin %s", pluginName)
// Start span for this plugin's PreLLMHook
spanCtx, handle := p.tracer.StartSpan(ctx, fmt.Sprintf("plugin.%s.prehook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin)
// Update pluginCtx with span context for nested operations
if spanCtx != nil {
if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok {
ctx.SetValue(schemas.BifrostContextKeySpanID, spanID)
}
}
pluginCtx := ctx.WithPluginScope(&pluginName)
req, shortCircuit, err = plugin.PreLLMHook(pluginCtx, req)
pluginCtx.ReleasePluginScope()
// End span with appropriate status
if err != nil {
p.tracer.SetAttribute(handle, "error", err.Error())
p.tracer.EndSpan(handle, schemas.SpanStatusError, err.Error())
p.preHookErrors = append(p.preHookErrors, err)
p.logger.Warn("error in PreLLMHook for plugin %s: %s", pluginName, err.Error())
} else if shortCircuit != nil {
p.tracer.SetAttribute(handle, "short_circuit", true)
p.tracer.EndSpan(handle, schemas.SpanStatusOk, "short-circuit")
} else {
p.tracer.EndSpan(handle, schemas.SpanStatusOk, "")
}
p.executedPreHooks = i + 1
if shortCircuit != nil {
return req, shortCircuit, p.executedPreHooks // short-circuit: only plugins up to and including i ran
}
}
return req, nil, p.executedPreHooks
}
// RunPostLLMHooks executes PostHooks in reverse order for the plugins whose PreLLMHook ran.
// Accepts the response and error, and allows plugins to transform either (e.g., recover from error, or invalidate a response).
// Returns the final response and error after all hooks. If both are set, error takes precedence unless error is nil.
// runFrom is the count of plugins whose PreHooks ran; PostHooks will run in reverse from index (runFrom - 1) down to 0
// For streaming requests, it accumulates timing per plugin instead of creating individual spans per chunk.
func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostResponse, *schemas.BifrostError) {
// If the skip plugin pipeline flag is set, skip the plugin pipeline
if skipPluginPipeline, ok := ctx.Value(schemas.BifrostContextKeySkipPluginPipeline).(bool); ok && skipPluginPipeline {
return resp, bifrostErr
}
// Defensive: ensure count is within valid bounds
if runFrom < 0 {
runFrom = 0
}
if runFrom > len(p.llmPlugins) {
runFrom = len(p.llmPlugins)
}
requestType, _, _, _ := GetResponseFields(resp, bifrostErr)
// Realtime turns carry StreamStartTime for plugin latency/final-chunk context,
// but they are finalized as one completed turn, not chunk-by-chunk stream output.
isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil && requestType != schemas.RealtimeRequest
ctx.BlockRestrictedWrites()
defer ctx.UnblockRestrictedWrites()
var err error
for i := runFrom - 1; i >= 0; i-- {
plugin := p.llmPlugins[i]
pluginName := plugin.GetName()
p.logger.Debug("running post-hook for plugin %s", pluginName)
if isStreaming {
// For streaming: accumulate timing, don't create individual spans per chunk
// Lazily create cached scoped contexts on first chunk (reused across all chunks)
if p.streamScopedCtxs == nil {
p.streamScopedCtxs = make(map[string]*schemas.BifrostContext, len(p.llmPlugins))
for _, pl := range p.llmPlugins {
name := pl.GetName()
p.streamScopedCtxs[name] = ctx.WithPluginScope(&name)
}
}
pluginCtx := p.streamScopedCtxs[pluginName]
start := time.Now()
resp, bifrostErr, err = plugin.PostLLMHook(pluginCtx, resp, bifrostErr)
duration := time.Since(start)
p.accumulatePluginTiming(pluginName, duration, err != nil)
if err != nil {
p.postHookErrors = append(p.postHookErrors, err)
p.logger.Warn("error in PostLLMHook for plugin %s: %v", pluginName, err)
}
} else {
// For non-streaming: create span per plugin (existing behavior)
spanCtx, handle := p.tracer.StartSpan(ctx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin)
// Update pluginCtx with span context for nested operations
if spanCtx != nil {
if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok {
ctx.SetValue(schemas.BifrostContextKeySpanID, spanID)
}
}
pluginCtx := ctx.WithPluginScope(&pluginName)
resp, bifrostErr, err = plugin.PostLLMHook(pluginCtx, resp, bifrostErr)
pluginCtx.ReleasePluginScope()
// End span with appropriate status
if err != nil {
p.tracer.SetAttribute(handle, "error", err.Error())
p.tracer.EndSpan(handle, schemas.SpanStatusError, err.Error())
p.postHookErrors = append(p.postHookErrors, err)
p.logger.Warn("error in PostLLMHook for plugin %s: %v", pluginName, err)
} else {
p.tracer.EndSpan(handle, schemas.SpanStatusOk, "")
}
}
// If a plugin recovers from an error (sets bifrostErr to nil and sets resp), allow that
// If a plugin invalidates a response (sets resp to nil and sets bifrostErr), allow that
}
// Increment chunk count for streaming
if isStreaming {
p.streamingMu.Lock()
p.chunkCount++
p.streamingMu.Unlock()
}
// Final logic: if both are set, error takes precedence, unless error is nil
if bifrostErr != nil {
if resp != nil && bifrostErr.StatusCode == nil && bifrostErr.Error != nil && bifrostErr.Error.Type == nil &&
bifrostErr.Error.Message == "" && bifrostErr.Error.Error == nil {
// Defensive: treat as recovery if error is empty
return resp, nil
}
return resp, bifrostErr
}
return resp, nil
}
// RunMCPPreHooks executes MCP PreHooks in order for all registered MCP plugins.
// Returns the modified request, any short-circuit decision, and the count of hooks that ran.
// If a plugin short-circuits, only PostHooks for plugins up to and including that plugin will run.
func (p *PluginPipeline) RunMCPPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, int) {
// If the skip plugin pipeline flag is set, skip the plugin pipeline
if skipPluginPipeline, ok := ctx.Value(schemas.BifrostContextKeySkipPluginPipeline).(bool); ok && skipPluginPipeline {
return req, nil, 0
}
var shortCircuit *schemas.MCPPluginShortCircuit
var err error
ctx.BlockRestrictedWrites()
defer ctx.UnblockRestrictedWrites()
for i, plugin := range p.mcpPlugins {
pluginName := plugin.GetName()
p.logger.Debug("running MCP pre-hook for plugin %s", pluginName)
// Start span for this plugin's PreMCPHook
spanCtx, handle := p.tracer.StartSpan(ctx, fmt.Sprintf("plugin.%s.mcp_prehook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin)
// Update pluginCtx with span context for nested operations
if spanCtx != nil {
if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok {
ctx.SetValue(schemas.BifrostContextKeySpanID, spanID)
}
}
pluginCtx := ctx.WithPluginScope(&pluginName)
req, shortCircuit, err = plugin.PreMCPHook(pluginCtx, req)
pluginCtx.ReleasePluginScope()
// End span with appropriate status
if err != nil {
p.tracer.SetAttribute(handle, "error", err.Error())
p.tracer.EndSpan(handle, schemas.SpanStatusError, err.Error())
p.preHookErrors = append(p.preHookErrors, err)
p.logger.Warn("error in PreMCPHook for plugin %s: %s", pluginName, err.Error())
} else if shortCircuit != nil {
p.tracer.SetAttribute(handle, "short_circuit", true)
p.tracer.EndSpan(handle, schemas.SpanStatusOk, "short-circuit")
} else {
p.tracer.EndSpan(handle, schemas.SpanStatusOk, "")
}
p.executedPreHooks = i + 1
if shortCircuit != nil {
return req, shortCircuit, p.executedPreHooks // short-circuit: only plugins up to and including i ran
}
}
return req, nil, p.executedPreHooks
}
// RunMCPPostHooks executes MCP PostHooks in reverse order for the plugins whose PreMCPHook ran.
// Accepts the MCP response and error, and allows plugins to transform either (e.g., recover from error, or invalidate a response).
// Returns the final MCP response and error after all hooks. If both are set, error takes precedence unless error is nil.
// runFrom is the count of plugins whose PreHooks ran; PostHooks will run in reverse from index (runFrom - 1) down to 0
func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostMCPResponse, *schemas.BifrostError) {
// If the skip plugin pipeline flag is set, skip the plugin pipeline
if skipPluginPipeline, ok := ctx.Value(schemas.BifrostContextKeySkipPluginPipeline).(bool); ok && skipPluginPipeline {
return mcpResp, bifrostErr
}
// Defensive: ensure count is within valid bounds
if runFrom < 0 {
runFrom = 0
}
if runFrom > len(p.mcpPlugins) {
runFrom = len(p.mcpPlugins)
}
ctx.BlockRestrictedWrites()
defer ctx.UnblockRestrictedWrites()
var err error
for i := runFrom - 1; i >= 0; i-- {
plugin := p.mcpPlugins[i]
pluginName := plugin.GetName()
p.logger.Debug("running MCP post-hook for plugin %s", pluginName)
// Create span per plugin
spanCtx, handle := p.tracer.StartSpan(ctx, fmt.Sprintf("plugin.%s.mcp_posthook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin)
// Update pluginCtx with span context for nested operations
if spanCtx != nil {
if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok {
ctx.SetValue(schemas.BifrostContextKeySpanID, spanID)
}
}
pluginCtx := ctx.WithPluginScope(&pluginName)
mcpResp, bifrostErr, err = plugin.PostMCPHook(pluginCtx, mcpResp, bifrostErr)
pluginCtx.ReleasePluginScope()
// End span with appropriate status
if err != nil {
p.tracer.SetAttribute(handle, "error", err.Error())
p.tracer.EndSpan(handle, schemas.SpanStatusError, err.Error())
p.postHookErrors = append(p.postHookErrors, err)
p.logger.Warn("error in PostMCPHook for plugin %s: %v", pluginName, err)
} else {
p.tracer.EndSpan(handle, schemas.SpanStatusOk, "")
}
// If a plugin recovers from an error (sets bifrostErr to nil and sets mcpResp), allow that
// If a plugin invalidates a response (sets mcpResp to nil and sets bifrostErr), allow that
}
// Final logic: if both are set, error takes precedence, unless error is nil
if bifrostErr != nil {
if mcpResp != nil && bifrostErr.StatusCode == nil && bifrostErr.Error != nil && bifrostErr.Error.Type == nil &&
bifrostErr.Error.Message == "" && bifrostErr.Error.Error == nil {
// Defensive: treat as recovery if error is empty
return mcpResp, nil
}
return mcpResp, bifrostErr
}
return mcpResp, nil
}
// resetPluginPipeline resets a PluginPipeline instance for reuse.
// IMPORTANT: drainAndAttachPluginLogs must be called on the root BifrostContext
// BEFORE this method, because it calls ReleasePluginScope on cached scoped contexts
// which nils out their pluginLogs pointer. The drain reads from the shared store
// on the root context, so it must happen while the store is still referenced.
func (p *PluginPipeline) resetPluginPipeline() {
// Drop cross-request references while the object sits in the pool.
// getPluginPipeline rebinds all four on acquisition, so nil'ing here
// only affects GC hygiene — important when plugins are hot-swapped.
p.llmPlugins = nil
p.mcpPlugins = nil
p.executedPreHooks = 0
clear(p.preHookErrors)
p.preHookErrors = p.preHookErrors[:0]
clear(p.postHookErrors)
p.postHookErrors = p.postHookErrors[:0]
// Reset streaming timing accumulation under lock — the provider goroutine's
// deferred finalizer may still be iterating these fields when the pipeline
// is returned to the pool. logger/tracer are nilled here too so the write
// is synchronized with the finalizer's read under the same mutex.
p.streamingMu.Lock()
p.logger = nil
p.tracer = nil
p.chunkCount = 0
if p.postHookTimings != nil {
// clear() drops *pluginTimingAccumulator values (freeing them for GC)
// while retaining the map's backing hash table for reuse.
clear(p.postHookTimings)
}
// clear() zeros elements in [0, len) — scrub before [:0] so the backing
// array doesn't retain live string references once the slice is truncated.
clear(p.postHookPluginOrder)
p.postHookPluginOrder = p.postHookPluginOrder[:0]
// Release cached scoped contexts for streaming
for _, scopedCtx := range p.streamScopedCtxs {
scopedCtx.ReleasePluginScope()
}
p.streamScopedCtxs = nil
p.streamingMu.Unlock()
}
// drainAndAttachPluginLogs drains accumulated plugin logs from the BifrostContext
// and attaches them to the trace for later retrieval by observability plugins.
func drainAndAttachPluginLogs(ctx *schemas.BifrostContext) {
tracer, traceID, err := GetTracerFromContext(ctx)
if err != nil || tracer == nil || traceID == "" {
return
}
logs := ctx.DrainPluginLogs()
if len(logs) == 0 {
return
}
tracer.AttachPluginLogs(traceID, logs)
}
// accumulatePluginTiming accumulates timing for a plugin during streaming
func (p *PluginPipeline) accumulatePluginTiming(pluginName string, duration time.Duration, hasError bool) {
p.streamingMu.Lock()
defer p.streamingMu.Unlock()
if p.postHookTimings == nil {
p.postHookTimings = make(map[string]*pluginTimingAccumulator)
}
timing, ok := p.postHookTimings[pluginName]
if !ok {
timing = &pluginTimingAccumulator{}
p.postHookTimings[pluginName] = timing
// Track order on first occurrence (first chunk)
p.postHookPluginOrder = append(p.postHookPluginOrder, pluginName)
}
timing.totalDuration += duration
timing.invocations++
if hasError {
timing.errors++
}
}
// FinalizeStreamingPostHookSpans creates aggregated spans for each plugin after streaming completes.
// This should be called once at the end of streaming to create one span per plugin with average timing.
// Spans are nested to mirror the pre-hook hierarchy (each post-hook is a child of the previous one).
func (p *PluginPipeline) FinalizeStreamingPostHookSpans(ctx context.Context) {
// Snapshot the accumulators under lock so per-chunk writers in the
// provider goroutine can't race with the finalizer. Tracer calls below
// run unlocked — we don't want to stall chunk writers on span I/O.
type snapshotEntry struct {
pluginName string
totalDuration time.Duration
invocations int
errors int
}
p.streamingMu.Lock()
// Capture tracer under the same lock that guards resetPluginPipeline's
// writes so the read/write pair on p.tracer is synchronized and the
// unlocked tracer calls below use a stable local.
tracer := p.tracer
if tracer == nil || p.postHookTimings == nil || len(p.postHookPluginOrder) == 0 {
p.streamingMu.Unlock()
return
}
snapshot := make([]snapshotEntry, 0, len(p.postHookPluginOrder))
for _, pluginName := range p.postHookPluginOrder {
timing, ok := p.postHookTimings[pluginName]
if !ok || timing.invocations == 0 {
continue
}
snapshot = append(snapshot, snapshotEntry{
pluginName: pluginName,
totalDuration: timing.totalDuration,
invocations: timing.invocations,
errors: timing.errors,
})
}
p.streamingMu.Unlock()
if len(snapshot) == 0 {
return
}
// Collect handles and timing info to end spans in reverse order
type spanInfo struct {
handle schemas.SpanHandle
hasErrors bool
}
spans := make([]spanInfo, 0, len(snapshot))
currentCtx := ctx
// Start spans in execution order (nested: each is a child of the previous)
for _, entry := range snapshot {
// Create span as child of the previous span (nested hierarchy)
newCtx, handle := tracer.StartSpan(currentCtx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(entry.pluginName)), schemas.SpanKindPlugin)
if handle == nil {
continue
}
// Calculate average duration in milliseconds
avgMs := float64(entry.totalDuration.Milliseconds()) / float64(entry.invocations)
// Set aggregated attributes
tracer.SetAttribute(handle, schemas.AttrPluginInvocations, entry.invocations)
tracer.SetAttribute(handle, schemas.AttrPluginAvgDurationMs, avgMs)
tracer.SetAttribute(handle, schemas.AttrPluginTotalDurationMs, entry.totalDuration.Milliseconds())
if entry.errors > 0 {
tracer.SetAttribute(handle, schemas.AttrPluginErrorCount, entry.errors)
}
spans = append(spans, spanInfo{handle: handle, hasErrors: entry.errors > 0})
currentCtx = newCtx
}
// End spans in reverse order (innermost first, like unwinding a call stack)
for i := len(spans) - 1; i >= 0; i-- {
if spans[i].hasErrors {
tracer.EndSpan(spans[i].handle, schemas.SpanStatusError, "some invocations failed")
} else {
tracer.EndSpan(spans[i].handle, schemas.SpanStatusOk, "")
}
}
}
// GetChunkCount returns the number of chunks processed during streaming
func (p *PluginPipeline) GetChunkCount() int {
p.streamingMu.Lock()
defer p.streamingMu.Unlock()
return p.chunkCount
}
// getPluginPipeline gets a PluginPipeline from the pool and configures it
func (bifrost *Bifrost) getPluginPipeline() *PluginPipeline {
pipeline := bifrost.pluginPipelinePool.Get().(*PluginPipeline)
pipeline.llmPlugins = *bifrost.llmPlugins.Load()
pipeline.mcpPlugins = *bifrost.mcpPlugins.Load()
pipeline.logger = bifrost.logger
pipeline.tracer = bifrost.getTracer()
return pipeline
}
// releasePluginPipeline returns a PluginPipeline to the pool.
// Caller must ensure drainAndAttachPluginLogs has already been called on the
// associated BifrostContext before calling this method.
func (bifrost *Bifrost) releasePluginPipeline(pipeline *PluginPipeline) {
pipeline.resetPluginPipeline()
bifrost.pluginPipelinePool.Put(pipeline)
}
// POOL & RESOURCE MANAGEMENT
// getChannelMessage gets a ChannelMessage from the pool and configures it with the request.
// It also gets response and error channels from their respective pools.
func (bifrost *Bifrost) getChannelMessage(req schemas.BifrostRequest) *ChannelMessage {
// Get channels from pool
responseChan := bifrost.responseChannelPool.Get().(chan *schemas.BifrostResponse)
errorChan := bifrost.errorChannelPool.Get().(chan schemas.BifrostError)
// Clear any previous values to avoid leaking between requests
select {
case <-responseChan:
default:
}
select {
case <-errorChan:
default:
}
// Get message from pool and configure it
msg := bifrost.channelMessagePool.Get().(*ChannelMessage)
msg.BifrostRequest = req
msg.Response = responseChan
msg.Err = errorChan
// Conditionally allocate ResponseStream for streaming requests only
if IsStreamRequestType(req.RequestType) {
responseStreamChan := bifrost.responseStreamPool.Get().(chan chan *schemas.BifrostStreamChunk)
// Clear any previous values to avoid leaking between requests
select {
case <-responseStreamChan:
default:
}
msg.ResponseStream = responseStreamChan
}
return msg
}
// drainQueueWithErrors drains all buffered messages from pq and sends each a
// "provider is shutting down" error. It must be called after all workers for
// the queue have exited (i.e. after wg.Wait()) to cover the TOCTOU window:
// a producer that passed isClosing() just before signalClosing fired can still
// win the `case pq.queue <- msg` branch in tryRequest, landing a message in
// the queue after the last worker's drain loop already exited via `default:`.
// Without this sweep, those callers block forever on <-msg.Response / <-msg.Err.
//
// Residual TOCTOU window (known limitation): this sweep runs exactly once via
// a non-blocking `select { default: }`. A producer that deposits a message
// after the sweep's `default:` branch exits has no worker and no sweep to drain
// it — the caller will block until its own context is cancelled. Fully closing
// this window requires a sender-side reference count (so the last producer can
// signal "queue is fully idle"), which is intentionally not implemented because
// it would add per-send atomic overhead on the hot path.
func (bifrost *Bifrost) drainQueueWithErrors(pq *ProviderQueue) {
for {
select {
case r := <-pq.queue:
provKey, mod, _ := r.GetRequestFields()
select {
case r.Err <- schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{Message: "provider is shutting down"},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: r.RequestType,
Provider: provKey,
OriginalModelRequested: mod,
},
}:
case <-r.Context.Done():
// No time.After needed: r.Err is a buffered channel of size 1 freshly
// allocated per request, so the send always completes immediately unless
// the caller already cancelled. ctx.Done() is the only valid escape.
}
default:
return
}
}
}
// releaseChannelMessage returns a ChannelMessage and its channels to their respective pools.
func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) {
// Put channels back in pools
bifrost.responseChannelPool.Put(msg.Response)
bifrost.errorChannelPool.Put(msg.Err)
// Return ResponseStream to pool if it was used
if msg.ResponseStream != nil {
// Drain any remaining channels to prevent memory leaks
select {
case <-msg.ResponseStream:
default:
}
bifrost.responseStreamPool.Put(msg.ResponseStream)
}
// Release of Bifrost Request is handled in handle methods as they are required for fallbacks
// Clear references and return to pool
msg.Response = nil
msg.ResponseStream = nil
msg.Err = nil
bifrost.channelMessagePool.Put(msg)
}
// resetBifrostRequest resets a BifrostRequest instance for reuse
func resetBifrostRequest(req *schemas.BifrostRequest) {
req.RequestType = ""
req.ListModelsRequest = nil
req.TextCompletionRequest = nil
req.ChatRequest = nil
req.ResponsesRequest = nil
req.CountTokensRequest = nil
req.EmbeddingRequest = nil
req.RerankRequest = nil
req.OCRRequest = nil
req.SpeechRequest = nil
req.TranscriptionRequest = nil
req.ImageGenerationRequest = nil
req.ImageEditRequest = nil
req.ImageVariationRequest = nil
req.VideoGenerationRequest = nil
req.VideoRetrieveRequest = nil
req.VideoDownloadRequest = nil
req.VideoListRequest = nil
req.VideoRemixRequest = nil
req.VideoDeleteRequest = nil
req.FileUploadRequest = nil
req.FileListRequest = nil
req.FileRetrieveRequest = nil
req.FileDeleteRequest = nil
req.FileContentRequest = nil
req.BatchCreateRequest = nil
req.BatchListRequest = nil
req.BatchRetrieveRequest = nil
req.BatchCancelRequest = nil
req.BatchDeleteRequest = nil
req.BatchResultsRequest = nil
req.ContainerCreateRequest = nil
req.ContainerListRequest = nil
req.ContainerRetrieveRequest = nil
req.ContainerDeleteRequest = nil
req.ContainerFileCreateRequest = nil
req.ContainerFileListRequest = nil
req.ContainerFileRetrieveRequest = nil
req.ContainerFileContentRequest = nil
req.ContainerFileDeleteRequest = nil
req.PassthroughRequest = nil
}
// getBifrostRequest gets a BifrostRequest from the pool
func (bifrost *Bifrost) getBifrostRequest() *schemas.BifrostRequest {
req := bifrost.bifrostRequestPool.Get().(*schemas.BifrostRequest)
return req
}
// releaseBifrostRequest returns a BifrostRequest to the pool
func (bifrost *Bifrost) releaseBifrostRequest(req *schemas.BifrostRequest) {
resetBifrostRequest(req)
bifrost.bifrostRequestPool.Put(req)
}
// resetMCPRequest resets a BifrostMCPRequest instance for reuse
func resetMCPRequest(req *schemas.BifrostMCPRequest) {
req.RequestType = ""
req.ChatAssistantMessageToolCall = nil
req.ResponsesToolMessage = nil
}
// getMCPRequest gets a BifrostMCPRequest from the pool
func (bifrost *Bifrost) getMCPRequest() *schemas.BifrostMCPRequest {
req := bifrost.mcpRequestPool.Get().(*schemas.BifrostMCPRequest)
return req
}
// releaseMCPRequest returns a BifrostMCPRequest to the pool
func (bifrost *Bifrost) releaseMCPRequest(req *schemas.BifrostMCPRequest) {
resetMCPRequest(req)
bifrost.mcpRequestPool.Put(req)
}
// getAllSupportedKeys retrieves all valid keys for a ListModels request.
// allowing the provider to aggregate results from multiple keys.
func (bifrost *Bifrost) getAllSupportedKeys(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider, baseProviderType schemas.ModelProvider) ([]schemas.Key, error) {
// Check if key has been set in the context explicitly
if ctx != nil {
key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key)
if ok {
if err := validateKey(baseProviderType, &key); err != nil {
return nil, fmt.Errorf("invalid direct key for provider %v: %w", baseProviderType, err)
}
// If a direct key is specified, return it as a single-element slice
return []schemas.Key{key}, nil
}
}
keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey)
if err != nil {
return nil, err
}
if len(keys) == 0 {
return nil, fmt.Errorf("no keys found for provider: %v", providerKey)
}
// Filter keys for ListModels - only check if key has a value
var supportedKeys []schemas.Key
for _, key := range keys {
// Skip disabled keys (default enabled when nil)
if key.Enabled != nil && !*key.Enabled {
continue
}
if err := validateKey(baseProviderType, &key); err != nil {
bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", key.Name, key.ID, providerKey, err.Error())
continue
}
if strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) {
supportedKeys = append(supportedKeys, key)
}
}
bifrost.logger.Debug("[Bifrost] Provider %s: %d valid keys found", providerKey, len(supportedKeys))
if len(supportedKeys) == 0 {
return nil, fmt.Errorf("no valid keys found for provider: %v", providerKey)
}
return supportedKeys, nil
}
// getKeysForBatchAndFileOps retrieves keys for batch and file operations with model filtering.
// For batch operations, only keys with UseForBatchAPI enabled are included.
// Model filtering: if model is specified and key has model restrictions, only include if model is in list.
func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider, baseProviderType schemas.ModelProvider, model *string, isBatchOp bool) ([]schemas.Key, error) {
// Check if key has been set in the context explicitly
if ctx != nil {
key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key)
if ok {
if err := validateKey(baseProviderType, &key); err != nil {
return nil, fmt.Errorf("invalid direct key for provider %v: %w", baseProviderType, err)
}
// If a direct key is specified, return it as a single-element slice
return []schemas.Key{key}, nil
}
}
keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey)
if err != nil {
return nil, err
}
if len(keys) == 0 {
return nil, fmt.Errorf("no keys found for provider: %v", providerKey)
}
var filteredKeys []schemas.Key
for _, k := range keys {
// Skip disabled keys
if k.Enabled != nil && !*k.Enabled {
continue
}
// For batch operations, only include keys with UseForBatchAPI enabled
if isBatchOp && (k.UseForBatchAPI == nil || !*k.UseForBatchAPI) {
continue
}
if err := validateKey(baseProviderType, &k); err != nil {
bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", k.Name, k.ID, providerKey, err.Error())
continue
}
// Model filtering logic:
// - If model is nil or empty → include all keys (no model filter)
// - If model is specified:
// - If model is in key.BlacklistedModels → exclude (wins over Models allow list)
// - If key.Models is ["*"] → include key (supports all non-blacklisted models)
// - If key.Models is empty → exclude key (deny-by-default)
// - If key.Models is non-empty → only include if model is in list
// Blacklist wins over allowlist
if model != nil && *model != "" {
if k.BlacklistedModels.IsBlocked(*model) || !k.Models.IsAllowed(*model) {
continue
}
}
// Check key value (or if provider allows empty keys or has Azure Entra ID credentials)
if strings.TrimSpace(k.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) {
filteredKeys = append(filteredKeys, k)
}
}
if len(filteredKeys) == 0 {
modelStr := ""
if model != nil {
modelStr = *model
}
if isBatchOp {
return nil, fmt.Errorf("no batch-enabled keys found for provider: %v and model: %s", providerKey, modelStr)
}
return nil, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, modelStr)
}
// Sort keys by ID for deterministic pagination order across requests
sort.Slice(filteredKeys, func(i, j int) bool {
return filteredKeys[i].ID < filteredKeys[j].ID
})
return filteredKeys, nil
}
// selectKeyFromProviderForModelWithPool returns the filtered pool of eligible keys for the given
// provider/model, along with a canRotate flag indicating whether key rotation across retries is
// permitted. Key selection (choosing which key to use) is deferred to executeRequestWithRetries
// via the keyProvider closure built by the caller.
//
// canRotate=false is returned for cases where the caller must always use the same key:
// - DirectKey (caller-supplied key bypasses all selection)
// - SkipKeySelection (provider allows keyless requests; empty slice returned)
// - Explicit BifrostContextKeyAPIKeyID / APIKeyName (user pinned a specific key)
// - Session stickiness (key persisted in KV store for the session lifetime)
// - Single-key pool (only one eligible key — rotation is a no-op, KV write skipped)
//
// canRotate=true is returned when there are two or more eligible keys and no pinning
// or stickiness constraint is in effect.
func (bifrost *Bifrost) selectKeyFromProviderForModelWithPool(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string, baseProviderType schemas.ModelProvider) ([]schemas.Key, bool, error) {
// DirectKey: caller supplied a key directly — no pool, no rotation.
if ctx != nil {
if key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key); ok {
if err := validateKey(baseProviderType, &key); err != nil {
return nil, false, fmt.Errorf("invalid direct key for provider %v: %w", baseProviderType, err)
}
return []schemas.Key{key}, false, nil
}
}
// SkipKeySelection: provider allows keyless requests — return empty pool, no rotation.
if skipKeySelection, ok := ctx.Value(schemas.BifrostContextKeySkipKeySelection).(bool); ok && skipKeySelection && isKeySkippingAllowed(providerKey) {
return []schemas.Key{}, false, nil
}
// Get keys for provider
keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey)
if err != nil {
return nil, false, err
}
if len(keys) == 0 {
return nil, false, fmt.Errorf("no keys found for provider: %v and model: %s", providerKey, model)
}
// For batch API operations, filter keys to only include those with UseForBatchAPI enabled
if isBatchRequestType(requestType) || isFileRequestType(requestType) {
var batchEnabledKeys []schemas.Key
for _, k := range keys {
if k.UseForBatchAPI != nil && *k.UseForBatchAPI {
batchEnabledKeys = append(batchEnabledKeys, k)
}
}
if len(batchEnabledKeys) == 0 {
return nil, false, fmt.Errorf("no config found for batch apis; enable 'Use for Batch APIs' on at least one key for provider: %v", providerKey)
}
keys = batchEnabledKeys
}
// Filter out keys that don't support the model: blacklisted_models wins over models allow list;
// if the key has no models list, it supports all models except those blacklisted.
var supportedKeys []schemas.Key
// Skip model check conditions
// We can improve these conditions in the future
skipModelCheck := (model == "" && (isFileRequestType(requestType) || isBatchRequestType(requestType) || isContainerRequestType(requestType) || isModellessVideoRequestType(requestType) || isPassthroughRequestType(requestType))) || requestType == schemas.ListModelsRequest
if skipModelCheck {
// When skipping model check: just verify keys are enabled and have values
for _, key := range keys {
// Skip disabled keys
if key.Enabled != nil && !*key.Enabled {
continue
}
if err := validateKey(baseProviderType, &key); err != nil {
bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", key.Name, key.ID, providerKey, err.Error())
continue
}
if strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType) {
supportedKeys = append(supportedKeys, key)
}
}
} else {
// When NOT skipping model check: do full model filtering
for _, key := range keys {
// Skip disabled keys
if key.Enabled != nil && !*key.Enabled {
continue
}
if err := validateKey(baseProviderType, &key); err != nil {
bifrost.logger.Warn("error validating key %s (%s) for provider %s: %s, skipping key", key.Name, key.ID, providerKey, err.Error())
continue
}
hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType)
// ["*"] = allow all models; [] = deny all; specific list = allow only listed
// NOTE: Model filtering uses the original requested model (which may be an alias).
// key.Models and key.BlacklistedModels must therefore be expressed in alias keys.
// The provider-specific identifier is resolved later in the handler closure via key.Aliases.Resolve(model).
modelSupported := hasValue && key.Models.IsAllowed(model) && !key.BlacklistedModels.IsBlocked(model)
if baseProviderType == schemas.VLLM && key.VLLMKeyConfig != nil {
if key.VLLMKeyConfig.ModelName != "" {
modelSupported = modelSupported && (key.VLLMKeyConfig.ModelName == model)
}
}
if modelSupported {
supportedKeys = append(supportedKeys, key)
}
}
}
if len(supportedKeys) == 0 {
return nil, false, fmt.Errorf("no keys found that support model: %s", model)
}
// Explicit key ID takes priority over key name — pin to that key, no rotation.
if ctx != nil {
if keyID, ok := ctx.Value(schemas.BifrostContextKeyAPIKeyID).(string); ok {
if keyID = strings.TrimSpace(keyID); keyID != "" {
for _, key := range supportedKeys {
if key.ID == keyID {
return []schemas.Key{key}, false, nil
}
}
return nil, false, fmt.Errorf("no supported key found with id %q for provider: %v and model: %s", keyID, providerKey, model)
}
}
if keyName, ok := ctx.Value(schemas.BifrostContextKeyAPIKeyName).(string); ok {
if keyName = strings.TrimSpace(keyName); keyName != "" {
for _, key := range supportedKeys {
if key.Name == keyName {
return []schemas.Key{key}, false, nil
}
}
return nil, false, fmt.Errorf("no supported key found with name %q for provider: %v and model: %s", keyName, providerKey, model)
}
}
}
// Single key: no rotation possible, skip session stickiness (no KV write needed).
if len(supportedKeys) == 1 {
return []schemas.Key{supportedKeys[0]}, false, nil
}
// Session stickiness: on the first request for a session ID, the randomly selected key is
// persisted in the KV store. Subsequent requests reuse it for the session lifetime. The sticky
// key is intentionally kept fixed across all retry attempts — return it as a single-element
// pool with canRotate=false so rate-limit retries also stay on the same key.
sessionID := ""
if ctx != nil {
if id, ok := ctx.Value(schemas.BifrostContextKeySessionID).(string); ok && id != "" {
sessionID = id
}
}
fallbackIndex := 0
if ctx != nil {
fallbackIndex, _ = ctx.Value(schemas.BifrostContextKeyFallbackIndex).(int)
}
stickinessActive := sessionID != "" && bifrost.kvStore != nil && fallbackIndex == 0
if stickinessActive {
kvKey := buildSessionKey(providerKey, sessionID, model)
ttl, _ := ctx.Value(schemas.BifrostContextKeySessionTTL).(time.Duration)
if ttl <= 0 {
ttl = schemas.DefaultSessionStickyTTL
}
if cachedKey, found, stale := getCachedKeyFromStore(bifrost.kvStore, kvKey, supportedKeys); found {
if err := bifrost.kvStore.SetWithTTL(kvKey, cachedKey.ID, ttl); err != nil {
bifrost.logger.Warn("error setting session cache for provider=%s key_id=%s: %s", providerKey, cachedKey.ID, err.Error())
}
return []schemas.Key{cachedKey}, false, nil
} else if stale {
if _, err := bifrost.kvStore.Delete(kvKey); err != nil {
bifrost.logger.Warn("error deleting stale session cache for provider=%s: %s", providerKey, err.Error())
}
}
selectedKey, err := bifrost.keySelector(ctx, supportedKeys, providerKey, model)
if err != nil {
return nil, false, err
}
wasSet, err := bifrost.kvStore.SetNXWithTTL(kvKey, selectedKey.ID, ttl)
if err != nil {
bifrost.logger.Warn("error setting session cache for provider=%s key_id=%s: %s", providerKey, selectedKey.ID, err.Error())
return []schemas.Key{selectedKey}, false, nil
}
if wasSet {
return []schemas.Key{selectedKey}, false, nil
}
// Another concurrent request won the race — re-read the persisted key.
if currentKey, found, stale := getCachedKeyFromStore(bifrost.kvStore, kvKey, supportedKeys); found {
return []schemas.Key{currentKey}, false, nil
} else if stale {
if _, err := bifrost.kvStore.Delete(kvKey); err != nil {
bifrost.logger.Warn("error deleting stale session cache for provider=%s: %s", providerKey, err.Error())
}
return []schemas.Key{selectedKey}, false, nil
}
return []schemas.Key{selectedKey}, false, nil
}
// Normal case: return the full filtered pool with rotation enabled.
return supportedKeys, true, nil
}
// getCachedKeyFromStore retrieves a key ID from the KV store and looks it up in supportedKeys.
// Returns the matching Key, found (true if key exists in supportedKeys), and stale (true if
// KV contains an ID but it is not in supportedKeys—caller should delete before SetNXWithTTL).
func getCachedKeyFromStore(kvStore schemas.KVStore, kvKey string, supportedKeys []schemas.Key) (schemas.Key, bool, bool) {
raw, err := kvStore.Get(kvKey)
if err != nil {
return schemas.Key{}, false, false
}
var cachedKeyID string
switch v := raw.(type) {
case string:
cachedKeyID = v
case []byte:
var s string
if err := sonic.Unmarshal(v, &s); err == nil {
cachedKeyID = s
} else {
cachedKeyID = string(v)
}
}
if cachedKeyID != "" {
for _, k := range supportedKeys {
if k.ID == cachedKeyID {
return k, true, false
}
}
return schemas.Key{}, false, true
}
return schemas.Key{}, false, false
}
// Shutdown gracefully stops all workers when triggered.
// It closes all request channels and waits for workers to exit.
func (bifrost *Bifrost) Shutdown() {
bifrost.logger.Info("closing all request channels...")
// Cancel the context if not already done
if bifrost.ctx.Err() == nil && bifrost.cancel != nil {
bifrost.cancel()
}
// Signal all provider queues to close. Workers exit via pq.done;
// we never close pq.queue to avoid "send on closed channel" panics in
// producers that are concurrently in tryRequest.
bifrost.requestQueues.Range(func(key, value interface{}) bool {
pq := value.(*ProviderQueue)
pq.signalClosing()
return true
})
// Wait for all workers to exit
bifrost.waitGroups.Range(func(key, value interface{}) bool {
waitGroup := value.(*sync.WaitGroup)
waitGroup.Wait()
return true
})
// Final drain sweep — same reasoning as RemoveProvider's Step 3b.
bifrost.requestQueues.Range(func(key, value interface{}) bool {
bifrost.drainQueueWithErrors(value.(*ProviderQueue))
return true
})
// Cleanup MCP manager
if bifrost.MCPManager != nil {
err := bifrost.MCPManager.Cleanup()
if err != nil {
bifrost.logger.Warn("Error cleaning up MCP manager: %s", err.Error())
}
}
// Stop the tracerWrapper to clean up background goroutines
if tracerWrapper := bifrost.tracer.Load().(*tracerWrapper); tracerWrapper != nil && tracerWrapper.tracer != nil {
tracerWrapper.tracer.Stop()
}
// Cleanup plugins
if llmPlugins := bifrost.llmPlugins.Load(); llmPlugins != nil {
for _, plugin := range *llmPlugins {
err := plugin.Cleanup()
if err != nil {
bifrost.logger.Warn(fmt.Sprintf("Error cleaning up LLM plugin: %s", err.Error()))
}
}
}
if mcpPlugins := bifrost.mcpPlugins.Load(); mcpPlugins != nil {
for _, plugin := range *mcpPlugins {
err := plugin.Cleanup()
if err != nil {
bifrost.logger.Warn(fmt.Sprintf("Error cleaning up MCP plugin: %s", err.Error()))
}
}
}
bifrost.logger.Info("all request channels closed")
}