Files
bifrost/transports/bifrost-http/handlers/inference.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

3808 lines
130 KiB
Go

// Package handlers provides HTTP request handlers for the Bifrost HTTP transport.
// This file contains completion request handlers for text and chat completions.
package handlers
import (
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/bytedance/sonic"
"github.com/fasthttp/router"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// forwardProviderHeaders forwards provider response headers to the HTTP response.
func forwardProviderHeaders(ctx *fasthttp.RequestCtx, headers map[string]string) {
for key, value := range headers {
ctx.Response.Header.Set(key, value)
}
}
// forwardProviderHeadersFromContext extracts provider response headers from the bifrost context
// and forwards them to the HTTP response. This ensures error responses also include provider headers.
func forwardProviderHeadersFromContext(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) {
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
forwardProviderHeaders(ctx, headers)
}
}
// CompletionHandler manages HTTP requests for completion operations
type CompletionHandler struct {
client *bifrost.Bifrost
handlerStore lib.HandlerStore
config *lib.Config
}
// NewInferenceHandler creates a new completion handler instance
func NewInferenceHandler(client *bifrost.Bifrost, config *lib.Config) *CompletionHandler {
return &CompletionHandler{
client: client,
handlerStore: config,
config: config,
}
}
// Known fields for CompletionRequest
var textParamsKnownFields = map[string]bool{
"prompt": true,
"model": true,
"fallbacks": true,
"best_of": true,
"echo": true,
"frequency_penalty": true,
"logit_bias": true,
"logprobs": true,
"max_tokens": true,
"n": true,
"presence_penalty": true,
"seed": true,
"stop": true,
"suffix": true,
"temperature": true,
"top_p": true,
"user": true,
}
// Known fields for CompletionRequest
var chatParamsKnownFields = map[string]bool{
"model": true,
"messages": true,
"fallbacks": true,
"stream": true,
"frequency_penalty": true,
"logit_bias": true,
"logprobs": true,
"max_completion_tokens": true,
"metadata": true,
"modalities": true,
"parallel_tool_calls": true,
"presence_penalty": true,
"prompt_cache_key": true,
"reasoning": true,
"reasoning_effort": true,
"reasoning_max_tokens": true,
"response_format": true,
"safety_identifier": true,
"service_tier": true,
"stream_options": true,
"store": true,
"temperature": true,
"tool_choice": true,
"tools": true,
"truncation": true,
"user": true,
"verbosity": true,
}
var responsesParamsKnownFields = map[string]bool{
"model": true,
"input": true,
"fallbacks": true,
"stream": true,
"background": true,
"conversation": true,
"include": true,
"instructions": true,
"max_output_tokens": true,
"max_tool_calls": true,
"metadata": true,
"parallel_tool_calls": true,
"previous_response_id": true,
"prompt_cache_key": true,
"reasoning": true,
"safety_identifier": true,
"service_tier": true,
"stream_options": true,
"store": true,
"temperature": true,
"text": true,
"top_logprobs": true,
"top_p": true,
"tool_choice": true,
"tools": true,
"truncation": true,
}
var embeddingParamsKnownFields = map[string]bool{
"model": true,
"input": true,
"fallbacks": true,
"encoding_format": true,
"dimensions": true,
}
var rerankParamsKnownFields = map[string]bool{
"model": true,
"query": true,
"documents": true,
"fallbacks": true,
"top_n": true,
"max_tokens_per_doc": true,
"priority": true,
"return_documents": true,
}
var ocrParamsKnownFields = map[string]bool{
"model": true,
"id": true,
"document": true,
"fallbacks": true,
"include_image_base64": true,
"pages": true,
"image_limit": true,
"image_min_size": true,
"table_format": true,
"extract_header": true,
"extract_footer": true,
"bbox_annotation_format": true,
"document_annotation_format": true,
"document_annotation_prompt": true,
}
var speechParamsKnownFields = map[string]bool{
"model": true,
"input": true,
"fallbacks": true,
"stream_format": true,
"voice": true,
"instructions": true,
"response_format": true,
"speed": true,
}
// imageGenerationParamsKnownFields contains known fields for image generation requests
// Based on ImageGenerationInput and ImageGenerationParameters structs
var imageGenerationParamsKnownFields = map[string]bool{
"model": true,
"prompt": true,
"fallbacks": true,
"stream": true,
"n": true,
"background": true,
"moderation": true,
"partial_images": true,
"size": true,
"quality": true,
"output_compression": true,
"output_format": true,
"style": true,
"response_format": true,
"seed": true,
"negative_prompt": true,
"num_inference_steps": true,
"user": true,
"aspect_ratio": true,
"input_images": true,
}
// imageEditParamsKnownFields contains known fields for image edit requests
// Based on ImageEditInput and ImageEditParameters structs
var imageEditParamsKnownFields = map[string]bool{
"model": true,
"prompt": true,
"fallbacks": true,
"image": true,
"image[]": true,
"mask": true,
"type": true,
"background": true,
"input_fidelity": true,
"n": true,
"output_compression": true,
"output_format": true,
"partial_images": true,
"quality": true,
"response_format": true,
"size": true,
"user": true,
"negative_prompt": true,
"seed": true,
"num_inference_steps": true,
"stream": true,
}
// imageVariationParamsKnownFields contains known fields for image variation requests
// Based on ImageVariationInput and ImageVariationParameters structs
var imageVariationParamsKnownFields = map[string]bool{
"model": true,
"fallbacks": true,
"image": true,
"image[]": true,
"n": true,
"response_format": true,
"size": true,
"user": true,
}
// videoGenerationParamsKnownFields contains known fields for video generation requests
// Based on VideoGenerationInput and VideoGenerationParameters structs
var videoGenerationParamsKnownFields = map[string]bool{
"model": true,
"prompt": true,
"input_reference": true,
"seconds": true,
"size": true,
"negative_prompt": true,
"seed": true,
"video_uri": true,
"audio": true,
"fallbacks": true,
}
var videoRemixParamsKnownFields = map[string]bool{
"prompt": true,
"fallbacks": true,
}
var transcriptionParamsKnownFields = map[string]bool{
"model": true,
"file": true,
"fallbacks": true,
"stream": true,
"language": true,
"prompt": true,
"response_format": true,
"file_format": true,
}
var countTokensParamsKnownFields = map[string]bool{
"model": true,
"messages": true,
"fallbacks": true,
"tools": true,
"instructions": true,
"text": true,
}
var batchCreateParamsKnownFields = map[string]bool{
"model": true,
"input_file_id": true,
"requests": true,
"endpoint": true,
"completion_window": true,
"metadata": true,
}
var containerCreateParamsKnownFields = map[string]bool{
"provider": true,
"name": true,
"expires_after": true,
"file_ids": true,
"memory_limit": true,
"metadata": true,
}
type BifrostParams struct {
Model string `json:"model"` // Model to use in "provider/model" format
Fallbacks []string `json:"fallbacks"` // Fallback providers and models in "provider/model" format
Stream *bool `json:"stream"` // Whether to stream the response
StreamFormat *string `json:"stream_format,omitempty"` // For speech
}
type TextRequest struct {
Prompt *schemas.TextCompletionInput `json:"prompt"`
BifrostParams
*schemas.TextCompletionParameters
}
type ChatRequest struct {
Messages []schemas.ChatMessage `json:"messages"`
BifrostParams
*schemas.ChatParameters
}
// UnmarshalJSON implements custom JSON unmarshalling for ChatRequest.
// This is needed because ChatParameters has a custom UnmarshalJSON method,
// which interferes with sonic's handling of the embedded BifrostParams struct.
func (cr *ChatRequest) UnmarshalJSON(data []byte) error {
// First, unmarshal BifrostParams fields directly
type bifrostAlias BifrostParams
var bp bifrostAlias
if err := sonic.Unmarshal(data, &bp); err != nil {
return err
}
cr.BifrostParams = BifrostParams(bp)
// Unmarshal messages
var msgStruct struct {
Messages []schemas.ChatMessage `json:"messages"`
}
if err := sonic.Unmarshal(data, &msgStruct); err != nil {
return err
}
cr.Messages = msgStruct.Messages
// Unmarshal ChatParameters (which has its own custom unmarshaller)
if cr.ChatParameters == nil {
cr.ChatParameters = &schemas.ChatParameters{}
}
if err := sonic.Unmarshal(data, cr.ChatParameters); err != nil {
return err
}
return nil
}
// ResponsesRequestInput is a union of string and array of responses messages
type ResponsesRequestInput struct {
ResponsesRequestInputStr *string
ResponsesRequestInputArray []schemas.ResponsesMessage
}
type ImageGenerationHTTPRequest struct {
*schemas.ImageGenerationInput
*schemas.ImageGenerationParameters
BifrostParams
}
type ImageEditHTTPRequest struct {
*schemas.ImageEditInput
*schemas.ImageEditParameters
BifrostParams
}
type ImageVariationHTTPRequest struct {
*schemas.ImageVariationInput
*schemas.ImageVariationParameters
BifrostParams
}
type VideoGenerationHTTPRequest struct {
*schemas.VideoGenerationInput
*schemas.VideoGenerationParameters
BifrostParams
}
// UnmarshalJSON unmarshals the responses request input
func (r *ResponsesRequestInput) UnmarshalJSON(data []byte) error {
var str string
if err := sonic.Unmarshal(data, &str); err == nil {
r.ResponsesRequestInputStr = &str
r.ResponsesRequestInputArray = nil
return nil
}
var array []schemas.ResponsesMessage
if err := sonic.Unmarshal(data, &array); err == nil {
r.ResponsesRequestInputStr = nil
r.ResponsesRequestInputArray = array
return nil
}
return fmt.Errorf("invalid responses request input")
}
// UnmarshalJSON implements custom JSON unmarshalling for ResponsesRequest.
// This is needed because ResponsesParameters has a custom UnmarshalJSON method,
// which interferes with sonic's handling of the embedded BifrostParams struct.
func (rr *ResponsesRequest) UnmarshalJSON(data []byte) error {
// First, unmarshal BifrostParams fields directly
type bifrostAlias BifrostParams
var bp bifrostAlias
if err := sonic.Unmarshal(data, &bp); err != nil {
return err
}
rr.BifrostParams = BifrostParams(bp)
// Unmarshal messages
var inputStruct struct {
Input ResponsesRequestInput `json:"input"`
}
if err := sonic.Unmarshal(data, &inputStruct); err != nil {
return err
}
rr.Input = inputStruct.Input
// Unmarshal ResponsesParameters (which has its own custom unmarshaller)
if rr.ResponsesParameters == nil {
rr.ResponsesParameters = &schemas.ResponsesParameters{}
}
if err := sonic.Unmarshal(data, rr.ResponsesParameters); err != nil {
return err
}
return nil
}
// ResponsesRequest is a bifrost responses request
type ResponsesRequest struct {
Input ResponsesRequestInput `json:"input"`
BifrostParams
*schemas.ResponsesParameters
}
// EmbeddingRequest is a bifrost embedding request
type EmbeddingRequest struct {
Input *schemas.EmbeddingInput `json:"input"`
BifrostParams
*schemas.EmbeddingParameters
}
// RerankRequest is a bifrost rerank request
type RerankRequest struct {
Query string `json:"query"`
Documents []schemas.RerankDocument `json:"documents"`
BifrostParams
*schemas.RerankParameters
}
// OCRHandlerRequest is a bifrost OCR request
type OCRHandlerRequest struct {
ID *string `json:"id,omitempty"`
Document schemas.OCRDocument `json:"document"`
BifrostParams
*schemas.OCRParameters
}
type SpeechRequest struct {
*schemas.SpeechInput
BifrostParams
*schemas.SpeechParameters
}
type TranscriptionRequest struct {
*schemas.TranscriptionInput
BifrostParams
*schemas.TranscriptionParameters
}
type VideoGenerationRequest struct {
*schemas.VideoGenerationInput
BifrostParams
*schemas.VideoGenerationParameters
}
type VideoRemixRequest struct {
*schemas.VideoGenerationInput
BifrostParams
ExtraParams map[string]any `json:"extra_params,omitempty"`
}
// BatchCreateRequest is a bifrost batch create request
type BatchCreateRequest struct {
Model string `json:"model"` // Model in "provider/model" format
InputFileID string `json:"input_file_id,omitempty"` // OpenAI-style file ID
Requests []schemas.BatchRequestItem `json:"requests,omitempty"` // Anthropic-style inline requests
Endpoint string `json:"endpoint,omitempty"` // e.g., "/v1/chat/completions"
CompletionWindow string `json:"completion_window,omitempty"` // e.g., "24h"
Metadata map[string]string `json:"metadata,omitempty"`
}
// BatchListRequest is a bifrost batch list request
type BatchListRequest struct {
Provider string `json:"provider"` // Provider name
Limit int `json:"limit,omitempty"` // Maximum number of batches to return
After *string `json:"after,omitempty"` // Cursor for pagination
Before *string `json:"before,omitempty"` // Cursor for pagination
}
// ContainerCreateRequest is a bifrost container create request
type ContainerCreateRequest struct {
Provider string `json:"provider"` // Provider name
Name string `json:"name"` // Name of the container
ExpiresAfter *schemas.ContainerExpiresAfter `json:"expires_after,omitempty"` // Expiration configuration
FileIDs []string `json:"file_ids,omitempty"` // IDs of existing files to copy into this container
MemoryLimit string `json:"memory_limit,omitempty"` // Memory limit (e.g., "1g", "4g")
Metadata map[string]string `json:"metadata,omitempty"` // User-provided metadata
}
// Helper functions
// enableRawRequestResponseForContainer sets per-request overrides to always capture and
// send back raw request/response for container operations. Container operations don't have
// model-specific content, so raw data is useful for debugging and should be enabled by default.
func enableRawRequestResponseForContainer(bifrostCtx *schemas.BifrostContext) {
bifrostCtx.SetValue(schemas.BifrostContextKeySendBackRawRequest, true)
bifrostCtx.SetValue(schemas.BifrostContextKeySendBackRawResponse, true)
bifrostCtx.SetValue(schemas.BifrostContextKeyStoreRawRequestResponse, true)
}
// parseFallbacks extracts fallbacks from string array and converts to Fallback structs
func parseFallbacks(fallbackStrings []string) ([]schemas.Fallback, error) {
fallbacks := make([]schemas.Fallback, 0, len(fallbackStrings))
for _, fallback := range fallbackStrings {
fallbackProvider, fallbackModelName := schemas.ParseModelString(fallback, "")
if fallbackProvider != "" && fallbackModelName != "" {
fallbacks = append(fallbacks, schemas.Fallback{
Provider: fallbackProvider,
Model: fallbackModelName,
})
}
}
return fallbacks, nil
}
func effectiveStream(bodyStream *bool) bool {
if bodyStream != nil {
return *bodyStream
}
return false
}
// extractExtraParams processes unknown fields from JSON data into ExtraParams
func extractExtraParams(data []byte, knownFields map[string]bool) (map[string]any, error) {
// Parse JSON to extract unknown fields
var rawData map[string]json.RawMessage
if err := sonic.Unmarshal(data, &rawData); err != nil {
return nil, err
}
// Extract unknown fields
extraParams := make(map[string]any)
for key, value := range rawData {
if !knownFields[key] {
var v any
if err := sonic.Unmarshal(value, &v); err != nil {
continue // Skip fields that can't be unmarshaled
}
extraParams[key] = v
}
}
return extraParams, nil
}
const (
// Maximum file size (25MB)
MaxFileSize = 25 * 1024 * 1024
// Primary MIME types for audio formats
AudioMimeMP3 = "audio/mpeg" // Covers MP3, MPEG, MPGA
AudioMimeMP4 = "audio/mp4" // MP4 audio
AudioMimeM4A = "audio/x-m4a" // M4A specific
AudioMimeOGG = "audio/ogg" // OGG audio
AudioMimeWAV = "audio/wav" // WAV audio
AudioMimeWEBM = "audio/webm" // WEBM audio
AudioMimeFLAC = "audio/flac" // FLAC audio
AudioMimeFLAC2 = "audio/x-flac" // Alternative FLAC
)
// PathToTypeMapping maps exact paths to request types (only for non-parameterized paths)
// Parameterized paths are set per-route in RegisterRoutes
var PathToTypeMapping = map[string]schemas.RequestType{
"/v1/completions": schemas.TextCompletionRequest,
"/v1/chat/completions": schemas.ChatCompletionRequest,
"/v1/responses": schemas.ResponsesRequest,
"/v1/embeddings": schemas.EmbeddingRequest,
"/v1/rerank": schemas.RerankRequest,
"/v1/ocr": schemas.OCRRequest,
"/v1/audio/speech": schemas.SpeechRequest,
"/v1/audio/transcriptions": schemas.TranscriptionRequest,
"/v1/images/generations": schemas.ImageGenerationRequest,
"/v1/responses/input_tokens": schemas.CountTokensRequest,
"/v1/images/edits": schemas.ImageEditRequest,
"/v1/images/variations": schemas.ImageVariationRequest,
"/v1/models": schemas.ListModelsRequest,
}
// createRequestTypeMiddleware creates a middleware that sets the request type for a specific route
func createRequestTypeMiddleware(requestType schemas.RequestType) schemas.BifrostHTTPMiddleware {
return func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
ctx.SetUserValue(schemas.BifrostContextKeyHTTPRequestType, requestType)
next(ctx)
}
}
}
// RegisterRequestTypeMiddleware handles exact path matching for non-parameterized routes
func RegisterRequestTypeMiddleware(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
if requestType, ok := PathToTypeMapping[path]; ok {
ctx.SetUserValue(schemas.BifrostContextKeyHTTPRequestType, requestType)
}
next(ctx)
}
}
// RegisterRoutes registers all completion-related routes
func (h *CompletionHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) {
// Base middlewares for all routes
baseMiddlewares := append([]schemas.BifrostHTTPMiddleware{RegisterRequestTypeMiddleware}, middlewares...)
// Model endpoints
r.GET("/v1/models", lib.ChainMiddlewares(h.listModels, baseMiddlewares...))
// Completion endpoints (non-parameterized)
r.POST("/v1/completions", lib.ChainMiddlewares(h.textCompletion, baseMiddlewares...))
r.POST("/v1/chat/completions", lib.ChainMiddlewares(h.chatCompletion, baseMiddlewares...))
r.POST("/v1/responses", lib.ChainMiddlewares(h.responses, baseMiddlewares...))
r.POST("/v1/embeddings", lib.ChainMiddlewares(h.embeddings, baseMiddlewares...))
r.POST("/v1/rerank", lib.ChainMiddlewares(h.rerank, baseMiddlewares...))
r.POST("/v1/ocr", lib.ChainMiddlewares(h.ocr, baseMiddlewares...))
r.POST("/v1/audio/speech", lib.ChainMiddlewares(h.speech, baseMiddlewares...))
r.POST("/v1/audio/transcriptions", lib.ChainMiddlewares(h.transcription, baseMiddlewares...))
r.POST("/v1/images/generations", lib.ChainMiddlewares(h.imageGeneration, baseMiddlewares...))
r.POST("/v1/responses/input_tokens", lib.ChainMiddlewares(h.countTokens, baseMiddlewares...))
r.POST("/v1/images/edits", lib.ChainMiddlewares(h.imageEdit, baseMiddlewares...))
r.POST("/v1/images/variations", lib.ChainMiddlewares(h.imageVariation, baseMiddlewares...))
r.POST("/v1/videos", lib.ChainMiddlewares(h.videoGeneration, baseMiddlewares...))
// Video API endpoints (parameterized routes need explicit request type middleware)
videoListMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.VideoListRequest)}, middlewares...)
videoRetrieveMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.VideoRetrieveRequest)}, middlewares...)
videoDownloadMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.VideoDownloadRequest)}, middlewares...)
videoDeleteMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.VideoDeleteRequest)}, middlewares...)
videoRemixMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.VideoRemixRequest)}, middlewares...)
r.GET("/v1/videos", lib.ChainMiddlewares(h.videoList, videoListMW...))
r.GET("/v1/videos/{video_id}", lib.ChainMiddlewares(h.videoRetrieve, videoRetrieveMW...))
r.GET("/v1/videos/{video_id}/content", lib.ChainMiddlewares(h.videoDownload, videoDownloadMW...))
r.DELETE("/v1/videos/{video_id}", lib.ChainMiddlewares(h.videoDelete, videoDeleteMW...))
r.POST("/v1/videos/{video_id}/remix", lib.ChainMiddlewares(h.videoRemix, videoRemixMW...))
// Batch API endpoints (parameterized routes need explicit request type middleware)
batchCreateMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.BatchCreateRequest)}, middlewares...)
batchListMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.BatchListRequest)}, middlewares...)
batchRetrieveMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.BatchRetrieveRequest)}, middlewares...)
batchCancelMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.BatchCancelRequest)}, middlewares...)
batchResultsMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.BatchResultsRequest)}, middlewares...)
r.POST("/v1/batches", lib.ChainMiddlewares(h.batchCreate, batchCreateMW...))
r.GET("/v1/batches", lib.ChainMiddlewares(h.batchList, batchListMW...))
r.GET("/v1/batches/{batch_id}", lib.ChainMiddlewares(h.batchRetrieve, batchRetrieveMW...))
r.POST("/v1/batches/{batch_id}/cancel", lib.ChainMiddlewares(h.batchCancel, batchCancelMW...))
r.GET("/v1/batches/{batch_id}/results", lib.ChainMiddlewares(h.batchResults, batchResultsMW...))
// File API endpoints (parameterized routes need explicit request type middleware)
fileUploadMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.FileUploadRequest)}, middlewares...)
fileListMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.FileListRequest)}, middlewares...)
fileRetrieveMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.FileRetrieveRequest)}, middlewares...)
fileDeleteMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.FileDeleteRequest)}, middlewares...)
fileContentMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.FileContentRequest)}, middlewares...)
r.POST("/v1/files", lib.ChainMiddlewares(h.fileUpload, fileUploadMW...))
r.GET("/v1/files", lib.ChainMiddlewares(h.fileList, fileListMW...))
r.GET("/v1/files/{file_id}", lib.ChainMiddlewares(h.fileRetrieve, fileRetrieveMW...))
r.DELETE("/v1/files/{file_id}", lib.ChainMiddlewares(h.fileDelete, fileDeleteMW...))
r.GET("/v1/files/{file_id}/content", lib.ChainMiddlewares(h.fileContent, fileContentMW...))
// Container API endpoints (parameterized routes need explicit request type middleware)
containerCreateMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.ContainerCreateRequest)}, middlewares...)
containerListMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.ContainerListRequest)}, middlewares...)
containerRetrieveMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.ContainerRetrieveRequest)}, middlewares...)
containerDeleteMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.ContainerDeleteRequest)}, middlewares...)
r.POST("/v1/containers", lib.ChainMiddlewares(h.containerCreate, containerCreateMW...))
r.GET("/v1/containers", lib.ChainMiddlewares(h.containerList, containerListMW...))
r.GET("/v1/containers/{container_id}", lib.ChainMiddlewares(h.containerRetrieve, containerRetrieveMW...))
r.DELETE("/v1/containers/{container_id}", lib.ChainMiddlewares(h.containerDelete, containerDeleteMW...))
// Container Files API endpoints (parameterized routes need explicit request type middleware)
containerFileCreateMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.ContainerFileCreateRequest)}, middlewares...)
containerFileListMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.ContainerFileListRequest)}, middlewares...)
containerFileRetrieveMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.ContainerFileRetrieveRequest)}, middlewares...)
containerFileContentMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.ContainerFileContentRequest)}, middlewares...)
containerFileDeleteMW := append([]schemas.BifrostHTTPMiddleware{createRequestTypeMiddleware(schemas.ContainerFileDeleteRequest)}, middlewares...)
r.POST("/v1/containers/{container_id}/files", lib.ChainMiddlewares(h.containerFileCreate, containerFileCreateMW...))
r.GET("/v1/containers/{container_id}/files", lib.ChainMiddlewares(h.containerFileList, containerFileListMW...))
r.GET("/v1/containers/{container_id}/files/{file_id}", lib.ChainMiddlewares(h.containerFileRetrieve, containerFileRetrieveMW...))
r.GET("/v1/containers/{container_id}/files/{file_id}/content", lib.ChainMiddlewares(h.containerFileContent, containerFileContentMW...))
r.DELETE("/v1/containers/{container_id}/files/{file_id}", lib.ChainMiddlewares(h.containerFileDelete, containerFileDeleteMW...))
}
// listModels handles GET /v1/models - Process list models requests
// If provider is not specified, lists all models from all configured providers
func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) {
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel() // Ensure cleanup on function exit
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
var resp *schemas.BifrostListModelsResponse
var bifrostErr *schemas.BifrostError
pageSize := 0
if pageSizeStr := ctx.QueryArgs().Peek("page_size"); len(pageSizeStr) > 0 {
if n, err := strconv.Atoi(string(pageSizeStr)); err == nil && n >= 0 {
pageSize = n
}
}
pageToken := string(ctx.QueryArgs().Peek("page_token"))
bifrostListModelsReq := &schemas.BifrostListModelsRequest{
Provider: schemas.ModelProvider(provider),
PageSize: pageSize,
PageToken: pageToken,
}
// Pass-through unknown query params for provider-specific features
extraParams := map[string]interface{}{}
for k, v := range ctx.QueryArgs().All() {
s := string(k)
if s != "provider" && s != "page_size" && s != "page_token" {
extraParams[s] = string(v)
}
}
if len(extraParams) > 0 {
bifrostListModelsReq.ExtraParams = extraParams
}
// If provider is empty, list all models from all providers
if provider == "" {
resp, bifrostErr = h.client.ListAllModels(bifrostCtx, bifrostListModelsReq)
} else {
resp, bifrostErr = h.client.ListModelsRequest(bifrostCtx, bifrostListModelsReq)
}
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
// Add pricing data to the response
if len(resp.Data) > 0 && h.config.ModelCatalog != nil {
for i, modelEntry := range resp.Data {
provider, modelName := schemas.ParseModelString(modelEntry.ID, "")
pricingEntry := h.config.ModelCatalog.GetPricingEntryForModel(modelName, provider)
if pricingEntry == nil && modelEntry.Alias != nil {
// Retry with alias
pricingEntry = h.config.ModelCatalog.GetPricingEntryForModel(*modelEntry.Alias, provider)
}
if pricingEntry != nil && modelEntry.Pricing == nil {
pricing := &schemas.Pricing{}
if pricingEntry.InputCostPerToken != nil {
pricing.Prompt = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.InputCostPerToken))
}
if pricingEntry.OutputCostPerToken != nil {
pricing.Completion = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.OutputCostPerToken))
}
if pricingEntry.InputCostPerImage != nil {
pricing.Image = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.InputCostPerImage))
}
if pricingEntry.CacheReadInputTokenCost != nil {
pricing.InputCacheRead = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.CacheReadInputTokenCost))
}
if pricingEntry.CacheCreationInputTokenCost != nil {
pricing.InputCacheWrite = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.CacheCreationInputTokenCost))
}
resp.Data[i].Pricing = pricing
}
}
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
// Send successful response
SendJSON(ctx, resp)
}
// prepareTextCompletionRequest prepares a BifrostTextCompletionRequest from the HTTP request body
func prepareTextCompletionRequest(ctx *fasthttp.RequestCtx) (*TextRequest, *schemas.BifrostTextCompletionRequest, error) {
var req TextRequest
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
return nil, nil, fmt.Errorf("invalid request format: %v", err)
}
provider, modelName := schemas.ParseModelString(req.Model, "")
if provider == "" || modelName == "" {
return nil, nil, fmt.Errorf("model should be in provider/model format")
}
fallbacks, err := parseFallbacks(req.Fallbacks)
if err != nil {
return nil, nil, err
}
if req.Prompt == nil || (req.Prompt.PromptStr == nil && req.Prompt.PromptArray == nil) {
return nil, nil, fmt.Errorf("prompt is required for text completion")
}
if req.TextCompletionParameters == nil {
req.TextCompletionParameters = &schemas.TextCompletionParameters{}
}
extraParams, err := extractExtraParams(ctx.PostBody(), textParamsKnownFields)
if err != nil {
logger.Warn("Failed to extract extra params: %v", err)
} else {
req.TextCompletionParameters.ExtraParams = extraParams
}
bifrostTextReq := &schemas.BifrostTextCompletionRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Input: req.Prompt,
Params: req.TextCompletionParameters,
Fallbacks: fallbacks,
}
return &req, bifrostTextReq, nil
}
// textCompletion handles POST /v1/completions - Process text completion requests
func (h *CompletionHandler) textCompletion(ctx *fasthttp.RequestCtx) {
req, bifrostTextReq, err := prepareTextCompletionRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
if req.Stream != nil && *req.Stream {
h.handleStreamingTextCompletion(ctx, bifrostTextReq, bifrostCtx, cancel)
return
}
// NOTE: these defers wont work as expected when a non-streaming request is cancelled on flight.
// valyala/fasthttp does not support cancelling a request in the middle of a request.
// This is a known issue of valyala/fasthttp. And will be fixed here once it is fixed upstream.
defer cancel() // Ensure cleanup on function exit
resp, bifrostErr := h.client.TextCompletionRequest(bifrostCtx, bifrostTextReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
// Send successful response
SendJSON(ctx, resp)
}
// prepareChatCompletionRequest prepares a BifrostChatRequest from a ChatRequest
func prepareChatCompletionRequest(ctx *fasthttp.RequestCtx) (*ChatRequest, *schemas.BifrostChatRequest, error) {
req := ChatRequest{
ChatParameters: &schemas.ChatParameters{},
}
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
return nil, nil, fmt.Errorf("invalid request format: %v", err)
}
// Create BifrostChatRequest directly using segregated structure
provider, modelName := schemas.ParseModelString(req.Model, "")
if provider == "" || modelName == "" {
return nil, nil, fmt.Errorf("model should be in provider/model format")
}
// Parse fallbacks using helper function
fallbacks, err := parseFallbacks(req.Fallbacks)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse fallbacks: %v", err)
}
if len(req.Messages) == 0 {
return nil, nil, fmt.Errorf("messages is required for chat completion")
}
// Extract extra params
if req.ChatParameters == nil {
req.ChatParameters = &schemas.ChatParameters{}
}
extraParams, err := extractExtraParams(ctx.PostBody(), chatParamsKnownFields)
if err != nil {
logger.Warn("Failed to extract extra params: %v", err)
} else {
// Handle max_tokens -> max_completion_tokens mapping after extracting extra params
// If max_completion_tokens is nil and max_tokens is present in extra params, map it
// This is to support the legacy max_tokens field, which is still used by some implementations.
if req.ChatParameters.MaxCompletionTokens == nil {
if maxTokensVal, exists := extraParams["max_tokens"]; exists {
// Type check and convert to int
// JSON numbers are unmarshaled as float64, so we need to handle that
var maxTokens int
if maxTokensFloat, ok := maxTokensVal.(float64); ok {
maxTokens = int(maxTokensFloat)
req.ChatParameters.MaxCompletionTokens = &maxTokens
// Remove max_tokens from extra params since we've mapped it
delete(extraParams, "max_tokens")
} else if maxTokensInt, ok := maxTokensVal.(int); ok {
req.ChatParameters.MaxCompletionTokens = &maxTokensInt
// Remove max_tokens from extra params since we've mapped it
delete(extraParams, "max_tokens")
}
}
}
req.ChatParameters.ExtraParams = extraParams
}
// Create segregated BifrostChatRequest
bifrostChatReq := &schemas.BifrostChatRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Input: req.Messages,
Params: req.ChatParameters,
Fallbacks: fallbacks,
}
return &req, bifrostChatReq, nil
}
// chatCompletion handles POST /v1/chat/completions - Process chat completion requests
func (h *CompletionHandler) chatCompletion(ctx *fasthttp.RequestCtx) {
req, bifrostChatReq, err := prepareChatCompletionRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
if effectiveStream(req.Stream) {
h.handleStreamingChatCompletion(ctx, bifrostChatReq, bifrostCtx, cancel)
return
}
defer cancel() // Ensure cleanup on function exit
// Complete the request
resp, bifrostErr := h.client.ChatCompletionRequest(bifrostCtx, bifrostChatReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
// Send successful response
SendJSON(ctx, resp)
}
// prepareResponsesRequest prepares a BifrostResponsesRequest from a ResponsesRequest
func prepareResponsesRequest(ctx *fasthttp.RequestCtx) (*ResponsesRequest, *schemas.BifrostResponsesRequest, error) {
var req ResponsesRequest
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
return nil, nil, fmt.Errorf("invalid request format: %v", err)
}
// Create BifrostResponsesRequest directly using segregated structure
provider, modelName := schemas.ParseModelString(req.Model, "")
if provider == "" || modelName == "" {
return nil, nil, fmt.Errorf("model should be in provider/model format")
}
// Parse fallbacks using helper function
fallbacks, err := parseFallbacks(req.Fallbacks)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse fallbacks: %v", err)
}
if len(req.Input.ResponsesRequestInputArray) == 0 && req.Input.ResponsesRequestInputStr == nil {
return nil, nil, fmt.Errorf("input is required for responses")
}
// Extract extra params
if req.ResponsesParameters == nil {
req.ResponsesParameters = &schemas.ResponsesParameters{}
}
extraParams, err := extractExtraParams(ctx.PostBody(), responsesParamsKnownFields)
if err != nil {
logger.Warn("Failed to extract extra params: %v", err)
} else {
req.ResponsesParameters.ExtraParams = extraParams
}
input := req.Input.ResponsesRequestInputArray
if input == nil {
input = []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{ContentStr: req.Input.ResponsesRequestInputStr},
},
}
}
// Create segregated BifrostResponsesRequest
bifrostResponsesReq := &schemas.BifrostResponsesRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Input: input,
Params: req.ResponsesParameters,
Fallbacks: fallbacks,
}
return &req, bifrostResponsesReq, nil
}
// responses handles POST /v1/responses - Process responses requests
func (h *CompletionHandler) responses(ctx *fasthttp.RequestCtx) {
req, bifrostResponsesReq, err := prepareResponsesRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
if effectiveStream(req.Stream) {
h.handleStreamingResponses(ctx, bifrostResponsesReq, bifrostCtx, cancel)
return
}
defer cancel() // Ensure cleanup on function exit
resp, bifrostErr := h.client.ResponsesRequest(bifrostCtx, bifrostResponsesReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
// Send successful response
SendJSON(ctx, resp)
}
// prepareEmbeddingRequest prepares a BifrostEmbeddingRequest from the HTTP request body
func prepareEmbeddingRequest(ctx *fasthttp.RequestCtx) (*EmbeddingRequest, *schemas.BifrostEmbeddingRequest, error) {
var req EmbeddingRequest
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
return nil, nil, fmt.Errorf("invalid request format: %v", err)
}
provider, modelName := schemas.ParseModelString(req.Model, "")
if provider == "" || modelName == "" {
return nil, nil, fmt.Errorf("model should be in provider/model format")
}
fallbacks, err := parseFallbacks(req.Fallbacks)
if err != nil {
return nil, nil, err
}
if req.Input == nil || (req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil) {
return nil, nil, fmt.Errorf("input is required for embeddings")
}
if req.EmbeddingParameters == nil {
req.EmbeddingParameters = &schemas.EmbeddingParameters{}
}
extraParams, err := extractExtraParams(ctx.PostBody(), embeddingParamsKnownFields)
if err != nil {
logger.Warn("Failed to extract extra params: %v", err)
} else {
req.EmbeddingParameters.ExtraParams = extraParams
}
bifrostEmbeddingReq := &schemas.BifrostEmbeddingRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Input: req.Input,
Params: req.EmbeddingParameters,
Fallbacks: fallbacks,
}
return &req, bifrostEmbeddingReq, nil
}
// embeddings handles POST /v1/embeddings - Process embeddings requests
func (h *CompletionHandler) embeddings(ctx *fasthttp.RequestCtx) {
_, bifrostEmbeddingReq, err := prepareEmbeddingRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.EmbeddingRequest(bifrostCtx, bifrostEmbeddingReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
// Send successful response
SendJSON(ctx, resp)
}
// prepareRerankRequest prepares a BifrostRerankRequest from the HTTP request body
func prepareRerankRequest(ctx *fasthttp.RequestCtx) (*RerankRequest, *schemas.BifrostRerankRequest, error) {
var req RerankRequest
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
return nil, nil, fmt.Errorf("invalid request format: %v", err)
}
// Parse model
provider, modelName := schemas.ParseModelString(req.Model, "")
if provider == "" || modelName == "" {
return nil, nil, fmt.Errorf("model should be in provider/model format")
}
// Parse fallbacks
fallbacks, err := parseFallbacks(req.Fallbacks)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse fallbacks: %v", err)
}
if strings.TrimSpace(req.Query) == "" {
return nil, nil, fmt.Errorf("query is required for rerank")
}
if len(req.Documents) == 0 {
return nil, nil, fmt.Errorf("documents are required for rerank")
}
for i, doc := range req.Documents {
if strings.TrimSpace(doc.Text) == "" {
return nil, nil, fmt.Errorf("document text is required for rerank at index %d", i)
}
}
// Extract extra params
if req.RerankParameters == nil {
req.RerankParameters = &schemas.RerankParameters{}
}
if req.RerankParameters.TopN != nil && *req.RerankParameters.TopN < 1 {
return nil, nil, fmt.Errorf("top_n must be at least 1")
}
extraParams, err := extractExtraParams(ctx.PostBody(), rerankParamsKnownFields)
if err != nil {
logger.Warn("Failed to extract extra params: %v", err)
} else {
req.RerankParameters.ExtraParams = extraParams
}
// Create BifrostRerankRequest
bifrostRerankReq := &schemas.BifrostRerankRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Query: req.Query,
Documents: req.Documents,
Params: req.RerankParameters,
Fallbacks: fallbacks,
}
return &req, bifrostRerankReq, nil
}
// rerank handles POST /v1/rerank - Process rerank requests
func (h *CompletionHandler) rerank(ctx *fasthttp.RequestCtx) {
_, bifrostRerankReq, err := prepareRerankRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.RerankRequest(bifrostCtx, bifrostRerankReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
// Send successful response
SendJSON(ctx, resp)
}
// prepareOCRRequest prepares a BifrostOCRRequest from the HTTP request body
func prepareOCRRequest(ctx *fasthttp.RequestCtx) (*OCRHandlerRequest, *schemas.BifrostOCRRequest, error) {
var req OCRHandlerRequest
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
return nil, nil, fmt.Errorf("invalid request format: %v", err)
}
// Parse model
provider, modelName := schemas.ParseModelString(req.Model, "")
if provider == "" || modelName == "" {
return nil, nil, fmt.Errorf("model should be in provider/model format")
}
// Parse fallbacks
fallbacks, err := parseFallbacks(req.Fallbacks)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse fallbacks: %v", err)
}
if req.Document.Type == "" {
return nil, nil, fmt.Errorf("document type is required for ocr")
}
if req.Document.Type == schemas.OCRDocumentTypeDocumentURL && (req.Document.DocumentURL == nil || *req.Document.DocumentURL == "") {
return nil, nil, fmt.Errorf("document_url is required when document type is document_url")
}
if req.Document.Type == schemas.OCRDocumentTypeImageURL && (req.Document.ImageURL == nil || *req.Document.ImageURL == "") {
return nil, nil, fmt.Errorf("image_url is required when document type is image_url")
}
// Extract extra params
if req.OCRParameters == nil {
req.OCRParameters = &schemas.OCRParameters{}
}
extraParams, err := extractExtraParams(ctx.PostBody(), ocrParamsKnownFields)
if err != nil {
logger.Warn("Failed to extract extra params: %v", err)
} else {
req.OCRParameters.ExtraParams = extraParams
}
// Create BifrostOCRRequest
bifrostOCRReq := &schemas.BifrostOCRRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
ID: req.ID,
Document: req.Document,
Params: req.OCRParameters,
Fallbacks: fallbacks,
}
return &req, bifrostOCRReq, nil
}
// ocr handles POST /v1/ocr - Process OCR requests
func (h *CompletionHandler) ocr(ctx *fasthttp.RequestCtx) {
_, bifrostOCRReq, err := prepareOCRRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.OCRRequest(bifrostCtx, bifrostOCRReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
// Send successful response
SendJSON(ctx, resp)
}
// prepareSpeechRequest prepares a BifrostSpeechRequest from the HTTP request body
func prepareSpeechRequest(ctx *fasthttp.RequestCtx) (*SpeechRequest, *schemas.BifrostSpeechRequest, error) {
var req SpeechRequest
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
return nil, nil, fmt.Errorf("invalid request format: %v", err)
}
provider, modelName := schemas.ParseModelString(req.Model, "")
if provider == "" || modelName == "" {
return nil, nil, fmt.Errorf("model should be in provider/model format")
}
fallbacks, err := parseFallbacks(req.Fallbacks)
if err != nil {
return nil, nil, err
}
if req.SpeechInput == nil || req.SpeechInput.Input == "" {
return nil, nil, fmt.Errorf("input is required for speech completion")
}
if req.SpeechParameters == nil || req.VoiceConfig == nil || (req.VoiceConfig.Voice == nil && len(req.VoiceConfig.MultiVoiceConfig) == 0) {
return nil, nil, fmt.Errorf("voice is required for speech completion")
}
if req.SpeechParameters == nil {
req.SpeechParameters = &schemas.SpeechParameters{}
}
extraParams, err := extractExtraParams(ctx.PostBody(), speechParamsKnownFields)
if err != nil {
logger.Warn("Failed to extract extra params: %v", err)
} else {
req.SpeechParameters.ExtraParams = extraParams
}
bifrostSpeechReq := &schemas.BifrostSpeechRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Input: req.SpeechInput,
Params: req.SpeechParameters,
Fallbacks: fallbacks,
}
return &req, bifrostSpeechReq, nil
}
// speech handles POST /v1/audio/speech - Process speech completion requests
func (h *CompletionHandler) speech(ctx *fasthttp.RequestCtx) {
req, bifrostSpeechReq, err := prepareSpeechRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
if req.StreamFormat != nil && *req.StreamFormat == "sse" {
h.handleStreamingSpeech(ctx, bifrostSpeechReq, bifrostCtx, cancel)
return
}
defer cancel() // Ensure cleanup on function exit
resp, bifrostErr := h.client.SpeechRequest(bifrostCtx, bifrostSpeechReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
// Preserve the attachment header through the large-response shortcut; the
// normal binary path sets this explicitly after the stream check.
if !(bifrostSpeechReq.Provider == schemas.Elevenlabs && req.WithTimestamps != nil && *req.WithTimestamps) {
bifrostCtx.SetValue(schemas.BifrostContextKeyLargeResponseContentDisposition, "attachment; filename=speech.mp3")
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
// Send successful response
// When with_timestamps is true, Elevenlabs returns base64 encoded audio
hasTimestamps := req.WithTimestamps != nil && *req.WithTimestamps
if bifrostSpeechReq.Provider == schemas.Elevenlabs && hasTimestamps {
ctx.Response.Header.Set("Content-Type", "application/json")
SendJSON(ctx, resp)
return
}
if resp.Audio == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Speech response is missing audio data")
return
}
ctx.Response.Header.Set("Content-Type", "audio/mpeg")
ctx.Response.Header.Set("Content-Disposition", "attachment; filename=speech.mp3")
ctx.Response.Header.Set("Content-Length", strconv.Itoa(len(resp.Audio)))
ctx.Response.SetBody(resp.Audio)
}
// prepareTranscriptionRequest prepares a BifrostTranscriptionRequest from a multipart form.
// Returns the request, whether streaming was requested, and any error.
func prepareTranscriptionRequest(ctx *fasthttp.RequestCtx) (*schemas.BifrostTranscriptionRequest, bool, error) {
form, err := ctx.MultipartForm()
if err != nil {
return nil, false, fmt.Errorf("failed to parse multipart form: %v", err)
}
modelValues := form.Value["model"]
if len(modelValues) == 0 || modelValues[0] == "" {
return nil, false, fmt.Errorf("model is required")
}
provider, modelName := schemas.ParseModelString(modelValues[0], "")
if provider == "" || modelName == "" {
return nil, false, fmt.Errorf("model should be in provider/model format")
}
fileHeaders := form.File["file"]
if len(fileHeaders) == 0 {
return nil, false, fmt.Errorf("file is required")
}
fileHeader := fileHeaders[0]
file, err := fileHeader.Open()
if err != nil {
return nil, false, fmt.Errorf("failed to open uploaded file: %v", err)
}
defer file.Close()
fileData, err := io.ReadAll(file)
if err != nil {
return nil, false, fmt.Errorf("failed to read uploaded file: %v", err)
}
transcriptionInput := &schemas.TranscriptionInput{
File: fileData,
Filename: fileHeader.Filename,
}
transcriptionParams := &schemas.TranscriptionParameters{}
if languageValues := form.Value["language"]; len(languageValues) > 0 && languageValues[0] != "" {
transcriptionParams.Language = &languageValues[0]
}
if promptValues := form.Value["prompt"]; len(promptValues) > 0 && promptValues[0] != "" {
transcriptionParams.Prompt = &promptValues[0]
}
if responseFormatValues := form.Value["response_format"]; len(responseFormatValues) > 0 && responseFormatValues[0] != "" {
transcriptionParams.ResponseFormat = &responseFormatValues[0]
}
if transcriptionParams.ExtraParams == nil {
transcriptionParams.ExtraParams = make(map[string]interface{})
}
for key, value := range form.Value {
if len(value) > 0 && value[0] != "" && !transcriptionParamsKnownFields[key] {
transcriptionParams.ExtraParams[key] = value[0]
}
}
stream := false
if streamValues := form.Value["stream"]; len(streamValues) > 0 && streamValues[0] == "true" {
stream = true
}
bifrostTranscriptionReq := &schemas.BifrostTranscriptionRequest{
Model: modelName,
Provider: schemas.ModelProvider(provider),
Input: transcriptionInput,
Params: transcriptionParams,
}
return bifrostTranscriptionReq, stream, nil
}
// transcription handles POST /v1/audio/transcriptions - Process transcription requests
func (h *CompletionHandler) transcription(ctx *fasthttp.RequestCtx) {
bifrostTranscriptionReq, stream, err := prepareTranscriptionRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
if stream {
h.handleStreamingTranscriptionRequest(ctx, bifrostTranscriptionReq, bifrostCtx, cancel)
return
}
defer cancel()
resp, bifrostErr := h.client.TranscriptionRequest(bifrostCtx, bifrostTranscriptionReq)
// Handle response
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
// Send successful response
SendJSON(ctx, resp)
}
// countTokens handles POST /v1/responses/input_tokens - Process count tokens requests
func (h *CompletionHandler) countTokens(ctx *fasthttp.RequestCtx) {
_, bifrostResponsesReq, err := prepareResponsesRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
response, bifrostErr := h.client.CountTokensRequest(bifrostCtx, bifrostResponsesReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
forwardProviderHeaders(ctx, response.ExtraFields.ProviderResponseHeaders)
// Send successful response
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, response)
}
// handleStreamingTextCompletion handles streaming text completion requests using Server-Sent Events (SSE)
func (h *CompletionHandler) handleStreamingTextCompletion(ctx *fasthttp.RequestCtx, req *schemas.BifrostTextCompletionRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) {
// Use the cancellable context from ConvertToBifrostContext
// See router.go for detailed explanation of why we need a cancellable context
getStream := func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return h.client.TextCompletionStreamRequest(bifrostCtx, req)
}
h.handleStreamingResponse(ctx, bifrostCtx, getStream, cancel)
}
// handleStreamingChatCompletion handles streaming chat completion requests using Server-Sent Events (SSE)
func (h *CompletionHandler) handleStreamingChatCompletion(ctx *fasthttp.RequestCtx, req *schemas.BifrostChatRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) {
// Use the cancellable context from ConvertToBifrostContext
// See router.go for detailed explanation of why we need a cancellable context
getStream := func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return h.client.ChatCompletionStreamRequest(bifrostCtx, req)
}
h.handleStreamingResponse(ctx, bifrostCtx, getStream, cancel)
}
// handleStreamingResponses handles streaming responses requests using Server-Sent Events (SSE)
func (h *CompletionHandler) handleStreamingResponses(ctx *fasthttp.RequestCtx, req *schemas.BifrostResponsesRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) {
// Use the cancellable context from ConvertToBifrostContext
// See router.go for detailed explanation of why we need a cancellable context
getStream := func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return h.client.ResponsesStreamRequest(bifrostCtx, req)
}
h.handleStreamingResponse(ctx, bifrostCtx, getStream, cancel)
}
// handleStreamingSpeech handles streaming speech requests using Server-Sent Events (SSE)
func (h *CompletionHandler) handleStreamingSpeech(ctx *fasthttp.RequestCtx, req *schemas.BifrostSpeechRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) {
// Use the cancellable context from ConvertToBifrostContext
// See router.go for detailed explanation of why we need a cancellable context
getStream := func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return h.client.SpeechStreamRequest(bifrostCtx, req)
}
h.handleStreamingResponse(ctx, bifrostCtx, getStream, cancel)
}
// handleStreamingTranscriptionRequest handles streaming transcription requests using Server-Sent Events (SSE)
func (h *CompletionHandler) handleStreamingTranscriptionRequest(ctx *fasthttp.RequestCtx, req *schemas.BifrostTranscriptionRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) {
// Use the cancellable context from ConvertToBifrostContext
// See router.go for detailed explanation of why we need a cancellable context
getStream := func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return h.client.TranscriptionStreamRequest(bifrostCtx, req)
}
h.handleStreamingResponse(ctx, bifrostCtx, getStream, cancel)
}
// handleStreamingResponse is a generic function to handle streaming responses using Server-Sent Events (SSE)
// The cancel function is called ONLY when client disconnects are detected via write errors.
// Bifrost handles cleanup internally for normal completion and errors, so we only cancel
// upstream streams when write errors indicate the client has disconnected.
func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, getStream func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError), cancel context.CancelFunc) {
// Get the streaming channel — called BEFORE setting SSE headers so that
// provider errors return proper HTTP status codes + JSON content type.
stream, bifrostErr := getStream()
if bifrostErr != nil {
cancel()
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
// SSE headers set only after successful stream setup
ctx.SetContentType("text/event-stream")
ctx.Response.Header.Set("Cache-Control", "no-cache")
ctx.Response.Header.Set("Connection", "keep-alive")
// Forward provider response headers stored in context by streaming handlers
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
forwardProviderHeaders(ctx, headers)
}
// Signal to tracing middleware that trace completion should be deferred
// The streaming callback will complete the trace after the stream ends
ctx.SetUserValue(schemas.BifrostContextKeyDeferTraceCompletion, true)
// Pre-allocate atomic.Value slot for the transport post-hook completer.
// TransportInterceptorMiddleware stores the completer into this slot after next(ctx)
// returns. The goroutine reads from the closure-captured pointer, avoiding any ctx
// access after the handler returns (fasthttp recycles RequestCtx).
var completerSlot atomic.Value
ctx.SetUserValue(schemas.BifrostContextKeyTransportPostHookCompleter, &completerSlot)
// Get the trace completer function for use in the streaming callback.
// Signature: func([]schemas.PluginLogEntry) — accepts transport plugin logs so it
// never needs to read from ctx.UserValue (ctx may be recycled).
traceCompleter, _ := ctx.UserValue(schemas.BifrostContextKeyTraceCompleter).(func([]schemas.PluginLogEntry))
// Get stream chunk interceptor for plugin hooks
interceptor := h.config.GetStreamChunkInterceptor()
var httpReq *schemas.HTTPRequest
if interceptor != nil {
httpReq = lib.BuildHTTPRequestFromFastHTTP(ctx)
}
// Use SSEStreamReader to bypass fasthttp's internal pipe (fasthttputil.PipeConns)
// which batches multiple SSE events into single TCP segments.
// Each event is delivered individually via a channel, ensuring one HTTP chunk per event.
reader := lib.NewSSEStreamReader()
ctx.Response.SetBodyStream(reader, -1)
// Producer goroutine: processes the stream channel, formats SSE events, sends to reader
go func() {
var transportLogs []schemas.PluginLogEntry
completerRan := false
// runCompleter invokes the transport post-hook completer at most once.
// sendSSEOnError=true emits plugin errors as SSE "event: error" frames so the
// client sees them (happy path, before [DONE]); =false logs server-side only
// (early-return / defer fallback, after stream termination).
runCompleter := func(sendSSEOnError bool) {
if completerRan {
return
}
// Bounded wait for TransportInterceptorMiddleware to publish the completer.
// It calls slot.Store after next(ctx) returns, which races with this goroutine
// on fast/empty streams. 100ms is ample — the store runs a few instructions
// after the handler returns.
var loaded any
deadline := time.Now().Add(100 * time.Millisecond)
for {
if loaded = completerSlot.Load(); loaded != nil {
break
}
if time.Now().After(deadline) {
break
}
time.Sleep(time.Millisecond)
}
if loaded == nil {
return
}
postHookCompleter, ok := loaded.(func() ([]schemas.PluginLogEntry, error))
if !ok {
return
}
completerRan = true
logs, err := postHookCompleter()
if err != nil {
if sendSSEOnError {
errorJSON, marshalErr := sonic.Marshal(map[string]string{"error": err.Error()})
if marshalErr == nil {
reader.SendError(errorJSON)
}
} else {
logger.Warn("transport post-hook failed after stream terminated: %v", err)
}
}
transportLogs = logs
}
defer func() {
schemas.ReleaseHTTPRequest(httpReq)
// Fallback: on early-return paths (client disconnect, interceptor error)
// we never reached the pre-[DONE] invocation, so run it now. Any error is
// logged server-side only — the stream is already closing.
runCompleter(false)
reader.Done()
// Complete the trace after streaming finishes, passing transport plugin logs.
// This ensures all spans (including llm.call) are properly ended before the trace is sent to OTEL.
if traceCompleter != nil {
traceCompleter(transportLogs)
}
}()
var includeEventType bool
var skipDoneMarker bool
// Process streaming responses
for chunk := range stream {
if chunk == nil {
continue
}
includeEventType = false
if chunk.BifrostResponsesStreamResponse != nil ||
chunk.BifrostImageGenerationStreamResponse != nil ||
(chunk.BifrostError != nil && (chunk.BifrostError.ExtraFields.RequestType == schemas.ResponsesStreamRequest || chunk.BifrostError.ExtraFields.RequestType == schemas.ImageGenerationStreamRequest || chunk.BifrostError.ExtraFields.RequestType == schemas.ImageEditStreamRequest)) {
includeEventType = true
}
// Image generation streams don't use [DONE] marker
if chunk.BifrostImageGenerationStreamResponse != nil {
skipDoneMarker = true
}
// Allow plugins to modify/filter the chunk via StreamChunkInterceptor
if interceptor != nil {
var err error
chunk, err = interceptor.InterceptChunk(bifrostCtx, httpReq, chunk)
if err != nil {
if chunk == nil {
errorJSON, marshalErr := sonic.Marshal(map[string]string{"error": err.Error()})
if marshalErr != nil {
cancel() // Payload invalid
return
}
// Return error event and stop streaming
reader.SendError(errorJSON)
cancel()
return
}
// Else add warn log and continue
logger.Warn("%v", err)
}
if chunk == nil {
// Skip chunk if plugin wants to skip it
continue
}
}
// Convert response to JSON
chunkJSON, err := sonic.Marshal(chunk)
if err != nil {
logger.Warn("Failed to marshal streaming response: %v", err)
continue
}
// Format and send as SSE data
var eventType string
if includeEventType {
// For responses and image gen API, use OpenAI-compatible format with event line
if chunk.BifrostResponsesStreamResponse != nil {
eventType = string(chunk.BifrostResponsesStreamResponse.Type)
} else if chunk.BifrostImageGenerationStreamResponse != nil {
eventType = string(chunk.BifrostImageGenerationStreamResponse.Type)
} else if chunk.BifrostError != nil {
eventType = string(schemas.ResponsesStreamResponseTypeError)
}
}
if !reader.SendEvent(eventType, chunkJSON) {
cancel() // Client disconnected, cancel upstream stream
return
}
}
// Run the transport post-hook completer BEFORE the terminal [DONE] marker so
// that any plugin error can still be delivered to the client as an SSE event.
// Post-hooks emitted after [DONE] reach the wire but most clients stop reading
// once they see [DONE], so they'd be silently dropped.
runCompleter(true)
if !includeEventType && !skipDoneMarker {
// Send the [DONE] marker to indicate the end of the stream (only for non-responses/image-gen APIs)
if !reader.SendDone() {
cancel()
return
}
}
// Note: OpenAI responses API doesn't use [DONE] marker, it ends when the stream closes
// Stream completed normally, Bifrost handles cleanup internally
cancel()
}()
}
// validateAudioFile checks if the file size and format are valid
func (h *CompletionHandler) validateAudioFile(fileHeader *multipart.FileHeader) error {
// Check file size
if fileHeader.Size > MaxFileSize {
return fmt.Errorf("file size exceeds maximum limit of %d MB", MaxFileSize/1024/1024)
}
// Get file extension
ext := strings.ToLower(filepath.Ext(fileHeader.Filename))
// Check file extension
validExtensions := map[string]bool{
".flac": true,
".mp3": true,
".mp4": true,
".mpeg": true,
".mpga": true,
".m4a": true,
".ogg": true,
".wav": true,
".webm": true,
}
if !validExtensions[ext] {
return fmt.Errorf("unsupported file format: %s. Supported formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, webm", ext)
}
// Open file to check MIME type
file, err := fileHeader.Open()
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
// Read first 512 bytes for MIME type detection
buffer := make([]byte, 512)
_, err = file.Read(buffer)
if err != nil && err != io.EOF {
return fmt.Errorf("failed to read file header: %v", err)
}
// Check MIME type
mimeType := http.DetectContentType(buffer)
validMimeTypes := map[string]bool{
// Primary MIME types
AudioMimeMP3: true, // Covers MP3, MPEG, MPGA
AudioMimeMP4: true,
AudioMimeM4A: true,
AudioMimeOGG: true,
AudioMimeWAV: true,
AudioMimeWEBM: true,
AudioMimeFLAC: true,
AudioMimeFLAC2: true,
// Alternative MIME types
"audio/mpeg3": true,
"audio/x-wav": true,
"audio/vnd.wave": true,
"audio/x-mpeg": true,
"audio/x-mpeg3": true,
"audio/x-mpg": true,
"audio/x-mpegaudio": true,
}
if !validMimeTypes[mimeType] {
return fmt.Errorf("invalid file type: %s. Supported audio formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, webm", mimeType)
}
// Reset file pointer for subsequent reads
_, err = file.Seek(0, 0)
if err != nil {
return fmt.Errorf("failed to reset file pointer: %v", err)
}
return nil
}
// prepareImageGenerationRequest prepares a BifrostImageGenerationRequest from the HTTP request body
func prepareImageGenerationRequest(ctx *fasthttp.RequestCtx) (*ImageGenerationHTTPRequest, *schemas.BifrostImageGenerationRequest, error) {
var req ImageGenerationHTTPRequest
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
return nil, nil, fmt.Errorf("invalid request format: %v", err)
}
provider, modelName := schemas.ParseModelString(req.Model, "")
if provider == "" || modelName == "" {
return nil, nil, fmt.Errorf("model should be in provider/model format")
}
if req.ImageGenerationInput == nil || req.Prompt == "" {
return nil, nil, fmt.Errorf("prompt cannot be empty")
}
if req.ImageGenerationParameters == nil {
req.ImageGenerationParameters = &schemas.ImageGenerationParameters{}
}
extraParams, err := extractExtraParams(ctx.PostBody(), imageGenerationParamsKnownFields)
if err != nil {
logger.Warn("Failed to extract extra params: %v", err)
} else {
req.ImageGenerationParameters.ExtraParams = extraParams
}
fallbacks, err := parseFallbacks(req.Fallbacks)
if err != nil {
return nil, nil, err
}
bifrostReq := &schemas.BifrostImageGenerationRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Input: req.ImageGenerationInput,
Params: req.ImageGenerationParameters,
Fallbacks: fallbacks,
}
return &req, bifrostReq, nil
}
// imageGeneration handles POST /v1/images/generations - Processes image generation requests
func (h *CompletionHandler) imageGeneration(ctx *fasthttp.RequestCtx) {
req, bifrostReq, err := prepareImageGenerationRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
cancel()
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
// Handle streaming image generation
if req.BifrostParams.Stream != nil && *req.BifrostParams.Stream {
h.handleStreamingImageGeneration(ctx, bifrostReq, bifrostCtx, cancel)
return
}
defer cancel()
// Execute request
resp, bifrostErr := h.client.ImageGenerationRequest(bifrostCtx, bifrostReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// handleStreamingImageGeneration handles streaming image generation requests using Server-Sent Events (SSE)
func (h *CompletionHandler) handleStreamingImageGeneration(ctx *fasthttp.RequestCtx, req *schemas.BifrostImageGenerationRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) {
// Use the cancellable context from ConvertToBifrostContext
// See router.go for detailed explanation of why we need a cancellable context
// Pass the context directly instead of copying to avoid copying lock values
getStream := func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return h.client.ImageGenerationStreamRequest(bifrostCtx, req)
}
h.handleStreamingResponse(ctx, bifrostCtx, getStream, cancel)
}
// prepareImageEditRequest prepares a BifrostImageEditRequest from a multipart form
func prepareImageEditRequest(ctx *fasthttp.RequestCtx) (*ImageEditHTTPRequest, *schemas.BifrostImageEditRequest, error) {
var req ImageEditHTTPRequest
form, err := ctx.MultipartForm()
if err != nil {
return nil, nil, fmt.Errorf("failed to parse multipart form: %v", err)
}
modelValues := form.Value["model"]
if len(modelValues) == 0 || modelValues[0] == "" {
return nil, nil, fmt.Errorf("model is required")
}
req.Model = modelValues[0]
provider, modelName := schemas.ParseModelString(req.Model, "")
if provider == "" || modelName == "" {
return nil, nil, fmt.Errorf("model should be in provider/model format")
}
var editType string
if typeValues := form.Value["type"]; len(typeValues) > 0 && typeValues[0] != "" {
editType = typeValues[0]
}
promptValues := form.Value["prompt"]
var imageFiles []*multipart.FileHeader
if imageFilesArray := form.File["image[]"]; len(imageFilesArray) > 0 {
imageFiles = imageFilesArray
} else if imageFilesSingle := form.File["image"]; len(imageFilesSingle) > 0 {
imageFiles = imageFilesSingle
}
if len(imageFiles) == 0 {
return nil, nil, fmt.Errorf("at least one image is required")
}
images := make([]schemas.ImageInput, 0, len(imageFiles))
for _, fh := range imageFiles {
f, err := fh.Open()
if err != nil {
return nil, nil, fmt.Errorf("failed to open uploaded file: %v", err)
}
fileData, err := io.ReadAll(f)
f.Close()
if err != nil {
return nil, nil, fmt.Errorf("failed to read uploaded file: %v", err)
}
images = append(images, schemas.ImageInput{Image: fileData})
}
prompt := ""
if len(promptValues) > 0 && promptValues[0] != "" {
prompt = promptValues[0]
}
req.ImageEditInput = &schemas.ImageEditInput{
Images: images,
Prompt: prompt,
}
req.ImageEditParameters = &schemas.ImageEditParameters{}
if nValues := form.Value["n"]; len(nValues) > 0 && nValues[0] != "" {
n, err := strconv.Atoi(nValues[0])
if err != nil {
return nil, nil, fmt.Errorf("invalid n value: %v", err)
}
req.ImageEditParameters.N = &n
}
if backgroundValues := form.Value["background"]; len(backgroundValues) > 0 && backgroundValues[0] != "" {
req.ImageEditParameters.Background = &backgroundValues[0]
}
if inputFidelityValues := form.Value["input_fidelity"]; len(inputFidelityValues) > 0 && inputFidelityValues[0] != "" {
req.ImageEditParameters.InputFidelity = &inputFidelityValues[0]
}
if partialImagesValues := form.Value["partial_images"]; len(partialImagesValues) > 0 && partialImagesValues[0] != "" {
partialImages, err := strconv.Atoi(partialImagesValues[0])
if err != nil {
return nil, nil, fmt.Errorf("invalid partial_images value: %v", err)
}
req.ImageEditParameters.PartialImages = &partialImages
}
if sizeValues := form.Value["size"]; len(sizeValues) > 0 && sizeValues[0] != "" {
req.ImageEditParameters.Size = &sizeValues[0]
}
if qualityValues := form.Value["quality"]; len(qualityValues) > 0 && qualityValues[0] != "" {
req.ImageEditParameters.Quality = &qualityValues[0]
}
if outputFormatValues := form.Value["output_format"]; len(outputFormatValues) > 0 && outputFormatValues[0] != "" {
req.ImageEditParameters.OutputFormat = &outputFormatValues[0]
}
if numInferenceStepsValues := form.Value["num_inference_steps"]; len(numInferenceStepsValues) > 0 && numInferenceStepsValues[0] != "" {
numInferenceSteps, err := strconv.Atoi(numInferenceStepsValues[0])
if err != nil {
return nil, nil, fmt.Errorf("invalid num_inference_steps value: %v", err)
}
req.ImageEditParameters.NumInferenceSteps = &numInferenceSteps
}
if seedValues := form.Value["seed"]; len(seedValues) > 0 && seedValues[0] != "" {
seed, err := strconv.Atoi(seedValues[0])
if err != nil {
return nil, nil, fmt.Errorf("invalid seed value: %v", err)
}
req.ImageEditParameters.Seed = &seed
}
if outputCompressionValues := form.Value["output_compression"]; len(outputCompressionValues) > 0 && outputCompressionValues[0] != "" {
outputCompression, err := strconv.Atoi(outputCompressionValues[0])
if err != nil {
return nil, nil, fmt.Errorf("invalid output_compression value: %v", err)
}
req.ImageEditParameters.OutputCompression = &outputCompression
}
if negativePromptValues := form.Value["negative_prompt"]; len(negativePromptValues) > 0 && negativePromptValues[0] != "" {
req.ImageEditParameters.NegativePrompt = &negativePromptValues[0]
}
if responseFormatValues := form.Value["response_format"]; len(responseFormatValues) > 0 && responseFormatValues[0] != "" {
req.ImageEditParameters.ResponseFormat = &responseFormatValues[0]
}
if userValues := form.Value["user"]; len(userValues) > 0 && userValues[0] != "" {
req.ImageEditParameters.User = &userValues[0]
}
if editType != "" {
req.ImageEditParameters.Type = &editType
}
if maskFiles := form.File["mask"]; len(maskFiles) > 0 {
maskFile := maskFiles[0]
f, err := maskFile.Open()
if err != nil {
return nil, nil, fmt.Errorf("failed to open mask file: %v", err)
}
maskData, err := io.ReadAll(f)
f.Close()
if err != nil {
return nil, nil, fmt.Errorf("failed to read mask file: %v", err)
}
req.ImageEditParameters.Mask = maskData
}
if req.ImageEditParameters.ExtraParams == nil {
req.ImageEditParameters.ExtraParams = make(map[string]interface{})
}
for key, value := range form.Value {
if len(value) > 0 && value[0] != "" && !imageEditParamsKnownFields[key] {
req.ImageEditParameters.ExtraParams[key] = value[0]
}
}
if fallbackValues := form.Value["fallbacks"]; len(fallbackValues) > 0 {
req.Fallbacks = fallbackValues
}
if streamValues := form.Value["stream"]; len(streamValues) > 0 && streamValues[0] != "" {
stream := streamValues[0] == "true"
req.Stream = &stream
}
fallbacks, err := parseFallbacks(req.Fallbacks)
if err != nil {
return nil, nil, err
}
bifrostReq := &schemas.BifrostImageEditRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Input: req.ImageEditInput,
Params: req.ImageEditParameters,
Fallbacks: fallbacks,
}
return &req, bifrostReq, nil
}
// imageEdit handles POST /v1/images/edits - Processes image edit requests
func (h *CompletionHandler) imageEdit(ctx *fasthttp.RequestCtx) {
req, bifrostReq, err := prepareImageEditRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
// Handle streaming image edit
if req.Stream != nil && *req.Stream {
h.handleStreamingImageEditRequest(ctx, bifrostReq, bifrostCtx, cancel)
return
}
defer cancel()
// Execute request
resp, bifrostErr := h.client.ImageEditRequest(bifrostCtx, bifrostReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// handleStreamingImageEditRequest handles streaming image edit requests using Server-Sent Events (SSE)
func (h *CompletionHandler) handleStreamingImageEditRequest(ctx *fasthttp.RequestCtx, req *schemas.BifrostImageEditRequest, bifrostCtx *schemas.BifrostContext, cancel context.CancelFunc) {
// Use the cancellable context from ConvertToBifrostContext
// See router.go for detailed explanation of why we need a cancellable context
getStream := func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return h.client.ImageEditStreamRequest(bifrostCtx, req)
}
h.handleStreamingResponse(ctx, bifrostCtx, getStream, cancel)
}
// prepareImageVariationRequest prepares a BifrostImageVariationRequest from a multipart form
func prepareImageVariationRequest(ctx *fasthttp.RequestCtx) (*schemas.BifrostImageVariationRequest, error) {
rawBody := ctx.Request.Body()
form, err := ctx.MultipartForm()
if err != nil {
return nil, fmt.Errorf("failed to parse multipart form: %v", err)
}
modelValues := form.Value["model"]
if len(modelValues) == 0 || modelValues[0] == "" {
return nil, fmt.Errorf("model is required")
}
provider, modelName := schemas.ParseModelString(modelValues[0], "")
if provider == "" || modelName == "" {
return nil, fmt.Errorf("model should be in provider/model format")
}
var imageFiles []*multipart.FileHeader
if imageFilesArray := form.File["image[]"]; len(imageFilesArray) > 0 {
imageFiles = imageFilesArray
} else if imageFilesSingle := form.File["image"]; len(imageFilesSingle) > 0 {
imageFiles = imageFilesSingle
}
if len(imageFiles) == 0 {
return nil, fmt.Errorf("at least one image is required")
}
images := make([][]byte, 0, len(imageFiles))
for _, fileHeader := range imageFiles {
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open uploaded file: %v", err)
}
fileData, err := io.ReadAll(file)
file.Close()
if err != nil {
return nil, fmt.Errorf("failed to read uploaded file: %v", err)
}
images = append(images, fileData)
}
variationInput := &schemas.ImageVariationInput{
Image: schemas.ImageInput{
Image: images[0],
},
}
variationParams := &schemas.ImageVariationParameters{}
if nValues := form.Value["n"]; len(nValues) > 0 && nValues[0] != "" {
n, err := strconv.Atoi(nValues[0])
if err != nil {
return nil, fmt.Errorf("invalid n value: %v", err)
}
variationParams.N = &n
}
if responseFormatValues := form.Value["response_format"]; len(responseFormatValues) > 0 && responseFormatValues[0] != "" {
variationParams.ResponseFormat = &responseFormatValues[0]
}
if sizeValues := form.Value["size"]; len(sizeValues) > 0 && sizeValues[0] != "" {
variationParams.Size = &sizeValues[0]
}
if userValues := form.Value["user"]; len(userValues) > 0 && userValues[0] != "" {
variationParams.User = &userValues[0]
}
if variationParams.ExtraParams == nil {
variationParams.ExtraParams = make(map[string]interface{})
}
if len(images) > 1 {
variationParams.ExtraParams["images"] = images[1:]
}
for key, value := range form.Value {
if len(value) > 0 && value[0] != "" && !imageVariationParamsKnownFields[key] {
variationParams.ExtraParams[key] = value[0]
}
}
if fallbackValues := form.Value["fallbacks"]; len(fallbackValues) > 0 {
fallbacks, err := parseFallbacks(fallbackValues)
if err != nil {
return nil, err
}
return &schemas.BifrostImageVariationRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Input: variationInput,
Params: variationParams,
Fallbacks: fallbacks,
RawRequestBody: rawBody,
}, nil
}
return &schemas.BifrostImageVariationRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Input: variationInput,
Params: variationParams,
RawRequestBody: rawBody,
}, nil
}
// imageVariation handles POST /v1/images/variations - Processes image variation requests
func (h *CompletionHandler) imageVariation(ctx *fasthttp.RequestCtx) {
bifrostReq, err := prepareImageVariationRequest(ctx)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
// Execute request (no streaming for variations)
resp, bifrostErr := h.client.ImageVariationRequest(bifrostCtx, bifrostReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// videoGeneration handles POST /v1/videos - Processes video generation requests
func (h *CompletionHandler) videoGeneration(ctx *fasthttp.RequestCtx) {
var req VideoGenerationRequest
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
return
}
// Create BifrostVideoGenerationRequest directly using segregated structure
provider, modelName := schemas.ParseModelString(req.Model, "")
if provider == "" || modelName == "" {
SendError(ctx, fasthttp.StatusBadRequest, "model should be in provider/model format")
return
}
fallbacks, err := parseFallbacks(req.Fallbacks)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, err.Error())
return
}
if req.VideoGenerationInput == nil || req.Prompt == "" {
SendError(ctx, fasthttp.StatusBadRequest, "prompt cannot be empty")
return
}
if req.VideoGenerationParameters == nil {
req.VideoGenerationParameters = &schemas.VideoGenerationParameters{}
}
extraParams, err := extractExtraParams(ctx.PostBody(), videoGenerationParamsKnownFields)
if err != nil {
logger.Warn("Failed to extract extra params: %v", err)
} else {
req.VideoGenerationParameters.ExtraParams = extraParams
}
bifrostReq := &schemas.BifrostVideoGenerationRequest{
Provider: schemas.ModelProvider(provider),
Model: modelName,
Input: req.VideoGenerationInput,
Params: req.VideoGenerationParameters,
Fallbacks: fallbacks,
}
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
if bifrostCtx == nil {
cancel()
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
defer cancel()
resp, bifrostErr := h.client.VideoGenerationRequest(bifrostCtx, bifrostReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// videoRetrieve handles GET /v1/videos/{video_id} - Retrieve a video generation job
func (h *CompletionHandler) videoRetrieve(ctx *fasthttp.RequestCtx) {
// Get video ID from URL parameter
videoID, ok := ctx.UserValue("video_id").(string)
if !ok || videoID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "video_id is required")
return
}
// Decode URL-encoded video ID
decodedID, err := url.PathUnescape(videoID)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, "invalid video_id encoding")
return
}
idParts := strings.SplitN(decodedID, ":", 2)
if len(idParts) != 2 || idParts[0] == "" || idParts[1] == "" {
SendError(ctx, fasthttp.StatusBadRequest, "video_id must be in id:provider format")
return
}
provider := schemas.ModelProvider(idParts[1])
// Build Bifrost video retrieve request
bifrostVideoReq := &schemas.BifrostVideoRetrieveRequest{
Provider: provider,
ID: idParts[0],
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.VideoRetrieveRequest(bifrostCtx, bifrostVideoReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// videoDownload handles GET /v1/videos/{video_id}/content - Download video content
func (h *CompletionHandler) videoDownload(ctx *fasthttp.RequestCtx) {
// Get video ID from URL parameter
videoID, ok := ctx.UserValue("video_id").(string)
if !ok || videoID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "video_id is required")
return
}
// Decode URL-encoded video ID
decodedID, err := url.PathUnescape(videoID)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, "invalid video_id encoding")
return
}
idParts := strings.SplitN(decodedID, ":", 2)
if len(idParts) != 2 || idParts[0] == "" || idParts[1] == "" {
SendError(ctx, fasthttp.StatusBadRequest, "video_id must be in id:provider format")
return
}
// take variant from query parameters
variant := string(ctx.QueryArgs().Peek("variant"))
// Build Bifrost video download request
bifrostVideoReq := &schemas.BifrostVideoDownloadRequest{
Provider: schemas.ModelProvider(idParts[1]),
ID: idParts[0],
}
if variant != "" {
bifrostVideoReq.Variant = schemas.Ptr(schemas.VideoDownloadVariant(variant))
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.VideoDownloadRequest(bifrostCtx, bifrostVideoReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
// Set appropriate headers for binary download
ctx.Response.Header.Set("Content-Type", resp.ContentType)
ctx.Response.Header.Set("Content-Length", strconv.Itoa(len(resp.Content)))
ctx.Response.SetBody(resp.Content)
}
// videoList handles GET /v1/videos - List video generation jobs
func (h *CompletionHandler) videoList(ctx *fasthttp.RequestCtx) {
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost video list request
bifrostVideoReq := &schemas.BifrostVideoListRequest{
Provider: schemas.ModelProvider(provider),
}
// Parse optional query parameters
if afterBytes := ctx.QueryArgs().Peek("after"); len(afterBytes) > 0 {
after := string(afterBytes)
bifrostVideoReq.After = &after
}
if limitBytes := ctx.QueryArgs().Peek("limit"); len(limitBytes) > 0 {
limit, err := strconv.Atoi(string(limitBytes))
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, "invalid limit parameter")
return
}
bifrostVideoReq.Limit = &limit
}
if orderBytes := ctx.QueryArgs().Peek("order"); len(orderBytes) > 0 {
order := string(orderBytes)
bifrostVideoReq.Order = &order
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.VideoListRequest(bifrostCtx, bifrostVideoReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// videoDelete handles DELETE /v1/videos/{video_id} - Delete a video generation job
func (h *CompletionHandler) videoDelete(ctx *fasthttp.RequestCtx) {
// Get video ID from URL parameter
videoID, ok := ctx.UserValue("video_id").(string)
if !ok || videoID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "video_id is required")
return
}
// Decode URL-encoded video ID
decodedID, err := url.PathUnescape(videoID)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, "invalid video_id encoding")
return
}
idParts := strings.SplitN(decodedID, ":", 2)
if len(idParts) != 2 || idParts[0] == "" || idParts[1] == "" {
SendError(ctx, fasthttp.StatusBadRequest, "video_id must be in id:provider format")
return
}
// Build Bifrost video delete request
bifrostVideoReq := &schemas.BifrostVideoDeleteRequest{
Provider: schemas.ModelProvider(idParts[1]),
ID: idParts[0],
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.VideoDeleteRequest(bifrostCtx, bifrostVideoReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// videoRemix handles POST /v1/videos/{video_id}/remix - Remix an existing video
func (h *CompletionHandler) videoRemix(ctx *fasthttp.RequestCtx) {
// Get video ID from URL parameter
videoID, ok := ctx.UserValue("video_id").(string)
if !ok || videoID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "video_id is required")
return
}
// Decode URL-encoded video ID
decodedID, err := url.PathUnescape(videoID)
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, "invalid video_id encoding")
return
}
idParts := strings.SplitN(decodedID, ":", 2)
if len(idParts) != 2 || idParts[0] == "" || idParts[1] == "" {
SendError(ctx, fasthttp.StatusBadRequest, "video_id must be in id:provider format")
return
}
// Parse request body
var req VideoRemixRequest
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
return
}
// Validate prompt
if req.VideoGenerationInput == nil || req.VideoGenerationInput.Prompt == "" {
SendError(ctx, fasthttp.StatusBadRequest, "prompt is required")
return
}
provider := schemas.ModelProvider(idParts[1])
extraParams, err := extractExtraParams(ctx.PostBody(), videoRemixParamsKnownFields)
if err != nil {
logger.Warn("Failed to extract extra params: %v", err)
} else {
req.ExtraParams = extraParams
}
// Build Bifrost video remix request
bifrostVideoReq := &schemas.BifrostVideoRemixRequest{
Provider: provider,
ID: idParts[0],
Input: &schemas.VideoGenerationInput{
Prompt: req.VideoGenerationInput.Prompt,
},
ExtraParams: req.ExtraParams,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.VideoRemixRequest(bifrostCtx, bifrostVideoReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// batchCreate handles POST /v1/batches - Create a new batch job
func (h *CompletionHandler) batchCreate(ctx *fasthttp.RequestCtx) {
var req BatchCreateRequest
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
return
}
// Parse provider from model string
provider, modelName := schemas.ParseModelString(req.Model, "")
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "model should be in provider/model format or provider must be specified")
return
}
// Validate that at least one of InputFileID or Requests is provided
if req.InputFileID == "" && len(req.Requests) == 0 {
SendError(ctx, fasthttp.StatusBadRequest, "either input_file_id or requests is required")
return
}
// Extract extra params
extraParams, err := extractExtraParams(ctx.PostBody(), batchCreateParamsKnownFields)
if err != nil {
logger.Warn("Failed to extract extra params: %v", err)
}
var model *string
if modelName != "" {
model = schemas.Ptr(modelName)
}
// Build Bifrost batch create request
bifrostBatchReq := &schemas.BifrostBatchCreateRequest{
Provider: schemas.ModelProvider(provider),
Model: model,
InputFileID: req.InputFileID,
Requests: req.Requests,
Endpoint: schemas.BatchEndpoint(req.Endpoint),
CompletionWindow: req.CompletionWindow,
Metadata: req.Metadata,
ExtraParams: extraParams,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.BatchCreateRequest(bifrostCtx, bifrostBatchReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// batchList handles GET /v1/batches - List batch jobs
func (h *CompletionHandler) batchList(ctx *fasthttp.RequestCtx) {
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Parse limit parameter
limit := 0
if limitStr := ctx.QueryArgs().Peek("limit"); len(limitStr) > 0 {
if n, err := strconv.Atoi(string(limitStr)); err == nil && n > 0 {
limit = n
}
}
// Parse pagination parameters
var after, before *string
if afterStr := ctx.QueryArgs().Peek("after"); len(afterStr) > 0 {
s := string(afterStr)
after = &s
}
if beforeStr := ctx.QueryArgs().Peek("before"); len(beforeStr) > 0 {
s := string(beforeStr)
before = &s
}
// Build Bifrost batch list request
bifrostBatchReq := &schemas.BifrostBatchListRequest{
Provider: schemas.ModelProvider(provider),
Limit: limit,
After: after,
BeforeID: before,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.BatchListRequest(bifrostCtx, bifrostBatchReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// batchRetrieve handles GET /v1/batches/{batch_id} - Retrieve a batch job
func (h *CompletionHandler) batchRetrieve(ctx *fasthttp.RequestCtx) {
// Get batch ID from URL parameter
batchID, ok := ctx.UserValue("batch_id").(string)
if !ok || batchID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "batch_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost batch retrieve request
bifrostBatchReq := &schemas.BifrostBatchRetrieveRequest{
Provider: schemas.ModelProvider(provider),
BatchID: batchID,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.BatchRetrieveRequest(bifrostCtx, bifrostBatchReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// batchCancel handles POST /v1/batches/{batch_id}/cancel - Cancel a batch job
func (h *CompletionHandler) batchCancel(ctx *fasthttp.RequestCtx) {
// Get batch ID from URL parameter
batchID := ctx.UserValue("batch_id").(string)
if batchID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "batch_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost batch cancel request
bifrostBatchReq := &schemas.BifrostBatchCancelRequest{
Provider: schemas.ModelProvider(provider),
BatchID: batchID,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.BatchCancelRequest(bifrostCtx, bifrostBatchReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// batchResults handles GET /v1/batches/{batch_id}/results - Get batch results
func (h *CompletionHandler) batchResults(ctx *fasthttp.RequestCtx) {
// Get batch ID from URL parameter
batchID := ctx.UserValue("batch_id").(string)
if batchID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "batch_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost batch results request
bifrostBatchReq := &schemas.BifrostBatchResultsRequest{
Provider: schemas.ModelProvider(provider),
BatchID: batchID,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.BatchResultsRequest(bifrostCtx, bifrostBatchReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// fileUpload handles POST /v1/files - Upload a file
func (h *CompletionHandler) fileUpload(ctx *fasthttp.RequestCtx) {
// Parse multipart form
form, err := ctx.MultipartForm()
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Failed to parse multipart form: %v", err))
return
}
// Get provider from query parameters or header
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
// Try to get from header (for OpenAI SDK compatibility)
provider = string(ctx.Request.Header.Peek("x-model-provider"))
// Try to get from extra_body
if provider == "" && len(form.Value["provider"]) > 0 {
provider = string(form.Value["provider"][0])
}
}
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter or x-model-provider header is required")
return
}
// Extract purpose (required)
purposeValues := form.Value["purpose"]
if len(purposeValues) == 0 || purposeValues[0] == "" {
SendError(ctx, fasthttp.StatusBadRequest, "purpose is required")
return
}
purpose := purposeValues[0]
// Extract file (required)
fileHeaders := form.File["file"]
if len(fileHeaders) == 0 {
SendError(ctx, fasthttp.StatusBadRequest, "file is required")
return
}
fileHeader := fileHeaders[0]
// Open and read the file
file, err := fileHeader.Open()
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Internal Server Error")
return
}
defer file.Close()
// Read file data
fileData, err := io.ReadAll(file)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Internal Server Error")
return
}
// Build Bifrost file upload request
bifrostFileReq := &schemas.BifrostFileUploadRequest{
Provider: schemas.ModelProvider(provider),
File: fileData,
Filename: fileHeader.Filename,
Purpose: schemas.FilePurpose(purpose),
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.FileUploadRequest(bifrostCtx, bifrostFileReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// fileList handles GET /v1/files - List files
func (h *CompletionHandler) fileList(ctx *fasthttp.RequestCtx) {
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("x-model-provider"))
if provider == "" {
// Try to get from header
provider = string(ctx.Request.Header.Peek("x-model-provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "x-model-provider query parameter or x-model-provider header is required")
return
}
}
// Parse optional parameters
purpose := string(ctx.QueryArgs().Peek("purpose"))
limit := 0
if limitStr := ctx.QueryArgs().Peek("limit"); len(limitStr) > 0 {
if n, err := strconv.Atoi(string(limitStr)); err == nil && n > 0 {
limit = n
}
}
var after, order *string
if afterStr := ctx.QueryArgs().Peek("after"); len(afterStr) > 0 {
s := string(afterStr)
after = &s
}
if orderStr := ctx.QueryArgs().Peek("order"); len(orderStr) > 0 {
s := string(orderStr)
order = &s
}
// Build Bifrost file list request
bifrostFileReq := &schemas.BifrostFileListRequest{
Provider: schemas.ModelProvider(provider),
Purpose: schemas.FilePurpose(purpose),
Limit: limit,
After: after,
Order: order,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.FileListRequest(bifrostCtx, bifrostFileReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// fileRetrieve handles GET /v1/files/{file_id} - Retrieve file metadata
func (h *CompletionHandler) fileRetrieve(ctx *fasthttp.RequestCtx) {
// Get file ID from URL parameter
fileID := ctx.UserValue("file_id").(string)
if fileID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "file_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost file retrieve request
bifrostFileReq := &schemas.BifrostFileRetrieveRequest{
Provider: schemas.ModelProvider(provider),
FileID: fileID,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.FileRetrieveRequest(bifrostCtx, bifrostFileReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// fileDelete handles DELETE /v1/files/{file_id} - Delete a file
func (h *CompletionHandler) fileDelete(ctx *fasthttp.RequestCtx) {
// Get file ID from URL parameter
fileID := ctx.UserValue("file_id").(string)
if fileID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "file_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost file delete request
bifrostFileReq := &schemas.BifrostFileDeleteRequest{
Provider: schemas.ModelProvider(provider),
FileID: fileID,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.FileDeleteRequest(bifrostCtx, bifrostFileReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// fileContent handles GET /v1/files/{file_id}/content - Download file content
func (h *CompletionHandler) fileContent(ctx *fasthttp.RequestCtx) {
// Get file ID from URL parameter
fileID := ctx.UserValue("file_id").(string)
if fileID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "file_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost file content request
bifrostFileReq := &schemas.BifrostFileContentRequest{
Provider: schemas.ModelProvider(provider),
FileID: fileID,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
resp, bifrostErr := h.client.FileContentRequest(bifrostCtx, bifrostFileReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
// Set appropriate headers for file download
ctx.Response.Header.Set("Content-Type", resp.ContentType)
ctx.Response.Header.Set("Content-Length", strconv.Itoa(len(resp.Content)))
ctx.Response.SetBody(resp.Content)
}
// containerCreate handles POST /v1/containers - Create a new container
func (h *CompletionHandler) containerCreate(ctx *fasthttp.RequestCtx) {
var req ContainerCreateRequest
if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err))
return
}
// Validate required fields
if req.Provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider is required")
return
}
if req.Name == "" {
SendError(ctx, fasthttp.StatusBadRequest, "name is required")
return
}
// Extract extra params
extraParams, err := extractExtraParams(ctx.PostBody(), containerCreateParamsKnownFields)
if err != nil {
logger.Warn("Failed to extract extra params: %v", err)
}
// Build Bifrost container create request
bifrostContainerReq := &schemas.BifrostContainerCreateRequest{
Provider: schemas.ModelProvider(req.Provider),
Name: req.Name,
ExpiresAfter: req.ExpiresAfter,
FileIDs: req.FileIDs,
MemoryLimit: req.MemoryLimit,
Metadata: req.Metadata,
ExtraParams: extraParams,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
enableRawRequestResponseForContainer(bifrostCtx)
resp, bifrostErr := h.client.ContainerCreateRequest(bifrostCtx, bifrostContainerReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// containerList handles GET /v1/containers - List containers
func (h *CompletionHandler) containerList(ctx *fasthttp.RequestCtx) {
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Parse limit parameter
limit := 0
if limitStr := ctx.QueryArgs().Peek("limit"); len(limitStr) > 0 {
if n, err := strconv.Atoi(string(limitStr)); err == nil && n > 0 {
limit = n
}
}
// Parse pagination parameters
var after, order *string
if afterStr := ctx.QueryArgs().Peek("after"); len(afterStr) > 0 {
after = bifrost.Ptr(string(afterStr))
}
if orderStr := ctx.QueryArgs().Peek("order"); len(orderStr) > 0 {
order = bifrost.Ptr(string(orderStr))
}
// Build Bifrost container list request
bifrostContainerReq := &schemas.BifrostContainerListRequest{
Provider: schemas.ModelProvider(provider),
Limit: limit,
After: after,
Order: order,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
enableRawRequestResponseForContainer(bifrostCtx)
resp, bifrostErr := h.client.ContainerListRequest(bifrostCtx, bifrostContainerReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// containerRetrieve handles GET /v1/containers/{container_id} - Retrieve a container
func (h *CompletionHandler) containerRetrieve(ctx *fasthttp.RequestCtx) {
// Get container ID from URL parameter
containerID, ok := ctx.UserValue("container_id").(string)
if !ok || containerID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "container_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost container retrieve request
bifrostContainerReq := &schemas.BifrostContainerRetrieveRequest{
Provider: schemas.ModelProvider(provider),
ContainerID: containerID,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
enableRawRequestResponseForContainer(bifrostCtx)
resp, bifrostErr := h.client.ContainerRetrieveRequest(bifrostCtx, bifrostContainerReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// containerDelete handles DELETE /v1/containers/{container_id} - Delete a container
func (h *CompletionHandler) containerDelete(ctx *fasthttp.RequestCtx) {
// Get container ID from URL parameter
containerID, ok := ctx.UserValue("container_id").(string)
if !ok || containerID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "container_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost container delete request
bifrostContainerReq := &schemas.BifrostContainerDeleteRequest{
Provider: schemas.ModelProvider(provider),
ContainerID: containerID,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
enableRawRequestResponseForContainer(bifrostCtx)
resp, bifrostErr := h.client.ContainerDeleteRequest(bifrostCtx, bifrostContainerReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// =============================================================================
// CONTAINER FILES HANDLERS
// =============================================================================
// containerFileCreate handles POST /v1/containers/{container_id}/files - Create a file in a container
func (h *CompletionHandler) containerFileCreate(ctx *fasthttp.RequestCtx) {
// Get container ID from URL parameter
containerID, ok := ctx.UserValue("container_id").(string)
if !ok || containerID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "container_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost container file create request
bifrostContainerFileReq := &schemas.BifrostContainerFileCreateRequest{
Provider: schemas.ModelProvider(provider),
ContainerID: containerID,
}
// Check if this is a multipart request or JSON request
contentType := string(ctx.Request.Header.ContentType())
if strings.HasPrefix(contentType, "multipart/form-data") {
// Handle multipart file upload
fileHeader, err := ctx.FormFile("file")
if err != nil {
SendError(ctx, fasthttp.StatusBadRequest, "file is required for multipart upload")
return
}
file, err := fileHeader.Open()
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Internal Server Error")
return
}
defer file.Close()
fileContent, err := io.ReadAll(file)
if err != nil {
SendError(ctx, fasthttp.StatusInternalServerError, "Internal Server Error")
return
}
bifrostContainerFileReq.File = fileContent
// Extract optional file_path from multipart form
if filePath := ctx.FormValue("file_path"); len(filePath) > 0 {
bifrostContainerFileReq.Path = bifrost.Ptr(string(filePath))
}
} else {
// Handle JSON request with file_id
var reqBody struct {
FileID string `json:"file_id"`
FilePath string `json:"file_path,omitempty"`
}
if err := sonic.Unmarshal(ctx.PostBody(), &reqBody); err != nil {
SendError(ctx, fasthttp.StatusBadRequest, "Invalid JSON body")
return
}
if reqBody.FileID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "file_id is required in JSON body")
return
}
bifrostContainerFileReq.FileID = bifrost.Ptr(reqBody.FileID)
if reqBody.FilePath != "" {
bifrostContainerFileReq.Path = bifrost.Ptr(reqBody.FilePath)
}
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
enableRawRequestResponseForContainer(bifrostCtx)
resp, bifrostErr := h.client.ContainerFileCreateRequest(bifrostCtx, bifrostContainerFileReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// containerFileList handles GET /v1/containers/{container_id}/files - List files in a container
func (h *CompletionHandler) containerFileList(ctx *fasthttp.RequestCtx) {
// Get container ID from URL parameter
containerID, ok := ctx.UserValue("container_id").(string)
if !ok || containerID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "container_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost container file list request
bifrostContainerFileReq := &schemas.BifrostContainerFileListRequest{
Provider: schemas.ModelProvider(provider),
ContainerID: containerID,
}
// Parse pagination parameters
if limit := ctx.QueryArgs().Peek("limit"); len(limit) > 0 {
if limitInt, err := strconv.Atoi(string(limit)); err == nil && limitInt > 0 {
bifrostContainerFileReq.Limit = limitInt
}
}
if after := string(ctx.QueryArgs().Peek("after")); after != "" {
bifrostContainerFileReq.After = bifrost.Ptr(after)
}
if order := string(ctx.QueryArgs().Peek("order")); order != "" {
bifrostContainerFileReq.Order = bifrost.Ptr(order)
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
enableRawRequestResponseForContainer(bifrostCtx)
resp, bifrostErr := h.client.ContainerFileListRequest(bifrostCtx, bifrostContainerFileReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// containerFileRetrieve handles GET /v1/containers/{container_id}/files/{file_id} - Retrieve a file from a container
func (h *CompletionHandler) containerFileRetrieve(ctx *fasthttp.RequestCtx) {
// Get container ID from URL parameter
containerID, ok := ctx.UserValue("container_id").(string)
if !ok || containerID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "container_id is required")
return
}
// Get file ID from URL parameter
fileID, ok := ctx.UserValue("file_id").(string)
if !ok || fileID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "file_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost container file retrieve request
bifrostContainerFileReq := &schemas.BifrostContainerFileRetrieveRequest{
Provider: schemas.ModelProvider(provider),
ContainerID: containerID,
FileID: fileID,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
enableRawRequestResponseForContainer(bifrostCtx)
resp, bifrostErr := h.client.ContainerFileRetrieveRequest(bifrostCtx, bifrostContainerFileReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}
// containerFileContent handles GET /v1/containers/{container_id}/files/{file_id}/content - Retrieve file content from a container
func (h *CompletionHandler) containerFileContent(ctx *fasthttp.RequestCtx) {
// Get container ID from URL parameter
containerID, ok := ctx.UserValue("container_id").(string)
if !ok || containerID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "container_id is required")
return
}
// Get file ID from URL parameter
fileID, ok := ctx.UserValue("file_id").(string)
if !ok || fileID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "file_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost container file content request
bifrostContainerFileReq := &schemas.BifrostContainerFileContentRequest{
Provider: schemas.ModelProvider(provider),
ContainerID: containerID,
FileID: fileID,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
enableRawRequestResponseForContainer(bifrostCtx)
resp, bifrostErr := h.client.ContainerFileContentRequest(bifrostCtx, bifrostContainerFileReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
// Send binary content with appropriate content type
ctx.SetContentType(resp.ContentType)
ctx.SetBody(resp.Content)
}
// containerFileDelete handles DELETE /v1/containers/{container_id}/files/{file_id} - Delete a file from a container
func (h *CompletionHandler) containerFileDelete(ctx *fasthttp.RequestCtx) {
// Get container ID from URL parameter
containerID, ok := ctx.UserValue("container_id").(string)
if !ok || containerID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "container_id is required")
return
}
// Get file ID from URL parameter
fileID, ok := ctx.UserValue("file_id").(string)
if !ok || fileID == "" {
SendError(ctx, fasthttp.StatusBadRequest, "file_id is required")
return
}
// Get provider from query parameters
provider := string(ctx.QueryArgs().Peek("provider"))
if provider == "" {
SendError(ctx, fasthttp.StatusBadRequest, "provider query parameter is required")
return
}
// Build Bifrost container file delete request
bifrostContainerFileReq := &schemas.BifrostContainerFileDeleteRequest{
Provider: schemas.ModelProvider(provider),
ContainerID: containerID,
FileID: fileID,
}
// Convert context
bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys(), h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist())
defer cancel()
if bifrostCtx == nil {
SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context")
return
}
enableRawRequestResponseForContainer(bifrostCtx)
resp, bifrostErr := h.client.ContainerFileDeleteRequest(bifrostCtx, bifrostContainerFileReq)
if bifrostErr != nil {
forwardProviderHeadersFromContext(ctx, bifrostCtx)
SendBifrostError(ctx, bifrostErr)
return
}
if resp != nil && resp.ExtraFields.ProviderResponseHeaders != nil {
forwardProviderHeaders(ctx, resp.ExtraFields.ProviderResponseHeaders)
}
if streamLargeResponseIfActive(ctx, bifrostCtx) {
return
}
SendJSON(ctx, resp)
}