7066 lines
252 KiB
Go
7066 lines
252 KiB
Go
// Package openai provides the OpenAI provider implementation for the Bifrost framework.
|
|
package openai
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"maps"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/bytedance/sonic"
|
|
|
|
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
|
schemas "github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
// OpenAIProvider implements the Provider interface for OpenAI's GPT API.
|
|
type OpenAIProvider struct {
|
|
logger schemas.Logger // Logger for provider operations
|
|
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
|
|
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
|
|
networkConfig schemas.NetworkConfig // Network configuration including extra headers
|
|
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
|
|
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
|
|
customProviderConfig *schemas.CustomProviderConfig // Custom provider config
|
|
disableStore bool // Whether to force store=false on outgoing requests
|
|
}
|
|
|
|
// NewOpenAIProvider creates a new OpenAI provider instance.
|
|
// It initializes the HTTP client with the provided configuration and sets up response pools.
|
|
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
|
|
func NewOpenAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) *OpenAIProvider {
|
|
config.CheckAndSetDefaults()
|
|
|
|
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
|
|
client := &fasthttp.Client{
|
|
ReadTimeout: requestTimeout,
|
|
WriteTimeout: requestTimeout,
|
|
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
|
|
MaxIdleConnDuration: 30 * time.Second,
|
|
MaxConnWaitTimeout: requestTimeout,
|
|
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
|
|
ConnPoolStrategy: fasthttp.FIFO,
|
|
}
|
|
|
|
// // Pre-warm response pools
|
|
// for range config.ConcurrencyAndBufferSize.Concurrency {
|
|
// openAIResponsePool.Put(&schemas.BifrostResponse{})
|
|
// }
|
|
|
|
// Configure proxy and retry policy
|
|
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
|
|
client = providerUtils.ConfigureDialer(client)
|
|
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
|
|
streamingClient := providerUtils.BuildStreamingClient(client)
|
|
// Set default BaseURL if not provided
|
|
if config.NetworkConfig.BaseURL == "" {
|
|
config.NetworkConfig.BaseURL = "https://api.openai.com"
|
|
}
|
|
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
|
|
|
|
return &OpenAIProvider{
|
|
logger: logger,
|
|
client: client,
|
|
streamingClient: streamingClient,
|
|
networkConfig: config.NetworkConfig,
|
|
sendBackRawRequest: config.SendBackRawRequest,
|
|
sendBackRawResponse: config.SendBackRawResponse,
|
|
customProviderConfig: config.CustomProviderConfig,
|
|
disableStore: config.OpenAIConfig != nil && config.OpenAIConfig.DisableStore,
|
|
}
|
|
}
|
|
|
|
// GetProviderKey returns the provider identifier for OpenAI.
|
|
func (provider *OpenAIProvider) GetProviderKey() schemas.ModelProvider {
|
|
return providerUtils.GetProviderName(schemas.OpenAI, provider.customProviderConfig)
|
|
}
|
|
|
|
// buildRequestURL constructs the full request URL using the provider's configuration.
|
|
func (provider *OpenAIProvider) buildRequestURL(ctx *schemas.BifrostContext, defaultPath string, requestType schemas.RequestType) string {
|
|
path, isCompleteURL := providerUtils.GetRequestPath(ctx, defaultPath, provider.customProviderConfig, requestType)
|
|
if isCompleteURL {
|
|
return path
|
|
}
|
|
return provider.networkConfig.BaseURL + path
|
|
}
|
|
|
|
func (provider *OpenAIProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ListModelsRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
providerName := provider.GetProviderKey()
|
|
|
|
if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess {
|
|
return providerUtils.HandleKeylessListModelsRequest(providerName, func() (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
|
return ListModelsByKey(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/models", schemas.ListModelsRequest),
|
|
schemas.Key{},
|
|
request.Unfiltered,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerName,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
)
|
|
})
|
|
}
|
|
|
|
return HandleOpenAIListModelsRequest(ctx,
|
|
provider.client,
|
|
request,
|
|
provider.buildRequestURL(ctx, "/v1/models", schemas.ListModelsRequest),
|
|
keys,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerName,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
)
|
|
}
|
|
|
|
// ListModelsByKey performs a list models request for a single key.
|
|
// Returns the list-models response, or an error if the request fails.
|
|
func ListModelsByKey(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
key schemas.Key,
|
|
unfiltered bool,
|
|
extraHeaders map[string]string,
|
|
providerName schemas.ModelProvider,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
bifrostErr := ParseOpenAIError(resp)
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Copy response body before releasing
|
|
responseBody := append([]byte(nil), resp.Body()...)
|
|
|
|
openaiResponse := &OpenAIListModelsResponse{}
|
|
|
|
// Use enhanced response handler with pre-allocated response
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, openaiResponse, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
response := openaiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, key.Aliases, unfiltered)
|
|
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
|
|
// Set raw request if enabled
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
// Set raw response if enabled
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// HandleOpenAIListModelsRequest handles a list models request to OpenAI's API.
|
|
func HandleOpenAIListModelsRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
request *schemas.BifrostListModelsRequest,
|
|
url string,
|
|
keys []schemas.Key,
|
|
extraHeaders map[string]string,
|
|
providerName schemas.ModelProvider,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
|
if len(keys) == 0 {
|
|
return ListModelsByKey(ctx, client, url, schemas.Key{}, request.Unfiltered, extraHeaders, providerName, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
listModelsByKeyWrapper := func(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
|
return ListModelsByKey(ctx, client, url, key, request.Unfiltered, extraHeaders, providerName, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
return providerUtils.HandleMultipleListModelsRequests(
|
|
ctx,
|
|
keys,
|
|
request,
|
|
listModelsByKeyWrapper,
|
|
)
|
|
}
|
|
|
|
// TextCompletion is not supported by the OpenAI provider.
|
|
// Returns an error indicating that text completion is not available.
|
|
func (provider *OpenAIProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TextCompletionRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
return HandleOpenAITextCompletionRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/completions", schemas.TextCompletionRequest),
|
|
request,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
provider.GetProviderKey(),
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
nil,
|
|
nil,
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAITextCompletionRequest handles a text completion request to OpenAI's API.
|
|
func HandleOpenAITextCompletionRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostTextCompletionRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
providerName schemas.ModelProvider,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
customResponseHandler responseHandler[schemas.BifrostTextCompletionResponse],
|
|
customErrorConverter ErrorConverter,
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
// resp lifecycle: managed by finalizeOpenAIResponse or released on error paths
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Large payload passthrough: stream body directly without JSON marshaling
|
|
if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled {
|
|
if lpErr != nil {
|
|
return nil, lpErr
|
|
}
|
|
if len(lpResult.ResponseBody) > 0 {
|
|
response := &schemas.BifrostTextCompletionResponse{}
|
|
if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
|
}
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}
|
|
return response, nil
|
|
}
|
|
return &schemas.BifrostTextCompletionResponse{
|
|
Model: request.Model,
|
|
Usage: lpResult.Usage,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return ToOpenAITextCompletionRequest(request), nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
req.SetBody(jsonData)
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
if customErrorConverter != nil {
|
|
return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger)
|
|
respOwned = false // ownership transferred
|
|
if finalErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, finalErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
if lpResult != nil {
|
|
return &schemas.BifrostTextCompletionResponse{
|
|
Model: request.Model,
|
|
Usage: lpResult.Usage,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
response := &schemas.BifrostTextCompletionResponse{}
|
|
|
|
var rawRequest, rawResponse interface{}
|
|
|
|
if customResponseHandler != nil {
|
|
rawRequest, rawResponse, bifrostErr = customResponseHandler(body, response, jsonData, sendBackRawRequest, sendBackRawResponse)
|
|
} else {
|
|
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(body, response, jsonData, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
|
|
// Set raw request if enabled
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
// Set raw response if enabled
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// TextCompletionStream performs a streaming text completion request to OpenAI's API.
|
|
// It formats the request, sends it to OpenAI, and processes the response.
|
|
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
|
|
func (provider *OpenAIProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TextCompletionStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
var authHeader map[string]string
|
|
if key.Value.GetValue() != "" {
|
|
authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()}
|
|
}
|
|
return HandleOpenAITextCompletionStreaming(
|
|
ctx,
|
|
provider.streamingClient,
|
|
provider.buildRequestURL(ctx, "/v1/completions", schemas.TextCompletionStreamRequest),
|
|
request,
|
|
authHeader,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
nil,
|
|
postHookRunner,
|
|
nil,
|
|
nil,
|
|
provider.logger,
|
|
postHookSpanFinalizer,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAITextCompletionStreaming handles text completion streaming for OpenAI-compatible APIs.
|
|
// This shared function reduces code duplication between providers that use the same SSE format.
|
|
func HandleOpenAITextCompletionStreaming(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostTextCompletionRequest,
|
|
authHeader map[string]string,
|
|
extraHeaders map[string]string,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
providerName schemas.ModelProvider,
|
|
customErrorConverter ErrorConverter,
|
|
postHookRunner schemas.PostHookRunner,
|
|
customResponseHandler responseHandler[schemas.BifrostTextCompletionResponse],
|
|
postResponseConverter func(*schemas.BifrostTextCompletionResponse) *schemas.BifrostTextCompletionResponse,
|
|
logger schemas.Logger,
|
|
postHookSpanFinalizer func(context.Context),
|
|
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
headers := map[string]string{
|
|
"Content-Type": "application/json",
|
|
"Accept": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
}
|
|
|
|
if authHeader != nil {
|
|
maps.Copy(headers, authHeader)
|
|
}
|
|
|
|
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
reqBody := ToOpenAITextCompletionRequest(request)
|
|
if reqBody != nil {
|
|
reqBody.Stream = schemas.Ptr(true)
|
|
reqBody.StreamOptions = &schemas.ChatStreamOptions{
|
|
IncludeUsage: schemas.Ptr(true),
|
|
}
|
|
}
|
|
return reqBody, nil
|
|
})
|
|
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Create HTTP request for streaming
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
resp.StreamBody = true
|
|
defer fasthttp.ReleaseRequest(req)
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.SetRequestURI(url)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
// Set headers
|
|
for key, value := range headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
|
|
setStreamingRequestBody(ctx, req, jsonBody, providerName)
|
|
|
|
// Use streaming-aware client when large payload optimization is active — ensures
|
|
// MaxResponseBodySize > 0 so ErrBodyTooLarge triggers StreamBody for Content-Length responses.
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Make the request
|
|
err := activeClient.Do(req, resp)
|
|
if err != nil {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
if errors.Is(err, context.Canceled) {
|
|
return nil, providerUtils.EnrichError(ctx, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Type: schemas.Ptr(schemas.RequestCancelled),
|
|
Message: schemas.ErrRequestCancelled,
|
|
Error: err,
|
|
},
|
|
}, jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Store provider response headers in context before status check so error responses also forward them
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Check for HTTP errors
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
if customErrorConverter != nil {
|
|
return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
|
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
|
responseChan := make(chan *schemas.BifrostStreamChunk)
|
|
close(responseChan)
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Create response channel
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
|
|
// Start streaming in a goroutine
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
}
|
|
close(responseChan)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
|
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
|
defer stopIdleTimeout()
|
|
|
|
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
|
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger)
|
|
defer stopCancellation()
|
|
|
|
// Skip scanner for non-SSE responses — avoids bufio.Scanner buffer bloat
|
|
// on non-line-delimited data (e.g. provider returned JSON instead of SSE).
|
|
if providerUtils.DrainNonSSEStreamResponse(resp) {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
sseReader := providerUtils.GetSSEDataReader(ctx, reader)
|
|
|
|
chunkIndex := -1
|
|
usage := &schemas.BifrostLLMUsage{}
|
|
|
|
var finishReason *string
|
|
var messageID string
|
|
startTime := time.Now()
|
|
lastChunkTime := startTime
|
|
|
|
for {
|
|
// If context was cancelled/timed out, let defer handle it
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
data, readErr := sseReader.ReadDataLine()
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
logger.Warn("Error reading stream: %v", readErr)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
break
|
|
}
|
|
jsonData := string(data)
|
|
var response schemas.BifrostTextCompletionResponse
|
|
if customResponseHandler != nil {
|
|
rawRequest, rawResponse, handlerErr := customResponseHandler([]byte(jsonData), &response, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if handlerErr != nil {
|
|
// TODO fix this
|
|
if sendBackRawRequest {
|
|
handlerErr.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
if sendBackRawResponse {
|
|
handlerErr.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, handlerErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
} else {
|
|
|
|
// Quick check for error field (allocation-free using sonic.GetFromString)
|
|
if errorNode, _ := sonic.GetFromString(jsonData, "error"); errorNode.Exists() {
|
|
// Only unmarshal when we know there's an error
|
|
var bifrostErr schemas.BifrostError
|
|
if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil {
|
|
if bifrostErr.Error != nil && bifrostErr.Error.Message != "" {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse into bifrost response
|
|
if err := sonic.UnmarshalString(jsonData, &response); err != nil {
|
|
logger.Warn("Failed to parse stream response: %v", err)
|
|
continue
|
|
}
|
|
}
|
|
|
|
// choices be array if nil
|
|
if response.Choices == nil {
|
|
response.Choices = []schemas.BifrostResponseChoice{}
|
|
}
|
|
|
|
if postResponseConverter != nil {
|
|
if converted := postResponseConverter(&response); converted != nil {
|
|
response = *converted
|
|
} else {
|
|
logger.Warn("postResponseConverter returned nil; leaving chunk unmodified")
|
|
}
|
|
}
|
|
|
|
// Handle usage-only chunks (when stream_options include_usage is true)
|
|
if response.Usage != nil {
|
|
// Collect usage information and send at the end of the stream
|
|
// Here in some cases usage comes before final message
|
|
// So we need to check if the response.Usage is nil and then if usage != nil
|
|
// then add up all tokens
|
|
if response.Usage.PromptTokens > usage.PromptTokens {
|
|
usage.PromptTokens = response.Usage.PromptTokens
|
|
}
|
|
if response.Usage.CompletionTokens > usage.CompletionTokens {
|
|
usage.CompletionTokens = response.Usage.CompletionTokens
|
|
}
|
|
if response.Usage.TotalTokens > usage.TotalTokens {
|
|
usage.TotalTokens = response.Usage.TotalTokens
|
|
}
|
|
calculatedTotal := usage.PromptTokens + usage.CompletionTokens
|
|
if calculatedTotal > usage.TotalTokens {
|
|
usage.TotalTokens = calculatedTotal
|
|
}
|
|
if response.Usage.CompletionTokensDetails != nil {
|
|
usage.CompletionTokensDetails = response.Usage.CompletionTokensDetails
|
|
}
|
|
if response.Usage.PromptTokensDetails != nil {
|
|
usage.PromptTokensDetails = response.Usage.PromptTokensDetails
|
|
}
|
|
response.Usage = nil
|
|
}
|
|
|
|
// Skip empty responses or responses without choices
|
|
if len(response.Choices) == 0 {
|
|
continue
|
|
}
|
|
|
|
// Handle finish reason, usually in the final chunk
|
|
choice := response.Choices[0]
|
|
if choice.FinishReason != nil && *choice.FinishReason != "" {
|
|
// Collect finish reason and send at the end of the stream
|
|
finishReason = choice.FinishReason
|
|
response.Choices[0].FinishReason = nil
|
|
}
|
|
|
|
if response.ID != "" && messageID == "" {
|
|
messageID = response.ID
|
|
}
|
|
|
|
// Handle regular content chunks
|
|
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
|
|
chunkIndex++
|
|
|
|
response.ExtraFields.ChunkIndex = chunkIndex
|
|
response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds()
|
|
lastChunkTime = time.Now()
|
|
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = jsonData
|
|
}
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(&response, nil, nil, nil, nil, nil), responseChan, postHookSpanFinalizer)
|
|
}
|
|
|
|
// For providers that don't send [DONE] marker break on finish_reason
|
|
if !providerUtils.ProviderSendsDoneMarker(providerName) && finishReason != nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
response := providerUtils.CreateBifrostTextCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.TextCompletionStreamRequest)
|
|
if postResponseConverter != nil {
|
|
response = postResponseConverter(response)
|
|
if response == nil {
|
|
logger.Warn("postResponseConverter returned nil; leaving chunk unmodified")
|
|
return
|
|
}
|
|
}
|
|
// Set raw request if enabled
|
|
if sendBackRawRequest {
|
|
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
|
|
}
|
|
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(response, nil, nil, nil, nil, nil), responseChan, postHookSpanFinalizer)
|
|
}()
|
|
|
|
return responseChan, nil
|
|
}
|
|
|
|
// ChatCompletion performs a chat completion request to the OpenAI API.
|
|
// It supports both text and image content in messages.
|
|
// Returns a BifrostResponse containing the completion results or an error if the request fails.
|
|
func (provider *OpenAIProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
|
// Check if chat completion is allowed for this provider
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if provider.disableStore {
|
|
if request.Params == nil {
|
|
request.Params = &schemas.ChatParameters{}
|
|
}
|
|
request.Params.Store = schemas.Ptr(false)
|
|
}
|
|
|
|
return HandleOpenAIChatCompletionRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/chat/completions", schemas.ChatCompletionRequest),
|
|
request,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
nil,
|
|
nil,
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAIChatCompletionRequest handles a chat completion request to OpenAI's API.
|
|
func HandleOpenAIChatCompletionRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostChatRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
providerName schemas.ModelProvider,
|
|
customResponseHandler responseHandler[schemas.BifrostChatResponse],
|
|
customErrorConverter ErrorConverter,
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
// resp lifecycle: managed by finalizeOpenAIResponse or released on error paths
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Large payload passthrough: stream body directly without JSON marshaling
|
|
if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled {
|
|
if lpErr != nil {
|
|
return nil, lpErr
|
|
}
|
|
if len(lpResult.ResponseBody) > 0 {
|
|
response := &schemas.BifrostChatResponse{}
|
|
if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
|
}
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}
|
|
return response, nil
|
|
}
|
|
return &schemas.BifrostChatResponse{
|
|
Model: request.Model,
|
|
Usage: lpResult.Usage,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return ToOpenAIChatRequest(ctx, request), nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
req.SetBody(jsonData)
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
if customErrorConverter != nil {
|
|
return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger)
|
|
respOwned = false // ownership transferred
|
|
if finalErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, finalErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
if lpResult != nil {
|
|
return &schemas.BifrostChatResponse{
|
|
Model: request.Model,
|
|
Usage: lpResult.Usage,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
response := &schemas.BifrostChatResponse{}
|
|
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
|
|
var rawRequest, rawResponse interface{}
|
|
|
|
if customResponseHandler != nil {
|
|
rawRequest, rawResponse, bifrostErr = customResponseHandler(body, response, jsonData, sendBackRawRequest, sendBackRawResponse)
|
|
} else {
|
|
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(body, response, jsonData, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
|
|
// Set raw request if enabled
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
// Set raw response if enabled
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// ChatCompletionStream handles streaming for OpenAI chat completions.
|
|
// It formats messages, prepares request body, and uses shared streaming logic.
|
|
// Returns a channel for streaming responses and any error that occurred.
|
|
func (provider *OpenAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
// Check if chat completion stream is allowed for this provider
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
var authHeader map[string]string
|
|
if key.Value.GetValue() != "" {
|
|
authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()}
|
|
}
|
|
if provider.disableStore {
|
|
if request.Params == nil {
|
|
request.Params = &schemas.ChatParameters{}
|
|
}
|
|
request.Params.Store = schemas.Ptr(false)
|
|
}
|
|
|
|
// Use shared streaming logic
|
|
return HandleOpenAIChatCompletionStreaming(
|
|
ctx,
|
|
provider.streamingClient,
|
|
provider.buildRequestURL(ctx, "/v1/chat/completions", schemas.ChatCompletionStreamRequest),
|
|
request,
|
|
authHeader,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
postHookRunner,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
provider.logger,
|
|
postHookSpanFinalizer,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAIChatCompletionStreaming handles streaming for OpenAI-compatible APIs.
|
|
// This shared function reduces code duplication between providers that use the same SSE format.
|
|
func HandleOpenAIChatCompletionStreaming(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostChatRequest,
|
|
authHeader map[string]string,
|
|
extraHeaders map[string]string,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
providerName schemas.ModelProvider,
|
|
postHookRunner schemas.PostHookRunner,
|
|
customRequestConverter func(*schemas.BifrostChatRequest) (providerUtils.RequestBodyWithExtraParams, error),
|
|
customResponseHandler responseHandler[schemas.BifrostChatResponse],
|
|
customErrorConverter ErrorConverter,
|
|
postRequestConverter func(*OpenAIChatRequest) *OpenAIChatRequest,
|
|
postResponseConverter func(*schemas.BifrostChatResponse) *schemas.BifrostChatResponse,
|
|
logger schemas.Logger,
|
|
postHookSpanFinalizer func(context.Context),
|
|
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
// Check if the request is a redirect from ResponsesStream to ChatCompletionStream
|
|
isResponsesToChatCompletionsFallback := false
|
|
var responsesStreamState *schemas.ChatToResponsesStreamState
|
|
if ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback) != nil {
|
|
isResponsesToChatCompletionsFallbackValue, ok := ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback).(bool)
|
|
if ok && isResponsesToChatCompletionsFallbackValue {
|
|
isResponsesToChatCompletionsFallback = true
|
|
responsesStreamState = schemas.AcquireChatToResponsesStreamState()
|
|
}
|
|
}
|
|
|
|
headers := map[string]string{
|
|
"Content-Type": "application/json",
|
|
"Accept": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
}
|
|
|
|
if authHeader != nil {
|
|
// Copy auth header to headers
|
|
maps.Copy(headers, authHeader)
|
|
}
|
|
|
|
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
if customRequestConverter != nil {
|
|
return customRequestConverter(request)
|
|
}
|
|
reqBody := ToOpenAIChatRequest(ctx, request)
|
|
if reqBody != nil {
|
|
reqBody.Stream = schemas.Ptr(true)
|
|
reqBody.StreamOptions = &schemas.ChatStreamOptions{
|
|
IncludeUsage: schemas.Ptr(true),
|
|
}
|
|
if postRequestConverter != nil {
|
|
reqBody = postRequestConverter(reqBody)
|
|
}
|
|
}
|
|
return reqBody, nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Create HTTP request for streaming
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
resp.StreamBody = true
|
|
defer fasthttp.ReleaseRequest(req)
|
|
|
|
// Updating request
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.SetRequestURI(url)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
// Set headers
|
|
for key, value := range headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
|
|
setStreamingRequestBody(ctx, req, jsonBody, providerName)
|
|
|
|
// Use streaming-aware client when large payload optimization is active — ensures
|
|
// MaxResponseBodySize > 0 so ErrBodyTooLarge triggers StreamBody for Content-Length responses.
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Make the request
|
|
err := activeClient.Do(req, resp)
|
|
if err != nil {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
if errors.Is(err, context.Canceled) {
|
|
return nil, providerUtils.EnrichError(ctx, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Type: schemas.Ptr(schemas.RequestCancelled),
|
|
Message: schemas.ErrRequestCancelled,
|
|
Error: err,
|
|
},
|
|
}, jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Store provider response headers in context before status check so error responses also forward them
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Check for HTTP errors
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
if customErrorConverter != nil {
|
|
return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
|
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
|
responseChan := make(chan *schemas.BifrostStreamChunk)
|
|
close(responseChan)
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Create response channel
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
|
|
// Start streaming in a goroutine
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
}
|
|
// Release the responses stream state if it was acquired (for ResponsesToChatCompletions fallback)
|
|
schemas.ReleaseChatToResponsesStreamState(responsesStreamState)
|
|
close(responseChan)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
|
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
|
defer stopIdleTimeout()
|
|
|
|
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
|
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger)
|
|
defer stopCancellation()
|
|
|
|
// Skip scanner for non-SSE responses — avoids bufio.Scanner buffer bloat
|
|
// on non-line-delimited data (e.g. provider returned JSON instead of SSE).
|
|
if providerUtils.DrainNonSSEStreamResponse(resp) {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
sseReader := providerUtils.GetSSEDataReader(ctx, reader)
|
|
|
|
chunkIndex := -1
|
|
usage := &schemas.BifrostLLMUsage{}
|
|
|
|
startTime := time.Now()
|
|
lastChunkTime := startTime
|
|
|
|
var finishReason *string
|
|
var messageID string
|
|
var modelName string
|
|
var created int
|
|
forwardedTerminalFinishReason := false
|
|
|
|
for {
|
|
// If context was cancelled/timed out, let defer handle it
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
data, readErr := sseReader.ReadDataLine()
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
logger.Warn("Error reading stream: %v", readErr)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
break
|
|
}
|
|
jsonData := string(data)
|
|
|
|
// Quick check for error field (allocation-free using sonic.GetFromString)
|
|
if errorNode, _ := sonic.GetFromString(jsonData, "error"); errorNode.Exists() {
|
|
// Only unmarshal when we know there's an error
|
|
var bifrostErr schemas.BifrostError
|
|
if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil {
|
|
if bifrostErr.Error != nil && bifrostErr.Error.Message != "" {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse into bifrost response
|
|
var response schemas.BifrostChatResponse
|
|
// TODO fix this
|
|
if customResponseHandler != nil {
|
|
rawRequest, rawResponse, handlerErr := customResponseHandler([]byte(jsonData), &response, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if handlerErr != nil {
|
|
if sendBackRawRequest {
|
|
handlerErr.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
if sendBackRawResponse {
|
|
handlerErr.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, handlerErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
} else {
|
|
if err := sonic.UnmarshalString(jsonData, &response); err != nil {
|
|
logger.Warn("Failed to parse stream response: %v", err)
|
|
continue
|
|
}
|
|
}
|
|
|
|
// choices be array if nil
|
|
if response.Choices == nil {
|
|
response.Choices = []schemas.BifrostResponseChoice{}
|
|
}
|
|
|
|
if isResponsesToChatCompletionsFallback {
|
|
spreadResponses := response.ToBifrostResponsesStreamResponse(responsesStreamState)
|
|
for _, response := range spreadResponses {
|
|
if response.Type == schemas.ResponsesStreamResponseTypeError {
|
|
bifrostErr := &schemas.BifrostError{
|
|
Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeError)),
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{},
|
|
}
|
|
|
|
if response.Message != nil {
|
|
bifrostErr.Error.Message = *response.Message
|
|
}
|
|
if response.Param != nil {
|
|
bifrostErr.Error.Param = *response.Param
|
|
}
|
|
if response.Code != nil {
|
|
bifrostErr.Error.Code = response.Code
|
|
}
|
|
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
response.ExtraFields.ChunkIndex = response.SequenceNumber
|
|
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = jsonData
|
|
}
|
|
|
|
if response.Type == schemas.ResponsesStreamResponseTypeCompleted || response.Type == schemas.ResponsesStreamResponseTypeIncomplete {
|
|
// Set raw request if enabled
|
|
if sendBackRawRequest {
|
|
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
|
|
}
|
|
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds()
|
|
lastChunkTime = time.Now()
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, response, nil, nil, nil), responseChan, postHookSpanFinalizer)
|
|
}
|
|
} else {
|
|
if postResponseConverter != nil {
|
|
if converted := postResponseConverter(&response); converted != nil {
|
|
response = *converted
|
|
} else {
|
|
logger.Warn("postResponseConverter returned nil; leaving chunk unmodified")
|
|
}
|
|
}
|
|
|
|
// Handle usage-only chunks (when stream_options include_usage is true)
|
|
if response.Usage != nil {
|
|
// Collect usage information and send at the end of the stream
|
|
// Here in some cases usage comes before final message
|
|
// So we need to check if the response.Usage is nil and then if usage != nil
|
|
// then add up all tokens
|
|
if response.Usage.PromptTokens > usage.PromptTokens {
|
|
usage.PromptTokens = response.Usage.PromptTokens
|
|
}
|
|
if response.Usage.CompletionTokens > usage.CompletionTokens {
|
|
usage.CompletionTokens = response.Usage.CompletionTokens
|
|
}
|
|
if response.Usage.TotalTokens > usage.TotalTokens {
|
|
usage.TotalTokens = response.Usage.TotalTokens
|
|
}
|
|
calculatedTotal := usage.PromptTokens + usage.CompletionTokens
|
|
if calculatedTotal > usage.TotalTokens {
|
|
usage.TotalTokens = calculatedTotal
|
|
}
|
|
if response.Usage.PromptTokensDetails != nil {
|
|
usage.PromptTokensDetails = response.Usage.PromptTokensDetails
|
|
}
|
|
if response.Usage.CompletionTokensDetails != nil {
|
|
usage.CompletionTokensDetails = response.Usage.CompletionTokensDetails
|
|
}
|
|
if response.Usage.Cost != nil {
|
|
usage.Cost = response.Usage.Cost
|
|
}
|
|
response.Usage = nil
|
|
}
|
|
|
|
if response.Model != "" {
|
|
modelName = response.Model
|
|
}
|
|
|
|
// Skip empty responses or responses without choices
|
|
if len(response.Choices) == 0 {
|
|
continue
|
|
}
|
|
|
|
// Handle finish reason, usually in the final chunk
|
|
choice := response.Choices[0]
|
|
if choice.FinishReason != nil && *choice.FinishReason != "" {
|
|
// Collect finish reason and send at the end of the stream
|
|
finishReason = choice.FinishReason
|
|
}
|
|
|
|
if response.ID != "" && messageID == "" {
|
|
messageID = response.ID
|
|
}
|
|
if response.Created != 0 && created == 0 {
|
|
created = response.Created
|
|
}
|
|
|
|
// Handle regular content chunks, including reasoning
|
|
if choice.ChatStreamResponseChoice != nil &&
|
|
choice.ChatStreamResponseChoice.Delta != nil &&
|
|
((choice.ChatStreamResponseChoice.Delta.Content != nil && *choice.ChatStreamResponseChoice.Delta.Content != "") ||
|
|
choice.ChatStreamResponseChoice.Delta.Reasoning != nil ||
|
|
len(choice.ChatStreamResponseChoice.Delta.ReasoningDetails) > 0 ||
|
|
choice.ChatStreamResponseChoice.Delta.Audio != nil ||
|
|
len(choice.ChatStreamResponseChoice.Delta.ToolCalls) > 0) {
|
|
if choice.FinishReason != nil && *choice.FinishReason != "" {
|
|
forwardedTerminalFinishReason = true
|
|
}
|
|
chunkIndex++
|
|
|
|
response.ExtraFields.ChunkIndex = chunkIndex
|
|
response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds()
|
|
lastChunkTime = time.Now()
|
|
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = jsonData
|
|
}
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, &response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer)
|
|
}
|
|
|
|
// For providers that don't send [DONE] marker break on finish_reason
|
|
if !providerUtils.ProviderSendsDoneMarker(providerName) && finishReason != nil {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if !isResponsesToChatCompletionsFallback {
|
|
finalFinishReason := finishReason
|
|
if forwardedTerminalFinishReason {
|
|
finalFinishReason = nil
|
|
}
|
|
response := providerUtils.CreateBifrostChatCompletionChunkResponse(messageID, usage, finalFinishReason, chunkIndex, modelName, created)
|
|
if postResponseConverter != nil {
|
|
response = postResponseConverter(response)
|
|
}
|
|
// Set raw request if enabled
|
|
if sendBackRawRequest {
|
|
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
|
|
}
|
|
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, response, nil, nil, nil, nil), responseChan, postHookSpanFinalizer)
|
|
}
|
|
}()
|
|
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Responses performs a responses request to the OpenAI API.
|
|
func (provider *OpenAIProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
|
// Check if chat completion is allowed for this provider
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if provider.disableStore {
|
|
if request.Params == nil {
|
|
request.Params = &schemas.ResponsesParameters{}
|
|
}
|
|
request.Params.Store = schemas.Ptr(false)
|
|
}
|
|
|
|
return HandleOpenAIResponsesRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/responses", schemas.ResponsesRequest),
|
|
request,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
nil,
|
|
nil,
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAIResponsesRequest handles a responses request to OpenAI's API.
|
|
func HandleOpenAIResponsesRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostResponsesRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
providerName schemas.ModelProvider,
|
|
customResponseHandler responseHandler[schemas.BifrostResponsesResponse],
|
|
customErrorConverter ErrorConverter,
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
// resp lifecycle: managed by finalizeOpenAIResponse or released on error paths
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Large payload passthrough: stream body directly without JSON marshaling
|
|
if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled {
|
|
if lpErr != nil {
|
|
return nil, lpErr
|
|
}
|
|
if len(lpResult.ResponseBody) > 0 {
|
|
response := &schemas.BifrostResponsesResponse{}
|
|
if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
|
}
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}
|
|
return response, nil
|
|
}
|
|
return &schemas.BifrostResponsesResponse{
|
|
Model: request.Model,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
// Use centralized converter
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return ToOpenAIResponsesRequest(request), nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
req.SetBody(jsonData)
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
if customErrorConverter != nil {
|
|
return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger)
|
|
respOwned = false // ownership transferred
|
|
if finalErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, finalErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
if lpResult != nil {
|
|
return &schemas.BifrostResponsesResponse{
|
|
Model: request.Model,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
response := &schemas.BifrostResponsesResponse{}
|
|
|
|
var rawRequest, rawResponse interface{}
|
|
|
|
if customResponseHandler != nil {
|
|
rawRequest, rawResponse, bifrostErr = customResponseHandler(body, response, jsonData, sendBackRawRequest, sendBackRawResponse)
|
|
} else {
|
|
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(body, response, jsonData, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
|
|
// Set raw request if enabled
|
|
if sendBackRawRequest {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
// Set raw response if enabled
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// ResponsesStream performs a streaming responses request to the OpenAI API.
|
|
func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
// Check if chat completion stream is allowed for this provider
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
var authHeader map[string]string
|
|
if key.Value.GetValue() != "" {
|
|
authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()}
|
|
}
|
|
if provider.disableStore {
|
|
if request.Params == nil {
|
|
request.Params = &schemas.ResponsesParameters{}
|
|
}
|
|
request.Params.Store = schemas.Ptr(false)
|
|
}
|
|
|
|
// Use shared streaming logic
|
|
return HandleOpenAIResponsesStreaming(
|
|
ctx,
|
|
provider.streamingClient,
|
|
provider.buildRequestURL(ctx, "/v1/responses", schemas.ResponsesStreamRequest),
|
|
request,
|
|
authHeader,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
postHookRunner,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
provider.logger,
|
|
postHookSpanFinalizer,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAIResponsesStreaming handles streaming for OpenAI-compatible APIs.
|
|
// This shared function reduces code duplication between providers that use the same SSE format.
|
|
func HandleOpenAIResponsesStreaming(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostResponsesRequest,
|
|
authHeader map[string]string,
|
|
extraHeaders map[string]string,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
providerName schemas.ModelProvider,
|
|
postHookRunner schemas.PostHookRunner,
|
|
customResponseHandler responseHandler[schemas.BifrostResponsesStreamResponse],
|
|
customErrorConverter ErrorConverter,
|
|
postRequestConverter func(*OpenAIResponsesRequest) *OpenAIResponsesRequest,
|
|
postResponseConverter func(*schemas.BifrostResponsesStreamResponse) *schemas.BifrostResponsesStreamResponse,
|
|
logger schemas.Logger,
|
|
postHookSpanFinalizer func(context.Context),
|
|
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
// Prepare SGL headers (SGL typically doesn't require authorization, but we include it if provided)
|
|
headers := map[string]string{
|
|
"Content-Type": "application/json",
|
|
"Accept": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
}
|
|
|
|
if authHeader != nil {
|
|
// Copy auth header to headers
|
|
maps.Copy(headers, authHeader)
|
|
}
|
|
|
|
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
reqBody := ToOpenAIResponsesRequest(request)
|
|
if reqBody != nil {
|
|
reqBody.Stream = schemas.Ptr(true)
|
|
if postRequestConverter != nil {
|
|
reqBody = postRequestConverter(reqBody)
|
|
}
|
|
}
|
|
return reqBody, nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Create HTTP request for streaming
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
resp.StreamBody = true
|
|
defer fasthttp.ReleaseRequest(req)
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.SetRequestURI(url)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
// Set headers
|
|
for key, value := range headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
|
|
setStreamingRequestBody(ctx, req, jsonBody, providerName)
|
|
|
|
// Use streaming-aware client when large payload optimization is active — ensures
|
|
// MaxResponseBodySize > 0 so ErrBodyTooLarge triggers StreamBody for Content-Length responses.
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Make the request
|
|
err := activeClient.Do(req, resp)
|
|
if err != nil {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
if errors.Is(err, context.Canceled) {
|
|
return nil, providerUtils.EnrichError(ctx, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Type: schemas.Ptr(schemas.RequestCancelled),
|
|
Message: schemas.ErrRequestCancelled,
|
|
Error: err,
|
|
},
|
|
}, jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Store provider response headers in context before status check so error responses also forward them
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Check for HTTP errors
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
if customErrorConverter != nil {
|
|
return nil, providerUtils.EnrichError(ctx, customErrorConverter(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
|
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
|
responseChan := make(chan *schemas.BifrostStreamChunk)
|
|
close(responseChan)
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Create response channel
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
|
|
// Start streaming in a goroutine
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
}
|
|
close(responseChan)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
|
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
|
defer stopIdleTimeout()
|
|
|
|
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
|
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger)
|
|
defer stopCancellation()
|
|
|
|
// Skip scanner for non-SSE responses — avoids bufio.Scanner buffer bloat
|
|
// on non-line-delimited data (e.g. provider returned JSON instead of SSE).
|
|
if providerUtils.DrainNonSSEStreamResponse(resp) {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
sseReader := providerUtils.GetSSEDataReader(ctx, reader)
|
|
|
|
startTime := time.Now()
|
|
lastChunkTime := startTime
|
|
|
|
for {
|
|
// If context was cancelled/timed out, let defer handle it
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
data, readErr := sseReader.ReadDataLine()
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
logger.Warn("Error reading stream: %v", readErr)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer)
|
|
}
|
|
break
|
|
}
|
|
jsonData := string(data)
|
|
|
|
// Parse into bifrost response
|
|
var response schemas.BifrostResponsesStreamResponse
|
|
// TODO fix this
|
|
if customResponseHandler != nil {
|
|
rawRequest, rawResponse, bifrostErr := customResponseHandler([]byte(jsonData), &response, nil, false, false)
|
|
if bifrostErr != nil {
|
|
if sendBackRawRequest {
|
|
bifrostErr.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
if sendBackRawResponse {
|
|
bifrostErr.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
} else {
|
|
if err := sonic.UnmarshalString(jsonData, &response); err != nil {
|
|
logger.Warn("Failed to parse stream response: %v", err)
|
|
continue
|
|
}
|
|
|
|
if postResponseConverter != nil {
|
|
if converted := postResponseConverter(&response); converted != nil {
|
|
response = *converted
|
|
} else {
|
|
logger.Warn("postResponseConverter returned nil; leaving chunk unmodified")
|
|
}
|
|
}
|
|
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = jsonData
|
|
}
|
|
|
|
if response.Type == schemas.ResponsesStreamResponseTypeError {
|
|
bifrostErr := &schemas.BifrostError{
|
|
Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeError)),
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{},
|
|
}
|
|
|
|
if response.Message != nil {
|
|
bifrostErr.Error.Message = *response.Message
|
|
}
|
|
if response.Param != nil {
|
|
bifrostErr.Error.Param = *response.Param
|
|
}
|
|
if response.Code != nil {
|
|
bifrostErr.Error.Code = response.Code
|
|
}
|
|
if response.Response != nil && response.Response.Error != nil {
|
|
if response.Response.Error.Message != "" && bifrostErr.Error.Message == "" {
|
|
bifrostErr.Error.Message = response.Response.Error.Message
|
|
}
|
|
if response.Response.Error.Code != "" && (bifrostErr.Error.Code == nil || *bifrostErr.Error.Code == "") {
|
|
bifrostErr.Error.Code = schemas.Ptr(response.Response.Error.Code)
|
|
}
|
|
}
|
|
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, []byte(jsonData), sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
// Some providers (e.g. Fireworks) send response.failed on HTTP 200 streams
|
|
// instead of a pre-stream 4xx. Convert to BifrostError for consistent handling.
|
|
if response.Type == schemas.ResponsesStreamResponseTypeFailed {
|
|
bifrostErr := &schemas.BifrostError{
|
|
Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeFailed)),
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{},
|
|
}
|
|
if response.Response != nil && response.Response.Error != nil {
|
|
bifrostErr.Error.Message = response.Response.Error.Message
|
|
bifrostErr.Error.Code = &response.Response.Error.Code
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, []byte(jsonData), sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
response.ExtraFields.ChunkIndex = response.SequenceNumber
|
|
if response.Type == schemas.ResponsesStreamResponseTypeCompleted || response.Type == schemas.ResponsesStreamResponseTypeIncomplete {
|
|
// Set raw request if enabled
|
|
if sendBackRawRequest {
|
|
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
|
|
}
|
|
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil, nil), responseChan, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
response.ExtraFields.Latency = time.Since(lastChunkTime).Milliseconds()
|
|
lastChunkTime = time.Now()
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, &response, nil, nil, nil), responseChan, postHookSpanFinalizer)
|
|
}
|
|
}
|
|
}()
|
|
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Embedding generates embeddings for the given input text(s).
|
|
// The input can be either a single string or a slice of strings for batch embedding.
|
|
// Returns a BifrostResponse containing the embedding(s) and any error that occurred.
|
|
func (provider *OpenAIProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
|
// Check if embedding is allowed for this provider
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Use the shared embedding request handler
|
|
return HandleOpenAIEmbeddingRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/embeddings", schemas.EmbeddingRequest),
|
|
request,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
provider.GetProviderKey(),
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
nil,
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAIEmbeddingRequest handles embedding requests for OpenAI-compatible APIs.
|
|
// This shared function reduces code duplication between providers that use the same embedding request format.
|
|
func HandleOpenAIEmbeddingRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostEmbeddingRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
providerName schemas.ModelProvider,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
customResponseHandler responseHandler[schemas.BifrostEmbeddingResponse],
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
// resp lifecycle: managed by finalizeOpenAIResponse or released on error paths
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Large payload passthrough: stream body directly without JSON marshaling
|
|
if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled {
|
|
if lpErr != nil {
|
|
return nil, lpErr
|
|
}
|
|
if len(lpResult.ResponseBody) > 0 {
|
|
response := &schemas.BifrostEmbeddingResponse{}
|
|
if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
|
}
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}
|
|
return response, nil
|
|
}
|
|
return &schemas.BifrostEmbeddingResponse{
|
|
Model: request.Model,
|
|
Usage: lpResult.Usage,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
// Use centralized converter
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return ToOpenAIEmbeddingRequest(request), nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
req.SetBody(jsonData)
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body())))
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger)
|
|
respOwned = false // ownership transferred
|
|
if finalErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, finalErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
if lpResult != nil {
|
|
return &schemas.BifrostEmbeddingResponse{
|
|
Model: request.Model,
|
|
Usage: lpResult.Usage,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
response := &schemas.BifrostEmbeddingResponse{}
|
|
|
|
var rawRequest, rawResponse interface{}
|
|
|
|
if customResponseHandler != nil {
|
|
rawRequest, rawResponse, bifrostErr = customResponseHandler(body, response, jsonData, sendBackRawRequest, sendBackRawResponse)
|
|
} else {
|
|
rawRequest, rawResponse, bifrostErr = providerUtils.HandleProviderResponse(body, response, jsonData, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
|
|
// Set raw request if enabled
|
|
if sendBackRawRequest {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
// Set raw response if enabled
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// Speech handles non-streaming speech synthesis requests.
|
|
// It formats the request body, makes the API call, and returns the response.
|
|
// Returns the response and any error that occurred.
|
|
func (provider *OpenAIProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.SpeechRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return HandleOpenAISpeechRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/audio/speech", schemas.SpeechRequest),
|
|
request,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
provider.GetProviderKey(),
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
nil,
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAISpeechRequest handles speech requests for OpenAI-compatible APIs.
|
|
// This shared function reduces code duplication between providers that use the same speech request format.
|
|
func HandleOpenAISpeechRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostSpeechRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
providerName schemas.ModelProvider,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
customResponseHandler responseHandler[schemas.BifrostSpeechResponse],
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
// resp lifecycle: managed by finalizeOpenAIResponse or released on error paths
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Large payload passthrough: stream body directly without JSON marshaling
|
|
if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled {
|
|
if lpErr != nil {
|
|
return nil, lpErr
|
|
}
|
|
// Speech response is raw audio bytes (MP3/WAV), not JSON
|
|
return &schemas.BifrostSpeechResponse{
|
|
Audio: lpResult.ResponseBody,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAISpeechRequest(request), nil })
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
req.SetBody(jsonData)
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body())))
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Get the binary audio data from the response body
|
|
body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger)
|
|
respOwned = false // ownership transferred
|
|
if finalErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, finalErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
if lpResult != nil {
|
|
return &schemas.BifrostSpeechResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
// Create final response with the audio data
|
|
// Note: For speech synthesis, we return the binary audio data in the raw response
|
|
// The audio data is typically in MP3, WAV, or other audio formats as specified by response_format
|
|
bifrostResponse := &schemas.BifrostSpeechResponse{
|
|
Audio: body,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerResponseHeaders,
|
|
},
|
|
}
|
|
|
|
if sendBackRawRequest {
|
|
providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData)
|
|
}
|
|
|
|
return bifrostResponse, nil
|
|
}
|
|
|
|
// SpeechStream handles streaming for speech synthesis.
|
|
// It formats the request body, creates HTTP request, and uses shared streaming logic.
|
|
// Returns a channel for streaming responses and any error that occurred.
|
|
func (provider *OpenAIProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.SpeechStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, model := range providerUtils.UnsupportedSpeechStreamModels {
|
|
if model == request.Model {
|
|
return nil, providerUtils.NewBifrostOperationError(fmt.Sprintf("model %s is not supported for streaming speech synthesis", model), nil)
|
|
}
|
|
}
|
|
|
|
var authHeader map[string]string
|
|
if key.Value.GetValue() != "" {
|
|
authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()}
|
|
}
|
|
|
|
return HandleOpenAISpeechStreamRequest(
|
|
ctx,
|
|
provider.streamingClient,
|
|
provider.buildRequestURL(ctx, "/v1/audio/speech", schemas.SpeechStreamRequest),
|
|
request,
|
|
authHeader,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
postHookRunner,
|
|
nil,
|
|
nil,
|
|
provider.logger,
|
|
postHookSpanFinalizer,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAISpeechStreamRequest handles speech stream requests for OpenAI-compatible APIs.
|
|
// This shared function reduces code duplication between providers that use the same speech stream request format.
|
|
func HandleOpenAISpeechStreamRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostSpeechRequest,
|
|
authHeader map[string]string,
|
|
extraHeaders map[string]string,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
providerName schemas.ModelProvider,
|
|
postHookRunner schemas.PostHookRunner,
|
|
postRequestConverter func(*OpenAISpeechRequest) *OpenAISpeechRequest,
|
|
postResponseConverter func(*schemas.BifrostSpeechStreamResponse) *schemas.BifrostSpeechStreamResponse,
|
|
logger schemas.Logger,
|
|
postHookSpanFinalizer func(context.Context),
|
|
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
// Create HTTP request for streaming
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
resp.StreamBody = true
|
|
defer fasthttp.ReleaseRequest(req)
|
|
|
|
// Prepare OpenAI headers
|
|
headers := map[string]string{
|
|
"Content-Type": "application/json",
|
|
"Accept": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
}
|
|
|
|
if authHeader != nil {
|
|
maps.Copy(headers, authHeader)
|
|
}
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.SetRequestURI(url)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
// Set any extra headers from network config
|
|
// Set headers
|
|
for key, value := range headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
|
|
// Use centralized converter
|
|
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
reqBody := ToOpenAISpeechRequest(request)
|
|
if reqBody != nil {
|
|
reqBody.StreamFormat = schemas.Ptr("sse")
|
|
if postRequestConverter != nil {
|
|
reqBody = postRequestConverter(reqBody)
|
|
}
|
|
}
|
|
return reqBody, nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
setStreamingRequestBody(ctx, req, jsonBody, providerName)
|
|
|
|
// Use streaming-aware client when large payload optimization is active — ensures
|
|
// MaxResponseBodySize > 0 so ErrBodyTooLarge triggers StreamBody for Content-Length responses.
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Make the request
|
|
err := activeClient.Do(req, resp)
|
|
if err != nil {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
if errors.Is(err, context.Canceled) {
|
|
return nil, providerUtils.EnrichError(ctx, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Type: schemas.Ptr(schemas.RequestCancelled),
|
|
Message: schemas.ErrRequestCancelled,
|
|
Error: err,
|
|
},
|
|
}, jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Store provider response headers in context before status check so error responses also forward them
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Check for HTTP errors
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
|
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
|
responseChan := make(chan *schemas.BifrostStreamChunk)
|
|
close(responseChan)
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Create response channel
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
|
|
// Start streaming in a goroutine
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
}
|
|
close(responseChan)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
|
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
|
defer stopIdleTimeout()
|
|
|
|
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
|
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger)
|
|
defer stopCancellation()
|
|
|
|
// Skip scanner for non-SSE responses — avoids bufio.Scanner buffer bloat
|
|
// on non-line-delimited data (e.g. provider returned JSON instead of SSE).
|
|
if providerUtils.DrainNonSSEStreamResponse(resp) {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
sseReader := providerUtils.GetSSEDataReader(ctx, reader)
|
|
chunkIndex := -1
|
|
|
|
startTime := time.Now()
|
|
lastChunkTime := startTime
|
|
|
|
for {
|
|
// If context was cancelled/timed out, let defer handle it
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
|
|
data, readErr := sseReader.ReadDataLine()
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
logger.Warn("Error reading stream: %v", readErr)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer)
|
|
}
|
|
break
|
|
}
|
|
jsonData := string(data)
|
|
|
|
// Quick check for error field (allocation-free using sonic.GetFromString)
|
|
if errorNode, _ := sonic.GetFromString(jsonData, "error"); errorNode.Exists() {
|
|
// Only unmarshal when we know there's an error
|
|
var bifrostErr schemas.BifrostError
|
|
if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil {
|
|
if bifrostErr.Error != nil && bifrostErr.Error.Message != "" {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse into bifrost response
|
|
var response schemas.BifrostSpeechStreamResponse
|
|
if err := sonic.UnmarshalString(jsonData, &response); err != nil {
|
|
logger.Warn("Failed to parse stream response: %v", err)
|
|
continue
|
|
}
|
|
|
|
if postResponseConverter != nil {
|
|
if converted := postResponseConverter(&response); converted != nil {
|
|
response = *converted
|
|
} else {
|
|
logger.Warn("postResponseConverter returned nil; leaving chunk unmodified")
|
|
}
|
|
}
|
|
|
|
chunkIndex++
|
|
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: chunkIndex,
|
|
Latency: time.Since(lastChunkTime).Milliseconds(),
|
|
}
|
|
lastChunkTime = time.Now()
|
|
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = jsonData
|
|
}
|
|
|
|
if response.Usage != nil {
|
|
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
|
|
if sendBackRawRequest {
|
|
providerUtils.ParseAndSetRawRequest(&response.ExtraFields, jsonBody)
|
|
}
|
|
response.BackfillParams(request)
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, &response, nil, nil), responseChan, postHookSpanFinalizer)
|
|
}
|
|
}()
|
|
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Transcription handles non-streaming transcription requests.
|
|
// It creates a multipart form, adds fields, makes the API call, and returns the response.
|
|
// Returns the response and any error that occurred.
|
|
func (provider *OpenAIProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TranscriptionRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return HandleOpenAITranscriptionRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/audio/transcriptions", schemas.TranscriptionRequest),
|
|
request,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
provider.GetProviderKey(),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
nil,
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
func HandleOpenAITranscriptionRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostTranscriptionRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
providerName schemas.ModelProvider,
|
|
sendBackRawResponse bool,
|
|
customResponseHandler responseHandler[schemas.BifrostTranscriptionResponse],
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
|
// Large payload passthrough: stream multipart body directly without parsing
|
|
if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled {
|
|
if lpErr != nil {
|
|
return nil, lpErr
|
|
}
|
|
// Unmarshal the upstream response body to preserve transcription text and fields
|
|
if len(lpResult.ResponseBody) > 0 {
|
|
response := &schemas.BifrostTranscriptionResponse{}
|
|
if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
|
}
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}
|
|
return response, nil
|
|
}
|
|
return &schemas.BifrostTranscriptionResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
// resp lifecycle: managed by finalizeOpenAIResponse or released on error paths
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Use centralized converter
|
|
reqBody := ToOpenAITranscriptionRequest(request)
|
|
if reqBody == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil)
|
|
}
|
|
|
|
// Create multipart form
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
if err := ParseTranscriptionFormDataBodyFromRequest(writer, reqBody, providerName); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.SetContentType(writer.FormDataContentType()) // This sets multipart/form-data with boundary
|
|
req.SetBody(body.Bytes())
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
responseBody, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger)
|
|
respOwned = false // ownership transferred
|
|
if finalErr != nil {
|
|
return nil, finalErr
|
|
}
|
|
if lpResult != nil {
|
|
return &schemas.BifrostTranscriptionResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
// Check for empty response
|
|
trimmed := strings.TrimSpace(string(responseBody))
|
|
if len(trimmed) == 0 {
|
|
return nil, &schemas.BifrostError{
|
|
IsBifrostError: true,
|
|
Error: &schemas.ErrorField{
|
|
Message: schemas.ErrProviderResponseEmpty,
|
|
},
|
|
}
|
|
}
|
|
|
|
copiedResponseBody := append([]byte(nil), responseBody...)
|
|
|
|
// Parse OpenAI's transcription response directly into BifrostTranscribe
|
|
response := &schemas.BifrostTranscriptionResponse{}
|
|
var rawResponse interface{}
|
|
if request.Params != nil && schemas.IsPlainTextTranscriptionFormat(request.Params.ResponseFormat) {
|
|
response.Text = string(copiedResponseBody)
|
|
if sendBackRawResponse {
|
|
rawResponse = string(copiedResponseBody)
|
|
}
|
|
} else if customResponseHandler != nil {
|
|
_, rawResponse, bifrostErr = customResponseHandler(copiedResponseBody, response, nil, false, sendBackRawResponse)
|
|
} else {
|
|
if err := sonic.Unmarshal(copiedResponseBody, response); err != nil {
|
|
// Check if it's an HTML response
|
|
if providerUtils.IsHTMLResponse(resp, copiedResponseBody) {
|
|
return nil, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Message: schemas.ErrProviderResponseHTML,
|
|
Error: errors.New(string(copiedResponseBody)),
|
|
},
|
|
}
|
|
}
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
|
}
|
|
|
|
// TODO: add HandleProviderResponse here
|
|
|
|
// Parse raw response for RawResponse field
|
|
if sendBackRawResponse {
|
|
if err := sonic.Unmarshal(copiedResponseBody, &rawResponse); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerResponseHeaders,
|
|
}
|
|
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// TranscriptionStream performs a streaming transcription request to the OpenAI API.
|
|
func (provider *OpenAIProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TranscriptionStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var authHeader map[string]string
|
|
if key.Value.GetValue() != "" {
|
|
authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()}
|
|
}
|
|
|
|
return HandleOpenAITranscriptionStreamRequest(
|
|
ctx,
|
|
provider.streamingClient,
|
|
provider.buildRequestURL(ctx, "/v1/audio/transcriptions", schemas.TranscriptionStreamRequest),
|
|
request,
|
|
authHeader,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
false,
|
|
provider.GetProviderKey(),
|
|
postHookRunner,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
provider.logger,
|
|
postHookSpanFinalizer,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAITranscriptionStreamRequest handles transcription stream requests for OpenAI-compatible APIs.
|
|
// This shared function reduces code duplication between providers that use the same transcription stream request format.
|
|
func HandleOpenAITranscriptionStreamRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostTranscriptionRequest,
|
|
authHeader map[string]string,
|
|
extraHeaders map[string]string,
|
|
sendBackRawResponse bool,
|
|
accumulateText bool,
|
|
providerName schemas.ModelProvider,
|
|
postHookRunner schemas.PostHookRunner,
|
|
customResponseHandler responseHandler[schemas.BifrostTranscriptionStreamResponse],
|
|
postRequestConverter func(*OpenAITranscriptionRequest) *OpenAITranscriptionRequest,
|
|
postResponseConverter func(*schemas.BifrostTranscriptionStreamResponse) *schemas.BifrostTranscriptionStreamResponse,
|
|
logger schemas.Logger,
|
|
postHookSpanFinalizer func(context.Context),
|
|
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
// Use centralized converter
|
|
reqBody := ToOpenAITranscriptionRequest(request)
|
|
if reqBody == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("transcription input is not provided", nil)
|
|
}
|
|
reqBody.Stream = schemas.Ptr(true)
|
|
if postRequestConverter != nil {
|
|
reqBody = postRequestConverter(reqBody)
|
|
}
|
|
|
|
// Create multipart form
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
|
|
if bifrostErr := ParseTranscriptionFormDataBodyFromRequest(writer, reqBody, providerName); bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Prepare OpenAI headers
|
|
headers := map[string]string{
|
|
"Content-Type": writer.FormDataContentType(),
|
|
"Accept": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
}
|
|
|
|
if authHeader != nil {
|
|
maps.Copy(headers, authHeader)
|
|
}
|
|
|
|
// Create HTTP request for streaming
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
resp.StreamBody = true
|
|
defer fasthttp.ReleaseRequest(req)
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.SetRequestURI(url)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
// Set headers
|
|
for key, value := range headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
|
|
req.SetBody(body.Bytes())
|
|
|
|
// Make the request
|
|
err := client.Do(req, resp)
|
|
if err != nil {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
if errors.Is(err, context.Canceled) {
|
|
return nil, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Type: schemas.Ptr(schemas.RequestCancelled),
|
|
Message: schemas.ErrRequestCancelled,
|
|
Error: err,
|
|
},
|
|
}
|
|
}
|
|
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
|
|
}
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
|
|
}
|
|
|
|
// Store provider response headers in context before status check so error responses also forward them
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Check for HTTP errors
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
|
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
|
responseChan := make(chan *schemas.BifrostStreamChunk)
|
|
close(responseChan)
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Create response channel
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
|
|
// Start streaming in a goroutine
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
}
|
|
close(responseChan)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
|
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
|
defer stopIdleTimeout()
|
|
|
|
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
|
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger)
|
|
defer stopCancellation()
|
|
|
|
// Skip scanner for non-SSE responses — avoids bufio.Scanner buffer bloat
|
|
// on non-line-delimited data (e.g. provider returned JSON instead of SSE).
|
|
if providerUtils.DrainNonSSEStreamResponse(resp) {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
sseReader := providerUtils.GetSSEDataReader(ctx, reader)
|
|
chunkIndex := -1
|
|
|
|
startTime := time.Now()
|
|
lastChunkTime := startTime
|
|
var fullTranscriptionText string
|
|
|
|
for {
|
|
// If context was cancelled/timed out, let defer handle it
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
|
|
data, readErr := sseReader.ReadDataLine()
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
logger.Warn("Error reading stream: %v", readErr)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer)
|
|
}
|
|
break
|
|
}
|
|
jsonData := string(data)
|
|
// TODo fix this
|
|
response := &schemas.BifrostTranscriptionStreamResponse{}
|
|
var bifrostErr *schemas.BifrostError
|
|
if customResponseHandler != nil {
|
|
_, _, bifrostErr = customResponseHandler([]byte(jsonData), response, nil, false, false)
|
|
if bifrostErr != nil {
|
|
if sendBackRawResponse {
|
|
bifrostErr.ExtraFields.RawResponse = jsonData
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, body.Bytes(), []byte(jsonData), false, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
} else {
|
|
// Quick check for error field (allocation-free using sonic.GetFromString)
|
|
if errorNode, _ := sonic.GetFromString(jsonData, "error"); errorNode.Exists() {
|
|
// Only unmarshal when we know there's an error
|
|
var bifrostErrVal schemas.BifrostError
|
|
if err := sonic.UnmarshalString(jsonData, &bifrostErrVal); err == nil {
|
|
if bifrostErrVal.Error != nil && bifrostErrVal.Error.Message != "" {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
respBody := append([]byte(nil), resp.Body()...)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErrVal, body.Bytes(), respBody, false, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := sonic.UnmarshalString(jsonData, response); err != nil {
|
|
logger.Warn("Failed to parse stream response: %v", err)
|
|
continue
|
|
|
|
}
|
|
}
|
|
|
|
if postResponseConverter != nil {
|
|
if converted := postResponseConverter(response); converted != nil {
|
|
response = converted
|
|
} else {
|
|
logger.Warn("postResponseConverter returned nil; leaving chunk unmodified")
|
|
}
|
|
}
|
|
|
|
chunkIndex++
|
|
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: chunkIndex,
|
|
Latency: time.Since(lastChunkTime).Milliseconds(),
|
|
}
|
|
lastChunkTime = time.Now()
|
|
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = jsonData
|
|
}
|
|
|
|
if response.Usage != nil || response.Type == schemas.TranscriptionStreamResponseTypeDone {
|
|
response.ExtraFields.Latency = time.Since(startTime).Milliseconds()
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
|
|
if accumulateText {
|
|
response.Text = fullTranscriptionText
|
|
}
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, response, nil), responseChan, postHookSpanFinalizer)
|
|
}
|
|
}()
|
|
|
|
return responseChan, nil
|
|
}
|
|
|
|
// ImageGeneration performs an Image Generation request to OpenAI's API.
|
|
// It formats the request, sends it to OpenAI, and processes the response.
|
|
// Returns a BifrostResponse containing the bifrost response or an error if the request fails.
|
|
func (provider *OpenAIProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key,
|
|
req *schemas.BifrostImageGenerationRequest,
|
|
) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ImageGenerationRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return HandleOpenAIImageGenerationRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/images/generations", schemas.ImageGenerationRequest),
|
|
req,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
provider.GetProviderKey(),
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAIImageGenerationRequest handles image generation requests for OpenAI-compatible APIs.
|
|
// This shared function reduces code duplication between providers that use the same image generation request format.
|
|
func HandleOpenAIImageGenerationRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostImageGenerationRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
providerName schemas.ModelProvider,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
// resp lifecycle: managed by finalizeOpenAIResponse or released on error paths
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if value := key.Value.GetValue(); value != "" {
|
|
req.Header.Set("Authorization", "Bearer "+value)
|
|
}
|
|
|
|
// Large payload passthrough: stream body directly without JSON marshaling
|
|
if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled {
|
|
if lpErr != nil {
|
|
return nil, lpErr
|
|
}
|
|
if len(lpResult.ResponseBody) > 0 {
|
|
response := &schemas.BifrostImageGenerationResponse{}
|
|
if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
|
}
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}
|
|
return response, nil
|
|
}
|
|
return &schemas.BifrostImageGenerationResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
// Use centralized converter
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return ToOpenAIImageGenerationRequest(request), nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
req.SetBody(jsonData)
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body())))
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger)
|
|
respOwned = false // ownership transferred
|
|
if finalErr != nil {
|
|
return nil, finalErr
|
|
}
|
|
if lpResult != nil {
|
|
return &schemas.BifrostImageGenerationResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
response := &schemas.BifrostImageGenerationResponse{}
|
|
|
|
// Use enhanced response handler with pre-allocated response
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, response, jsonData, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
|
|
// Set raw request if enabled
|
|
if sendBackRawRequest {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
// Set raw response if enabled
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// ImageGenerationStream handles streaming for image generation.
|
|
// It formats the request body, creates HTTP request, and uses shared streaming logic.
|
|
// Returns a channel for streaming responses and any error that occurred.
|
|
func (provider *OpenAIProvider) ImageGenerationStream(
|
|
ctx *schemas.BifrostContext,
|
|
postHookRunner schemas.PostHookRunner,
|
|
postHookSpanFinalizer func(context.Context),
|
|
key schemas.Key,
|
|
request *schemas.BifrostImageGenerationRequest,
|
|
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
if request == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil)
|
|
}
|
|
|
|
// Check if image generation stream is allowed for this provider
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ImageGenerationStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var authHeader map[string]string
|
|
if value := key.Value.GetValue(); value != "" {
|
|
authHeader = map[string]string{"Authorization": "Bearer " + value}
|
|
}
|
|
// Use shared streaming logic
|
|
return HandleOpenAIImageGenerationStreaming(
|
|
ctx,
|
|
provider.streamingClient,
|
|
provider.buildRequestURL(ctx, "/v1/images/generations", schemas.ImageGenerationStreamRequest),
|
|
request,
|
|
authHeader,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
postHookRunner,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
provider.logger,
|
|
postHookSpanFinalizer,
|
|
)
|
|
}
|
|
|
|
func HandleOpenAIImageGenerationStreaming(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostImageGenerationRequest,
|
|
authHeader map[string]string,
|
|
extraHeaders map[string]string,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
providerName schemas.ModelProvider,
|
|
postHookRunner schemas.PostHookRunner,
|
|
customRequestConverter func(*schemas.BifrostImageGenerationRequest) (providerUtils.RequestBodyWithExtraParams, error),
|
|
postRequestConverter func(*OpenAIImageGenerationRequest) *OpenAIImageGenerationRequest,
|
|
postResponseConverter func(*schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse,
|
|
logger schemas.Logger,
|
|
postHookSpanFinalizer func(context.Context),
|
|
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
// Set headers
|
|
headers := map[string]string{
|
|
"Content-Type": "application/json",
|
|
"Accept": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
}
|
|
|
|
if authHeader != nil {
|
|
// Copy auth header to headers
|
|
maps.Copy(headers, authHeader)
|
|
}
|
|
|
|
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
if customRequestConverter != nil {
|
|
return customRequestConverter(request)
|
|
}
|
|
reqBody := ToOpenAIImageGenerationRequest(request)
|
|
if reqBody != nil {
|
|
reqBody.Stream = schemas.Ptr(true)
|
|
if postRequestConverter != nil {
|
|
reqBody = postRequestConverter(reqBody)
|
|
}
|
|
}
|
|
return reqBody, nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Create HTTP request for streaming
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
resp.StreamBody = true
|
|
defer fasthttp.ReleaseRequest(req)
|
|
|
|
// Updating request
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.SetRequestURI(url)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
// Set headers
|
|
for key, value := range headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
|
|
setStreamingRequestBody(ctx, req, jsonBody, providerName)
|
|
|
|
// Use streaming-aware client when large payload optimization is active — ensures
|
|
// MaxResponseBodySize > 0 so ErrBodyTooLarge triggers StreamBody for Content-Length responses.
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Make the request
|
|
err := activeClient.Do(req, resp)
|
|
if err != nil {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
if errors.Is(err, context.Canceled) {
|
|
return nil, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Type: schemas.Ptr(schemas.RequestCancelled),
|
|
Message: schemas.ErrRequestCancelled,
|
|
Error: err,
|
|
},
|
|
}
|
|
}
|
|
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
|
|
}
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
|
|
}
|
|
|
|
// Store provider response headers in context before status check so error responses also forward them
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Check for HTTP errors
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
|
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
|
responseChan := make(chan *schemas.BifrostStreamChunk)
|
|
close(responseChan)
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Create response channel
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
|
|
// Start streaming in a goroutine
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
}
|
|
close(responseChan)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
|
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
|
defer stopIdleTimeout()
|
|
|
|
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
|
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger)
|
|
defer stopCancellation()
|
|
|
|
// Skip scanner for non-SSE responses — avoids bufio.Scanner buffer bloat
|
|
// on non-line-delimited data (e.g. provider returned JSON instead of SSE).
|
|
if providerUtils.DrainNonSSEStreamResponse(resp) {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
sseReader := providerUtils.GetSSEDataReader(ctx, reader)
|
|
|
|
startTime := time.Now()
|
|
lastChunkTime := startTime
|
|
var collectedUsage *schemas.ImageUsage
|
|
// Track chunk indices per image - similar to how speech/transcription track chunkIndex
|
|
imageChunkIndices := make(map[int]int) // image index -> chunk index
|
|
// Track images that have started (via partial chunks) but not yet completed
|
|
// This allows us to correctly match completed events to images even if chunks are interleaved
|
|
incompleteImages := make(map[int]bool)
|
|
maxImageIndex := -1 // Track maximum image index for NImages calculation
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
data, readErr := sseReader.ReadDataLine()
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
logger.Warn("Error reading stream: %v", readErr)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer)
|
|
}
|
|
break
|
|
}
|
|
jsonData := string(data)
|
|
|
|
// Quick check for error field (allocation-free using sonic.GetFromString)
|
|
if errorNode, _ := sonic.GetFromString(jsonData, "error"); errorNode.Exists() {
|
|
// Only unmarshal when we know there's an error
|
|
var bifrostErr schemas.BifrostError
|
|
if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil {
|
|
if bifrostErr.Error != nil && bifrostErr.Error.Message != "" {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse minimally to extract usage and check for errors
|
|
var response OpenAIImageStreamResponse
|
|
if err := sonic.UnmarshalString(jsonData, &response); err != nil {
|
|
logger.Warn("Failed to parse stream response: %v", err)
|
|
continue
|
|
}
|
|
|
|
// Check if response type indicates an error
|
|
if response.Type == "error" {
|
|
bifrostErr := &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{},
|
|
}
|
|
// Guard access to response.Error fields
|
|
if response.Error != nil {
|
|
bifrostErr.Error.Message = response.Error.Message
|
|
if response.Error.Code != nil {
|
|
bifrostErr.Error.Code = response.Error.Code
|
|
}
|
|
if response.Error.Param != nil {
|
|
bifrostErr.Error.Param = response.Error.Param
|
|
}
|
|
if response.Error.Type != nil {
|
|
bifrostErr.Error.Type = response.Error.Type
|
|
}
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
// Collect usage from completed event
|
|
if response.Usage != nil {
|
|
collectedUsage = &schemas.ImageUsage{
|
|
InputTokens: response.Usage.InputTokens,
|
|
OutputTokens: response.Usage.OutputTokens,
|
|
TotalTokens: response.Usage.TotalTokens,
|
|
}
|
|
}
|
|
|
|
// Determine if this is the final chunk
|
|
isCompleted := response.Type == schemas.ImageGenerationEventTypeCompleted
|
|
|
|
// Determine image index with robust tracking for interleaved chunks
|
|
// Both partial and completed chunks should use PartialImageIndex when available
|
|
var imageIndex int
|
|
if response.PartialImageIndex != nil {
|
|
// Use explicit index from response
|
|
imageIndex = *response.PartialImageIndex
|
|
if isCompleted {
|
|
// Mark this image as completed
|
|
delete(incompleteImages, imageIndex)
|
|
} else {
|
|
// Mark this image as started (incomplete)
|
|
incompleteImages[imageIndex] = true
|
|
}
|
|
} else {
|
|
// Fallback: PartialImageIndex is nil, use tracked state
|
|
if isCompleted {
|
|
// For completed chunks, match to the oldest incomplete image
|
|
// This handles interleaved chunks correctly
|
|
if len(incompleteImages) == 0 {
|
|
// Fallback: if no incomplete images tracked, this shouldn't happen in normal flow
|
|
// but we'll default to 0 to prevent panics
|
|
imageIndex = 0
|
|
logger.Warn("Received completed event but no incomplete images tracked, defaulting to index 0")
|
|
} else {
|
|
// Find the minimum (oldest) incomplete image index
|
|
// Completed events should match the oldest image that was started
|
|
minIndex := -1
|
|
for idx := range incompleteImages {
|
|
if minIndex == -1 || idx < minIndex {
|
|
minIndex = idx
|
|
}
|
|
}
|
|
imageIndex = minIndex
|
|
// Mark this image as completed
|
|
delete(incompleteImages, imageIndex)
|
|
logger.Warn("Completed event missing PartialImageIndex, using oldest incomplete image index %d", imageIndex)
|
|
}
|
|
} else {
|
|
// For partial chunks without PartialImageIndex, allocate a new unique index
|
|
// Use maxImageIndex + 1 to ensure uniqueness
|
|
imageIndex = maxImageIndex + 1
|
|
// Mark this image as started (incomplete)
|
|
incompleteImages[imageIndex] = true
|
|
}
|
|
}
|
|
|
|
// Update maximum image index for NImages calculation
|
|
if imageIndex > maxImageIndex {
|
|
maxImageIndex = imageIndex
|
|
}
|
|
|
|
// Increment chunk index for this image
|
|
if _, exists := imageChunkIndices[imageIndex]; !exists {
|
|
imageChunkIndices[imageIndex] = 0
|
|
} else {
|
|
imageChunkIndices[imageIndex]++
|
|
}
|
|
chunkIndex := imageChunkIndices[imageIndex]
|
|
// Build chunk with all OpenAI fields
|
|
chunk := &schemas.BifrostImageGenerationStreamResponse{
|
|
Type: response.Type,
|
|
Index: imageIndex, // Which image (0-N)
|
|
ChunkIndex: chunkIndex, // Chunk order within this image (top-level)
|
|
CreatedAt: response.CreatedAt,
|
|
Size: response.Size,
|
|
Quality: response.Quality,
|
|
Background: response.Background,
|
|
OutputFormat: response.OutputFormat,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: chunkIndex, // Chunk order within this image
|
|
Latency: time.Since(lastChunkTime).Milliseconds(),
|
|
},
|
|
}
|
|
|
|
if postResponseConverter != nil {
|
|
if converted := postResponseConverter(chunk); converted != nil {
|
|
chunk = converted
|
|
} else {
|
|
logger.Warn("postResponseConverter returned nil; leaving chunk unmodified")
|
|
}
|
|
}
|
|
|
|
// Only set PartialImageIndex for partial images, not for completed events
|
|
if !isCompleted {
|
|
chunk.PartialImageIndex = response.PartialImageIndex
|
|
}
|
|
// Set SequenceNumber if present
|
|
if response.SequenceNumber != nil {
|
|
chunk.SequenceNumber = *response.SequenceNumber
|
|
}
|
|
lastChunkTime = time.Now()
|
|
|
|
// Copy b64_json if present
|
|
if response.B64JSON != nil {
|
|
chunk.B64JSON = *response.B64JSON
|
|
}
|
|
|
|
// Set raw response on every chunk if enabled
|
|
if sendBackRawResponse {
|
|
chunk.ExtraFields.RawResponse = jsonData
|
|
}
|
|
|
|
if isCompleted {
|
|
if collectedUsage != nil {
|
|
// Set NImages based on maximum image index seen (maxImageIndex + 1 since indices are 0-based)
|
|
if maxImageIndex >= 0 {
|
|
nImages := maxImageIndex + 1
|
|
collectedUsage.OutputTokensDetails = &schemas.ImageTokenDetails{
|
|
NImages: nImages,
|
|
}
|
|
}
|
|
chunk.Usage = collectedUsage
|
|
}
|
|
// For completed chunk, use total latency from start
|
|
chunk.ExtraFields.Latency = time.Since(startTime).Milliseconds()
|
|
chunk.BackfillParams(&schemas.BifrostRequest{
|
|
ImageGenerationRequest: request,
|
|
})
|
|
// Set raw request only on final chunk if enabled
|
|
if sendBackRawRequest {
|
|
providerUtils.ParseAndSetRawRequest(&chunk.ExtraFields, jsonBody)
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
}
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, chunk),
|
|
responseChan, postHookSpanFinalizer)
|
|
|
|
if isCompleted {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Rerank is not supported by the OpenAI provider.
|
|
func (provider *OpenAIProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// OCR is not supported by the Openai provider.
|
|
func (provider *OpenAIProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// VideoGeneration performs a video generation request via the OpenAI API.
|
|
func (provider *OpenAIProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.VideoGenerationRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return HandleOpenAIVideoGenerationRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/videos", schemas.VideoGenerationRequest),
|
|
request,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
provider.GetProviderKey(),
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
// VideoRetrieve retrieves a video generation job from the OpenAI API.
|
|
func (provider *OpenAIProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.VideoRetrieveRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
providerName := provider.GetProviderKey()
|
|
if request.ID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("video_id is required", nil)
|
|
}
|
|
videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName)
|
|
|
|
return HandleOpenAIVideoRetrieveRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/videos/"+videoID, schemas.VideoRetrieveRequest),
|
|
request,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
nil, // OpenAI uses Bearer from key
|
|
providerName,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.VideoDownload,
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
// VideoDownload downloads video content from OpenAI.
|
|
func (provider *OpenAIProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.VideoDownloadRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
providerName := provider.GetProviderKey()
|
|
|
|
if request.ID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("video_id is required", nil)
|
|
}
|
|
videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName)
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
// Build URL: /v1/videos/{video_id}/content
|
|
requestURL := provider.buildRequestURL(ctx, "/v1/videos/"+videoID+"/content", schemas.VideoDownloadRequest)
|
|
|
|
if request.Variant != nil && *request.Variant != "" {
|
|
// attach variant to url if present
|
|
requestURL = fmt.Sprintf("%s?variant=%s", requestURL, url.QueryEscape(string(*request.Variant)))
|
|
}
|
|
|
|
req.SetRequestURI(requestURL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
// Get content type from response
|
|
contentType := string(resp.Header.ContentType())
|
|
if contentType == "" {
|
|
// Default to video/mp4 if not specified
|
|
contentType = "video/mp4"
|
|
}
|
|
|
|
// Copy the binary content
|
|
content := append([]byte(nil), body...)
|
|
|
|
return &schemas.BifrostVideoDownloadResponse{
|
|
VideoID: providerUtils.AddVideoIDProviderSuffix(videoID, providerName),
|
|
Content: content,
|
|
ContentType: contentType,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerResponseHeaders,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// VideoDelete deletes a video generation job from the OpenAI API.
|
|
func (provider *OpenAIProvider) VideoDelete(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.VideoDeleteRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
providerName := provider.GetProviderKey()
|
|
|
|
if request.ID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("video_id is required", nil)
|
|
}
|
|
videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName)
|
|
|
|
return HandleOpenAIVideoDeleteRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/videos/"+videoID, schemas.VideoDeleteRequest),
|
|
videoID,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerName,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
// VideoList lists videos from OpenAI.
|
|
func (provider *OpenAIProvider) VideoList(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.VideoListRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return HandleOpenAIVideoListRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/videos", schemas.VideoListRequest),
|
|
request,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
provider.GetProviderKey(),
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAIVideoGenerationRequest handles video generation requests for OpenAI-compatible APIs.
|
|
// It creates a multipart form, adds fields, makes the API call, and returns the response.
|
|
func HandleOpenAIVideoGenerationRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostVideoGenerationRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
providerName schemas.ModelProvider,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Use centralized converter
|
|
reqBody, err := ToOpenAIVideoGenerationRequest(request)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to convert video generation request to openai format", err)
|
|
}
|
|
if reqBody == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("video generation input is not provided", nil)
|
|
}
|
|
|
|
// Create multipart form
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
if err := parseVideoGenerationFormDataBodyFromRequest(writer, reqBody, providerName); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.SetContentType(writer.FormDataContentType()) // This sets multipart/form-data with boundary
|
|
req.SetBody(body.Bytes())
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
responseBody, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
// Check for empty response
|
|
trimmed := strings.TrimSpace(string(responseBody))
|
|
if len(trimmed) == 0 {
|
|
return nil, &schemas.BifrostError{
|
|
IsBifrostError: true,
|
|
Error: &schemas.ErrorField{
|
|
Message: schemas.ErrProviderResponseEmpty,
|
|
},
|
|
}
|
|
}
|
|
|
|
// Parse OpenAI's video generation response
|
|
response := &schemas.BifrostVideoGenerationResponse{}
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
if response.ID != "" {
|
|
response.ID = providerUtils.AddVideoIDProviderSuffix(response.ID, providerName)
|
|
}
|
|
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerResponseHeaders,
|
|
}
|
|
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
if sendBackRawRequest {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// VideoDownloadFunc downloads video content. Used by HandleOpenAIVideoRetrieveRequest for enrichment.
|
|
type VideoDownloadHandler func(ctx *schemas.BifrostContext, key schemas.Key, req *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError)
|
|
|
|
// HandleOpenAIVideoRetrieveRequest handles video retrieve requests for OpenAI-compatible APIs.
|
|
// When authHeaders is non-nil, they are applied for authentication (e.g. Azure api-key); otherwise Bearer from key is used.
|
|
// When videoDownloadFunc is non-nil and ctx has VideoOutputRequested with status completed, the handler fetches video content and appends to response.
|
|
func HandleOpenAIVideoRetrieveRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostVideoRetrieveRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
authHeaders map[string]string,
|
|
providerName schemas.ModelProvider,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
videoDownloaddHandler VideoDownloadHandler,
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if len(authHeaders) > 0 {
|
|
for k, v := range authHeaders {
|
|
req.Header.Set(k, v)
|
|
}
|
|
} else if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
response := &schemas.BifrostVideoGenerationResponse{}
|
|
_, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, response, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
if response.ID != "" {
|
|
response.ID = providerUtils.AddVideoIDProviderSuffix(response.ID, providerName)
|
|
}
|
|
if response.RemixedFromVideoID != nil && *response.RemixedFromVideoID != "" {
|
|
remixID := providerUtils.AddVideoIDProviderSuffix(*response.RemixedFromVideoID, providerName)
|
|
response.RemixedFromVideoID = &remixID
|
|
}
|
|
|
|
if videoDownloaddHandler != nil {
|
|
downloadVideo, ok := ctx.Value(schemas.BifrostContextKeyVideoOutputRequested).(bool)
|
|
if ok && downloadVideo && response.Status == schemas.VideoStatusCompleted {
|
|
videoDownloadRequest := &schemas.BifrostVideoDownloadRequest{
|
|
Provider: providerName,
|
|
ID: response.ID,
|
|
}
|
|
videoDownloadResponse, bifrostErr := videoDownloaddHandler(ctx, key, videoDownloadRequest)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
if len(videoDownloadResponse.Content) > 0 {
|
|
output := schemas.VideoOutput{
|
|
Type: schemas.VideoOutputTypeBase64,
|
|
ContentType: videoDownloadResponse.ContentType,
|
|
}
|
|
base64Data := base64.StdEncoding.EncodeToString(videoDownloadResponse.Content)
|
|
output.Base64Data = &base64Data
|
|
response.Videos = append(response.Videos, output)
|
|
} else {
|
|
logger.Warn("no content found for video download request for %s video retrieve request", providerName)
|
|
}
|
|
}
|
|
}
|
|
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerResponseHeaders,
|
|
}
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
return response, nil
|
|
}
|
|
|
|
// HandleOpenAIVideoDeleteRequest handles video deletion requests for OpenAI-compatible APIs.
|
|
func HandleOpenAIVideoDeleteRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
videoID string,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
providerName schemas.ModelProvider,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodDelete)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
// Parse OpenAI's video response
|
|
response := &schemas.BifrostVideoDeleteResponse{}
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, response, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
if response.ID != "" {
|
|
response.ID = providerUtils.AddVideoIDProviderSuffix(response.ID, providerName)
|
|
}
|
|
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerResponseHeaders,
|
|
}
|
|
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
if sendBackRawRequest {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// HandleOpenAIVideoListRequest handles video list requests for OpenAI-compatible APIs.
|
|
func HandleOpenAIVideoListRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
baseURL string,
|
|
request *schemas.BifrostVideoListRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
providerName schemas.ModelProvider,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Build URL with query parameters
|
|
values := url.Values{}
|
|
if request.After != nil && *request.After != "" {
|
|
values.Set("after", providerUtils.StripVideoIDProviderSuffix(*request.After, providerName))
|
|
}
|
|
if request.Limit != nil {
|
|
values.Set("limit", fmt.Sprintf("%d", *request.Limit))
|
|
}
|
|
if request.Order != nil && *request.Order != "" {
|
|
values.Set("order", *request.Order)
|
|
}
|
|
finalURL := baseURL
|
|
if encoded := values.Encode(); encoded != "" {
|
|
finalURL = baseURL + "?" + encoded
|
|
}
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
req.SetRequestURI(finalURL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
response := &schemas.BifrostVideoListResponse{}
|
|
_, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, response, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
for i := range response.Data {
|
|
if response.Data[i].ID != "" {
|
|
response.Data[i].ID = providerUtils.AddVideoIDProviderSuffix(response.Data[i].ID, providerName)
|
|
}
|
|
if response.Data[i].RemixedFromVideoID != nil && *response.Data[i].RemixedFromVideoID != "" {
|
|
remixID := providerUtils.AddVideoIDProviderSuffix(*response.Data[i].RemixedFromVideoID, providerName)
|
|
response.Data[i].RemixedFromVideoID = &remixID
|
|
}
|
|
}
|
|
if response.FirstID != nil && *response.FirstID != "" {
|
|
firstID := providerUtils.AddVideoIDProviderSuffix(*response.FirstID, providerName)
|
|
response.FirstID = &firstID
|
|
}
|
|
if response.LastID != nil && *response.LastID != "" {
|
|
lastID := providerUtils.AddVideoIDProviderSuffix(*response.LastID, providerName)
|
|
response.LastID = &lastID
|
|
}
|
|
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerResponseHeaders,
|
|
}
|
|
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// CountTokens performs a count tokens request to the OpenAI API.
|
|
func (provider *OpenAIProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.CountTokensRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return HandleOpenAICountTokensRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/responses/input_tokens", schemas.CountTokensRequest),
|
|
request,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
// HandleOpenAICountTokensRequest handles a count tokens request to OpenAI's API.
|
|
func HandleOpenAICountTokensRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostResponsesRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
providerName schemas.ModelProvider,
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
// resp lifecycle: managed by finalizeOpenAIResponse or released on error paths
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Large payload passthrough: stream body directly without JSON marshaling
|
|
if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled {
|
|
if lpErr != nil {
|
|
return nil, lpErr
|
|
}
|
|
if len(lpResult.ResponseBody) > 0 {
|
|
response := &schemas.BifrostCountTokensResponse{}
|
|
if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
|
}
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}
|
|
return response, nil
|
|
}
|
|
return &schemas.BifrostCountTokensResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return ToOpenAIResponsesRequest(request), nil
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
req.SetBody(jsonData)
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body())))
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
body, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger)
|
|
respOwned = false // ownership transferred
|
|
if finalErr != nil {
|
|
return nil, finalErr
|
|
}
|
|
if lpResult != nil {
|
|
return &schemas.BifrostCountTokensResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
response := &schemas.BifrostCountTokensResponse{}
|
|
|
|
// Use enhanced response handler with pre-allocated response
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, response, jsonData, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
response.Model = request.Model
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// ImageEdit performs image editing via the OpenAI Images API.
|
|
func (provider *OpenAIProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ImageEditRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return HandleOpenAIImageEditRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/images/edits", schemas.ImageEditRequest),
|
|
request,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
false,
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
provider.logger,
|
|
)
|
|
}
|
|
|
|
func HandleOpenAIImageEditRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostImageEditRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
providerName schemas.ModelProvider,
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
// Large payload passthrough: stream multipart body directly without parsing
|
|
if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled {
|
|
if lpErr != nil {
|
|
return nil, lpErr
|
|
}
|
|
if len(lpResult.ResponseBody) > 0 {
|
|
response := &schemas.BifrostImageGenerationResponse{}
|
|
if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
|
}
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}
|
|
return response, nil
|
|
}
|
|
return &schemas.BifrostImageGenerationResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
openaiReq := ToOpenAIImageEditRequest(request)
|
|
if openaiReq == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil)
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
// resp lifecycle: managed by finalizeOpenAIResponse or released on error paths
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
req.Header.Set("Content-Type", "multipart/form-data")
|
|
|
|
// Create multipart form
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
if err := parseImageEditFormDataBodyFromRequest(writer, openaiReq, providerName); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.SetContentType(writer.FormDataContentType())
|
|
bodyData := body.Bytes()
|
|
req.SetBody(bodyData)
|
|
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, bodyData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), bodyData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
bodyBytes, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger)
|
|
respOwned = false // ownership transferred
|
|
if finalErr != nil {
|
|
return nil, finalErr
|
|
}
|
|
if lpResult != nil {
|
|
return &schemas.BifrostImageGenerationResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
response := &schemas.BifrostImageGenerationResponse{}
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(bodyBytes, response, bodyData, false, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
|
|
// Set raw request if enabled
|
|
if sendBackRawRequest {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
// Set raw response if enabled
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
return response, nil
|
|
}
|
|
|
|
// ImageEditStream streams image edits via the OpenAI Images API.
|
|
func (provider *OpenAIProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
// Check if image generation stream is allowed for this provider
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ImageEditStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var authHeader map[string]string
|
|
if value := key.Value.GetValue(); value != "" {
|
|
authHeader = map[string]string{"Authorization": "Bearer " + value}
|
|
}
|
|
|
|
return HandleOpenAIImageEditStreamRequest(
|
|
ctx,
|
|
provider.streamingClient,
|
|
provider.buildRequestURL(ctx, "/v1/images/edits", schemas.ImageEditStreamRequest),
|
|
request,
|
|
authHeader,
|
|
provider.networkConfig.ExtraHeaders,
|
|
false,
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
postHookRunner,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
provider.logger,
|
|
postHookSpanFinalizer,
|
|
)
|
|
}
|
|
|
|
func HandleOpenAIImageEditStreamRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostImageEditRequest,
|
|
authHeader map[string]string,
|
|
extraHeaders map[string]string,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
providerName schemas.ModelProvider,
|
|
postHookRunner schemas.PostHookRunner,
|
|
customRequestConverter func(*schemas.BifrostImageEditRequest) (providerUtils.RequestBodyWithExtraParams, error),
|
|
postRequestConverter func(*OpenAIImageEditRequest) *OpenAIImageEditRequest,
|
|
postResponseConverter func(*schemas.BifrostImageGenerationStreamResponse) *schemas.BifrostImageGenerationStreamResponse,
|
|
logger schemas.Logger,
|
|
postHookSpanFinalizer func(context.Context),
|
|
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
reqBody := ToOpenAIImageEditRequest(request)
|
|
if reqBody == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("image edit input is not provided", nil)
|
|
}
|
|
|
|
reqBody.Stream = schemas.Ptr(true)
|
|
if postRequestConverter != nil {
|
|
reqBody = postRequestConverter(reqBody)
|
|
}
|
|
// Create multipart form
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
|
|
if bifrostErr := parseImageEditFormDataBodyFromRequest(writer, reqBody, providerName); bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Prepare OpenAI headers
|
|
headers := map[string]string{
|
|
"Content-Type": writer.FormDataContentType(),
|
|
"Accept": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
}
|
|
|
|
if authHeader != nil {
|
|
maps.Copy(headers, authHeader)
|
|
}
|
|
|
|
// Create HTTP request for streaming
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
resp.StreamBody = true
|
|
defer fasthttp.ReleaseRequest(req)
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.SetRequestURI(url)
|
|
|
|
// Set headers
|
|
for key, value := range headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
|
|
req.SetBody(body.Bytes())
|
|
|
|
// Make the request
|
|
err := client.Do(req, resp)
|
|
if err != nil {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
if errors.Is(err, context.Canceled) {
|
|
return nil, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Type: schemas.Ptr(schemas.RequestCancelled),
|
|
Message: schemas.ErrRequestCancelled,
|
|
Error: err,
|
|
},
|
|
}
|
|
}
|
|
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
|
|
}
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
|
|
}
|
|
// Store provider response headers in context before status check so error responses also forward them
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Check for HTTP errors
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Large payload streaming passthrough — pipe raw upstream SSE to client
|
|
if providerUtils.SetupStreamingPassthrough(ctx, resp) {
|
|
responseChan := make(chan *schemas.BifrostStreamChunk)
|
|
close(responseChan)
|
|
return responseChan, nil
|
|
}
|
|
|
|
// Create response channel
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
|
|
// Start streaming in a goroutine
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, logger, postHookSpanFinalizer)
|
|
}
|
|
close(responseChan)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
|
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
|
defer stopIdleTimeout()
|
|
|
|
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
|
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), logger)
|
|
defer stopCancellation()
|
|
|
|
// Skip scanner for non-SSE responses — avoids bufio.Scanner buffer bloat
|
|
// on non-line-delimited data (e.g. provider returned JSON instead of SSE).
|
|
if providerUtils.DrainNonSSEStreamResponse(resp) {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, errors.New("provider returned non-SSE response for streaming request"), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
sseReader := providerUtils.GetSSEDataReader(ctx, reader)
|
|
|
|
startTime := time.Now()
|
|
lastChunkTime := startTime
|
|
var collectedUsage *schemas.ImageUsage
|
|
// Track chunk indices per image - similar to how speech/transcription track chunkIndex
|
|
imageChunkIndices := make(map[int]int) // image index -> chunk index
|
|
// Track images that have started (via partial chunks) but not yet completed
|
|
// This allows us to correctly match completed events to images even if chunks are interleaved
|
|
incompleteImages := make(map[int]bool)
|
|
maxImageIndex := -1 // Track maximum image index for NImages calculation
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
data, readErr := sseReader.ReadDataLine()
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
logger.Warn(fmt.Sprintf("Error reading stream: %v", readErr))
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, responseChan, logger, postHookSpanFinalizer)
|
|
}
|
|
break
|
|
}
|
|
jsonData := string(data)
|
|
|
|
// Quick check for error field (allocation-free using sonic.GetFromString)
|
|
if errorNode, _ := sonic.GetFromString(jsonData, "error"); errorNode.Exists() {
|
|
// Only unmarshal when we know there's an error
|
|
var bifrostErr schemas.BifrostError
|
|
if err := sonic.UnmarshalString(jsonData, &bifrostErr); err == nil {
|
|
if bifrostErr.Error != nil && bifrostErr.Error.Message != "" {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, &bifrostErr, body.Bytes(), nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Parse minimally to extract usage and check for errors
|
|
var response OpenAIImageStreamResponse
|
|
if err := sonic.UnmarshalString(jsonData, &response); err != nil {
|
|
logger.Warn("Failed to parse stream response: %v", err)
|
|
continue
|
|
}
|
|
|
|
// Check if response type indicates an error
|
|
if response.Type == "error" {
|
|
bifrostErr := &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{},
|
|
}
|
|
// Guard access to response.Error fields
|
|
if response.Error != nil {
|
|
bifrostErr.Error.Message = response.Error.Message
|
|
if response.Error.Code != nil {
|
|
bifrostErr.Error.Code = response.Error.Code
|
|
}
|
|
if response.Error.Param != nil {
|
|
bifrostErr.Error.Param = response.Error.Param
|
|
}
|
|
if response.Error.Type != nil {
|
|
bifrostErr.Error.Type = response.Error.Type
|
|
}
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, bifrostErr, responseChan, logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
|
|
// Collect usage from completed event
|
|
if response.Usage != nil {
|
|
collectedUsage = &schemas.ImageUsage{
|
|
InputTokens: response.Usage.InputTokens,
|
|
OutputTokens: response.Usage.OutputTokens,
|
|
TotalTokens: response.Usage.TotalTokens,
|
|
}
|
|
}
|
|
|
|
// Determine if this is the final chunk
|
|
isCompleted := response.Type == schemas.ImageGenerationEventTypeCompleted || response.Type == schemas.ImageEditEventTypeCompleted
|
|
|
|
// Determine image index with robust tracking for interleaved chunks
|
|
// Both partial and completed chunks should use PartialImageIndex when available
|
|
var imageIndex int
|
|
if response.PartialImageIndex != nil {
|
|
// Use explicit index from response
|
|
imageIndex = *response.PartialImageIndex
|
|
if isCompleted {
|
|
// Mark this image as completed
|
|
delete(incompleteImages, imageIndex)
|
|
} else {
|
|
// Mark this image as started (incomplete)
|
|
incompleteImages[imageIndex] = true
|
|
}
|
|
} else {
|
|
// Fallback: PartialImageIndex is nil, use tracked state
|
|
if isCompleted {
|
|
// For completed chunks, match to the oldest incomplete image
|
|
// This handles interleaved chunks correctly
|
|
if len(incompleteImages) == 0 {
|
|
// Fallback: if no incomplete images tracked, this shouldn't happen in normal flow
|
|
// but we'll default to 0 to prevent panics
|
|
imageIndex = 0
|
|
logger.Warn("Received completed event but no incomplete images tracked, defaulting to index 0")
|
|
} else {
|
|
// Find the minimum (oldest) incomplete image index
|
|
// Completed events should match the oldest image that was started
|
|
minIndex := -1
|
|
for idx := range incompleteImages {
|
|
if minIndex == -1 || idx < minIndex {
|
|
minIndex = idx
|
|
}
|
|
}
|
|
imageIndex = minIndex
|
|
// Mark this image as completed
|
|
delete(incompleteImages, imageIndex)
|
|
logger.Warn(fmt.Sprintf("Completed event missing PartialImageIndex, using oldest incomplete image index %d", imageIndex))
|
|
}
|
|
} else {
|
|
// For partial chunks without PartialImageIndex, allocate a new unique index
|
|
// Use maxImageIndex + 1 to ensure uniqueness
|
|
imageIndex = maxImageIndex + 1
|
|
// Mark this image as started (incomplete)
|
|
incompleteImages[imageIndex] = true
|
|
}
|
|
}
|
|
|
|
// Update maximum image index for NImages calculation
|
|
if imageIndex > maxImageIndex {
|
|
maxImageIndex = imageIndex
|
|
}
|
|
|
|
// Increment chunk index for this image
|
|
if _, exists := imageChunkIndices[imageIndex]; !exists {
|
|
imageChunkIndices[imageIndex] = 0
|
|
} else {
|
|
imageChunkIndices[imageIndex]++
|
|
}
|
|
chunkIndex := imageChunkIndices[imageIndex]
|
|
// Build chunk with all OpenAI fields
|
|
chunk := &schemas.BifrostImageGenerationStreamResponse{
|
|
Type: response.Type,
|
|
Index: imageIndex, // Which image (0-N)
|
|
ChunkIndex: chunkIndex, // Chunk order within this image (top-level)
|
|
CreatedAt: response.CreatedAt,
|
|
Size: response.Size,
|
|
Quality: response.Quality,
|
|
Background: response.Background,
|
|
OutputFormat: response.OutputFormat,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
ChunkIndex: chunkIndex, // Chunk order within this image
|
|
Latency: time.Since(lastChunkTime).Milliseconds(),
|
|
},
|
|
}
|
|
|
|
if postResponseConverter != nil {
|
|
if converted := postResponseConverter(chunk); converted != nil {
|
|
chunk = converted
|
|
} else {
|
|
logger.Warn("postResponseConverter returned nil; leaving chunk unmodified")
|
|
}
|
|
}
|
|
|
|
// Only set PartialImageIndex for partial images, not for completed events
|
|
if !isCompleted {
|
|
chunk.PartialImageIndex = response.PartialImageIndex
|
|
}
|
|
// Set SequenceNumber if present
|
|
if response.SequenceNumber != nil {
|
|
chunk.SequenceNumber = *response.SequenceNumber
|
|
}
|
|
lastChunkTime = time.Now()
|
|
|
|
// Copy b64_json if present
|
|
if response.B64JSON != nil {
|
|
chunk.B64JSON = *response.B64JSON
|
|
}
|
|
|
|
// Set raw response on every chunk if enabled
|
|
if sendBackRawResponse {
|
|
chunk.ExtraFields.RawResponse = jsonData
|
|
}
|
|
|
|
if isCompleted {
|
|
if collectedUsage != nil {
|
|
// Set NImages based on maximum image index seen (maxImageIndex + 1 since indices are 0-based)
|
|
if maxImageIndex >= 0 {
|
|
nImages := maxImageIndex + 1
|
|
collectedUsage.OutputTokensDetails = &schemas.ImageTokenDetails{
|
|
NImages: nImages,
|
|
}
|
|
}
|
|
chunk.Usage = collectedUsage
|
|
}
|
|
// For completed chunk, use total latency from start
|
|
chunk.ExtraFields.Latency = time.Since(startTime).Milliseconds()
|
|
chunk.BackfillParams(&schemas.BifrostRequest{
|
|
ImageEditRequest: request,
|
|
})
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
}
|
|
|
|
providerUtils.ProcessAndSendResponse(ctx, postHookRunner,
|
|
providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, nil, nil, chunk),
|
|
responseChan, postHookSpanFinalizer)
|
|
|
|
if isCompleted {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return responseChan, nil
|
|
}
|
|
|
|
// ImageVariation performs an image variation request to openai's images api.
|
|
func (provider *OpenAIProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ImageVariationRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
response, err := HandleOpenAIImageVariationRequest(
|
|
ctx,
|
|
provider.client,
|
|
provider.buildRequestURL(ctx, "/v1/images/variations", schemas.ImageVariationRequest),
|
|
request,
|
|
key,
|
|
provider.networkConfig.ExtraHeaders,
|
|
false,
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
provider.logger,
|
|
)
|
|
return response, err
|
|
}
|
|
|
|
// ImageVariation performs an image variation request
|
|
// HandleOpenAIImageVariationRequest handles image variation requests for OpenAI-compatible providers
|
|
func HandleOpenAIImageVariationRequest(
|
|
ctx *schemas.BifrostContext,
|
|
client *fasthttp.Client,
|
|
url string,
|
|
request *schemas.BifrostImageVariationRequest,
|
|
key schemas.Key,
|
|
extraHeaders map[string]string,
|
|
sendBackRawRequest bool,
|
|
sendBackRawResponse bool,
|
|
providerName schemas.ModelProvider,
|
|
logger schemas.Logger,
|
|
) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
// Large payload passthrough: stream multipart body directly without parsing
|
|
if lpResult, lpErr, handled := handleOpenAILargePayloadPassthrough(ctx, client, url, key, extraHeaders, providerName, logger); handled {
|
|
if lpErr != nil {
|
|
return nil, lpErr
|
|
}
|
|
if len(lpResult.ResponseBody) > 0 {
|
|
response := &schemas.BifrostImageGenerationResponse{}
|
|
if err := sonic.Unmarshal(lpResult.ResponseBody, response); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
|
}
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{Latency: lpResult.Latency}
|
|
return response, nil
|
|
}
|
|
return &schemas.BifrostImageGenerationResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
openaiReq := ToOpenAIImageVariationRequest(request)
|
|
if openaiReq == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to convert request to OpenAI format", nil)
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
// resp lifecycle: managed by finalizeOpenAIResponse or released on error paths
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
|
req.SetRequestURI(url)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Create multipart form
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
if err := parseImageVariationFormDataBodyFromRequest(writer, openaiReq, providerName); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.SetContentType(writer.FormDataContentType())
|
|
bodyData := body.Bytes()
|
|
req.SetBody(bodyData)
|
|
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, bodyData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
// Extract provider response headers early so they're available on error paths too
|
|
providerResponseHeaders := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerResponseHeaders)
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), bodyData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
bodyBytes, lpResult, finalErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger)
|
|
respOwned = false // ownership transferred
|
|
if finalErr != nil {
|
|
return nil, finalErr
|
|
}
|
|
if lpResult != nil {
|
|
return &schemas.BifrostImageGenerationResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{Latency: lpResult.Latency},
|
|
}, nil
|
|
}
|
|
|
|
response := &schemas.BifrostImageGenerationResponse{}
|
|
_, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(bodyBytes, response, bodyData, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
|
|
// Set raw response if enabled
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
return response, nil
|
|
}
|
|
|
|
// FileUpload uploads a file to OpenAI.
|
|
func (provider *OpenAIProvider) FileUpload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.FileUploadRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(request.File) == 0 {
|
|
return nil, providerUtils.NewBifrostOperationError("file content is required", nil)
|
|
}
|
|
|
|
if request.Purpose == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("purpose is required", nil)
|
|
}
|
|
|
|
// Create multipart form data
|
|
var buf bytes.Buffer
|
|
writer := multipart.NewWriter(&buf)
|
|
|
|
// Add purpose field
|
|
if err := writer.WriteField("purpose", string(request.Purpose)); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to write purpose field", err)
|
|
}
|
|
|
|
// Add expires_after fields if provided
|
|
if request.ExpiresAfter != nil {
|
|
if err := writer.WriteField("expires_after[anchor]", request.ExpiresAfter.Anchor); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[anchor] field", err)
|
|
}
|
|
if err := writer.WriteField("expires_after[seconds]", fmt.Sprintf("%d", request.ExpiresAfter.Seconds)); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to write expires_after[seconds] field", err)
|
|
}
|
|
}
|
|
|
|
// Add file field
|
|
filename := request.Filename
|
|
if filename == "" {
|
|
filename = "file.jsonl"
|
|
}
|
|
part, err := writer.CreateFormFile("file", filename)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to create form file", err)
|
|
}
|
|
if _, err := part.Write(request.File); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to write file content", err)
|
|
}
|
|
|
|
if err := writer.Close(); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to close multipart writer", err)
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(provider.buildRequestURL(ctx, "/v1/files", schemas.FileUploadRequest))
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType(writer.FormDataContentType())
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
req.SetBody(buf.Bytes())
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
provider.logger.Debug("error from %s provider: %s", provider.GetProviderKey(), string(resp.Body()))
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
var openAIResp OpenAIFileResponse
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
fileResponse := openAIResp.ToBifrostFileUploadResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)
|
|
fileResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
return fileResponse, nil
|
|
}
|
|
|
|
// FileList lists files using serial pagination across keys.
|
|
// Exhausts all pages from one key before moving to the next.
|
|
func (provider *OpenAIProvider) FileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.FileListRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
providerName := provider.GetProviderKey()
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
|
|
// Initialize serial pagination helper
|
|
helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err)
|
|
}
|
|
|
|
// Get current key to query
|
|
key, nativeCursor, ok := helper.GetCurrentKey()
|
|
if !ok {
|
|
// All keys exhausted
|
|
return &schemas.BifrostFileListResponse{
|
|
Object: "list",
|
|
Data: []schemas.FileObject{},
|
|
HasMore: false,
|
|
}, nil
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Build URL with query params
|
|
requestURL := provider.buildRequestURL(ctx, "/v1/files", schemas.FileListRequest)
|
|
values := url.Values{}
|
|
if request.Purpose != "" {
|
|
values.Set("purpose", string(request.Purpose))
|
|
}
|
|
if request.Limit > 0 {
|
|
values.Set("limit", fmt.Sprintf("%d", request.Limit))
|
|
}
|
|
// Use native cursor from serial helper instead of request.After
|
|
if nativeCursor != "" {
|
|
values.Set("after", nativeCursor)
|
|
}
|
|
if request.Order != nil && *request.Order != "" {
|
|
values.Set("order", *request.Order)
|
|
}
|
|
if encoded := values.Encode(); encoded != "" {
|
|
requestURL += "?" + encoded
|
|
}
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(requestURL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
body, decodeErr := providerUtils.CheckAndDecodeBody(resp)
|
|
if decodeErr != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr)
|
|
}
|
|
|
|
var openAIResp OpenAIFileListResponse
|
|
_, _, bifrostErr = providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Convert files to Bifrost format
|
|
files := make([]schemas.FileObject, 0, len(openAIResp.Data))
|
|
var lastFileID string
|
|
for _, file := range openAIResp.Data {
|
|
files = append(files, schemas.FileObject{
|
|
ID: file.ID,
|
|
Object: file.Object,
|
|
Bytes: file.Bytes,
|
|
CreatedAt: file.CreatedAt,
|
|
Filename: file.Filename,
|
|
Purpose: schemas.FilePurpose(file.Purpose),
|
|
Status: ToBifrostFileStatus(file.Status),
|
|
StatusDetails: file.StatusDetails,
|
|
})
|
|
lastFileID = file.ID
|
|
}
|
|
|
|
// Build cursor for next request
|
|
// OpenAI uses LastID as the cursor for pagination
|
|
nextCursor, hasMore := helper.BuildNextCursor(openAIResp.HasMore, lastFileID)
|
|
|
|
// Convert to Bifrost response
|
|
bifrostResp := &schemas.BifrostFileListResponse{
|
|
Object: "list",
|
|
Data: files,
|
|
HasMore: hasMore,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
|
|
},
|
|
}
|
|
if nextCursor != "" {
|
|
bifrostResp.After = &nextCursor
|
|
}
|
|
|
|
return bifrostResp, nil
|
|
}
|
|
|
|
// FileRetrieve retrieves file metadata from OpenAI by trying each key until found.
|
|
func (provider *OpenAIProvider) FileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.FileRetrieveRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
providerName := provider.GetProviderKey()
|
|
|
|
if request.FileID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("file_id is required", nil)
|
|
}
|
|
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/files/" + request.FileID)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
lastErr = ParseOpenAIError(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
continue
|
|
}
|
|
|
|
var openAIResp OpenAIFileResponse
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
|
|
return openAIResp.ToBifrostFileRetrieveResponse(providerName, latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
// FileDelete deletes a file from OpenAI by trying each key until successful.
|
|
func (provider *OpenAIProvider) FileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.FileDeleteRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
providerName := provider.GetProviderKey()
|
|
|
|
if request.FileID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("file_id is required", nil)
|
|
}
|
|
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/files/" + request.FileID)
|
|
req.Header.SetMethod(http.MethodDelete)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
lastErr = ParseOpenAIError(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
continue
|
|
}
|
|
|
|
var openAIResp OpenAIFileDeleteResponse
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
|
|
result := &schemas.BifrostFileDeleteResponse{
|
|
ID: openAIResp.ID,
|
|
Object: openAIResp.Object,
|
|
Deleted: openAIResp.Deleted,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
|
|
if sendBackRawRequest {
|
|
result.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
if sendBackRawResponse {
|
|
result.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
// FileContent downloads file content from OpenAI by trying each key until found.
|
|
func (provider *OpenAIProvider) FileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.FileContentRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
providerName := provider.GetProviderKey()
|
|
|
|
if request.FileID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("file_id is required", nil)
|
|
}
|
|
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/files/" + request.FileID + "/content")
|
|
req.Header.SetMethod(http.MethodGet)
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
lastErr = ParseOpenAIError(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
continue
|
|
}
|
|
|
|
// Get content type from response
|
|
contentType := string(resp.Header.ContentType())
|
|
if contentType == "" {
|
|
contentType = "application/octet-stream"
|
|
}
|
|
content := append([]byte(nil), body...)
|
|
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
|
|
return &schemas.BifrostFileContentResponse{
|
|
FileID: request.FileID,
|
|
Content: content,
|
|
ContentType: contentType,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
// VideoRemix remixes an existing video from the OpenAI provider.
|
|
func (provider *OpenAIProvider) VideoRemix(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.VideoRemixRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
providerName := provider.GetProviderKey()
|
|
|
|
if request.ID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("video_id is required", nil)
|
|
}
|
|
if request.Input == nil || request.Input.Prompt == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("prompt is required", nil)
|
|
}
|
|
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return ToOpenAIVideoRemixRequest(request)
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
videoID := providerUtils.StripVideoIDProviderSuffix(request.ID, providerName)
|
|
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(provider.buildRequestURL(ctx, "/v1/videos/"+videoID+"/remix", schemas.VideoRemixRequest))
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
req.SetBody(jsonData)
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
provider.logger.Debug("error from %s provider: %s", providerName, string(resp.Body()))
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Parse OpenAI's video response
|
|
response := &schemas.BifrostVideoGenerationResponse{}
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, response, jsonData, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
if response.ID != "" {
|
|
response.ID = providerUtils.AddVideoIDProviderSuffix(response.ID, providerName)
|
|
}
|
|
if response.RemixedFromVideoID != nil && *response.RemixedFromVideoID != "" {
|
|
remixID := providerUtils.AddVideoIDProviderSuffix(*response.RemixedFromVideoID, providerName)
|
|
response.RemixedFromVideoID = &remixID
|
|
}
|
|
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
}
|
|
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
if sendBackRawRequest {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// BatchCreate creates a new batch job.
|
|
func (provider *OpenAIProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.BatchCreateRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
inputFileID := request.InputFileID
|
|
|
|
// If no file_id provided but inline requests are available, upload them first
|
|
if inputFileID == "" && len(request.Requests) > 0 {
|
|
// Convert inline requests to JSONL format
|
|
jsonlData, err := ConvertRequestsToJSONL(request.Requests)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to convert requests to JSONL", err)
|
|
}
|
|
|
|
// Upload the file with purpose "batch"
|
|
uploadResp, bifrostErr := provider.FileUpload(ctx, key, &schemas.BifrostFileUploadRequest{
|
|
Provider: schemas.OpenAI,
|
|
File: jsonlData,
|
|
Filename: "batch_requests.jsonl",
|
|
Purpose: "batch",
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
inputFileID = uploadResp.ID
|
|
}
|
|
|
|
// Validate that we have a file ID (either provided or uploaded)
|
|
if inputFileID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("either input_file_id or requests array is required for OpenAI batch API", nil)
|
|
}
|
|
|
|
// Validate that we have an endpoint
|
|
if request.Endpoint == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("endpoint is required for OpenAI batch API", nil)
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(provider.buildRequestURL(ctx, "/v1/batches", schemas.BatchCreateRequest))
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Build request body
|
|
openAIReq := &OpenAIBatchRequest{
|
|
InputFileID: inputFileID,
|
|
Endpoint: string(request.Endpoint),
|
|
CompletionWindow: request.CompletionWindow,
|
|
Metadata: request.Metadata,
|
|
OutputExpiresAfter: request.OutputExpiresAfter,
|
|
}
|
|
|
|
// Set default completion window if not provided
|
|
if openAIReq.CompletionWindow == "" {
|
|
openAIReq.CompletionWindow = "24h"
|
|
}
|
|
|
|
jsonData, err := providerUtils.MarshalSorted(openAIReq)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
|
}
|
|
req.SetBody(jsonData)
|
|
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
return nil, providerUtils.EnrichError(ctx, ParseOpenAIError(resp), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
var openAIResp OpenAIBatchResponse
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, jsonData, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, body, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
return openAIResp.ToBifrostBatchCreateResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse), nil
|
|
}
|
|
|
|
// BatchList lists batch jobs using serial pagination across keys.
|
|
// Exhausts all pages from one key before moving to the next.
|
|
func (provider *OpenAIProvider) BatchList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.BatchListRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
// Initialize serial pagination helper
|
|
helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err)
|
|
}
|
|
|
|
// Get current key to query
|
|
key, nativeCursor, ok := helper.GetCurrentKey()
|
|
if !ok {
|
|
// All keys exhausted
|
|
return &schemas.BifrostBatchListResponse{
|
|
Object: "list",
|
|
Data: []schemas.BifrostBatchRetrieveResponse{},
|
|
HasMore: false,
|
|
}, nil
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
// Build URL with query params
|
|
baseURL := provider.buildRequestURL(ctx, "/v1/batches", schemas.BatchListRequest)
|
|
values := url.Values{}
|
|
if request.Limit > 0 {
|
|
values.Set("limit", fmt.Sprintf("%d", request.Limit))
|
|
}
|
|
// Use native cursor from serial helper instead of request.After
|
|
if nativeCursor != "" {
|
|
values.Set("after", nativeCursor)
|
|
}
|
|
requestURL := baseURL
|
|
if encodedValues := values.Encode(); encodedValues != "" {
|
|
requestURL += "?" + encodedValues
|
|
}
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(requestURL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
body, decodeErr := providerUtils.CheckAndDecodeBody(resp)
|
|
if decodeErr != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, decodeErr)
|
|
}
|
|
|
|
var openAIResp OpenAIBatchListResponse
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Convert batches to Bifrost format
|
|
batches := make([]schemas.BifrostBatchRetrieveResponse, 0, len(openAIResp.Data))
|
|
var lastBatchID string
|
|
for _, batch := range openAIResp.Data {
|
|
batches = append(batches, *batch.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse))
|
|
lastBatchID = batch.ID
|
|
}
|
|
|
|
// Build cursor for next request
|
|
// OpenAI uses LastID as the cursor for pagination
|
|
nextCursor, hasMore := helper.BuildNextCursor(openAIResp.HasMore, lastBatchID)
|
|
|
|
// Convert to Bifrost response
|
|
bifrostResp := &schemas.BifrostBatchListResponse{
|
|
Object: "list",
|
|
Data: batches,
|
|
HasMore: hasMore,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
if nextCursor != "" {
|
|
bifrostResp.NextCursor = &nextCursor
|
|
}
|
|
|
|
return bifrostResp, nil
|
|
}
|
|
|
|
// BatchRetrieve retrieves a specific batch job by trying each key until found.
|
|
func (provider *OpenAIProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.BatchRetrieveRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if request.BatchID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil)
|
|
}
|
|
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/batches/" + request.BatchID)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
lastErr = ParseOpenAIError(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
continue
|
|
}
|
|
|
|
var openAIResp OpenAIBatchResponse
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
|
|
result := openAIResp.ToBifrostBatchRetrieveResponse(latency, sendBackRawRequest, sendBackRawResponse, rawRequest, rawResponse)
|
|
return result, nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
// BatchCancel cancels a batch job by trying each key until successful.
|
|
func (provider *OpenAIProvider) BatchCancel(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.BatchCancelRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if request.BatchID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil)
|
|
}
|
|
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/batches/" + request.BatchID + "/cancel")
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
lastErr = ParseOpenAIError(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
continue
|
|
}
|
|
|
|
var openAIResp OpenAIBatchResponse
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &openAIResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
|
|
result := &schemas.BifrostBatchCancelResponse{
|
|
ID: openAIResp.ID,
|
|
Object: openAIResp.Object,
|
|
Status: ToBifrostBatchStatus(openAIResp.Status),
|
|
CancellingAt: openAIResp.CancellingAt,
|
|
CancelledAt: openAIResp.CancelledAt,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
|
|
if openAIResp.RequestCounts != nil {
|
|
result.RequestCounts = schemas.BatchRequestCounts{
|
|
Total: openAIResp.RequestCounts.Total,
|
|
Completed: openAIResp.RequestCounts.Completed,
|
|
Failed: openAIResp.RequestCounts.Failed,
|
|
}
|
|
}
|
|
|
|
if sendBackRawRequest {
|
|
result.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
if sendBackRawResponse {
|
|
result.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
// BatchDelete is not supported by the OpenAI provider.
|
|
func (provider *OpenAIProvider) BatchDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// BatchResults retrieves batch results by trying each key until successful.
|
|
// Note: For OpenAI, batch results are obtained by downloading the output_file_id.
|
|
// This method returns the file content parsed as batch results.
|
|
func (provider *OpenAIProvider) BatchResults(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.BatchResultsRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if request.BatchID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("batch_id is required", nil)
|
|
}
|
|
|
|
// First, retrieve the batch to get the output_file_id (this already iterates over keys)
|
|
batchResp, bifrostErr := provider.BatchRetrieve(ctx, keys, &schemas.BifrostBatchRetrieveRequest{
|
|
Provider: request.Provider,
|
|
BatchID: request.BatchID,
|
|
})
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
if batchResp.OutputFileID == nil || *batchResp.OutputFileID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("batch results not available: output_file_id is empty (batch may not be completed)", nil)
|
|
}
|
|
|
|
// Download the output file - try each key
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
// Set headers
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/files/" + *batchResp.OutputFileID + "/content")
|
|
req.Header.SetMethod(http.MethodGet)
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
lastErr = ParseOpenAIError(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
continue
|
|
}
|
|
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
|
|
// Parse JSONL content - each line is a separate result
|
|
var results []schemas.BatchResultItem
|
|
|
|
parseResult := providerUtils.ParseJSONL(body, func(line []byte) error {
|
|
var resultItem schemas.BatchResultItem
|
|
if err := sonic.Unmarshal(line, &resultItem); err != nil {
|
|
provider.logger.Warn("failed to parse batch result line: %v", err)
|
|
return err
|
|
}
|
|
results = append(results, resultItem)
|
|
return nil
|
|
})
|
|
|
|
batchResultsResp := &schemas.BifrostBatchResultsResponse{
|
|
BatchID: request.BatchID,
|
|
Results: results,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
|
|
if len(parseResult.Errors) > 0 {
|
|
batchResultsResp.ExtraFields.ParseErrors = parseResult.Errors
|
|
}
|
|
|
|
return batchResultsResp, nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
// ContainerCreate creates a new container via OpenAI's API.
|
|
func (provider *OpenAIProvider) ContainerCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerCreateRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if request == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil)
|
|
}
|
|
|
|
if request.Name == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: name is required", nil)
|
|
}
|
|
|
|
// Build request body
|
|
reqBody := map[string]interface{}{
|
|
"name": request.Name,
|
|
}
|
|
|
|
if request.ExpiresAfter != nil {
|
|
reqBody["expires_after"] = map[string]interface{}{
|
|
"anchor": request.ExpiresAfter.Anchor,
|
|
"minutes": request.ExpiresAfter.Minutes,
|
|
}
|
|
}
|
|
|
|
if len(request.FileIDs) > 0 {
|
|
reqBody["file_ids"] = request.FileIDs
|
|
}
|
|
|
|
if request.MemoryLimit != "" {
|
|
reqBody["memory_limit"] = request.MemoryLimit
|
|
}
|
|
|
|
if len(request.Metadata) > 0 {
|
|
reqBody["metadata"] = request.Metadata
|
|
}
|
|
|
|
// Merge ExtraParams into reqBody (do not overwrite mandatory keys)
|
|
for k, v := range request.ExtraParams {
|
|
if _, exists := reqBody[k]; !exists {
|
|
reqBody[k] = v
|
|
}
|
|
}
|
|
|
|
jsonBody, err := providerUtils.MarshalSorted(reqBody)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
req.SetRequestURI(provider.buildRequestURL(ctx, "/v1/containers", schemas.ContainerCreateRequest))
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
req.SetBody(jsonBody)
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK && resp.StatusCode() != fasthttp.StatusCreated {
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
// Parse response
|
|
responseBody := append([]byte(nil), resp.Body()...)
|
|
|
|
var containerResp struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Name string `json:"name"`
|
|
CreatedAt int64 `json:"created_at"`
|
|
Status schemas.ContainerStatus `json:"status"`
|
|
ExpiresAfter *schemas.ContainerExpiresAfter `json:"expires_after"`
|
|
LastActiveAt *int64 `json:"last_active_at"`
|
|
MemoryLimit string `json:"memory_limit"`
|
|
Metadata map[string]string `json:"metadata"`
|
|
}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &containerResp, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
response := &schemas.BifrostContainerCreateResponse{
|
|
ID: containerResp.ID,
|
|
Object: containerResp.Object,
|
|
Name: containerResp.Name,
|
|
CreatedAt: containerResp.CreatedAt,
|
|
Status: containerResp.Status,
|
|
ExpiresAfter: containerResp.ExpiresAfter,
|
|
LastActiveAt: containerResp.LastActiveAt,
|
|
MemoryLimit: containerResp.MemoryLimit,
|
|
Metadata: containerResp.Metadata,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// ContainerList lists containers via OpenAI's API.
|
|
// Uses SerialListHelper for multi-key pagination - exhausts all pages from one key before moving to next.
|
|
func (provider *OpenAIProvider) ContainerList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
|
|
if request == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil)
|
|
}
|
|
if len(keys) == 0 {
|
|
if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess {
|
|
keys = []schemas.Key{{}}
|
|
} else {
|
|
return nil, providerUtils.NewBifrostOperationError("provider config not found", nil)
|
|
}
|
|
}
|
|
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerListRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
// Initialize serial pagination helper for multi-key support
|
|
helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err)
|
|
}
|
|
|
|
// Get current key to query
|
|
key, nativeCursor, ok := helper.GetCurrentKey()
|
|
if !ok {
|
|
// All keys exhausted
|
|
return &schemas.BifrostContainerListResponse{
|
|
Object: "list",
|
|
Data: []schemas.ContainerObject{},
|
|
HasMore: false,
|
|
}, nil
|
|
}
|
|
|
|
// Build query string
|
|
queryParams := url.Values{}
|
|
if request.Limit > 0 {
|
|
queryParams.Set("limit", fmt.Sprintf("%d", request.Limit))
|
|
}
|
|
// Use native cursor from helper instead of request.After
|
|
if nativeCursor != "" {
|
|
queryParams.Set("after", nativeCursor)
|
|
}
|
|
if request.Order != nil {
|
|
queryParams.Set("order", *request.Order)
|
|
}
|
|
|
|
requestURL := provider.buildRequestURL(ctx, "/v1/containers", schemas.ContainerListRequest)
|
|
if len(queryParams) > 0 {
|
|
requestURL += "?" + queryParams.Encode()
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
req.SetRequestURI(requestURL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
// Parse response
|
|
responseBody := append([]byte(nil), resp.Body()...)
|
|
|
|
var listResp struct {
|
|
Object string `json:"object"`
|
|
Data []schemas.ContainerObject `json:"data"`
|
|
FirstID *string `json:"first_id"`
|
|
LastID *string `json:"last_id"`
|
|
HasMore bool `json:"has_more"`
|
|
}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &listResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Track last container ID for pagination cursor
|
|
var lastContainerID string
|
|
for _, container := range listResp.Data {
|
|
lastContainerID = container.ID
|
|
}
|
|
|
|
// Build cursor for next request (handles cross-key pagination)
|
|
nextCursor, hasMore := helper.BuildNextCursor(listResp.HasMore, lastContainerID)
|
|
|
|
response := &schemas.BifrostContainerListResponse{
|
|
Object: listResp.Object,
|
|
Data: listResp.Data,
|
|
FirstID: listResp.FirstID,
|
|
LastID: listResp.LastID,
|
|
HasMore: hasMore,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
|
|
// Set encoded cursor for next page
|
|
if nextCursor != "" {
|
|
response.After = &nextCursor
|
|
}
|
|
|
|
if sendBackRawRequest {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
if sendBackRawResponse {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// ContainerRetrieve retrieves a specific container via OpenAI's API.
|
|
func (provider *OpenAIProvider) ContainerRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
|
|
if request == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil)
|
|
}
|
|
if len(keys) == 0 {
|
|
if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess {
|
|
keys = []schemas.Key{{}}
|
|
} else {
|
|
return nil, providerUtils.NewBifrostOperationError("provider config not found", nil)
|
|
}
|
|
}
|
|
if request.ContainerID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("container_id is required", nil)
|
|
}
|
|
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerRetrieveRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
req.SetRequestURI(provider.buildRequestURL(ctx, "/v1/containers/"+request.ContainerID, schemas.ContainerRetrieveRequest))
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
lastErr = ParseOpenAIError(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
// Parse response
|
|
responseBody := append([]byte(nil), resp.Body()...)
|
|
|
|
var containerResp struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Name string `json:"name"`
|
|
CreatedAt int64 `json:"created_at"`
|
|
Status schemas.ContainerStatus `json:"status"`
|
|
ExpiresAfter *schemas.ContainerExpiresAfter `json:"expires_after"`
|
|
LastActiveAt *int64 `json:"last_active_at"`
|
|
MemoryLimit string `json:"memory_limit"`
|
|
Metadata map[string]string `json:"metadata"`
|
|
}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &containerResp, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
response := &schemas.BifrostContainerRetrieveResponse{
|
|
ID: containerResp.ID,
|
|
Object: containerResp.Object,
|
|
Name: containerResp.Name,
|
|
CreatedAt: containerResp.CreatedAt,
|
|
Status: containerResp.Status,
|
|
ExpiresAfter: containerResp.ExpiresAfter,
|
|
LastActiveAt: containerResp.LastActiveAt,
|
|
MemoryLimit: containerResp.MemoryLimit,
|
|
Metadata: containerResp.Metadata,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
return response, nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
// ContainerDelete deletes a container via OpenAI's API.
|
|
func (provider *OpenAIProvider) ContainerDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
|
|
if request == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil)
|
|
}
|
|
if len(keys) == 0 {
|
|
if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess {
|
|
keys = []schemas.Key{{}}
|
|
} else {
|
|
return nil, providerUtils.NewBifrostOperationError("provider config not found", nil)
|
|
}
|
|
}
|
|
if request.ContainerID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("container_id is required", nil)
|
|
}
|
|
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerDeleteRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
req.SetRequestURI(provider.buildRequestURL(ctx, "/v1/containers/"+request.ContainerID, schemas.ContainerDeleteRequest))
|
|
req.Header.SetMethod(http.MethodDelete)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
lastErr = ParseOpenAIError(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
// Parse response
|
|
responseBody := append([]byte(nil), resp.Body()...)
|
|
|
|
var deleteResp struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Deleted bool `json:"deleted"`
|
|
}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &deleteResp, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = bifrostErr
|
|
continue
|
|
}
|
|
|
|
response := &schemas.BifrostContainerDeleteResponse{
|
|
ID: deleteResp.ID,
|
|
Object: deleteResp.Object,
|
|
Deleted: deleteResp.Deleted,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
return response, nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
// =============================================================================
|
|
// CONTAINER FILES API
|
|
// =============================================================================
|
|
|
|
// ContainerFileCreate creates a file in a container via OpenAI's API.
|
|
func (provider *OpenAIProvider) ContainerFileCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerFileCreateRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if request == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil)
|
|
}
|
|
|
|
if request.ContainerID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil)
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
endpoint := fmt.Sprintf("/v1/containers/%s/files", request.ContainerID)
|
|
req.SetRequestURI(provider.buildRequestURL(ctx, endpoint, schemas.ContainerFileCreateRequest))
|
|
req.Header.SetMethod(http.MethodPost)
|
|
|
|
// Handle file upload (multipart only)
|
|
if len(request.File) == 0 {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: file is required", nil)
|
|
}
|
|
|
|
// Multipart file upload
|
|
var body bytes.Buffer
|
|
writer := multipart.NewWriter(&body)
|
|
|
|
// Add file
|
|
part, err := writer.CreateFormFile("file", "file")
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to create multipart form", err)
|
|
}
|
|
if _, err = part.Write(request.File); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to write file to multipart form", err)
|
|
}
|
|
if err := writer.Close(); err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to close multipart form", err)
|
|
}
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|
req.SetBody(body.Bytes())
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
// Make request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() >= 400 {
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
// Decode response body (handles content-encoding like gzip)
|
|
responseBody, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
var fileResp struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Bytes int64 `json:"bytes"`
|
|
CreatedAt int64 `json:"created_at"`
|
|
ContainerID string `json:"container_id"`
|
|
Path string `json:"path"`
|
|
Source string `json:"source"`
|
|
}
|
|
|
|
_, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &fileResp, nil, false, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
containerFileCreateResponse := &schemas.BifrostContainerFileCreateResponse{
|
|
ID: fileResp.ID,
|
|
Object: fileResp.Object,
|
|
Bytes: fileResp.Bytes,
|
|
CreatedAt: fileResp.CreatedAt,
|
|
ContainerID: fileResp.ContainerID,
|
|
Path: fileResp.Path,
|
|
Source: fileResp.Source,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
|
|
// We don't capture payload for security reasons
|
|
if sendBackRawRequest {
|
|
containerFileCreateResponse.ExtraFields.RawRequest = "<REDACTED>"
|
|
}
|
|
if sendBackRawResponse {
|
|
containerFileCreateResponse.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return containerFileCreateResponse, nil
|
|
}
|
|
|
|
// ContainerFileList lists files in a container via OpenAI's API.
|
|
// Uses SerialListHelper for multi-key pagination - exhausts all pages from one key before moving to next.
|
|
func (provider *OpenAIProvider) ContainerFileList(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
|
|
if request == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil)
|
|
}
|
|
|
|
if request.ContainerID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil)
|
|
}
|
|
|
|
if len(keys) == 0 {
|
|
if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess {
|
|
keys = []schemas.Key{{}}
|
|
} else {
|
|
return nil, providerUtils.NewBifrostOperationError("no keys provided", nil)
|
|
}
|
|
}
|
|
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerFileListRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
// Initialize serial pagination helper for multi-key support
|
|
helper, err := providerUtils.NewSerialListHelper(keys, request.After, provider.logger)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid pagination cursor", err)
|
|
}
|
|
|
|
// Get current key to query
|
|
key, nativeCursor, ok := helper.GetCurrentKey()
|
|
if !ok {
|
|
// All keys exhausted
|
|
return &schemas.BifrostContainerFileListResponse{
|
|
Object: "list",
|
|
Data: []schemas.ContainerFileObject{},
|
|
HasMore: false,
|
|
}, nil
|
|
}
|
|
|
|
// Build URL with query parameters
|
|
endpoint := fmt.Sprintf("/v1/containers/%s/files", request.ContainerID)
|
|
requestURL := provider.buildRequestURL(ctx, endpoint, schemas.ContainerFileListRequest)
|
|
|
|
// Add query parameters
|
|
queryParams := url.Values{}
|
|
if request.Limit > 0 {
|
|
queryParams.Set("limit", fmt.Sprintf("%d", request.Limit))
|
|
}
|
|
// Use native cursor from helper instead of request.After
|
|
if nativeCursor != "" {
|
|
queryParams.Set("after", nativeCursor)
|
|
}
|
|
if request.Order != nil {
|
|
queryParams.Set("order", *request.Order)
|
|
}
|
|
if len(queryParams) > 0 {
|
|
requestURL = requestURL + "?" + queryParams.Encode()
|
|
}
|
|
|
|
// Create request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
req.SetRequestURI(requestURL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
if resp.StatusCode() >= 400 {
|
|
return nil, ParseOpenAIError(resp)
|
|
}
|
|
|
|
// Decode response body (handles content-encoding like gzip)
|
|
responseBody, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
var listResp struct {
|
|
Object string `json:"object"`
|
|
Data []schemas.ContainerFileObject `json:"data"`
|
|
FirstID *string `json:"first_id"`
|
|
LastID *string `json:"last_id"`
|
|
HasMore bool `json:"has_more"`
|
|
}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &listResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Track last file ID for pagination cursor
|
|
var lastFileID string
|
|
for _, file := range listResp.Data {
|
|
lastFileID = file.ID
|
|
}
|
|
|
|
// Build cursor for next request (handles cross-key pagination)
|
|
nextCursor, hasMore := helper.BuildNextCursor(listResp.HasMore, lastFileID)
|
|
|
|
containerFileListResponse := &schemas.BifrostContainerFileListResponse{
|
|
Object: listResp.Object,
|
|
Data: listResp.Data,
|
|
FirstID: listResp.FirstID,
|
|
LastID: listResp.LastID,
|
|
HasMore: hasMore,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
|
|
// Set encoded cursor for next page
|
|
if nextCursor != "" {
|
|
containerFileListResponse.After = &nextCursor
|
|
}
|
|
|
|
if sendBackRawRequest {
|
|
containerFileListResponse.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
if sendBackRawResponse {
|
|
containerFileListResponse.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return containerFileListResponse, nil
|
|
}
|
|
|
|
// ContainerFileRetrieve retrieves a file from a container via OpenAI's API.
|
|
func (provider *OpenAIProvider) ContainerFileRetrieve(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
|
|
if len(keys) == 0 {
|
|
if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess {
|
|
keys = []schemas.Key{{}}
|
|
} else {
|
|
return nil, providerUtils.NewBifrostOperationError("no keys provided", nil)
|
|
}
|
|
}
|
|
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerFileRetrieveRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if request == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil)
|
|
}
|
|
|
|
if request.ContainerID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil)
|
|
}
|
|
|
|
if request.FileID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil)
|
|
}
|
|
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
endpoint := fmt.Sprintf("/v1/containers/%s/files/%s", request.ContainerID, request.FileID)
|
|
req.SetRequestURI(provider.buildRequestURL(ctx, endpoint, schemas.ContainerFileRetrieveRequest))
|
|
req.Header.SetMethod(http.MethodGet)
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
lastErr = bifrostErr
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
if resp.StatusCode() >= 400 {
|
|
lastErr = ParseOpenAIError(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
// Decode response body (handles content-encoding like gzip)
|
|
responseBody, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
var fileResp struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Bytes int64 `json:"bytes"`
|
|
CreatedAt int64 `json:"created_at"`
|
|
ContainerID string `json:"container_id"`
|
|
Path string `json:"path"`
|
|
Source string `json:"source"`
|
|
}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &fileResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
lastErr = bifrostErr
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
containerFileRetrieveResponse := &schemas.BifrostContainerFileRetrieveResponse{
|
|
ID: fileResp.ID,
|
|
Object: fileResp.Object,
|
|
Bytes: fileResp.Bytes,
|
|
CreatedAt: fileResp.CreatedAt,
|
|
ContainerID: fileResp.ContainerID,
|
|
Path: fileResp.Path,
|
|
Source: fileResp.Source,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
|
|
if sendBackRawRequest {
|
|
containerFileRetrieveResponse.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
if sendBackRawResponse {
|
|
containerFileRetrieveResponse.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
return containerFileRetrieveResponse, nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
// ContainerFileContent retrieves the content of a file from a container via OpenAI's API.
|
|
func (provider *OpenAIProvider) ContainerFileContent(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
|
|
if len(keys) == 0 {
|
|
if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess {
|
|
keys = []schemas.Key{{}}
|
|
} else {
|
|
return nil, providerUtils.NewBifrostOperationError("no keys provided", nil)
|
|
}
|
|
}
|
|
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerFileContentRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if request == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil)
|
|
}
|
|
|
|
if request.ContainerID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil)
|
|
}
|
|
|
|
if request.FileID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil)
|
|
}
|
|
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
endpoint := fmt.Sprintf("/v1/containers/%s/files/%s/content", request.ContainerID, request.FileID)
|
|
req.SetRequestURI(provider.buildRequestURL(ctx, endpoint, schemas.ContainerFileContentRequest))
|
|
req.Header.SetMethod(http.MethodGet)
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
lastErr = bifrostErr
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
if resp.StatusCode() >= 400 {
|
|
lastErr = ParseOpenAIError(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
// Get content type from response header
|
|
contentType := string(resp.Header.ContentType())
|
|
if contentType == "" {
|
|
contentType = "application/octet-stream"
|
|
}
|
|
|
|
// Decode response body (handles content-encoding like gzip)
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
continue
|
|
}
|
|
content := append([]byte(nil), body...)
|
|
|
|
containerFileContentResponse := &schemas.BifrostContainerFileContentResponse{
|
|
Content: content,
|
|
ContentType: contentType,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
containerFileContentResponse.ExtraFields.RawRequest = map[string]string{
|
|
"container_id": request.ContainerID,
|
|
"file_id": request.FileID,
|
|
}
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
containerFileContentResponse.ExtraFields.RawResponse = "<REDACTED>"
|
|
}
|
|
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
return containerFileContentResponse, nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
// ContainerFileDelete deletes a file from a container via OpenAI's API.
|
|
func (provider *OpenAIProvider) ContainerFileDelete(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
|
|
if len(keys) == 0 {
|
|
if provider.customProviderConfig != nil && provider.customProviderConfig.IsKeyLess {
|
|
keys = []schemas.Key{{}}
|
|
} else {
|
|
return nil, providerUtils.NewBifrostOperationError("no keys provided", nil)
|
|
}
|
|
}
|
|
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ContainerFileDeleteRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if request == nil {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: nil", nil)
|
|
}
|
|
|
|
if request.ContainerID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: container_id is required", nil)
|
|
}
|
|
|
|
if request.FileID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid request: file_id is required", nil)
|
|
}
|
|
|
|
var lastErr *schemas.BifrostError
|
|
for _, key := range keys {
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
endpoint := fmt.Sprintf("/v1/containers/%s/files/%s", request.ContainerID, request.FileID)
|
|
req.SetRequestURI(provider.buildRequestURL(ctx, endpoint, schemas.ContainerFileDeleteRequest))
|
|
req.Header.SetMethod(http.MethodDelete)
|
|
|
|
if key.Value.GetValue() != "" {
|
|
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
wait()
|
|
if bifrostErr != nil {
|
|
lastErr = bifrostErr
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
if resp.StatusCode() >= 400 {
|
|
lastErr = ParseOpenAIError(resp)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
// Decode response body (handles content-encoding like gzip)
|
|
responseBody, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
lastErr = providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
|
|
var deleteResp struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Deleted bool `json:"deleted"`
|
|
}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &deleteResp, nil, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
lastErr = bifrostErr
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
continue
|
|
}
|
|
|
|
containerFileDeleteResponse := &schemas.BifrostContainerFileDeleteResponse{
|
|
ID: deleteResp.ID,
|
|
Object: deleteResp.Object,
|
|
Deleted: deleteResp.Deleted,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
},
|
|
}
|
|
|
|
if sendBackRawRequest {
|
|
containerFileDeleteResponse.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
if sendBackRawResponse {
|
|
containerFileDeleteResponse.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
return containerFileDeleteResponse, nil
|
|
}
|
|
|
|
return nil, lastErr
|
|
}
|
|
|
|
func (provider *OpenAIProvider) Passthrough(
|
|
ctx *schemas.BifrostContext,
|
|
key schemas.Key,
|
|
req *schemas.BifrostPassthroughRequest,
|
|
) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.PassthroughRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
path := req.Path
|
|
// if path has v1 or v1/ remove it
|
|
if after, ok := strings.CutPrefix(path, "/v1"); ok {
|
|
path = after
|
|
}
|
|
|
|
url := provider.networkConfig.BaseURL + "/v1" + path
|
|
if req.RawQuery != "" {
|
|
url += "?" + req.RawQuery
|
|
}
|
|
|
|
fasthttpReq := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
defer fasthttp.ReleaseRequest(fasthttpReq)
|
|
|
|
fasthttpReq.Header.SetMethod(req.Method)
|
|
fasthttpReq.SetRequestURI(url)
|
|
|
|
providerUtils.SetExtraHeaders(ctx, fasthttpReq, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
for k, v := range req.SafeHeaders {
|
|
fasthttpReq.Header.Set(k, v)
|
|
}
|
|
|
|
if key.Value.GetValue() != "" {
|
|
fasthttpReq.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
fasthttpReq.SetBody(req.Body)
|
|
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, fasthttpReq, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
headers := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err)
|
|
}
|
|
|
|
// Remove wire-level encoding headers after decoding; downstream should recalculate them for the buffered body.
|
|
for k := range headers {
|
|
if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") {
|
|
delete(headers, k)
|
|
}
|
|
}
|
|
|
|
bifrostResponse := &schemas.BifrostPassthroughResponse{
|
|
StatusCode: resp.StatusCode(),
|
|
Headers: headers,
|
|
Body: body,
|
|
}
|
|
|
|
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &bifrostResponse.ExtraFields)
|
|
}
|
|
|
|
return bifrostResponse, nil
|
|
}
|
|
|
|
func (provider *OpenAIProvider) PassthroughStream(
|
|
ctx *schemas.BifrostContext,
|
|
postHookRunner schemas.PostHookRunner,
|
|
postHookSpanFinalizer func(context.Context),
|
|
key schemas.Key,
|
|
req *schemas.BifrostPassthroughRequest,
|
|
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.PassthroughStreamRequest); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
path := req.Path
|
|
if after, ok := strings.CutPrefix(path, "/v1"); ok {
|
|
path = after
|
|
}
|
|
url := provider.networkConfig.BaseURL + "/v1" + path
|
|
if req.RawQuery != "" {
|
|
url += "?" + req.RawQuery
|
|
}
|
|
|
|
fasthttpReq := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
resp.StreamBody = true
|
|
defer fasthttp.ReleaseRequest(fasthttpReq)
|
|
|
|
fasthttpReq.Header.SetMethod(req.Method)
|
|
fasthttpReq.SetRequestURI(url)
|
|
|
|
providerUtils.SetExtraHeaders(ctx, fasthttpReq, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
for k, v := range req.SafeHeaders {
|
|
fasthttpReq.Header.Set(k, v)
|
|
}
|
|
|
|
fasthttpReq.Header.Set("Connection", "close")
|
|
|
|
if key.Value.GetValue() != "" {
|
|
fasthttpReq.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
|
}
|
|
|
|
fasthttpReq.SetBody(req.Body)
|
|
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.streamingClient, resp)
|
|
|
|
startTime := time.Now()
|
|
|
|
if err := activeClient.Do(fasthttpReq, resp); err != nil {
|
|
providerUtils.ReleaseStreamingResponse(resp)
|
|
if errors.Is(err, context.Canceled) {
|
|
return nil, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Type: schemas.Ptr(schemas.RequestCancelled),
|
|
Message: schemas.ErrRequestCancelled,
|
|
Error: err,
|
|
},
|
|
}
|
|
}
|
|
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
|
|
}
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
|
|
}
|
|
|
|
headers := make(map[string]string)
|
|
resp.Header.All()(func(k, v []byte) bool {
|
|
headers[string(k)] = string(v)
|
|
return true
|
|
})
|
|
|
|
rawBodyStream := resp.BodyStream()
|
|
if rawBodyStream == nil {
|
|
providerUtils.ReleaseStreamingResponse(resp)
|
|
return nil, providerUtils.NewBifrostOperationError(
|
|
"provider returned an empty stream body",
|
|
fmt.Errorf("provider returned an empty stream body"))
|
|
}
|
|
|
|
// Wrap reader with idle timeout to detect stalled streams.
|
|
bodyStream, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(rawBodyStream, rawBodyStream, providerUtils.GetStreamIdleTimeout(ctx))
|
|
|
|
// Cancellation must close the raw stream to unblock reads.
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger)
|
|
|
|
extraFields := schemas.BifrostResponseExtraFields{}
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields)
|
|
}
|
|
statusCode := resp.StatusCode()
|
|
|
|
ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer)
|
|
}
|
|
close(ch)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
defer stopIdleTimeout()
|
|
defer stopCancellation()
|
|
|
|
buf := make([]byte, 4096)
|
|
for {
|
|
n, readErr := bodyStream.Read(buf)
|
|
if n > 0 {
|
|
chunk := make([]byte, n)
|
|
copy(chunk, buf[:n])
|
|
select {
|
|
case ch <- &schemas.BifrostStreamChunk{
|
|
BifrostPassthroughResponse: &schemas.BifrostPassthroughResponse{
|
|
StatusCode: statusCode,
|
|
Headers: headers,
|
|
Body: chunk,
|
|
ExtraFields: extraFields,
|
|
},
|
|
}:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
if readErr == io.EOF {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
extraFields.Latency = time.Since(startTime).Milliseconds()
|
|
finalResp := &schemas.BifrostResponse{
|
|
PassthroughResponse: &schemas.BifrostPassthroughResponse{
|
|
StatusCode: statusCode,
|
|
Headers: headers,
|
|
ExtraFields: extraFields,
|
|
},
|
|
}
|
|
postHookRunner(ctx, finalResp, nil)
|
|
if postHookSpanFinalizer != nil {
|
|
postHookSpanFinalizer(ctx)
|
|
}
|
|
return
|
|
}
|
|
if readErr != nil {
|
|
if ctx.Err() != nil {
|
|
return // let defer handle cancel/timeout
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
extraFields.Latency = time.Since(startTime).Milliseconds()
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
return ch, nil
|
|
}
|