first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,172 @@
package openai
import (
"time"
"github.com/maximhq/bifrost/core/schemas"
)
// OpenAI Batch API Types
// OpenAIBatchRequest represents the request body for creating a batch.
type OpenAIBatchRequest struct {
InputFileID string `json:"input_file_id"`
Endpoint string `json:"endpoint"`
CompletionWindow string `json:"completion_window"`
Metadata map[string]string `json:"metadata,omitempty"`
OutputExpiresAfter *schemas.BatchExpiresAfter `json:"output_expires_after,omitempty"`
}
// OpenAIBatchResponse represents an OpenAI batch response.
type OpenAIBatchResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Endpoint string `json:"endpoint"`
Errors *schemas.BatchErrors `json:"errors,omitempty"`
InputFileID string `json:"input_file_id"`
CompletionWindow string `json:"completion_window"`
Status string `json:"status"`
OutputFileID *string `json:"output_file_id,omitempty"`
ErrorFileID *string `json:"error_file_id,omitempty"`
CreatedAt int64 `json:"created_at"`
InProgressAt *int64 `json:"in_progress_at,omitempty"`
ExpiresAt *int64 `json:"expires_at,omitempty"`
FinalizingAt *int64 `json:"finalizing_at,omitempty"`
CompletedAt *int64 `json:"completed_at,omitempty"`
FailedAt *int64 `json:"failed_at,omitempty"`
ExpiredAt *int64 `json:"expired_at,omitempty"`
CancellingAt *int64 `json:"cancelling_at,omitempty"`
CancelledAt *int64 `json:"cancelled_at,omitempty"`
RequestCounts *OpenAIBatchRequestCounts `json:"request_counts,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// OpenAIBatchRequestCounts represents the request counts for a batch.
type OpenAIBatchRequestCounts struct {
Total int `json:"total"`
Completed int `json:"completed"`
Failed int `json:"failed"`
}
// OpenAIBatchListResponse represents the response from listing batches.
type OpenAIBatchListResponse struct {
Object string `json:"object"`
Data []OpenAIBatchResponse `json:"data"`
FirstID *string `json:"first_id,omitempty"`
LastID *string `json:"last_id,omitempty"`
HasMore bool `json:"has_more"`
}
// ToBifrostBatchStatus converts OpenAI status to Bifrost status.
func ToBifrostBatchStatus(status string) schemas.BatchStatus {
switch status {
case "validating":
return schemas.BatchStatusValidating
case "failed":
return schemas.BatchStatusFailed
case "in_progress":
return schemas.BatchStatusInProgress
case "finalizing":
return schemas.BatchStatusFinalizing
case "completed":
return schemas.BatchStatusCompleted
case "expired":
return schemas.BatchStatusExpired
case "cancelling":
return schemas.BatchStatusCancelling
case "cancelled":
return schemas.BatchStatusCancelled
default:
return schemas.BatchStatus(status)
}
}
// ToBifrostBatchCreateResponse converts OpenAI batch response to Bifrost batch response.
func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse {
resp := &schemas.BifrostBatchCreateResponse{
ID: r.ID,
Object: r.Object,
Endpoint: r.Endpoint,
InputFileID: r.InputFileID,
CompletionWindow: r.CompletionWindow,
Status: ToBifrostBatchStatus(r.Status),
Metadata: r.Metadata,
CreatedAt: r.CreatedAt,
OutputFileID: r.OutputFileID,
ErrorFileID: r.ErrorFileID,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
if sendBackRawRequest {
resp.ExtraFields.RawRequest = rawRequest
}
if r.ExpiresAt != nil {
resp.ExpiresAt = r.ExpiresAt
}
if r.RequestCounts != nil {
resp.RequestCounts = schemas.BatchRequestCounts{
Total: r.RequestCounts.Total,
Completed: r.RequestCounts.Completed,
Failed: r.RequestCounts.Failed,
}
}
if sendBackRawResponse {
resp.ExtraFields.RawResponse = rawResponse
}
return resp
}
// ToBifrostBatchRetrieveResponse converts OpenAI batch response to Bifrost batch retrieve response.
func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse {
resp := &schemas.BifrostBatchRetrieveResponse{
ID: r.ID,
Object: r.Object,
Endpoint: r.Endpoint,
InputFileID: r.InputFileID,
CompletionWindow: r.CompletionWindow,
Status: ToBifrostBatchStatus(r.Status),
Metadata: r.Metadata,
CreatedAt: r.CreatedAt,
InProgressAt: r.InProgressAt,
FinalizingAt: r.FinalizingAt,
CompletedAt: r.CompletedAt,
FailedAt: r.FailedAt,
ExpiredAt: r.ExpiredAt,
CancellingAt: r.CancellingAt,
CancelledAt: r.CancelledAt,
OutputFileID: r.OutputFileID,
ErrorFileID: r.ErrorFileID,
Errors: r.Errors,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
if sendBackRawRequest {
resp.ExtraFields.RawRequest = rawRequest
}
if r.ExpiresAt != nil {
resp.ExpiresAt = r.ExpiresAt
}
if r.RequestCounts != nil {
resp.RequestCounts = schemas.BatchRequestCounts{
Total: r.RequestCounts.Total,
Completed: r.RequestCounts.Completed,
Failed: r.RequestCounts.Failed,
}
}
if sendBackRawResponse {
resp.ExtraFields.RawResponse = rawResponse
}
return resp
}

View File

@@ -0,0 +1,192 @@
package openai
import (
"strings"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostChatRequest converts an OpenAI chat request to Bifrost format
func (req *OpenAIChatRequest) ToBifrostChatRequest(ctx *schemas.BifrostContext) *schemas.BifrostChatRequest {
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
return &schemas.BifrostChatRequest{
Provider: provider,
Model: model,
Input: ConvertOpenAIMessagesToBifrostMessages(req.Messages),
Params: &req.ChatParameters,
Fallbacks: schemas.ParseFallbacks(req.Fallbacks),
}
}
// ToOpenAIChatRequest converts a Bifrost chat completion request to OpenAI format
func ToOpenAIChatRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.BifrostChatRequest) *OpenAIChatRequest {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil
}
openaiReq := &OpenAIChatRequest{
Model: bifrostReq.Model,
Messages: ConvertBifrostMessagesToOpenAIMessages(bifrostReq.Input),
}
if bifrostReq.Params != nil {
openaiReq.ChatParameters = *bifrostReq.Params
if openaiReq.ChatParameters.MaxCompletionTokens != nil && *openaiReq.ChatParameters.MaxCompletionTokens < MinMaxCompletionTokens {
openaiReq.ChatParameters.MaxCompletionTokens = schemas.Ptr(MinMaxCompletionTokens)
}
// Drop user field if it exceeds OpenAI's 64 character limit
openaiReq.ChatParameters.User = SanitizeUserField(openaiReq.ChatParameters.User)
openaiReq.ExtraParams = bifrostReq.Params.ExtraParams
// Normalize tool parameters for deterministic JSON serialization (improves prompt caching)
if len(openaiReq.ChatParameters.Tools) > 0 {
normalizedTools := make([]schemas.ChatTool, len(openaiReq.ChatParameters.Tools))
for i, tool := range openaiReq.ChatParameters.Tools {
normalizedTools[i] = tool
if tool.Function != nil && tool.Function.Parameters != nil {
funcCopy := *tool.Function
funcCopy.Parameters = tool.Function.Parameters.Normalized()
normalizedTools[i].Function = &funcCopy
}
}
openaiReq.ChatParameters.Tools = normalizedTools
}
}
switch bifrostReq.Provider {
case schemas.OpenAI, schemas.Azure:
return openaiReq
case schemas.XAI:
openaiReq.filterOpenAISpecificParameters()
openaiReq.applyXAICompatibility(bifrostReq.Model)
return openaiReq
case schemas.Gemini:
openaiReq.filterOpenAISpecificParameters()
// Removing extra parameters that are not supported by Gemini
openaiReq.ServiceTier = nil
return openaiReq
case schemas.Mistral:
openaiReq.filterOpenAISpecificParameters()
openaiReq.applyMistralCompatibility()
return openaiReq
case schemas.Vertex:
openaiReq.filterOpenAISpecificParameters()
// Apply Mistral-specific transformations for Vertex Mistral models
if schemas.IsMistralModel(bifrostReq.Model) {
openaiReq.applyMistralCompatibility()
}
return openaiReq
case schemas.Fireworks:
// Fireworks uses prompt_cache_isolation_key for cache isolation on chat/completions.
// Preserve it before the generic filter strips prompt_cache_key.
if openaiReq.ChatParameters.PromptCacheKey != nil && openaiReq.PromptCacheIsolationKey == nil {
openaiReq.PromptCacheIsolationKey = openaiReq.ChatParameters.PromptCacheKey
}
// Fireworks supports predicted outputs; save before the filter strips them.
prediction := openaiReq.ChatParameters.Prediction
openaiReq.filterOpenAISpecificParameters()
openaiReq.ChatParameters.Prediction = prediction
return openaiReq
default:
// Check if provider is a custom provider
if isCustomProvider, ok := ctx.Value(schemas.BifrostContextKeyIsCustomProvider).(bool); ok && isCustomProvider {
return openaiReq
}
openaiReq.filterOpenAISpecificParameters()
return openaiReq
}
}
// Filter OpenAI Specific Parameters
func (req *OpenAIChatRequest) filterOpenAISpecificParameters() {
// Handle reasoning parameter: OpenAI uses effort-based reasoning
// Priority: effort (native) > max_tokens (estimated)
if req.ChatParameters.Reasoning != nil {
reasoningCopy := *req.ChatParameters.Reasoning
req.ChatParameters.Reasoning = &reasoningCopy
if req.ChatParameters.Reasoning.Effort != nil {
// Native field is provided, use it (and clear max_tokens)
effort := *req.ChatParameters.Reasoning.Effort
// Convert "minimal" to "low"; cap "xhigh"/"max" to "high" — OpenAI tops out at high.
switch effort {
case "minimal":
req.ChatParameters.Reasoning.Effort = schemas.Ptr("low")
case "xhigh", "max":
req.ChatParameters.Reasoning.Effort = schemas.Ptr("high")
}
// Clear max_tokens since OpenAI doesn't use it
req.ChatParameters.Reasoning.MaxTokens = nil
} else if req.ChatParameters.Reasoning.MaxTokens != nil {
// Estimate effort from max_tokens
maxTokens := *req.ChatParameters.Reasoning.MaxTokens
maxCompletionTokens := utils.GetMaxOutputTokensOrDefault(req.Model, DefaultCompletionMaxTokens)
if req.ChatParameters.MaxCompletionTokens != nil {
maxCompletionTokens = *req.ChatParameters.MaxCompletionTokens
}
effort := utils.GetReasoningEffortFromBudgetTokens(maxTokens, MinReasoningMaxTokens, maxCompletionTokens)
req.ChatParameters.Reasoning.Effort = schemas.Ptr(effort)
// Clear max_tokens since OpenAI doesn't use it
req.ChatParameters.Reasoning.MaxTokens = nil
}
}
if req.ChatParameters.Prediction != nil {
req.ChatParameters.Prediction = nil
}
if req.ChatParameters.PromptCacheKey != nil {
req.ChatParameters.PromptCacheKey = nil
}
if req.ChatParameters.PromptCacheRetention != nil {
req.ChatParameters.PromptCacheRetention = nil
}
if req.ChatParameters.Verbosity != nil {
req.ChatParameters.Verbosity = nil
}
if req.ChatParameters.Store != nil {
req.ChatParameters.Store = nil
}
if req.ChatParameters.WebSearchOptions != nil {
req.ChatParameters.WebSearchOptions = nil
}
}
// applyMistralCompatibility applies Mistral-specific transformations to the request
func (req *OpenAIChatRequest) applyMistralCompatibility() {
// Mistral uses max_tokens instead of max_completion_tokens
if req.MaxCompletionTokens != nil {
req.MaxTokens = req.MaxCompletionTokens
req.MaxCompletionTokens = nil
}
// Mistral does not support ToolChoiceStruct, only simple tool choice strings are supported
if req.ToolChoice != nil && req.ToolChoice.ChatToolChoiceStruct != nil {
req.ToolChoice.ChatToolChoiceStr = schemas.Ptr("any")
req.ToolChoice.ChatToolChoiceStruct = nil
}
}
// applyXAICompatibility applies xAI-specific transformations to the request
func (req *OpenAIChatRequest) applyXAICompatibility(model string) {
// Only apply filters if this is a grok reasoning model
if !schemas.IsGrokReasoningModel(model) {
return
}
req.ChatParameters.PresencePenalty = nil
// Only non-mini grok-3 models support frequency_penalty and stop
// grok-3-mini only supports reasoning_effort in reasoning mode
if !strings.Contains(model, "grok-3") || strings.Contains(model, "grok-3-mini") {
req.ChatParameters.FrequencyPenalty = nil
req.ChatParameters.Stop = nil
}
// Only grok-3-mini supports reasoning_effort
if req.ChatParameters.Reasoning != nil &&
!strings.Contains(model, "grok-3-mini") {
// Clear reasoning_effort for non-grok-3-mini models
req.ChatParameters.Reasoning.Effort = nil
}
}

View File

@@ -0,0 +1,707 @@
package openai
import (
"encoding/json"
"strings"
"testing"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/require"
)
func TestToOpenAIChatRequest_ToolNormalization(t *testing.T) {
// Create tool parameters with keys in non-alphabetical order:
// "required" before "properties" before "type" — Normalized() should reorder to
// type → description → properties → required, then alphabetical.
unsortedParams := &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("zebra", map[string]interface{}{"type": "string"}),
schemas.KV("alpha", map[string]interface{}{"type": "number"}),
),
Required: []string{"zebra"},
}
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: "gpt-4o",
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser}},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: "function",
Function: &schemas.ChatToolFunction{
Name: "test_func",
Parameters: unsortedParams,
},
},
{
Type: "function",
Function: &schemas.ChatToolFunction{Name: "no_params_func"},
},
},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result := ToOpenAIChatRequest(ctx, bifrostReq)
if result == nil {
t.Fatal("expected non-nil result")
}
// Verify parameters are normalized: Properties keys should preserve original order
// (user-defined property names are kept in client order for LLM generation quality)
normalizedParams := result.ChatParameters.Tools[0].Function.Parameters
if normalizedParams == nil {
t.Fatal("expected normalized parameters to be non-nil")
}
keys := normalizedParams.Properties.Keys()
if len(keys) != 2 || keys[0] != "zebra" || keys[1] != "alpha" {
t.Errorf("expected Properties keys preserved as [zebra, alpha], got %v", keys)
}
// Verify tool without parameters is unaffected
if result.ChatParameters.Tools[1].Function.Parameters != nil {
t.Error("expected nil parameters for tool without parameters")
}
// Verify original bifrostReq.Params.Tools was NOT mutated
origKeys := bifrostReq.Params.Tools[0].Function.Parameters.Properties.Keys()
if len(origKeys) != 2 || origKeys[0] != "zebra" || origKeys[1] != "alpha" {
t.Errorf("original parameters were mutated: expected [zebra, alpha], got %v", origKeys)
}
// Verify the Function pointer is a different object (deep copy)
if result.ChatParameters.Tools[0].Function == bifrostReq.Params.Tools[0].Function {
t.Error("expected Function pointer to be a copy, not the original")
}
}
func TestToOpenAIChatRequest_PreservesN(t *testing.T) {
req := &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: "gpt-4.1",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentStr: schemas.Ptr("hello"),
},
},
},
Params: &schemas.ChatParameters{
N: schemas.Ptr(2),
},
}
out := ToOpenAIChatRequest(schemas.NewBifrostContext(nil, schemas.NoDeadline), req)
if out == nil {
t.Fatal("expected request")
}
if out.N == nil || *out.N != 2 {
t.Fatalf("expected n=2, got %#v", out.N)
}
}
func TestToOpenAIChatRequest_PreservesPropertyOrder(t *testing.T) {
params := &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("reasoning", map[string]interface{}{"type": "string", "description": "Step by step"}),
schemas.KV("answer", map[string]interface{}{"type": "string", "description": "Final answer"}),
schemas.KV("confidence", map[string]interface{}{"type": "number", "description": "Score"}),
),
Required: []string{"reasoning", "answer"},
}
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: "gpt-4o",
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser}},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{{
Type: "function",
Function: &schemas.ChatToolFunction{Name: "test_func", Parameters: params},
}},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result := ToOpenAIChatRequest(ctx, bifrostReq)
// CoT: property order preserved
normalizedParams := result.ChatParameters.Tools[0].Function.Parameters
keys := normalizedParams.Properties.Keys()
if len(keys) != 3 || keys[0] != "reasoning" || keys[1] != "answer" || keys[2] != "confidence" {
t.Errorf("expected property order [reasoning, answer, confidence], got %v", keys)
}
}
func TestToOpenAIChatRequest_PreservesExplicitEmptyToolParameters(t *testing.T) {
var tool schemas.ChatTool
err := json.Unmarshal([]byte(`{"type":"function","function":{"name":"empty_schema","parameters":{},"strict":false}}`), &tool)
if err != nil {
t.Fatalf("failed to unmarshal tool: %v", err)
}
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: "gpt-4o",
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser}},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{tool},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result := ToOpenAIChatRequest(ctx, bifrostReq)
if result == nil {
t.Fatal("expected non-nil result")
}
params := result.ChatParameters.Tools[0].Function.Parameters
if params == nil {
t.Fatal("expected tool parameters to be preserved")
}
marshaled, err := schemas.Marshal(params)
if err != nil {
t.Fatalf("failed to marshal parameters: %v", err)
}
if string(marshaled) != `{}` {
t.Fatalf("expected parameters to remain {}, got %s", marshaled)
}
}
func TestToOpenAIChatRequest_CachingDeterminism(t *testing.T) {
// Same properties, different structural key orders within property definitions
makeReq := func(props *schemas.OrderedMap) *schemas.BifrostChatRequest {
return &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: "gpt-4o",
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser}},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{{
Type: "function",
Function: &schemas.ChatToolFunction{
Name: "test",
Parameters: &schemas.ToolFunctionParameters{Type: "object", Properties: props},
},
}},
},
}
}
// Version A: type before description
propsA := schemas.NewOrderedMapFromPairs(
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "string"),
schemas.KV("description", "Step by step"),
)),
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "string"),
schemas.KV("description", "Final answer"),
)),
)
// Version B: description before type (different structural order)
propsB := schemas.NewOrderedMapFromPairs(
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(
schemas.KV("description", "Step by step"),
schemas.KV("type", "string"),
)),
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
schemas.KV("description", "Final answer"),
schemas.KV("type", "string"),
)),
)
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
resultA := ToOpenAIChatRequest(ctx, makeReq(propsA))
resultB := ToOpenAIChatRequest(ctx, makeReq(propsB))
jsonA, err := schemas.Marshal(resultA.ChatParameters.Tools[0].Function.Parameters)
if err != nil {
t.Fatalf("failed to marshal params A: %v", err)
}
jsonB, err := schemas.Marshal(resultB.ChatParameters.Tools[0].Function.Parameters)
if err != nil {
t.Fatalf("failed to marshal params B: %v", err)
}
if string(jsonA) != string(jsonB) {
t.Errorf("caching broken: same schema produced different JSON\nA: %s\nB: %s", jsonA, jsonB)
}
}
func TestToOpenAIChatRequest_FireworksPreservesReasoningAndCacheIsolation(t *testing.T) {
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
cacheKey := "cache-key-1"
reasoning := "step by step"
predictionContent := "fireworks ok"
userContent := "Reply with exactly: fireworks ok"
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Fireworks,
Model: "accounts/fireworks/models/deepseek-v3p2",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{
ContentStr: &userContent,
},
},
{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{
ContentStr: &predictionContent,
},
ChatAssistantMessage: &schemas.ChatAssistantMessage{
Reasoning: &reasoning,
},
},
},
Params: &schemas.ChatParameters{
PromptCacheKey: &cacheKey,
Prediction: &schemas.ChatPrediction{
Type: "content",
Content: predictionContent,
},
},
}
result := ToOpenAIChatRequest(ctx, bifrostReq)
if result == nil {
t.Fatal("expected non-nil result")
}
if result.PromptCacheIsolationKey == nil || *result.PromptCacheIsolationKey != cacheKey {
t.Fatalf("expected prompt_cache_isolation_key %q, got %v", cacheKey, result.PromptCacheIsolationKey)
}
if result.PromptCacheKey != nil {
t.Fatalf("expected prompt_cache_key to be stripped, got %v", *result.PromptCacheKey)
}
if result.Prediction == nil || result.Prediction.Content != predictionContent {
t.Fatalf("expected prediction to be preserved, got %#v", result.Prediction)
}
if len(result.Messages) != 2 || result.Messages[1].OpenAIChatAssistantMessage == nil {
t.Fatalf("expected assistant message with OpenAI assistant payload, got %#v", result.Messages)
}
if result.Messages[1].OpenAIChatAssistantMessage.Reasoning == nil || *result.Messages[1].OpenAIChatAssistantMessage.Reasoning != reasoning {
t.Fatalf("expected assistant reasoning_content %q, got %#v", reasoning, result.Messages[1].OpenAIChatAssistantMessage)
}
ctx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
bifrostReq,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToOpenAIChatRequest(ctx, bifrostReq), nil
},
)
if bifrostErr != nil {
t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message)
}
var jsonMap map[string]interface{}
if err := sonic.Unmarshal(wireBody, &jsonMap); err != nil {
t.Fatalf("failed to parse marshaled request body: %v", err)
}
if got, ok := jsonMap["prompt_cache_isolation_key"].(string); !ok || got != cacheKey {
t.Fatalf("expected prompt_cache_isolation_key %q in wire payload, got %#v", cacheKey, jsonMap["prompt_cache_isolation_key"])
}
if _, ok := jsonMap["prompt_cache_key"]; ok {
t.Fatalf("expected prompt_cache_key to be absent from wire payload, got %#v", jsonMap["prompt_cache_key"])
}
messages, ok := jsonMap["messages"].([]interface{})
if !ok || len(messages) != 2 {
t.Fatalf("expected 2 messages in wire payload, got %#v", jsonMap["messages"])
}
assistantMessage, ok := messages[1].(map[string]interface{})
if !ok {
t.Fatalf("expected assistant message object, got %#v", messages[1])
}
if got, ok := assistantMessage["reasoning_content"].(string); !ok || got != reasoning {
t.Fatalf("expected reasoning_content %q in assistant payload, got %#v", reasoning, assistantMessage["reasoning_content"])
}
}
// TestToOpenAIChatRequest_AnnotationsNotInWirePayload verifies that MCPToolAnnotations
// (stored on ChatTool with json:"-") are never included in the JSON body sent to OpenAI.
func TestToOpenAIChatRequest_AnnotationsNotInWirePayload(t *testing.T) {
readOnly := true
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.OpenAI,
Model: "gpt-4o",
Input: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hello")}},
},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "read_file",
Description: schemas.Ptr("Read a file"),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("path", map[string]interface{}{"type": "string"}),
),
Required: []string{"path"},
},
},
Annotations: &schemas.MCPToolAnnotations{
Title: "File Reader",
ReadOnlyHint: &readOnly,
},
},
},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result := ToOpenAIChatRequest(ctx, bifrostReq)
require.NotNil(t, result)
wireBody, err := json.Marshal(result)
require.NoError(t, err)
s := string(wireBody)
// Annotations must be absent from the wire payload
if strings.Contains(s, "annotations") {
t.Errorf("annotations field leaked into OpenAI wire payload: %s", s)
}
if strings.Contains(s, "readOnlyHint") {
t.Errorf("readOnlyHint leaked into OpenAI wire payload: %s", s)
}
if strings.Contains(s, "File Reader") {
t.Errorf("annotation title leaked into OpenAI wire payload: %s", s)
}
// The function definition must still be intact
if !strings.Contains(s, "read_file") {
t.Errorf("function name missing from OpenAI wire payload: %s", s)
}
}
func TestApplyXAICompatibility(t *testing.T) {
tests := []struct {
name string
model string
request *OpenAIChatRequest
validate func(t *testing.T, req *OpenAIChatRequest)
}{
{
name: "grok-3: preserves frequency_penalty and stop, clears presence_penalty and reasoning_effort",
model: "grok-3",
request: &OpenAIChatRequest{
Model: "grok-3",
Messages: []OpenAIMessage{},
ChatParameters: schemas.ChatParameters{
FrequencyPenalty: schemas.Ptr(0.5),
PresencePenalty: schemas.Ptr(0.3),
Stop: []string{"STOP"},
Reasoning: &schemas.ChatReasoning{
Effort: schemas.Ptr("high"),
},
},
},
validate: func(t *testing.T, req *OpenAIChatRequest) {
// frequency_penalty should be preserved
if req.FrequencyPenalty == nil || *req.FrequencyPenalty != 0.5 {
t.Errorf("Expected FrequencyPenalty to be preserved at 0.5, got %v", req.FrequencyPenalty)
}
// stop should be preserved
if len(req.Stop) != 1 || req.Stop[0] != "STOP" {
t.Errorf("Expected Stop to be preserved as ['STOP'], got %v", req.Stop)
}
// presence_penalty should be cleared
if req.PresencePenalty != nil {
t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty)
}
// reasoning_effort should be cleared for non-mini grok-3
if req.Reasoning == nil {
t.Fatal("Expected Reasoning to remain non-nil")
}
if req.Reasoning.Effort != nil {
t.Errorf("Expected Reasoning.Effort to be cleared (nil) for grok-3, got %v", *req.Reasoning.Effort)
}
},
},
{
name: "grok-3-mini: clears all penalties and stop, preserves reasoning_effort",
model: "grok-3-mini",
request: &OpenAIChatRequest{
Model: "grok-3-mini",
Messages: []OpenAIMessage{},
ChatParameters: schemas.ChatParameters{
FrequencyPenalty: schemas.Ptr(0.5),
PresencePenalty: schemas.Ptr(0.3),
Stop: []string{"STOP"},
Reasoning: &schemas.ChatReasoning{
Effort: schemas.Ptr("medium"),
},
},
},
validate: func(t *testing.T, req *OpenAIChatRequest) {
// presence_penalty should be cleared
if req.PresencePenalty != nil {
t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty)
}
// frequency_penalty should be cleared for grok-3-mini
if req.FrequencyPenalty != nil {
t.Errorf("Expected FrequencyPenalty to be cleared (nil) for grok-3-mini, got %v", *req.FrequencyPenalty)
}
// stop should be cleared for grok-3-mini
if req.Stop != nil {
t.Errorf("Expected Stop to be cleared (nil) for grok-3-mini, got %v", req.Stop)
}
// reasoning_effort should be preserved for grok-3-mini
if req.Reasoning == nil || req.Reasoning.Effort == nil {
t.Fatal("Expected Reasoning.Effort to be preserved for grok-3-mini")
}
if *req.Reasoning.Effort != "medium" {
t.Errorf("Expected Reasoning.Effort to be 'medium', got %v", *req.Reasoning.Effort)
}
},
},
{
name: "grok-4: clears all penalties, stop, and reasoning_effort",
model: "grok-4",
request: &OpenAIChatRequest{
Model: "grok-4",
Messages: []OpenAIMessage{},
ChatParameters: schemas.ChatParameters{
FrequencyPenalty: schemas.Ptr(0.5),
PresencePenalty: schemas.Ptr(0.3),
Stop: []string{"STOP"},
Reasoning: &schemas.ChatReasoning{
Effort: schemas.Ptr("high"),
},
},
},
validate: func(t *testing.T, req *OpenAIChatRequest) {
// presence_penalty should be cleared
if req.PresencePenalty != nil {
t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty)
}
// frequency_penalty should be cleared for grok-4
if req.FrequencyPenalty != nil {
t.Errorf("Expected FrequencyPenalty to be cleared (nil) for grok-4, got %v", *req.FrequencyPenalty)
}
// stop should be cleared for grok-4
if req.Stop != nil {
t.Errorf("Expected Stop to be cleared (nil) for grok-4, got %v", req.Stop)
}
// reasoning_effort should be cleared for grok-4
if req.Reasoning == nil {
t.Fatal("Expected Reasoning to remain non-nil")
}
if req.Reasoning.Effort != nil {
t.Errorf("Expected Reasoning.Effort to be cleared (nil) for grok-4, got %v", *req.Reasoning.Effort)
}
},
},
{
name: "grok-4-fast-reasoning: clears all penalties, stop, and reasoning_effort",
model: "grok-4-fast-reasoning",
request: &OpenAIChatRequest{
Model: "grok-4-fast-reasoning",
Messages: []OpenAIMessage{},
ChatParameters: schemas.ChatParameters{
FrequencyPenalty: schemas.Ptr(0.5),
PresencePenalty: schemas.Ptr(0.3),
Stop: []string{"STOP", "END"},
Reasoning: &schemas.ChatReasoning{
Effort: schemas.Ptr("high"),
},
},
},
validate: func(t *testing.T, req *OpenAIChatRequest) {
// presence_penalty should be cleared
if req.PresencePenalty != nil {
t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty)
}
// frequency_penalty should be cleared
if req.FrequencyPenalty != nil {
t.Errorf("Expected FrequencyPenalty to be cleared (nil), got %v", *req.FrequencyPenalty)
}
// stop should be cleared
if req.Stop != nil {
t.Errorf("Expected Stop to be cleared (nil), got %v", req.Stop)
}
// reasoning_effort should be cleared
if req.Reasoning == nil {
t.Fatal("Expected Reasoning to remain non-nil")
}
if req.Reasoning.Effort != nil {
t.Errorf("Expected Reasoning.Effort to be cleared (nil), got %v", *req.Reasoning.Effort)
}
},
},
{
name: "grok-code-fast-1: clears all penalties, stop, and reasoning_effort",
model: "grok-code-fast-1",
request: &OpenAIChatRequest{
Model: "grok-code-fast-1",
Messages: []OpenAIMessage{},
ChatParameters: schemas.ChatParameters{
FrequencyPenalty: schemas.Ptr(0.2),
PresencePenalty: schemas.Ptr(0.1),
Stop: []string{"END"},
Reasoning: &schemas.ChatReasoning{
Effort: schemas.Ptr("low"),
},
},
},
validate: func(t *testing.T, req *OpenAIChatRequest) {
// presence_penalty should be cleared
if req.PresencePenalty != nil {
t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty)
}
// frequency_penalty should be cleared
if req.FrequencyPenalty != nil {
t.Errorf("Expected FrequencyPenalty to be cleared (nil), got %v", *req.FrequencyPenalty)
}
// stop should be cleared
if req.Stop != nil {
t.Errorf("Expected Stop to be cleared (nil), got %v", req.Stop)
}
// reasoning_effort should be cleared
if req.Reasoning == nil {
t.Fatal("Expected Reasoning to remain non-nil")
}
if req.Reasoning.Effort != nil {
t.Errorf("Expected Reasoning.Effort to be cleared (nil), got %v", *req.Reasoning.Effort)
}
},
},
{
name: "non-reasoning grok model: no changes applied",
model: "grok-2-latest",
request: &OpenAIChatRequest{
Model: "grok-2-latest",
Messages: []OpenAIMessage{},
ChatParameters: schemas.ChatParameters{
FrequencyPenalty: schemas.Ptr(0.5),
PresencePenalty: schemas.Ptr(0.3),
Stop: []string{"STOP"},
Reasoning: &schemas.ChatReasoning{
Effort: schemas.Ptr("high"),
},
},
},
validate: func(t *testing.T, req *OpenAIChatRequest) {
// All parameters should be preserved for non-reasoning models
if req.FrequencyPenalty == nil || *req.FrequencyPenalty != 0.5 {
t.Errorf("Expected FrequencyPenalty to be preserved at 0.5, got %v", req.FrequencyPenalty)
}
if req.PresencePenalty == nil || *req.PresencePenalty != 0.3 {
t.Errorf("Expected PresencePenalty to be preserved at 0.3, got %v", req.PresencePenalty)
}
if len(req.Stop) != 1 || req.Stop[0] != "STOP" {
t.Errorf("Expected Stop to be preserved as ['STOP'], got %v", req.Stop)
}
if req.Reasoning == nil || req.Reasoning.Effort == nil {
t.Fatal("Expected Reasoning.Effort to be preserved")
}
if *req.Reasoning.Effort != "high" {
t.Errorf("Expected Reasoning.Effort to be 'high', got %v", *req.Reasoning.Effort)
}
},
},
{
name: "grok-3: handles nil reasoning gracefully",
model: "grok-3",
request: &OpenAIChatRequest{
Model: "grok-3",
Messages: []OpenAIMessage{},
ChatParameters: schemas.ChatParameters{
FrequencyPenalty: schemas.Ptr(0.5),
PresencePenalty: schemas.Ptr(0.3),
Stop: []string{"STOP"},
Reasoning: nil,
},
},
validate: func(t *testing.T, req *OpenAIChatRequest) {
// Should handle nil reasoning without panicking
if req.Reasoning != nil {
t.Errorf("Expected Reasoning to remain nil, got %v", req.Reasoning)
}
// Other parameters should still be processed
if req.PresencePenalty != nil {
t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty)
}
if req.FrequencyPenalty == nil || *req.FrequencyPenalty != 0.5 {
t.Errorf("Expected FrequencyPenalty to be preserved at 0.5, got %v", req.FrequencyPenalty)
}
},
},
{
name: "grok-3: preserves other parameters like temperature",
model: "grok-3",
request: &OpenAIChatRequest{
Model: "grok-3",
Messages: []OpenAIMessage{},
ChatParameters: schemas.ChatParameters{
Temperature: schemas.Ptr(0.8),
TopP: schemas.Ptr(0.9),
FrequencyPenalty: schemas.Ptr(0.5),
PresencePenalty: schemas.Ptr(0.3),
},
},
validate: func(t *testing.T, req *OpenAIChatRequest) {
// Unrelated parameters should be preserved
if req.Temperature == nil || *req.Temperature != 0.8 {
t.Errorf("Expected Temperature to be preserved at 0.8, got %v", req.Temperature)
}
if req.TopP == nil || *req.TopP != 0.9 {
t.Errorf("Expected TopP to be preserved at 0.9, got %v", req.TopP)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Apply the compatibility function
tt.request.applyXAICompatibility(tt.model)
// Validate the results
tt.validate(t, tt.request)
})
}
}

View File

@@ -0,0 +1,40 @@
package openai
import (
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostEmbeddingRequest converts an OpenAI embedding request to Bifrost format
func (request *OpenAIEmbeddingRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest {
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
return &schemas.BifrostEmbeddingRequest{
Provider: provider,
Model: model,
Input: request.Input,
Params: &request.EmbeddingParameters,
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
}
}
// ToOpenAIEmbeddingRequest converts a Bifrost embedding request to OpenAI format
func ToOpenAIEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *OpenAIEmbeddingRequest {
if bifrostReq == nil {
return nil
}
params := bifrostReq.Params
openaiReq := &OpenAIEmbeddingRequest{
Model: bifrostReq.Model,
Input: bifrostReq.Input,
}
// Map parameters
if params != nil {
openaiReq.EmbeddingParameters = *params
openaiReq.ExtraParams = params.ExtraParams
}
return openaiReq
}

View File

@@ -0,0 +1,54 @@
package openai
import (
"fmt"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// ErrorConverter is a function that converts provider-specific error responses to BifrostError.
type ErrorConverter func(resp *fasthttp.Response) *schemas.BifrostError
// ParseOpenAIError parses OpenAI error responses.
func ParseOpenAIError(resp *fasthttp.Response) *schemas.BifrostError {
var errorResp schemas.BifrostError
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
if errorResp.EventID != nil {
bifrostErr.EventID = errorResp.EventID
}
if errorResp.Error != nil {
if bifrostErr.Error == nil {
bifrostErr.Error = &schemas.ErrorField{}
}
bifrostErr.Error.Type = errorResp.Error.Type
bifrostErr.Error.Code = errorResp.Error.Code
if errorResp.Error.Message != "" {
bifrostErr.Error.Message = errorResp.Error.Message
}
bifrostErr.Error.Param = errorResp.Error.Param
if errorResp.Error.EventID != nil {
bifrostErr.Error.EventID = errorResp.Error.EventID
}
}
if bifrostErr.Error == nil {
bifrostErr.Error = &schemas.ErrorField{}
}
if strings.TrimSpace(bifrostErr.Error.Message) == "" {
if bifrostErr.StatusCode != nil {
bifrostErr.Error.Message = fmt.Sprintf("provider API error (status %d)", *bifrostErr.StatusCode)
} else {
bifrostErr.Error.Message = "provider API error"
}
}
// Set ExtraFields unconditionally so provider/model/request metadata is always attached
return bifrostErr
}

View File

@@ -0,0 +1,83 @@
package openai
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
func TestParseOpenAIError_FallbackMessageWhenProviderBodyIsNonOpenAIShape(t *testing.T) {
var resp fasthttp.Response
resp.SetStatusCode(fasthttp.StatusUnprocessableEntity)
resp.SetBodyString(`{"detail":[{"loc":["body","messages",0,"role"],"msg":"value is not a valid enumeration member"}]}`)
errResp := ParseOpenAIError(&resp)
if errResp == nil || errResp.Error == nil {
t.Fatal("expected non-nil error response")
}
if errResp.Error.Message == "" {
t.Fatal("expected non-empty error message")
}
if errResp.Error.Message != "provider API error (status 422)" {
t.Fatalf("expected fallback message, got %q", errResp.Error.Message)
}
}
func TestParseOpenAIError_PreservesProviderMessageWhenPresent(t *testing.T) {
var resp fasthttp.Response
resp.SetStatusCode(fasthttp.StatusUnprocessableEntity)
resp.SetBodyString(`{"error":{"message":"unsupported role: developer","type":"invalid_request_error","param":"messages.0.role","code":"invalid_value"}}`)
errResp := ParseOpenAIError(&resp)
if errResp == nil || errResp.Error == nil {
t.Fatal("expected non-nil error response")
}
if errResp.Error.Message != "unsupported role: developer" {
t.Fatalf("expected provider message, got %q", errResp.Error.Message)
}
}
func TestParseOpenAIError_FallbackMessageWhenBodyIsEmpty(t *testing.T) {
var resp fasthttp.Response
resp.SetStatusCode(fasthttp.StatusBadRequest)
resp.SetBody(nil)
errResp := ParseOpenAIError(&resp)
if errResp == nil || errResp.Error == nil {
t.Fatal("expected non-nil error response")
}
// HandleProviderAPIError returns ErrProviderResponseEmpty with HTTP status for empty bodies.
expectedMsg := schemas.ErrProviderResponseEmpty + " (HTTP 400)"
if errResp.Error.Message != expectedMsg {
t.Fatalf("expected %q, got %q", expectedMsg, errResp.Error.Message)
}
}
func TestParseOpenAIError_WhitespaceProviderMessageFallsBack(t *testing.T) {
var resp fasthttp.Response
resp.SetStatusCode(fasthttp.StatusBadRequest)
resp.SetBodyString(`{"error":{"message":" ","type":"invalid_request_error"}}`)
errResp := ParseOpenAIError(&resp)
if errResp == nil || errResp.Error == nil {
t.Fatal("expected non-nil error response")
}
if errResp.Error.Message != "provider API error (status 400)" {
t.Fatalf("expected fallback message, got %q", errResp.Error.Message)
}
}
func TestParseOpenAIError_DefaultStatusCodeFallsBackWithStatusNumber(t *testing.T) {
var resp fasthttp.Response
// fasthttp defaults zero-value response status code to 200.
resp.SetBodyString(`{"error":{"message":""}}`)
errResp := ParseOpenAIError(&resp)
if errResp == nil || errResp.Error == nil {
t.Fatal("expected non-nil error response")
}
if errResp.Error.Message != "provider API error (status 200)" {
t.Fatalf("expected fallback message with default status, got %q", errResp.Error.Message)
}
}

View File

@@ -0,0 +1,124 @@
package openai
import (
"bytes"
"time"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// OpenAI File API Types
// OpenAIFileResponse represents an OpenAI file response.
type OpenAIFileResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Bytes int64 `json:"bytes"`
CreatedAt int64 `json:"created_at"`
Filename string `json:"filename"`
Purpose schemas.FilePurpose `json:"purpose"`
Status string `json:"status,omitempty"`
StatusDetails *string `json:"status_details,omitempty"`
}
// OpenAIFileListResponse represents the response from listing files.
type OpenAIFileListResponse struct {
Object string `json:"object"`
Data []OpenAIFileResponse `json:"data"`
HasMore bool `json:"has_more,omitempty"`
}
// OpenAIFileDeleteResponse represents the response from deleting a file.
type OpenAIFileDeleteResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Deleted bool `json:"deleted"`
}
// ToBifrostFileStatus converts OpenAI status to Bifrost status.
func ToBifrostFileStatus(status string) schemas.FileStatus {
switch status {
case "uploaded":
return schemas.FileStatusUploaded
case "processed", "completed":
return schemas.FileStatusProcessed
case "processing", "in_progress":
return schemas.FileStatusProcessing
case "error", "failed":
return schemas.FileStatusError
case "deleted", "cancelled":
return schemas.FileStatusDeleted
default:
return schemas.FileStatus(status)
}
}
// ToBifrostFileUploadResponse converts OpenAI file response to Bifrost file upload response.
func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse {
resp := &schemas.BifrostFileUploadResponse{
ID: r.ID,
Object: r.Object,
Bytes: r.Bytes,
CreatedAt: r.CreatedAt,
Filename: r.Filename,
Purpose: r.Purpose,
Status: ToBifrostFileStatus(r.Status),
StatusDetails: r.StatusDetails,
StorageBackend: schemas.FileStorageAPI,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
if sendBackRawRequest {
resp.ExtraFields.RawRequest = rawRequest
}
if sendBackRawResponse {
resp.ExtraFields.RawResponse = rawResponse
}
return resp
}
// ToBifrostFileRetrieveResponse converts OpenAI file response to Bifrost file retrieve response.
func (r *OpenAIFileResponse) ToBifrostFileRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileRetrieveResponse {
resp := &schemas.BifrostFileRetrieveResponse{
ID: r.ID,
Object: r.Object,
Bytes: r.Bytes,
CreatedAt: r.CreatedAt,
Filename: r.Filename,
Purpose: r.Purpose,
Status: ToBifrostFileStatus(r.Status),
StatusDetails: r.StatusDetails,
StorageBackend: schemas.FileStorageAPI,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
if sendBackRawRequest {
resp.ExtraFields.RawRequest = rawRequest
}
if sendBackRawResponse {
resp.ExtraFields.RawResponse = rawResponse
}
return resp
}
// ConvertRequestsToJSONL converts batch request items to JSONL format.
func ConvertRequestsToJSONL(requests []schemas.BatchRequestItem) ([]byte, error) {
var buf bytes.Buffer
for _, req := range requests {
line, err := providerUtils.MarshalSorted(req)
if err != nil {
return nil, err
}
buf.Write(line)
buf.WriteByte('\n')
}
return buf.Bytes(), nil
}

View File

@@ -0,0 +1,361 @@
package openai
import (
"fmt"
"mime/multipart"
"net/http"
"strconv"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToOpenAIImageGenerationRequest converts a Bifrost Image Request to OpenAI format
func ToOpenAIImageGenerationRequest(bifrostReq *schemas.BifrostImageGenerationRequest) *OpenAIImageGenerationRequest {
if bifrostReq == nil || bifrostReq.Input == nil || bifrostReq.Input.Prompt == "" {
return nil
}
req := &OpenAIImageGenerationRequest{
Model: bifrostReq.Model,
Prompt: bifrostReq.Input.Prompt,
}
if bifrostReq.Params != nil {
req.ImageGenerationParameters = *bifrostReq.Params
}
switch bifrostReq.Provider {
case schemas.XAI:
filterXAISpecificParameters(req)
case schemas.OpenAI, schemas.Azure:
filterOpenAISpecificParameters(req)
}
if bifrostReq.Params != nil {
req.ExtraParams = bifrostReq.Params.ExtraParams
}
return req
}
func filterXAISpecificParameters(req *OpenAIImageGenerationRequest) {
req.ImageGenerationParameters.Quality = nil
req.ImageGenerationParameters.Style = nil
req.ImageGenerationParameters.Size = nil
req.ImageGenerationParameters.OutputCompression = nil
}
func filterOpenAISpecificParameters(req *OpenAIImageGenerationRequest) {
req.ImageGenerationParameters.Seed = nil
req.NumInferenceSteps = nil
req.NegativePrompt = nil
}
// ToBifrostImageGenerationRequest converts an OpenAI image generation request to Bifrost format
func (request *OpenAIImageGenerationRequest) ToBifrostImageGenerationRequest(ctx *schemas.BifrostContext) *schemas.BifrostImageGenerationRequest {
if request == nil {
return nil
}
provider, model := schemas.ParseModelString(request.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
return &schemas.BifrostImageGenerationRequest{
Provider: provider,
Model: model,
Input: &schemas.ImageGenerationInput{
Prompt: request.Prompt,
},
Params: &request.ImageGenerationParameters,
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
}
}
func (request *OpenAIImageEditRequest) ToBifrostImageEditRequest(ctx *schemas.BifrostContext) *schemas.BifrostImageEditRequest {
if request == nil {
return nil
}
provider, model := schemas.ParseModelString(request.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
return &schemas.BifrostImageEditRequest{
Provider: provider,
Model: model,
Input: request.Input,
Params: &request.ImageEditParameters,
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
}
}
func (request *OpenAIImageVariationRequest) ToBifrostImageVariationRequest(ctx *schemas.BifrostContext) *schemas.BifrostImageVariationRequest {
if request == nil {
return nil
}
provider, model := schemas.ParseModelString(request.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
return &schemas.BifrostImageVariationRequest{
Provider: provider,
Model: model,
Input: request.Input,
Params: &request.ImageVariationParameters,
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
}
}
func ToOpenAIImageEditRequest(bifrostReq *schemas.BifrostImageEditRequest) *OpenAIImageEditRequest {
if bifrostReq == nil || bifrostReq.Input == nil || bifrostReq.Input.Images == nil || bifrostReq.Input.Prompt == "" {
return nil
}
req := &OpenAIImageEditRequest{
Model: bifrostReq.Model,
Input: bifrostReq.Input,
}
if bifrostReq.Params != nil {
req.ImageEditParameters = *bifrostReq.Params
}
if bifrostReq.Params != nil {
req.ExtraParams = bifrostReq.Params.ExtraParams
}
return req
}
func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIImageEditRequest, providerName schemas.ModelProvider) *schemas.BifrostError {
// Add model field (required)
if err := writer.WriteField("model", openaiReq.Model); err != nil {
return providerUtils.NewBifrostOperationError("failed to write model field", err)
}
// Add prompt field (required)
if err := writer.WriteField("prompt", openaiReq.Input.Prompt); err != nil {
return providerUtils.NewBifrostOperationError("failed to write prompt field", err)
}
// Add stream field when requesting streaming
if openaiReq.Stream != nil && *openaiReq.Stream {
if err := writer.WriteField("stream", "true"); err != nil {
return providerUtils.NewBifrostOperationError("failed to write stream field", err)
}
}
// Add optional parameters before file parts so routing metadata arrives first upstream.
if openaiReq.N != nil {
if err := writer.WriteField("n", strconv.Itoa(*openaiReq.N)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write n field", err)
}
}
if openaiReq.Size != nil {
if err := writer.WriteField("size", *openaiReq.Size); err != nil {
return providerUtils.NewBifrostOperationError("failed to write size field", err)
}
}
if openaiReq.ResponseFormat != nil {
if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil {
return providerUtils.NewBifrostOperationError("failed to write response_format field", err)
}
}
if openaiReq.Quality != nil {
if err := writer.WriteField("quality", *openaiReq.Quality); err != nil {
return providerUtils.NewBifrostOperationError("failed to write quality field", err)
}
}
if openaiReq.Background != nil {
if err := writer.WriteField("background", *openaiReq.Background); err != nil {
return providerUtils.NewBifrostOperationError("failed to write background field", err)
}
}
if openaiReq.InputFidelity != nil {
if err := writer.WriteField("input_fidelity", *openaiReq.InputFidelity); err != nil {
return providerUtils.NewBifrostOperationError("failed to write input_fidelity field", err)
}
}
if openaiReq.PartialImages != nil {
if err := writer.WriteField("partial_images", strconv.Itoa(*openaiReq.PartialImages)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write partial_images field", err)
}
}
if openaiReq.OutputFormat != nil {
if err := writer.WriteField("output_format", *openaiReq.OutputFormat); err != nil {
return providerUtils.NewBifrostOperationError("failed to write output_format field", err)
}
}
if openaiReq.OutputCompression != nil {
if err := writer.WriteField("output_compression", strconv.Itoa(*openaiReq.OutputCompression)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write output_compression field", err)
}
}
if openaiReq.User != nil {
if err := writer.WriteField("user", *openaiReq.User); err != nil {
return providerUtils.NewBifrostOperationError("failed to write user field", err)
}
}
// Add image[] fields (one for each image)
for i, imageInput := range openaiReq.Input.Images {
fieldName := "image[]"
// Detect and validate MIME type
mimeType := http.DetectContentType(imageInput.Image)
// Fallback to PNG if content type is undetectable or generic
if mimeType == "" || mimeType == "application/octet-stream" {
mimeType = "image/png"
}
// Determine filename based on MIME type
var filename string
switch mimeType {
case "image/jpeg":
filename = fmt.Sprintf("image%d.jpg", i)
case "image/webp":
filename = fmt.Sprintf("image%d.webp", i)
default:
filename = fmt.Sprintf("image%d.png", i)
}
// Create form part with proper Content-Type header (not CreateFormFile which defaults to application/octet-stream)
part, err := writer.CreatePart(map[string][]string{
"Content-Disposition": {fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, filename)},
"Content-Type": {mimeType},
})
if err != nil {
return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to create form part for image %d", i), err)
}
if _, err := part.Write(imageInput.Image); err != nil {
return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to write image %d data", i), err)
}
}
// Add mask if present
if len(openaiReq.Mask) > 0 {
// Detect MIME type for mask
maskMimeType := http.DetectContentType(openaiReq.Mask)
if maskMimeType != "image/png" && maskMimeType != "image/jpeg" && maskMimeType != "image/webp" {
maskMimeType = "image/png"
}
var maskFilename string
switch maskMimeType {
case "image/jpeg":
maskFilename = "mask.jpg"
case "image/webp":
maskFilename = "mask.webp"
default:
maskFilename = "mask.png"
}
// Create form part with proper Content-Type header
maskPart, err := writer.CreatePart(map[string][]string{
"Content-Disposition": {`form-data; name="mask"; filename="` + maskFilename + `"`},
"Content-Type": {maskMimeType},
})
if err != nil {
return providerUtils.NewBifrostOperationError("failed to create mask form part", err)
}
if _, err := maskPart.Write(openaiReq.Mask); err != nil {
return providerUtils.NewBifrostOperationError("failed to write mask data", err)
}
}
// Close the multipart writer
if err := writer.Close(); err != nil {
return providerUtils.NewBifrostOperationError("failed to close multipart writer", err)
}
return nil
}
func ToOpenAIImageVariationRequest(bifrostReq *schemas.BifrostImageVariationRequest) *OpenAIImageVariationRequest {
if bifrostReq == nil || bifrostReq.Input == nil || bifrostReq.Input.Image.Image == nil || len(bifrostReq.Input.Image.Image) == 0 {
return nil
}
req := &OpenAIImageVariationRequest{
Model: bifrostReq.Model,
Input: bifrostReq.Input,
}
if bifrostReq.Params != nil {
req.ImageVariationParameters = *bifrostReq.Params
}
if bifrostReq.Params != nil {
req.ExtraParams = bifrostReq.Params.ExtraParams
}
return req
}
func parseImageVariationFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIImageVariationRequest, providerName schemas.ModelProvider) *schemas.BifrostError {
// Add model field (required)
if err := writer.WriteField("model", openaiReq.Model); err != nil {
return providerUtils.NewBifrostOperationError("failed to write model field", err)
}
// Add image file (required)
if openaiReq.Input == nil || openaiReq.Input.Image.Image == nil || len(openaiReq.Input.Image.Image) == 0 {
return providerUtils.NewBifrostOperationError("image is required", nil)
}
// Add optional parameters before the image part so metadata arrives first upstream.
if openaiReq.N != nil {
if err := writer.WriteField("n", strconv.Itoa(*openaiReq.N)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write n field", err)
}
}
if openaiReq.ResponseFormat != nil {
if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil {
return providerUtils.NewBifrostOperationError("failed to write response_format field", err)
}
}
if openaiReq.Size != nil {
if err := writer.WriteField("size", *openaiReq.Size); err != nil {
return providerUtils.NewBifrostOperationError("failed to write size field", err)
}
}
if openaiReq.User != nil {
if err := writer.WriteField("user", *openaiReq.User); err != nil {
return providerUtils.NewBifrostOperationError("failed to write user field", err)
}
}
// Detect MIME type
mimeType := http.DetectContentType(openaiReq.Input.Image.Image)
// If still not detected, default to PNG
if mimeType == "application/octet-stream" || mimeType == "" {
mimeType = "image/png"
}
filename := "image"
part, err := writer.CreatePart(map[string][]string{
"Content-Disposition": {fmt.Sprintf(`form-data; name="image"; filename="%s"`, filename)},
"Content-Type": {mimeType},
})
if err != nil {
return providerUtils.NewBifrostOperationError("failed to create image part", err)
}
if _, err := part.Write(openaiReq.Input.Image.Image); err != nil {
return providerUtils.NewBifrostOperationError("failed to write image data", err)
}
// Close the multipart writer
if err := writer.Close(); err != nil {
return providerUtils.NewBifrostOperationError("failed to close multipart writer", err)
}
return nil
}

View File

@@ -0,0 +1,156 @@
// Package openai provides the OpenAI provider implementation for the Bifrost framework.
package openai
import (
"net/http"
"time"
"github.com/bytedance/sonic"
"github.com/valyala/fasthttp"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
)
// largePayloadResult holds the lightweight metadata extracted from a large payload passthrough.
type largePayloadResult struct {
Usage *schemas.BifrostLLMUsage
Latency int64
ResponseBody []byte // non-nil for request types that need the raw upstream response (transcription, speech, etc.)
}
// setStreamingRequestBody sets the request body for streaming handlers.
// In normal mode it uses the marshaled jsonBody. In large payload mode it delegates to
// ApplyLargePayloadRequestBodyWithModelNormalization which streams the original request
// body to upstream with model prefix rewriting.
func setStreamingRequestBody(ctx *schemas.BifrostContext, req *fasthttp.Request, jsonBody []byte, providerName schemas.ModelProvider) {
if !providerUtils.ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, providerName) {
req.SetBody(jsonBody)
}
}
// handleOpenAILargePayloadPassthrough handles a complete large payload request-response cycle
// for OpenAI-compatible providers. When large payload mode is active, it streams the request
// body to upstream and optionally streams the response back without full materialization.
//
// Returns (result, nil, true) on success, (nil, err, true) on error, or (nil, nil, false) when
// large payload mode is not active and the caller should use the normal path.
func handleOpenAILargePayloadPassthrough(
ctx *schemas.BifrostContext,
client *fasthttp.Client,
url string,
key schemas.Key,
extraHeaders map[string]string,
providerName schemas.ModelProvider,
logger schemas.Logger,
) (*largePayloadResult, *schemas.BifrostError, bool) {
isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool)
if !isLargePayload {
return nil, nil, false
}
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
// resp lifecycle: managed manually when large response streaming is active
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
req.SetRequestURI(url)
req.Header.SetMethod(http.MethodPost)
if key.Value.GetValue() != "" {
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
}
// Rewrite model prefix and stream request body to upstream.
// Sets content-type from context; falls back to JSON if not set.
if !providerUtils.ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, providerName) {
fasthttp.ReleaseResponse(resp)
return nil, nil, false
}
if len(req.Header.ContentType()) == 0 {
req.Header.SetContentType("application/json")
}
// Choose client: enable response body streaming when threshold is configured
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
wait()
if bifrostErr != nil {
fasthttp.ReleaseResponse(resp)
return nil, bifrostErr, true
}
// Extract provider response headers early so they're available on error and large-response paths
if headers := providerUtils.ExtractProviderResponseHeaders(resp); headers != nil {
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers)
}
// Error responses are always small — materialize stream body for error parsing
if resp.StatusCode() != fasthttp.StatusOK {
providerUtils.MaterializeStreamErrorBody(ctx, resp)
parsedErr := ParseOpenAIError(resp)
fasthttp.ReleaseResponse(resp)
return nil, parsedErr, true
}
// Delegate response body handling (large detection + resp lifecycle) to finalizeOpenAIResponse
body, result, respErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger)
if respErr != nil {
return nil, respErr, true
}
if result != nil {
return result, nil, true
}
// Normal path — extract usage from raw bytes (passthrough doesn't parse structured response)
usage := extractOpenAIUsageFromBytes(body)
return &largePayloadResult{Usage: usage, Latency: latency.Milliseconds(), ResponseBody: body}, nil, true
}
// finalizeOpenAIResponse handles response body processing with optional large response detection.
// Delegates to FinalizeResponseWithLargeDetection for the core branching logic.
// Takes ownership of resp — caller must NOT defer ReleaseResponse and must set respOwned = false
// after this call returns.
//
// Returns:
// - (body, nil, nil) — normal path; body ready for parsing; resp released.
// - (nil, result, nil) — large response detected; context flags set for streaming; resp
// wrapped in reader (released on reader Close).
// - (nil, nil, err) — error; resp released.
func finalizeOpenAIResponse(
ctx *schemas.BifrostContext,
resp *fasthttp.Response,
latency time.Duration,
providerName schemas.ModelProvider,
logger schemas.Logger,
) ([]byte, *largePayloadResult, *schemas.BifrostError) {
body, isLarge, bifrostErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, logger)
if bifrostErr != nil {
fasthttp.ReleaseResponse(resp)
return nil, nil, bifrostErr
}
if isLarge {
// Extract usage from the response preview stored in context by FinalizeResponseWithLargeDetection
preview, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadResponsePreview).(string)
usage := extractOpenAIUsageFromBytes([]byte(preview))
// resp owned by LargeResponseReader in context — don't release
return nil, &largePayloadResult{Usage: usage, Latency: latency.Milliseconds()}, nil
}
// Normal path — body already copied by shared utility, safe to release resp
fasthttp.ReleaseResponse(resp)
return body, nil, nil
}
// extractOpenAIUsageFromBytes extracts usage metadata from OpenAI response bytes using sonic.Get.
// OpenAI responses have "usage" at the top level with prompt_tokens, completion_tokens, total_tokens.
func extractOpenAIUsageFromBytes(data []byte) *schemas.BifrostLLMUsage {
node, err := sonic.Get(data, "usage")
if err != nil {
return nil
}
raw, err := node.Raw()
if err != nil || raw == "" {
return nil
}
return providerUtils.ParseOpenAIUsageFromBytes([]byte(raw))
}

View File

@@ -0,0 +1,80 @@
package openai
import (
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostListModelsResponse converts an OpenAI list models response to a Bifrost list models response
func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(response.Data)),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: providerKey,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
for _, model := range response.Data {
for _, result := range pipeline.FilterModel(model.ID) {
entry := schemas.Model{
ID: string(providerKey) + "/" + result.ResolvedID,
Created: model.Created,
OwnedBy: schemas.Ptr(model.OwnedBy),
ContextLength: model.ContextWindow,
}
if result.AliasValue != "" {
entry.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, entry)
included[strings.ToLower(result.ResolvedID)] = true
}
}
bifrostResponse.Data = append(bifrostResponse.Data,
pipeline.BackfillModels(included)...)
return bifrostResponse
}
// ToOpenAIListModelsResponse converts a Bifrost list models response to an OpenAI list models response
func ToOpenAIListModelsResponse(response *schemas.BifrostListModelsResponse) *OpenAIListModelsResponse {
if response == nil {
return nil
}
openaiResponse := &OpenAIListModelsResponse{
Data: make([]OpenAIModel, 0, len(response.Data)),
}
for _, model := range response.Data {
openaiModel := OpenAIModel{
ID: model.ID,
Object: "model",
}
if model.Created != nil {
openaiModel.Created = model.Created
}
if model.OwnedBy != nil {
openaiModel.OwnedBy = *model.OwnedBy
}
openaiResponse.Data = append(openaiResponse.Data, openaiModel)
}
return openaiResponse
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,119 @@
package openai_test
import (
"os"
"strings"
"testing"
"github.com/maximhq/bifrost/core/internal/llmtests"
"github.com/maximhq/bifrost/core/schemas"
)
func TestOpenAI(t *testing.T) {
t.Parallel()
if strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) == "" {
t.Skip("Skipping OpenAI tests because OPENAI_API_KEY is not set")
}
client, ctx, cancel, err := llmtests.SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
testConfig := llmtests.ComprehensiveTestConfig{
Provider: schemas.OpenAI,
TextModel: "gpt-3.5-turbo-instruct",
ChatModel: "gpt-4o",
PromptCachingModel: "gpt-4.1",
Fallbacks: []schemas.Fallback{
{Provider: schemas.OpenAI, Model: "gpt-4o"},
},
VisionModel: "gpt-4o",
EmbeddingModel: "text-embedding-3-small",
TranscriptionModel: "gpt-4o-transcribe",
TranscriptionFallbacks: []schemas.Fallback{
{Provider: schemas.OpenAI, Model: "whisper-1"},
},
SpeechSynthesisModel: "gpt-4o-mini-tts",
ReasoningModel: "o4-mini", // o4-mini properly returns both reasoning items and message output
ImageGenerationModel: "gpt-image-1",
ImageEditModel: "gpt-image-1",
ImageVariationModel: "", // dall-e-2 is deprecated and no other OpenAI model supports image variations
VideoGenerationModel: "sora-2",
ChatAudioModel: "gpt-4o-mini-audio-preview",
PassthroughModel: "gpt-4o",
Scenarios: llmtests.TestScenarios{
TextCompletion: true,
TextCompletionStream: true,
SimpleChat: true,
CompletionStream: true,
MultiTurnConversation: true,
ToolCalls: true,
ToolCallsStreaming: true,
MultipleToolCalls: true,
MultipleToolCallsStreaming: true,
End2EndToolCalling: true,
AutomaticFunctionCall: true,
WebSearchTool: true,
ImageURL: true,
ImageBase64: true,
MultipleImages: true,
FileBase64: true,
FileURL: true,
CompleteEnd2End: true,
SpeechSynthesis: true,
SpeechSynthesisStream: true,
Transcription: true,
TranscriptionStream: true,
Embedding: true,
Reasoning: true,
ListModels: true,
ImageGeneration: true,
ImageGenerationStream: true,
ImageEdit: true,
ImageEditStream: true,
ImageVariation: false, // dall-e-2 is deprecated and no other OpenAI model supports image variations
VideoGeneration: false, // disabled for now because of long running operations
VideoRetrieve: false,
VideoRemix: false,
VideoDownload: false,
VideoList: false,
VideoDelete: false,
BatchCreate: true,
BatchList: true,
BatchRetrieve: true,
BatchCancel: true,
BatchResults: true,
FileUpload: true,
FileList: true,
FileRetrieve: true,
FileDelete: true,
FileContent: true,
FileBatchInput: true,
CountTokens: true,
ChatAudio: true,
StructuredOutputs: true, // Structured outputs with nullable enum support
ContainerCreate: true,
ContainerList: true,
ContainerRetrieve: true,
ContainerDelete: true,
ContainerFileCreate: true,
ContainerFileList: true,
ContainerFileRetrieve: true,
ContainerFileContent: true,
ContainerFileDelete: true,
PromptCaching: true,
PassthroughAPI: true,
WebSocketResponses: true,
Realtime: false,
},
RealtimeModel: "gpt-4o-realtime-preview",
}
t.Run("OpenAITests", func(t *testing.T) {
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
})
}

View File

@@ -0,0 +1,121 @@
package openai
import (
"bytes"
"mime/multipart"
"testing"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPayloadOrdering_OpenAIChatRequest(t *testing.T) {
req := &OpenAIChatRequest{
Model: "gpt-4o",
Messages: []OpenAIMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hello")},
},
},
ChatParameters: schemas.ChatParameters{
Temperature: schemas.Ptr(0.7),
Tools: []schemas.ChatTool{
{
Type: "function",
Function: &schemas.ChatToolFunction{
Name: "get_weather",
Description: schemas.Ptr("Get weather"),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("location", map[string]interface{}{"type": "string"}),
),
Required: []string{"location"},
},
},
},
},
Reasoning: &schemas.ChatReasoning{
Effort: schemas.Ptr("high"),
},
},
Stream: schemas.Ptr(true),
}
result, err := providerUtils.MarshalSorted(req)
require.NoError(t, err)
golden := `{"model":"gpt-4o","temperature":0.7,"stream":true,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}}],"reasoning_effort":"high"}`
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
// Determinism: 100 iterations must produce identical bytes
for i := 0; i < 100; i++ {
iter, err := providerUtils.MarshalSorted(req)
require.NoError(t, err)
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
}
}
func TestParseImageEditFormDataBodyFromRequest_OrdersMetadataBeforeFiles(t *testing.T) {
req := &OpenAIImageEditRequest{
Model: "gpt-image-1",
Input: &schemas.ImageEditInput{
Prompt: "edit this",
Images: []schemas.ImageInput{{Image: []byte("image-one")}, {Image: []byte("image-two")}},
},
ImageEditParameters: schemas.ImageEditParameters{
N: schemas.Ptr(2),
Size: schemas.Ptr("1024x1024"),
ResponseFormat: schemas.Ptr("b64_json"),
Quality: schemas.Ptr("high"),
Background: schemas.Ptr("transparent"),
InputFidelity: schemas.Ptr("high"),
PartialImages: schemas.Ptr(1),
OutputFormat: schemas.Ptr("png"),
OutputCompression: schemas.Ptr(80),
User: schemas.Ptr("user-123"),
Mask: []byte("mask-image"),
},
Stream: schemas.Ptr(true),
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
require.Nil(t, parseImageEditFormDataBodyFromRequest(writer, req, schemas.OpenAI))
order := multipartPartOrder(t, writer.FormDataContentType(), body.Bytes())
assert.Equal(t,
[]string{"model", "prompt", "stream", "n", "size", "response_format", "quality", "background", "input_fidelity", "partial_images", "output_format", "output_compression", "user", "image[]", "image[]", "mask"},
order,
)
}
func TestParseImageVariationFormDataBodyFromRequest_OrdersMetadataBeforeFile(t *testing.T) {
req := &OpenAIImageVariationRequest{
Model: "gpt-image-1",
Input: &schemas.ImageVariationInput{
Image: schemas.ImageInput{Image: []byte("image-variation")},
},
ImageVariationParameters: schemas.ImageVariationParameters{
N: schemas.Ptr(3),
ResponseFormat: schemas.Ptr("url"),
Size: schemas.Ptr("512x512"),
User: schemas.Ptr("user-456"),
},
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
require.Nil(t, parseImageVariationFormDataBodyFromRequest(writer, req, schemas.OpenAI))
order := multipartPartOrder(t, writer.FormDataContentType(), body.Bytes())
assert.Equal(t,
[]string{"model", "n", "response_format", "size", "user", "image"},
order,
)
}

View File

@@ -0,0 +1,967 @@
package openai
import (
"bytes"
"encoding/json"
"fmt"
"mime/multipart"
"net/http"
"net/url"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// SupportsRealtimeAPI returns true since OpenAI natively supports the Realtime API.
func (provider *OpenAIProvider) SupportsRealtimeAPI() bool {
return true
}
// RealtimeWebSocketURL returns the WSS URL for the OpenAI Realtime API.
// Format: wss://api.openai.com/v1/realtime?model=<model>
func (provider *OpenAIProvider) RealtimeWebSocketURL(key schemas.Key, model string) string {
base := provider.networkConfig.BaseURL
base = strings.Replace(base, "https://", "wss://", 1)
base = strings.Replace(base, "http://", "ws://", 1)
return base + "/v1/realtime?model=" + url.QueryEscape(model)
}
// RealtimeHeaders returns the headers required for the OpenAI Realtime WebSocket connection.
func (provider *OpenAIProvider) RealtimeHeaders(key schemas.Key) map[string]string {
headers := map[string]string{
"Authorization": "Bearer " + key.Value.GetValue(),
}
for k, v := range provider.networkConfig.ExtraHeaders {
headers[k] = v
}
return headers
}
// SupportsRealtimeWebRTC reports that OpenAI supports WebRTC SDP exchange.
func (provider *OpenAIProvider) SupportsRealtimeWebRTC() bool {
return true
}
// ExchangeRealtimeWebRTCSDP performs the GA SDP exchange via multipart POST to /v1/realtime/calls.
func (provider *OpenAIProvider) ExchangeRealtimeWebRTCSDP(
ctx *schemas.BifrostContext,
key schemas.Key,
model string,
sdp string,
session json.RawMessage,
) (string, *schemas.BifrostError) {
path := "/v1/realtime/calls"
if session == nil && strings.TrimSpace(model) != "" {
path += "?model=" + url.QueryEscape(model)
}
return provider.exchangeWebRTCSDP(ctx, key, path, sdp, session)
}
// ExchangeLegacyRealtimeWebRTCSDP performs the beta SDP exchange via multipart POST to /v1/realtime.
// Same multipart format but targets the legacy endpoint with model in the URL.
func (provider *OpenAIProvider) ExchangeLegacyRealtimeWebRTCSDP(
ctx *schemas.BifrostContext,
key schemas.Key,
sdp string,
session json.RawMessage,
model string,
) (string, *schemas.BifrostError) {
return provider.exchangeWebRTCSDP(ctx, key, "/v1/realtime?model="+url.QueryEscape(model), sdp, session)
}
// exchangeWebRTCSDP is the shared multipart SDP exchange implementation.
// Builds a multipart body with sdp + optional session, POSTs to the given path.
func (provider *OpenAIProvider) exchangeWebRTCSDP(
ctx *schemas.BifrostContext,
key schemas.Key,
path string,
sdp string,
session json.RawMessage,
) (string, *schemas.BifrostError) {
bodyBuf := &bytes.Buffer{}
writer := multipart.NewWriter(bodyBuf)
if err := writer.WriteField("sdp", sdp); err != nil {
return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream SDP body", err)
}
if session != nil {
if err := writer.WriteField("session", string(session)); err != nil {
return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream session body", err)
}
}
if err := writer.Close(); err != nil {
return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to finalize upstream SDP body", err)
}
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(provider.buildRequestURL(ctx, path, schemas.RealtimeRequest))
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType(writer.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
for k, v := range provider.networkConfig.ExtraHeaders {
req.Header.Set(k, v)
}
if headers, _ := ctx.Value(schemas.BifrostContextKeyRequestHeaders).(map[string]string); headers != nil {
if agentsSDK := headers["x-openai-agents-sdk"]; agentsSDK != "" {
req.Header.Set("X-OpenAI-Agents-SDK", agentsSDK)
}
}
req.SetBody(bodyBuf.Bytes())
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return "", bifrostErr
}
answerBody := resp.Body()
if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices {
return "", provider.realtimeWebRTCUpstreamError(ctx, resp.StatusCode(), answerBody)
}
return string(answerBody), nil
}
func (provider *OpenAIProvider) realtimeWebRTCUpstreamError(ctx *schemas.BifrostContext, statusCode int, body []byte) *schemas.BifrostError {
bifrostErr := &schemas.BifrostError{
IsBifrostError: false,
StatusCode: schemas.Ptr(fasthttp.StatusBadGateway),
Error: &schemas.ErrorField{
Type: schemas.Ptr("upstream_connection_error"),
Message: fmt.Sprintf("upstream realtime WebRTC handshake failed for %s", provider.GetProviderKey()),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.RealtimeRequest,
Provider: provider.GetProviderKey(),
},
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
bifrostErr.ExtraFields.RawResponse = map[string]any{
"status": statusCode,
"body": string(body),
}
}
return bifrostErr
}
func newRealtimeWebRTCSDPError(status int, errorType, message string, err error) *schemas.BifrostError {
bifrostErr := &schemas.BifrostError{
IsBifrostError: true,
StatusCode: schemas.Ptr(status),
Error: &schemas.ErrorField{
Type: schemas.Ptr(errorType),
Message: message,
},
}
if err != nil {
bifrostErr.Error.Error = err
}
return bifrostErr
}
func (provider *OpenAIProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool {
if event == nil {
return false
}
switch event.Type {
case schemas.RTEventResponseCreate, schemas.RTEventInputAudioBufferCommitted:
return true
default:
return false
}
}
func (provider *OpenAIProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType {
return schemas.RTEventResponseDone
}
func (provider *OpenAIProvider) RealtimeWebRTCDataChannelLabel() string {
return "oai-events"
}
func (provider *OpenAIProvider) RealtimeWebSocketSubprotocol() string {
return "realtime"
}
func (provider *OpenAIProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool {
return true
}
func (provider *OpenAIProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool {
switch eventType {
case schemas.RTEventResponseTextDelta,
schemas.RTEventResponseAudioTransDelta,
schemas.RealtimeEventType("response.output_text.delta"),
schemas.RealtimeEventType("response.output_audio_transcript.delta"):
return true
default:
return false
}
}
// CreateRealtimeClientSecret mints an OpenAI Realtime client secret and returns
// the native OpenAI response body unchanged.
func (provider *OpenAIProvider) CreateRealtimeClientSecret(
ctx *schemas.BifrostContext,
key schemas.Key,
endpointType schemas.RealtimeSessionEndpointType,
rawRequest json.RawMessage,
) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.RealtimeRequest); err != nil {
return nil, err
}
normalizedBody, requestedModel, bifrostErr := normalizeRealtimeClientSecretRequest(rawRequest, provider.GetProviderKey(), endpointType)
if bifrostErr != nil {
return nil, bifrostErr
}
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(provider.buildRequestURL(ctx, realtimeSessionUpstreamPath(endpointType), schemas.RealtimeRequest))
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
for k, v := range provider.realtimeSessionHeaders(key, endpointType) {
req.Header.Set(k, v)
}
req.SetBody(normalizedBody)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
headers := providerUtils.ExtractProviderResponseHeaders(resp)
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers)
if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices {
return nil, ParseOpenAIError(resp)
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err)
}
for k := range headers {
if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") {
delete(headers, k)
}
}
out := &schemas.BifrostPassthroughResponse{
StatusCode: resp.StatusCode(),
Headers: headers,
Body: body,
}
out.ExtraFields.Provider = provider.GetProviderKey()
out.ExtraFields.OriginalModelRequested = requestedModel
out.ExtraFields.RequestType = schemas.RealtimeRequest
out.ExtraFields.Latency = latency.Milliseconds()
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequestIfJSON(req, &out.ExtraFields)
}
return out, nil
}
func normalizeRealtimeClientSecretRequest(
rawRequest json.RawMessage,
defaultProvider schemas.ModelProvider,
endpointType schemas.RealtimeSessionEndpointType,
) ([]byte, string, *schemas.BifrostError) {
root, bifrostErr := schemas.ParseRealtimeClientSecretBody(rawRequest)
if bifrostErr != nil {
return nil, "", bifrostErr
}
modelValue, bifrostErr := schemas.ExtractRealtimeClientSecretModel(root)
if bifrostErr != nil {
return nil, "", bifrostErr
}
providerKey, normalizedModel := schemas.ParseModelString(modelValue, defaultProvider)
if normalizedModel == "" {
return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model is required", nil)
}
if providerKey == "" {
providerKey = defaultProvider
}
if providerKey == "" {
return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "unable to determine provider from model", nil)
}
if endpointType == schemas.RealtimeSessionEndpointSessions {
return normalizeRealtimeSessionsRequest(root, normalizedModel)
}
return normalizeRealtimeClientSecretsRequest(root, normalizedModel)
}
func normalizeRealtimeClientSecretsRequest(
root map[string]json.RawMessage,
normalizedModel string,
) ([]byte, string, *schemas.BifrostError) {
session := map[string]json.RawMessage{}
if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) {
if err := json.Unmarshal(existingSession, &session); err != nil {
return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err)
}
}
modelJSON, marshalErr := json.Marshal(normalizedModel)
if marshalErr != nil {
return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr)
}
session["model"] = modelJSON
if _, ok := session["type"]; !ok {
typeJSON, marshalErr := json.Marshal("realtime")
if marshalErr != nil {
return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session type", marshalErr)
}
session["type"] = typeJSON
}
delete(root, "model")
sessionJSON, marshalErr := json.Marshal(session)
if marshalErr != nil {
return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session", marshalErr)
}
root["session"] = sessionJSON
normalizedBody, marshalErr := json.Marshal(root)
if marshalErr != nil {
return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr)
}
return normalizedBody, normalizedModel, nil
}
func normalizeRealtimeSessionsRequest(
root map[string]json.RawMessage,
normalizedModel string,
) ([]byte, string, *schemas.BifrostError) {
if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) {
session := map[string]json.RawMessage{}
if err := json.Unmarshal(existingSession, &session); err != nil {
return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err)
}
for key, value := range session {
if _, exists := root[key]; !exists {
root[key] = value
}
}
}
modelJSON, marshalErr := json.Marshal(normalizedModel)
if marshalErr != nil {
return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr)
}
root["model"] = modelJSON
delete(root, "session")
normalizedBody, marshalErr := json.Marshal(root)
if marshalErr != nil {
return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr)
}
return normalizedBody, normalizedModel, nil
}
func (provider *OpenAIProvider) realtimeSessionHeaders(
key schemas.Key,
endpointType schemas.RealtimeSessionEndpointType,
) map[string]string {
headers := map[string]string{
"Authorization": "Bearer " + key.Value.GetValue(),
}
if endpointType == schemas.RealtimeSessionEndpointSessions {
headers["OpenAI-Beta"] = "realtime=v1"
}
for k, v := range provider.networkConfig.ExtraHeaders {
headers[k] = v
}
return headers
}
func realtimeSessionUpstreamPath(endpointType schemas.RealtimeSessionEndpointType) string {
if endpointType == schemas.RealtimeSessionEndpointSessions {
return "/v1/realtime/sessions"
}
return "/v1/realtime/client_secrets"
}
func newRealtimeClientSecretError(status int, errorType, message string, err error) *schemas.BifrostError {
return &schemas.BifrostError{
IsBifrostError: false,
StatusCode: schemas.Ptr(status),
Error: &schemas.ErrorField{
Type: schemas.Ptr(errorType),
Message: message,
Error: err,
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.RealtimeRequest,
Provider: schemas.OpenAI,
},
}
}
// openAIRealtimeEvent is the raw shape of an OpenAI Realtime protocol event.
type openAIRealtimeEvent struct {
Type string `json:"type"`
EventID string `json:"event_id,omitempty"`
Session json.RawMessage `json:"session,omitempty"`
Conversation json.RawMessage `json:"conversation,omitempty"`
Item json.RawMessage `json:"item,omitempty"`
Response json.RawMessage `json:"response,omitempty"`
Part json.RawMessage `json:"part,omitempty"`
Delta string `json:"delta,omitempty"`
Audio string `json:"audio,omitempty"`
Transcript string `json:"transcript,omitempty"`
Text string `json:"text,omitempty"`
Error json.RawMessage `json:"error,omitempty"`
ItemID string `json:"item_id,omitempty"`
OutputIndex *int `json:"output_index,omitempty"`
ContentIndex *int `json:"content_index,omitempty"`
ResponseID string `json:"response_id,omitempty"`
AudioEndMS *int `json:"audio_end_ms,omitempty"`
PreviousItemID string `json:"previous_item_id,omitempty"`
}
// openAIRealtimeSession is the session object within an OpenAI Realtime event.
type openAIRealtimeSession struct {
ID string `json:"id,omitempty"`
Model string `json:"model,omitempty"`
Modalities []string `json:"modalities,omitempty"`
Instructions string `json:"instructions,omitempty"`
Voice string `json:"voice,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"`
TurnDetection json.RawMessage `json:"turn_detection,omitempty"`
InputAudioFormat string `json:"input_audio_format,omitempty"`
OutputAudioType string `json:"output_audio_type,omitempty"`
Tools json.RawMessage `json:"tools,omitempty"`
}
// openAIRealtimeItem is the item object within an OpenAI Realtime event.
type openAIRealtimeItem struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Role string `json:"role,omitempty"`
Status string `json:"status,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
Name string `json:"name,omitempty"`
CallID string `json:"call_id,omitempty"`
Arguments string `json:"arguments,omitempty"`
Output string `json:"output,omitempty"`
}
// openAIRealtimeError is the error object within an OpenAI Realtime event.
type openAIRealtimeError struct {
Type string `json:"type,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Param string `json:"param,omitempty"`
}
// ToBifrostRealtimeEvent converts an OpenAI Realtime event (raw JSON) to the unified Bifrost format.
func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMessage) (*schemas.BifrostRealtimeEvent, error) {
var raw openAIRealtimeEvent
if err := json.Unmarshal(providerEvent, &raw); err != nil {
return nil, fmt.Errorf("failed to unmarshal OpenAI realtime event: %w", err)
}
event := &schemas.BifrostRealtimeEvent{
Type: schemas.RealtimeEventType(raw.Type),
EventID: raw.EventID,
RawData: providerEvent,
}
setRealtimeExtraParam(event, "item_id", raw.ItemID)
setRealtimeExtraParam(event, "previous_item_id", raw.PreviousItemID)
setRealtimeExtraParam(event, "output_index", raw.OutputIndex)
setRealtimeExtraParam(event, "content_index", raw.ContentIndex)
setRealtimeExtraParam(event, "response_id", raw.ResponseID)
setRealtimeExtraParam(event, "audio_end_ms", raw.AudioEndMS)
setRealtimeExtraParam(event, "transcript", raw.Transcript)
setRealtimeExtraParam(event, "text", raw.Text)
setRealtimeExtraParam(event, "conversation", raw.Conversation)
setRealtimeExtraParam(event, "response", raw.Response)
setRealtimeExtraParam(event, "part", raw.Part)
switch {
case raw.Session != nil:
var sess openAIRealtimeSession
if err := json.Unmarshal(raw.Session, &sess); err == nil {
event.Session = &schemas.RealtimeSession{
ID: sess.ID,
Model: sess.Model,
Modalities: sess.Modalities,
Instructions: sess.Instructions,
Voice: sess.Voice,
Temperature: sess.Temperature,
MaxOutputTokens: sess.MaxOutputTokens,
TurnDetection: sess.TurnDetection,
InputAudioFormat: sess.InputAudioFormat,
OutputAudioType: sess.OutputAudioType,
Tools: sess.Tools,
}
if extra := extractRealtimeNestedParams(raw.Session, "id", "model", "modalities", "instructions", "voice", "temperature", "max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools"); len(extra) > 0 {
event.Session.ExtraParams = extra
}
}
case raw.Item != nil:
var item openAIRealtimeItem
if err := json.Unmarshal(raw.Item, &item); err == nil {
event.Item = &schemas.RealtimeItem{
ID: item.ID,
Type: item.Type,
Role: item.Role,
Status: item.Status,
Content: item.Content,
Name: item.Name,
CallID: item.CallID,
Arguments: item.Arguments,
Output: item.Output,
}
if extra := extractRealtimeNestedParams(raw.Item, "id", "type", "role", "status", "content", "name", "call_id", "arguments", "output"); len(extra) > 0 {
event.Item.ExtraParams = extra
}
}
case raw.Error != nil:
var rtErr openAIRealtimeError
if err := json.Unmarshal(raw.Error, &rtErr); err == nil {
event.Error = &schemas.RealtimeError{
Type: rtErr.Type,
Code: rtErr.Code,
Message: rtErr.Message,
Param: rtErr.Param,
}
if extra := extractRealtimeNestedParams(raw.Error, "type", "code", "message", "param"); len(extra) > 0 {
event.Error.ExtraParams = extra
}
}
}
if isRealtimeDeltaEvent(raw.Type) {
event.Delta = &schemas.RealtimeDelta{
Text: raw.Text,
Audio: raw.Audio,
Transcript: raw.Transcript,
ItemID: raw.ItemID,
OutputIdx: raw.OutputIndex,
ContentIdx: raw.ContentIndex,
ResponseID: raw.ResponseID,
}
if raw.Delta != "" {
if event.Delta.Text == "" {
event.Delta.Text = raw.Delta
}
}
}
return event, nil
}
// ToProviderRealtimeEvent converts a unified Bifrost Realtime event back to OpenAI's native JSON.
func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) {
out := map[string]interface{}{
"type": string(bifrostEvent.Type),
}
if bifrostEvent.EventID != "" {
out["event_id"] = bifrostEvent.EventID
}
mergeRealtimeExtraParams(out, bifrostEvent.ExtraParams)
if bifrostEvent.Session != nil {
sess := map[string]interface{}{}
if bifrostEvent.Session.ID != "" && bifrostEvent.Type != schemas.RTEventSessionUpdate {
sess["id"] = bifrostEvent.Session.ID
}
if bifrostEvent.Session.Model != "" {
sess["model"] = bifrostEvent.Session.Model
}
if len(bifrostEvent.Session.Modalities) > 0 {
sess["modalities"] = bifrostEvent.Session.Modalities
}
if bifrostEvent.Session.Instructions != "" {
sess["instructions"] = bifrostEvent.Session.Instructions
}
if bifrostEvent.Session.Voice != "" {
sess["voice"] = bifrostEvent.Session.Voice
}
if bifrostEvent.Session.Temperature != nil {
sess["temperature"] = *bifrostEvent.Session.Temperature
}
if bifrostEvent.Session.MaxOutputTokens != nil {
sess["max_output_tokens"] = bifrostEvent.Session.MaxOutputTokens
}
if bifrostEvent.Session.TurnDetection != nil {
sess["turn_detection"] = bifrostEvent.Session.TurnDetection
}
if bifrostEvent.Session.InputAudioFormat != "" {
sess["input_audio_format"] = bifrostEvent.Session.InputAudioFormat
}
if bifrostEvent.Session.OutputAudioType != "" {
sess["output_audio_type"] = bifrostEvent.Session.OutputAudioType
}
if bifrostEvent.Session.Tools != nil {
sess["tools"] = bifrostEvent.Session.Tools
}
mergeRealtimeSessionExtraParams(sess, bifrostEvent.Session.ExtraParams, bifrostEvent.Type)
out["session"] = sess
}
if bifrostEvent.Item != nil {
item := map[string]interface{}{
"type": bifrostEvent.Item.Type,
}
if bifrostEvent.Item.ID != "" {
item["id"] = bifrostEvent.Item.ID
}
if bifrostEvent.Item.Role != "" {
item["role"] = bifrostEvent.Item.Role
}
if bifrostEvent.Item.Status != "" {
item["status"] = bifrostEvent.Item.Status
}
if bifrostEvent.Item.Content != nil {
item["content"] = bifrostEvent.Item.Content
}
if bifrostEvent.Item.Name != "" {
item["name"] = bifrostEvent.Item.Name
}
if bifrostEvent.Item.CallID != "" {
item["call_id"] = bifrostEvent.Item.CallID
}
if bifrostEvent.Item.Arguments != "" {
item["arguments"] = bifrostEvent.Item.Arguments
}
if bifrostEvent.Item.Output != "" {
item["output"] = bifrostEvent.Item.Output
}
mergeRealtimeExtraParams(item, bifrostEvent.Item.ExtraParams)
out["item"] = item
}
if bifrostEvent.Error != nil {
rtErr := map[string]interface{}{}
if bifrostEvent.Error.Type != "" {
rtErr["type"] = bifrostEvent.Error.Type
}
if bifrostEvent.Error.Code != "" {
rtErr["code"] = bifrostEvent.Error.Code
}
if bifrostEvent.Error.Message != "" {
rtErr["message"] = bifrostEvent.Error.Message
}
if bifrostEvent.Error.Param != "" {
rtErr["param"] = bifrostEvent.Error.Param
}
mergeRealtimeExtraParams(rtErr, bifrostEvent.Error.ExtraParams)
out["error"] = rtErr
}
if bifrostEvent.Delta != nil {
if bifrostEvent.Delta.Text != "" {
out["delta"] = bifrostEvent.Delta.Text
}
if bifrostEvent.Delta.Audio != "" {
out["audio"] = bifrostEvent.Delta.Audio
}
if bifrostEvent.Delta.Transcript != "" {
out["transcript"] = bifrostEvent.Delta.Transcript
}
if bifrostEvent.Delta.ItemID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "item_id") {
out["item_id"] = bifrostEvent.Delta.ItemID
}
if bifrostEvent.Delta.OutputIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "output_index") {
out["output_index"] = *bifrostEvent.Delta.OutputIdx
}
if bifrostEvent.Delta.ContentIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "content_index") {
out["content_index"] = *bifrostEvent.Delta.ContentIdx
}
if bifrostEvent.Delta.ResponseID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "response_id") {
out["response_id"] = bifrostEvent.Delta.ResponseID
}
}
return providerUtils.MarshalSorted(out)
}
func mergeRealtimeSessionExtraParams(out map[string]interface{}, params map[string]json.RawMessage, eventType schemas.RealtimeEventType) {
filtered := params
if eventType == schemas.RTEventSessionUpdate && len(params) > 0 {
filtered = make(map[string]json.RawMessage, len(params))
for key, value := range params {
switch key {
case "id", "object", "expires_at", "client_secret":
continue
default:
filtered[key] = value
}
}
}
mergeRealtimeExtraParams(out, filtered)
}
func (provider *OpenAIProvider) ExtractRealtimeTurnUsage(terminalEventRaw []byte) *schemas.BifrostLLMUsage {
if len(terminalEventRaw) == 0 {
return nil
}
var parsed openAIRealtimeResponseDoneEnvelope
if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil || parsed.Response.Usage == nil {
return nil
}
usage := &schemas.BifrostLLMUsage{
PromptTokens: parsed.Response.Usage.InputTokens,
CompletionTokens: parsed.Response.Usage.OutputTokens,
TotalTokens: parsed.Response.Usage.TotalTokens,
}
if parsed.Response.Usage.InputTokenDetails != nil {
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
TextTokens: parsed.Response.Usage.InputTokenDetails.TextTokens,
AudioTokens: parsed.Response.Usage.InputTokenDetails.AudioTokens,
ImageTokens: parsed.Response.Usage.InputTokenDetails.ImageTokens,
CachedReadTokens: parsed.Response.Usage.InputTokenDetails.CachedTokens,
}
}
if parsed.Response.Usage.OutputTokenDetails != nil {
usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{
TextTokens: parsed.Response.Usage.OutputTokenDetails.TextTokens,
AudioTokens: parsed.Response.Usage.OutputTokenDetails.AudioTokens,
ReasoningTokens: parsed.Response.Usage.OutputTokenDetails.ReasoningTokens,
ImageTokens: parsed.Response.Usage.OutputTokenDetails.ImageTokens,
CitationTokens: parsed.Response.Usage.OutputTokenDetails.CitationTokens,
NumSearchQueries: parsed.Response.Usage.OutputTokenDetails.NumSearchQueries,
AcceptedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.AcceptedPredictionTokens,
RejectedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.RejectedPredictionTokens,
}
}
return usage
}
func (provider *OpenAIProvider) ExtractRealtimeTurnOutput(terminalEventRaw []byte) *schemas.ChatMessage {
if len(terminalEventRaw) == 0 {
return nil
}
var parsed openAIRealtimeResponseDoneEnvelope
if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil {
return nil
}
content := extractOpenAIRealtimeResponseDoneAssistantText(parsed.Response.Output)
toolCalls := extractOpenAIRealtimeResponseDoneToolCalls(parsed.Response.Output)
if content == "" && len(toolCalls) == 0 {
return nil
}
message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant}
if content != "" {
message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(content)}
}
if len(toolCalls) > 0 {
message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ToolCalls: toolCalls}
}
return message
}
type openAIRealtimeResponseDoneEnvelope struct {
Response struct {
Output []openAIRealtimeResponseDoneOutput `json:"output"`
Usage *openAIRealtimeResponseDoneUsage `json:"usage"`
} `json:"response"`
}
type openAIRealtimeResponseDoneOutput struct {
ID string `json:"id"`
Type string `json:"type"`
Name string `json:"name"`
CallID string `json:"call_id"`
Arguments string `json:"arguments"`
Content []openAIRealtimeResponseDoneBlock `json:"content"`
}
type openAIRealtimeResponseDoneBlock struct {
Text string `json:"text"`
Transcript string `json:"transcript"`
Refusal string `json:"refusal"`
}
type openAIRealtimeResponseDoneUsage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails *openAIRealtimeResponseDoneInputTokenUsage `json:"input_token_details"`
OutputTokenDetails *openAIRealtimeResponseDoneOutputTokenUsage `json:"output_token_details"`
}
type openAIRealtimeResponseDoneInputTokenUsage struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"`
CachedTokens int `json:"cached_tokens"`
}
type openAIRealtimeResponseDoneOutputTokenUsage struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ReasoningTokens int `json:"reasoning_tokens"`
ImageTokens *int `json:"image_tokens"`
CitationTokens *int `json:"citation_tokens"`
NumSearchQueries *int `json:"num_search_queries"`
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
}
func extractOpenAIRealtimeResponseDoneAssistantText(outputs []openAIRealtimeResponseDoneOutput) string {
var sb strings.Builder
for _, output := range outputs {
if output.Type != "message" {
continue
}
for _, block := range output.Content {
switch {
case strings.TrimSpace(block.Text) != "":
sb.WriteString(block.Text)
case strings.TrimSpace(block.Transcript) != "":
sb.WriteString(block.Transcript)
case strings.TrimSpace(block.Refusal) != "":
sb.WriteString(block.Refusal)
}
}
}
return strings.TrimSpace(sb.String())
}
func extractOpenAIRealtimeResponseDoneToolCalls(outputs []openAIRealtimeResponseDoneOutput) []schemas.ChatAssistantMessageToolCall {
toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0)
for _, output := range outputs {
if output.Type != "function_call" {
continue
}
name := strings.TrimSpace(output.Name)
if name == "" {
continue
}
toolType := "function"
id := strings.TrimSpace(output.CallID)
if id == "" {
id = strings.TrimSpace(output.ID)
}
toolCall := schemas.ChatAssistantMessageToolCall{
Index: uint16(len(toolCalls)),
Type: &toolType,
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr(name),
Arguments: output.Arguments,
},
}
if id != "" {
toolCall.ID = schemas.Ptr(id)
}
toolCalls = append(toolCalls, toolCall)
}
return toolCalls
}
func setRealtimeExtraParam(event *schemas.BifrostRealtimeEvent, key string, value any) {
if event == nil || key == "" || value == nil {
return
}
switch v := value.(type) {
case string:
if v == "" {
return
}
case *int:
if v == nil {
return
}
case json.RawMessage:
if len(v) == 0 || string(v) == "null" {
return
}
}
raw, err := json.Marshal(value)
if err != nil {
return
}
if event.ExtraParams == nil {
event.ExtraParams = make(map[string]json.RawMessage)
}
event.ExtraParams[key] = raw
}
func mergeRealtimeExtraParams(out map[string]interface{}, params map[string]json.RawMessage) {
for key, raw := range params {
if len(raw) == 0 {
continue
}
var value any
if err := json.Unmarshal(raw, &value); err != nil {
continue
}
out[key] = value
}
}
func hasRealtimeExtraParam(params map[string]json.RawMessage, key string) bool {
if params == nil {
return false
}
raw, ok := params[key]
return ok && len(raw) > 0
}
func extractRealtimeNestedParams(raw json.RawMessage, knownKeys ...string) map[string]json.RawMessage {
if len(raw) == 0 {
return nil
}
root := map[string]json.RawMessage{}
if err := json.Unmarshal(raw, &root); err != nil {
return nil
}
for _, key := range knownKeys {
delete(root, key)
}
if len(root) == 0 {
return nil
}
return root
}
func isRealtimeDeltaEvent(eventType string) bool {
switch eventType {
case "response.text.delta",
"response.output_text.delta",
"response.audio.delta",
"response.output_audio.delta",
"response.audio_transcript.delta",
"response.output_audio_transcript.delta",
"conversation.item.input_audio_transcription.delta":
return true
}
return false
}

View File

@@ -0,0 +1,561 @@
package openai
import (
"encoding/json"
"testing"
"github.com/maximhq/bifrost/core/schemas"
)
func TestNormalizeRealtimeClientSecretRequest(t *testing.T) {
t.Parallel()
body, model, bifrostErr := normalizeRealtimeClientSecretRequest(
json.RawMessage(`{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}`),
schemas.OpenAI,
schemas.RealtimeSessionEndpointClientSecrets,
)
if bifrostErr != nil {
t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr)
}
if model != "gpt-4o-realtime-preview" {
t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview")
}
var payload map[string]json.RawMessage
if err := json.Unmarshal(body, &payload); err != nil {
t.Fatalf("failed to unmarshal normalized body: %v", err)
}
if _, ok := payload["model"]; ok {
t.Fatal("top-level model should be removed after normalization")
}
var session map[string]any
if err := json.Unmarshal(payload["session"], &session); err != nil {
t.Fatalf("failed to unmarshal session: %v", err)
}
if session["model"] != "gpt-4o-realtime-preview" {
t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview")
}
if session["type"] != "realtime" {
t.Fatalf("session.type = %v, want %q", session["type"], "realtime")
}
}
func TestNormalizeRealtimeClientSecretRequestUsesDefaultProvider(t *testing.T) {
t.Parallel()
body, model, bifrostErr := normalizeRealtimeClientSecretRequest(
json.RawMessage(`{"session":{"model":"gpt-4o-realtime-preview"}}`),
schemas.OpenAI,
schemas.RealtimeSessionEndpointClientSecrets,
)
if bifrostErr != nil {
t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr)
}
if model != "gpt-4o-realtime-preview" {
t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview")
}
var payload map[string]json.RawMessage
if err := json.Unmarshal(body, &payload); err != nil {
t.Fatalf("failed to unmarshal normalized body: %v", err)
}
var session map[string]any
if err := json.Unmarshal(payload["session"], &session); err != nil {
t.Fatalf("failed to unmarshal session: %v", err)
}
if session["model"] != "gpt-4o-realtime-preview" {
t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview")
}
if session["type"] != "realtime" {
t.Fatalf("session.type = %v, want %q", session["type"], "realtime")
}
}
func TestNormalizeRealtimeSessionsRequest(t *testing.T) {
t.Parallel()
body, model, bifrostErr := normalizeRealtimeClientSecretRequest(
json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`),
schemas.OpenAI,
schemas.RealtimeSessionEndpointSessions,
)
if bifrostErr != nil {
t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr)
}
if model != "gpt-4o-realtime-preview" {
t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview")
}
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
t.Fatalf("failed to unmarshal normalized body: %v", err)
}
if _, ok := payload["session"]; ok {
t.Fatal("legacy sessions endpoint should not forward nested session object")
}
if payload["model"] != "gpt-4o-realtime-preview" {
t.Fatalf("model = %v, want %q", payload["model"], "gpt-4o-realtime-preview")
}
if payload["voice"] != "alloy" {
t.Fatalf("voice = %v, want %q", payload["voice"], "alloy")
}
}
func TestToProviderRealtimeEventSerializesTopLevelClientFields(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
contentIndex, err := json.Marshal(0)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
audioEndMS, err := json.Marshal(640)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
Type: schemas.RealtimeEventType("conversation.item.truncate"),
ExtraParams: map[string]json.RawMessage{
"item_id": json.RawMessage(`"item_123"`),
"content_index": contentIndex,
"audio_end_ms": audioEndMS,
},
})
if err != nil {
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
}
var payload map[string]any
if err := json.Unmarshal(out, &payload); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
if payload["type"] != "conversation.item.truncate" {
t.Fatalf("type = %v, want %q", payload["type"], "conversation.item.truncate")
}
if payload["item_id"] != "item_123" {
t.Fatalf("item_id = %v, want %q", payload["item_id"], "item_123")
}
if payload["content_index"] != float64(0) {
t.Fatalf("content_index = %v, want 0", payload["content_index"])
}
if payload["audio_end_ms"] != float64(640) {
t.Fatalf("audio_end_ms = %v, want 640", payload["audio_end_ms"])
}
}
func TestToBifrostRealtimeEventParsesTopLevelClientFields(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.truncate","item_id":"item_123","content_index":0,"audio_end_ms":640}`))
if err != nil {
t.Fatalf("ToBifrostRealtimeEvent() error = %v", err)
}
var itemID string
if err := json.Unmarshal(event.ExtraParams["item_id"], &itemID); err != nil {
t.Fatalf("json.Unmarshal(item_id) error = %v", err)
}
if itemID != "item_123" {
t.Fatalf("item_id = %q, want %q", itemID, "item_123")
}
var contentIndex int
if err := json.Unmarshal(event.ExtraParams["content_index"], &contentIndex); err != nil {
t.Fatalf("json.Unmarshal(content_index) error = %v", err)
}
if contentIndex != 0 {
t.Fatalf("content_index = %d, want 0", contentIndex)
}
var audioEndMS int
if err := json.Unmarshal(event.ExtraParams["audio_end_ms"], &audioEndMS); err != nil {
t.Fatalf("json.Unmarshal(audio_end_ms) error = %v", err)
}
if audioEndMS != 640 {
t.Fatalf("audio_end_ms = %d, want 640", audioEndMS)
}
}
func TestToBifrostRealtimeEventParsesCompletedInputAudioTranscript(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.input_audio_transcription.completed","event_id":"evt_123","item_id":"item_123","content_index":0,"transcript":"Who are you?"}`))
if err != nil {
t.Fatalf("ToBifrostRealtimeEvent() error = %v", err)
}
var transcript string
if err := json.Unmarshal(event.ExtraParams["transcript"], &transcript); err != nil {
t.Fatalf("json.Unmarshal(transcript) error = %v", err)
}
if transcript != "Who are you?" {
t.Fatalf("transcript = %q, want %q", transcript, "Who are you?")
}
}
func TestToBifrostRealtimeEventParsesModernOutputTextDelta(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{
"type":"response.output_text.delta",
"event_id":"evt_123",
"item_id":"item_123",
"output_index":0,
"content_index":0,
"response_id":"resp_123",
"delta":"hello"
}`))
if err != nil {
t.Fatalf("ToBifrostRealtimeEvent() error = %v", err)
}
if event.Delta == nil || event.Delta.Text != "hello" {
t.Fatalf("Delta = %+v, want text=hello", event.Delta)
}
}
func TestShouldStartRealtimeTurn(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
tests := []struct {
name string
event *schemas.BifrostRealtimeEvent
want bool
}{
{
name: "response create starts turn",
event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseCreate},
want: true,
},
{
name: "audio buffer committed starts turn",
event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventInputAudioBufferCommitted},
want: true,
},
{
name: "response done does not start turn",
event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseDone},
want: false,
},
{
name: "nil event does not start turn",
event: nil,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := provider.ShouldStartRealtimeTurn(tt.event); got != tt.want {
t.Fatalf("ShouldStartRealtimeTurn() = %v, want %v", got, tt.want)
}
})
}
}
func TestToProviderRealtimeEventSerializesModernOutputTextDelta(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
outputIndex := 0
contentIndex := 0
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
Type: schemas.RealtimeEventType("response.output_text.delta"),
Delta: &schemas.RealtimeDelta{
Text: "hello",
ItemID: "item_123",
OutputIdx: &outputIndex,
ContentIdx: &contentIndex,
ResponseID: "resp_123",
},
})
if err != nil {
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
}
var payload map[string]any
if err := json.Unmarshal(out, &payload); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
if payload["type"] != "response.output_text.delta" {
t.Fatalf("type = %v, want response.output_text.delta", payload["type"])
}
if payload["delta"] != "hello" {
t.Fatalf("delta = %v, want hello", payload["delta"])
}
}
func TestToProviderRealtimeEventSerializesSessionID(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
Type: schemas.RTEventSessionCreated,
Session: &schemas.RealtimeSession{
ID: "sess_123",
Model: "gpt-realtime",
},
})
if err != nil {
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
}
var payload map[string]any
if err := json.Unmarshal(out, &payload); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
session, ok := payload["session"].(map[string]any)
if !ok {
t.Fatalf("session = %T, want object", payload["session"])
}
if session["id"] != "sess_123" {
t.Fatalf("session.id = %v, want sess_123", session["id"])
}
}
func TestToProviderRealtimeEventSerializesMessageItemStatus(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
content := json.RawMessage(`[{"type":"input_audio","transcript":"hello"}]`)
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
Type: schemas.RealtimeEventType("conversation.item.retrieved"),
Item: &schemas.RealtimeItem{
ID: "item_123",
Type: "message",
Role: "user",
Status: "completed",
Content: content,
},
})
if err != nil {
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
}
var payload map[string]any
if err := json.Unmarshal(out, &payload); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
item, ok := payload["item"].(map[string]any)
if !ok {
t.Fatalf("item = %T, want object", payload["item"])
}
if item["status"] != "completed" {
t.Fatalf("item.status = %v, want completed", item["status"])
}
}
func TestToBifrostRealtimeEventPreservesTopLevelResponsePayload(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{
"type":"response.done",
"event_id":"evt_123",
"response":{
"id":"resp_123",
"output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}]
}
}`))
if err != nil {
t.Fatalf("ToBifrostRealtimeEvent() error = %v", err)
}
var response map[string]any
if err := json.Unmarshal(event.ExtraParams["response"], &response); err != nil {
t.Fatalf("json.Unmarshal(response) error = %v", err)
}
if response["id"] != "resp_123" {
t.Fatalf("response.id = %v, want resp_123", response["id"])
}
}
func TestToProviderRealtimeEventSerializesTopLevelResponsePayload(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
Type: schemas.RTEventResponseDone,
ExtraParams: map[string]json.RawMessage{
"response": json.RawMessage(`{"id":"resp_123","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}]}`),
},
})
if err != nil {
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
}
var payload map[string]any
if err := json.Unmarshal(out, &payload); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
response, ok := payload["response"].(map[string]any)
if !ok {
t.Fatalf("response = %T, want object", payload["response"])
}
if response["id"] != "resp_123" {
t.Fatalf("response.id = %v, want resp_123", response["id"])
}
}
func TestToBifrostRealtimeEventPreservesTopLevelPartPayload(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{
"type":"response.content_part.added",
"event_id":"evt_123",
"item_id":"item_123",
"output_index":0,
"content_index":0,
"part":{
"type":"text",
"text":"hello"
}
}`))
if err != nil {
t.Fatalf("ToBifrostRealtimeEvent() error = %v", err)
}
var part map[string]any
if err := json.Unmarshal(event.ExtraParams["part"], &part); err != nil {
t.Fatalf("json.Unmarshal(part) error = %v", err)
}
if part["type"] != "text" {
t.Fatalf("part.type = %v, want text", part["type"])
}
}
func TestToProviderRealtimeEventSerializesTopLevelPartPayload(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
Type: schemas.RTEventResponseContentPartAdded,
ExtraParams: map[string]json.RawMessage{
"part": json.RawMessage(`{"type":"text","text":"hello"}`),
},
})
if err != nil {
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
}
var payload map[string]any
if err := json.Unmarshal(out, &payload); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
part, ok := payload["part"].(map[string]any)
if !ok {
t.Fatalf("part = %T, want object", payload["part"])
}
if part["type"] != "text" {
t.Fatalf("part.type = %v, want text", part["type"])
}
}
func TestParseRealtimeEventPreservesNestedSessionExtraParams(t *testing.T) {
t.Parallel()
event, err := schemas.ParseRealtimeEvent([]byte(`{
"type":"session.update",
"session":{
"type":"realtime",
"model":"gpt-4o-realtime-preview",
"output_modalities":["text"]
}
}`))
if err != nil {
t.Fatalf("ParseRealtimeEvent() error = %v", err)
}
if event.Session == nil {
t.Fatal("expected session to be parsed")
}
var outputModalities []string
if err := json.Unmarshal(event.Session.ExtraParams["output_modalities"], &outputModalities); err != nil {
t.Fatalf("json.Unmarshal(output_modalities) error = %v", err)
}
if len(outputModalities) != 1 || outputModalities[0] != "text" {
t.Fatalf("output_modalities = %v, want [text]", outputModalities)
}
}
func TestToProviderRealtimeEventSerializesNestedSessionExtraParams(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
Type: schemas.RTEventSessionUpdate,
Session: &schemas.RealtimeSession{
Model: "gpt-4o-realtime-preview",
ExtraParams: map[string]json.RawMessage{
"type": json.RawMessage(`"realtime"`),
"output_modalities": json.RawMessage(`["text"]`),
},
},
})
if err != nil {
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
}
var payload struct {
Type string `json:"type"`
Session map[string]any `json:"session"`
}
if err := json.Unmarshal(out, &payload); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
if payload.Type != "session.update" {
t.Fatalf("type = %q, want %q", payload.Type, "session.update")
}
if payload.Session["type"] != "realtime" {
t.Fatalf("session.type = %v, want realtime", payload.Session["type"])
}
outputModalities, ok := payload.Session["output_modalities"].([]any)
if !ok || len(outputModalities) != 1 || outputModalities[0] != "text" {
t.Fatalf("session.output_modalities = %v, want [text]", payload.Session["output_modalities"])
}
}
func TestToProviderRealtimeEventOmitsReadOnlySessionFieldsOnSessionUpdate(t *testing.T) {
t.Parallel()
provider := &OpenAIProvider{}
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
Type: schemas.RTEventSessionUpdate,
Session: &schemas.RealtimeSession{
ID: "sess_123",
Model: "gpt-realtime",
ExtraParams: map[string]json.RawMessage{
"type": json.RawMessage(`"realtime"`),
"object": json.RawMessage(`"realtime.session"`),
"expires_at": json.RawMessage(`1774614381`),
"client_secret": json.RawMessage(`{"value":"secret"}`),
"modalities": json.RawMessage(`["text","audio"]`),
},
},
})
if err != nil {
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
}
var payload struct {
Session map[string]any `json:"session"`
}
if err := json.Unmarshal(out, &payload); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
for _, key := range []string{"id", "object", "expires_at", "client_secret"} {
if _, ok := payload.Session[key]; ok {
t.Fatalf("session.%s unexpectedly present in session.update payload", key)
}
}
if payload.Session["type"] != "realtime" {
t.Fatalf("session.type = %v, want realtime", payload.Session["type"])
}
if payload.Session["model"] != "gpt-realtime" {
t.Fatalf("session.model = %v, want gpt-realtime", payload.Session["model"])
}
}

View File

@@ -0,0 +1,376 @@
package openai
import (
"strings"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostResponsesRequest converts an OpenAI responses request to Bifrost format
func (resp *OpenAIResponsesRequest) ToBifrostResponsesRequest(ctx *schemas.BifrostContext) *schemas.BifrostResponsesRequest {
if resp == nil {
return nil
}
defaultProvider := schemas.OpenAI
// for requests coming from azure sdk without provider prefix, we need to set the default provider to azure
if ctx != nil {
if isAzureUser, ok := ctx.Value(schemas.BifrostContextKeyIsAzureUserAgent).(bool); ok && isAzureUser {
defaultProvider = schemas.Azure
}
}
provider, model := schemas.ParseModelString(resp.Model, utils.CheckAndSetDefaultProvider(ctx, defaultProvider))
input := resp.Input.OpenAIResponsesRequestInputArray
if len(input) == 0 {
input = []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{ContentStr: resp.Input.OpenAIResponsesRequestInputStr},
},
}
}
return &schemas.BifrostResponsesRequest{
Provider: provider,
Model: model,
Input: input,
Params: &resp.ResponsesParameters,
Fallbacks: schemas.ParseFallbacks(resp.Fallbacks),
}
}
// ToOpenAIResponsesRequest converts a Bifrost responses request to OpenAI format
func ToOpenAIResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *OpenAIResponsesRequest {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil
}
var messages []schemas.ResponsesMessage
// OpenAI models (except for gpt-oss) do not support reasoning content blocks, so we need to convert them to summaries, if there are any
// OpenAI also doesn't support compaction content blocks, so we need to convert them to text blocks
messages = make([]schemas.ResponsesMessage, 0, len(bifrostReq.Input))
for _, message := range bifrostReq.Input {
// First, check if message has compaction content blocks and convert them to text
if message.Content != nil && len(message.Content.ContentBlocks) > 0 {
hasCompaction := false
for _, block := range message.Content.ContentBlocks {
if block.Type == schemas.ResponsesOutputMessageContentTypeCompaction {
hasCompaction = true
break
}
}
if hasCompaction {
// Create a new message with converted content blocks
newMessage := message
newContentBlocks := make([]schemas.ResponsesMessageContentBlock, 0, len(message.Content.ContentBlocks))
for _, block := range message.Content.ContentBlocks {
if block.Type == schemas.ResponsesOutputMessageContentTypeCompaction {
// Convert compaction block to text block
if block.ResponsesOutputMessageContentCompaction != nil && block.ResponsesOutputMessageContentCompaction.Summary != "" {
newContentBlocks = append(newContentBlocks, schemas.ResponsesMessageContentBlock{
Type: schemas.ResponsesOutputMessageContentTypeText,
Text: schemas.Ptr(block.ResponsesOutputMessageContentCompaction.Summary),
})
}
// If summary is empty, skip the block entirely
} else {
// Keep non-compaction blocks as-is
newContentBlocks = append(newContentBlocks, block)
}
}
// Only update if we have blocks remaining after conversion
if len(newContentBlocks) > 0 {
newMessage.Content = &schemas.ResponsesMessageContent{
ContentBlocks: newContentBlocks,
}
message = newMessage
} else {
// If all blocks were compaction with empty summaries, skip message
continue
}
}
}
if message.ResponsesReasoning != nil {
isGptOss := strings.Contains(bifrostReq.Model, "gpt-oss")
isReasoning := isOpenAIReasoningModel(bifrostReq.Model)
// For non-gpt-oss models, skip reasoning-only messages that have content blocks but no summaries.
// For non-reasoning models (e.g., gpt-4o), also skip when EncryptedContent is present since
// these models don't produce encrypted reasoning — any encrypted content is cross-provider
// (e.g., Gemini ThoughtSignatures) and cannot be decrypted by OpenAI.
if len(message.ResponsesReasoning.Summary) == 0 &&
message.Content != nil &&
len(message.Content.ContentBlocks) > 0 &&
!isGptOss &&
(message.ResponsesReasoning.EncryptedContent == nil || !isReasoning) {
continue
}
// If the message has summaries but no content blocks and the model is gpt-oss, then convert the summaries to content blocks
if len(message.ResponsesReasoning.Summary) > 0 && isGptOss &&
(message.Content == nil || len(message.Content.ContentBlocks) == 0) {
var newMessage schemas.ResponsesMessage
newMessage.ID = message.ID
newMessage.Type = message.Type
newMessage.Status = message.Status
newMessage.Role = message.Role
// Convert summaries to content blocks
contentBlocks := make([]schemas.ResponsesMessageContentBlock, 0, len(message.ResponsesReasoning.Summary))
for _, summary := range message.ResponsesReasoning.Summary {
contentBlocks = append(contentBlocks, schemas.ResponsesMessageContentBlock{
Type: schemas.ResponsesOutputMessageContentTypeReasoning,
Text: schemas.Ptr(summary.Text),
})
}
newMessage.Content = &schemas.ResponsesMessageContent{
ContentBlocks: contentBlocks,
}
messages = append(messages, newMessage)
} else {
// Clone the embedded pointer to avoid mutating the original input
reasoningCopy := *message.ResponsesReasoning
message.ResponsesReasoning = &reasoningCopy
// OpenAI's Responses API does not accept 'role' on reasoning items
message.Role = nil
// Strip cross-provider encrypted content that non-reasoning models cannot decrypt.
// Reasoning models (o1/o3/o4/GPT-5) may use EncryptedContent for multi-turn state.
if !isReasoning {
message.ResponsesReasoning.EncryptedContent = nil
}
messages = append(messages, message)
}
} else if message.ResponsesToolMessage != nil &&
message.ResponsesToolMessage.Action != nil &&
message.ResponsesToolMessage.Action.ResponsesComputerToolCallAction != nil {
action := message.ResponsesToolMessage.Action.ResponsesComputerToolCallAction
if action.Type == "zoom" || action.Region != nil {
// Copy action and modify
newAction := *action
newAction.Region = nil
if newAction.Type == "zoom" {
newAction.Type = "screenshot"
}
actionStructCopy := *message.ResponsesToolMessage.Action
actionStructCopy.ResponsesComputerToolCallAction = &newAction
toolMsgCopy := *message.ResponsesToolMessage
toolMsgCopy.Action = &actionStructCopy
message.ResponsesToolMessage = &toolMsgCopy
}
messages = append(messages, message)
} else {
messages = append(messages, message)
}
}
// Updating params
params := bifrostReq.Params
// Create the responses request with properly mapped parameters
req := &OpenAIResponsesRequest{
Model: bifrostReq.Model,
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputArray: messages,
},
}
if params != nil {
req.ResponsesParameters = *params
if req.ResponsesParameters.MaxOutputTokens != nil && *req.ResponsesParameters.MaxOutputTokens < MinMaxCompletionTokens {
req.ResponsesParameters.MaxOutputTokens = schemas.Ptr(MinMaxCompletionTokens)
}
// Drop user field if it exceeds OpenAI's 64 character limit
req.ResponsesParameters.User = SanitizeUserField(req.ResponsesParameters.User)
// Handle reasoning parameter: OpenAI uses effort-based reasoning
// Priority: effort (native) > max_tokens (estimated)
if req.ResponsesParameters.Reasoning != nil {
// Clone the Reasoning pointer to avoid mutating the original params
reasoningCopy := *req.ResponsesParameters.Reasoning
req.ResponsesParameters.Reasoning = &reasoningCopy
if req.ResponsesParameters.Reasoning.Effort != nil {
// Native field is provided, use it (and clear max_tokens)
effort := *req.ResponsesParameters.Reasoning.Effort
// Convert "minimal" to "low"; cap "xhigh"/"max" to "high" — OpenAI tops out at high.
switch effort {
case "minimal":
req.ResponsesParameters.Reasoning.Effort = schemas.Ptr("low")
case "xhigh", "max":
req.ResponsesParameters.Reasoning.Effort = schemas.Ptr("high")
}
// Clear max_tokens since OpenAI doesn't use it
req.ResponsesParameters.Reasoning.MaxTokens = nil
} else if req.ResponsesParameters.Reasoning.MaxTokens != nil {
// Estimate effort from max_tokens
maxTokens := *req.ResponsesParameters.Reasoning.MaxTokens
maxOutputTokens := utils.GetMaxOutputTokensOrDefault(req.Model, DefaultCompletionMaxTokens)
if req.ResponsesParameters.MaxOutputTokens != nil {
maxOutputTokens = *req.ResponsesParameters.MaxOutputTokens
}
effort := utils.GetReasoningEffortFromBudgetTokens(maxTokens, MinReasoningMaxTokens, maxOutputTokens)
req.ResponsesParameters.Reasoning.Effort = schemas.Ptr(effort)
// Clear max_tokens since OpenAI doesn't use it
req.ResponsesParameters.Reasoning.MaxTokens = nil
}
// summary:"none" is Anthropic-specific (maps to display:"omitted"); strip it for OpenAI.
if req.ResponsesParameters.Reasoning.Summary != nil && *req.ResponsesParameters.Reasoning.Summary == "none" {
req.ResponsesParameters.Reasoning.Summary = nil
}
// Handle xAI-specific parameter filtering
// Only grok-3-mini supports reasoning_effort
if bifrostReq.Provider == schemas.XAI &&
schemas.IsGrokReasoningModel(bifrostReq.Model) &&
!strings.Contains(bifrostReq.Model, "grok-3-mini") {
// Clear reasoning_effort for non-grok-3-mini xAI reasoning models
req.ResponsesParameters.Reasoning.Effort = nil
}
// Handle OpenAI-specific parameter filtering
// Only o1/o3 series models support reasoning.effort
// Regular models like gpt-4o, gpt-4, gpt-3.5-turbo don't support it
if bifrostReq.Provider == schemas.OpenAI && !isOpenAIReasoningModel(bifrostReq.Model) {
// Clear reasoning for non-reasoning OpenAI models to avoid API errors
req.ResponsesParameters.Reasoning = nil
}
}
// Strip top_p for OpenAI reasoning models (o1/o3 series) which reject it
// GPT-5.x accept top_p when reasoning.effort is "none" (defaults to "none" when omitted)
if isOpenAIReasoningModel(bifrostReq.Model) {
stripTopP := true
_, parsedModel := schemas.ParseModelString(bifrostReq.Model, schemas.OpenAI)
modelLower := strings.ToLower(parsedModel)
effort := ""
if req.ResponsesParameters.Reasoning != nil &&
req.ResponsesParameters.Reasoning.Effort != nil {
effort = *req.ResponsesParameters.Reasoning.Effort
}
// GPT-5.x: reasoning defaults to "none" when omitted, and top_p is allowed in that case
// Exception: -pro and -codex variants always reason (no "none" mode), so top_p must be stripped
if strings.HasPrefix(modelLower, "gpt-5.") &&
(effort == "" || effort == "none") &&
!strings.Contains(modelLower, "-pro") &&
!strings.Contains(modelLower, "-codex") {
stripTopP = false
}
if stripTopP {
req.ResponsesParameters.TopP = nil
}
}
// Normalize function tool parameters for deterministic JSON serialization.
// We must copy the Tools slice since it shares the backing array with bifrostReq.Params.Tools.
if len(req.Tools) > 0 {
normalizedTools := make([]schemas.ResponsesTool, len(req.Tools))
copy(normalizedTools, req.Tools)
for i, tool := range normalizedTools {
if tool.Type == schemas.ResponsesToolTypeFunction &&
tool.ResponsesToolFunction != nil &&
tool.ResponsesToolFunction.Parameters != nil {
funcCopy := *tool.ResponsesToolFunction
funcCopy.Parameters = tool.ResponsesToolFunction.Parameters.Normalized()
normalizedTools[i].ResponsesToolFunction = &funcCopy
}
}
req.Tools = normalizedTools
}
// Filter out tools that OpenAI doesn't support
req.filterUnsupportedTools()
}
if bifrostReq.Params != nil {
req.ExtraParams = bifrostReq.Params.ExtraParams
}
return req
}
// filterUnsupportedTools removes tool types that OpenAI doesn't support
func (resp *OpenAIResponsesRequest) filterUnsupportedTools() {
if len(resp.Tools) == 0 {
return
}
// Define OpenAI-supported tool types
supportedTypes := map[schemas.ResponsesToolType]bool{
schemas.ResponsesToolTypeFunction: true,
schemas.ResponsesToolTypeFileSearch: true,
schemas.ResponsesToolTypeComputerUsePreview: true,
schemas.ResponsesToolTypeWebSearch: true,
schemas.ResponsesToolTypeWebFetch: true,
schemas.ResponsesToolTypeMCP: true,
schemas.ResponsesToolTypeCodeInterpreter: true,
schemas.ResponsesToolTypeImageGeneration: true,
schemas.ResponsesToolTypeLocalShell: true,
schemas.ResponsesToolTypeCustom: true,
schemas.ResponsesToolTypeWebSearchPreview: true,
schemas.ResponsesToolTypeMemory: true,
schemas.ResponsesToolTypeToolSearch: true,
}
// Filter tools to only include supported types
filteredTools := make([]schemas.ResponsesTool, 0, len(resp.Tools))
for _, tool := range resp.Tools {
if supportedTypes[tool.Type] {
// check for computer use preview
if tool.Type == schemas.ResponsesToolTypeComputerUsePreview && tool.ResponsesToolComputerUsePreview != nil && tool.ResponsesToolComputerUsePreview.EnableZoom != nil {
newTool := tool
newComputerUse := &schemas.ResponsesToolComputerUsePreview{
DisplayHeight: tool.ResponsesToolComputerUsePreview.DisplayHeight,
DisplayWidth: tool.ResponsesToolComputerUsePreview.DisplayWidth,
Environment: tool.ResponsesToolComputerUsePreview.Environment,
// EnableZoom is intentionally omitted (nil) - OpenAI doesn't support it
}
newTool.ResponsesToolComputerUsePreview = newComputerUse
filteredTools = append(filteredTools, newTool)
} else if tool.Type == schemas.ResponsesToolTypeWebSearch && tool.ResponsesToolWebSearch != nil {
// Create a proper deep copy with new nested pointers to avoid mutating the original
newTool := tool
newWebSearch := &schemas.ResponsesToolWebSearch{}
// MaxUses is intentionally omitted (nil) - OpenAI doesn't support it
// Handle Filters: OpenAI doesn't support BlockedDomains or TimeRangeFilter
if tool.ResponsesToolWebSearch.Filters != nil {
hasAllowedDomains := len(tool.ResponsesToolWebSearch.Filters.AllowedDomains) > 0
if hasAllowedDomains {
// Keep only AllowedDomains (copy the slice to avoid sharing)
newWebSearch.Filters = &schemas.ResponsesToolWebSearchFilters{
AllowedDomains: append([]string(nil), tool.ResponsesToolWebSearch.Filters.AllowedDomains...),
// BlockedDomains and TimeRangeFilter are intentionally omitted - OpenAI doesn't support it
}
}
// If only blocked domains or both empty, Filters stays nil
}
// Copy other fields if they exist
if tool.ResponsesToolWebSearch.UserLocation != nil {
newWebSearch.UserLocation = tool.ResponsesToolWebSearch.UserLocation
}
if tool.ResponsesToolWebSearch.SearchContextSize != nil {
newWebSearch.SearchContextSize = tool.ResponsesToolWebSearch.SearchContextSize
}
newTool.ResponsesToolWebSearch = newWebSearch
filteredTools = append(filteredTools, newTool)
} else {
filteredTools = append(filteredTools, tool)
}
}
}
resp.Tools = filteredTools
}

View File

@@ -0,0 +1,872 @@
package openai
import (
"encoding/json"
"strings"
"testing"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
)
func TestOpenAIResponsesRequest_MarshalJSON_ReasoningMaxTokensAbsent(t *testing.T) {
tests := []struct {
name string
request *OpenAIResponsesRequest
description string
}{
{
name: "reasoning with MaxTokens set should omit max_tokens from output",
request: &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputStr: schemas.Ptr("test input"),
},
ResponsesParameters: schemas.ResponsesParameters{
Reasoning: &schemas.ResponsesParametersReasoning{
Effort: schemas.Ptr("high"),
MaxTokens: schemas.Ptr(1000),
Summary: schemas.Ptr("detailed"),
},
},
},
description: "When Reasoning.MaxTokens is set, it should be absent from JSON output",
},
{
name: "reasoning with all fields set should omit only max_tokens",
request: &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputStr: schemas.Ptr("test"),
},
ResponsesParameters: schemas.ResponsesParameters{
Reasoning: &schemas.ResponsesParametersReasoning{
Effort: schemas.Ptr("medium"),
GenerateSummary: schemas.Ptr("auto"),
Summary: schemas.Ptr("concise"),
MaxTokens: schemas.Ptr(500),
},
},
},
description: "All reasoning fields except MaxTokens should be present in output",
},
{
name: "reasoning with nil MaxTokens should not include max_tokens",
request: &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputStr: schemas.Ptr("test"),
},
ResponsesParameters: schemas.ResponsesParameters{
Reasoning: &schemas.ResponsesParametersReasoning{
Effort: schemas.Ptr("low"),
MaxTokens: nil,
},
},
},
description: "When Reasoning.MaxTokens is nil, max_tokens should not appear in output",
},
{
name: "nil reasoning should not include reasoning field",
request: &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputStr: schemas.Ptr("test"),
},
ResponsesParameters: schemas.ResponsesParameters{
Reasoning: nil,
},
},
description: "When Reasoning is nil, reasoning field should not appear in output",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
jsonBytes, err := tt.request.MarshalJSON()
if err != nil {
t.Fatalf("Failed to marshal JSON: %v", err)
}
// Parse the JSON to check structure
var jsonMap map[string]interface{}
if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil {
t.Fatalf("Failed to unmarshal marshaled JSON: %v", err)
}
// Check that reasoning.max_tokens is absent
if reasoning, ok := jsonMap["reasoning"].(map[string]interface{}); ok {
if maxTokens, exists := reasoning["max_tokens"]; exists {
t.Errorf("%s: reasoning.max_tokens should be absent from JSON output, but found: %v", tt.description, maxTokens)
}
// Verify other reasoning fields are present when they should be
if tt.request.Reasoning != nil {
if tt.request.Reasoning.Effort != nil {
if _, exists := reasoning["effort"]; !exists {
t.Error("reasoning.effort should be present in output")
}
}
if tt.request.Reasoning.Summary != nil {
if _, exists := reasoning["summary"]; !exists {
t.Error("reasoning.summary should be present in output")
}
}
if tt.request.Reasoning.GenerateSummary != nil {
if _, exists := reasoning["generate_summary"]; !exists {
t.Error("reasoning.generate_summary should be present in output")
}
}
}
} else if tt.request.Reasoning != nil {
// If reasoning is set, it should appear in JSON (unless all fields are nil/omitted)
if tt.request.Reasoning.Effort != nil || tt.request.Reasoning.Summary != nil || tt.request.Reasoning.GenerateSummary != nil {
t.Error("reasoning field should be present in JSON when Reasoning is set with non-nil fields")
}
}
})
}
}
func TestOpenAIResponsesRequest_MarshalJSON_InputStringForm(t *testing.T) {
tests := []struct {
name string
request *OpenAIResponsesRequest
expected string
description string
}{
{
name: "input as string is correctly marshaled",
request: &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputStr: schemas.Ptr("Hello, world!"),
},
},
expected: "Hello, world!",
description: "Input field should be marshaled as a string when OpenAIResponsesRequestInputStr is set",
},
{
name: "input as empty string is correctly marshaled",
request: &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputStr: schemas.Ptr(""),
},
},
expected: "",
description: "Input field should be marshaled as empty string when set to empty string",
},
{
name: "input as string with special characters",
request: &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputStr: schemas.Ptr(`{"key": "value"}`),
},
},
expected: `{"key": "value"}`,
description: "Input field should correctly marshal strings with special characters",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
jsonBytes, err := tt.request.MarshalJSON()
if err != nil {
t.Fatalf("Failed to marshal JSON: %v", err)
}
// Parse the JSON to check input field
var jsonMap map[string]interface{}
if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil {
t.Fatalf("Failed to unmarshal marshaled JSON: %v", err)
}
// Check that input is a string
inputValue, exists := jsonMap["input"]
if !exists {
t.Fatalf("%s: input field should be present in JSON", tt.description)
}
inputStr, ok := inputValue.(string)
if !ok {
t.Errorf("%s: input field should be a string, got type %T", tt.description, inputValue)
}
if inputStr != tt.expected {
t.Errorf("%s: expected input to be %q, got %q", tt.description, tt.expected, inputStr)
}
})
}
}
func TestOpenAIResponsesRequest_MarshalJSON_InputArrayForm(t *testing.T) {
tests := []struct {
name string
request *OpenAIResponsesRequest
validate func(t *testing.T, inputValue interface{})
description string
}{
{
name: "input as array is correctly marshaled",
request: &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("Hello"),
},
},
},
},
},
validate: func(t *testing.T, inputValue interface{}) {
inputArray, ok := inputValue.([]interface{})
if !ok {
t.Fatalf("Expected input to be an array, got type %T", inputValue)
}
if len(inputArray) != 1 {
t.Errorf("Expected 1 message in array, got %d", len(inputArray))
}
},
description: "Input field should be marshaled as an array when OpenAIResponsesRequestInputArray is set",
},
{
name: "input as empty array is correctly marshaled",
request: &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{},
},
},
validate: func(t *testing.T, inputValue interface{}) {
inputArray, ok := inputValue.([]interface{})
if !ok {
t.Fatalf("Expected input to be an array, got type %T", inputValue)
}
if len(inputArray) != 0 {
t.Errorf("Expected empty array, got %d elements", len(inputArray))
}
},
description: "Input field should be marshaled as empty array when set to empty array",
},
{
name: "input as array with multiple messages",
request: &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("You are a helpful assistant."),
},
},
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("What is 2+2?"),
},
},
},
},
},
validate: func(t *testing.T, inputValue interface{}) {
inputArray, ok := inputValue.([]interface{})
if !ok {
t.Fatalf("Expected input to be an array, got type %T", inputValue)
}
if len(inputArray) != 2 {
t.Errorf("Expected 2 messages in array, got %d", len(inputArray))
}
},
description: "Input field should correctly marshal arrays with multiple messages",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
jsonBytes, err := tt.request.MarshalJSON()
if err != nil {
t.Fatalf("Failed to marshal JSON: %v", err)
}
// Parse the JSON to check input field
var jsonMap map[string]interface{}
if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil {
t.Fatalf("Failed to unmarshal marshaled JSON: %v", err)
}
// Check that input is present
inputValue, exists := jsonMap["input"]
if !exists {
t.Fatalf("%s: input field should be present in JSON", tt.description)
}
// Validate using the provided function
tt.validate(t, inputValue)
})
}
}
func TestToOpenAIResponsesRequest_FireworksPreservesNativeFields(t *testing.T) {
bifrostReq := &schemas.BifrostResponsesRequest{
Provider: schemas.Fireworks,
Model: "accounts/fireworks/models/deepseek-v3p2",
Input: []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("hello"),
},
},
},
Params: &schemas.ResponsesParameters{
PreviousResponseID: schemas.Ptr("resp_previous"),
MaxToolCalls: schemas.Ptr(2),
Store: schemas.Ptr(true),
},
}
request := ToOpenAIResponsesRequest(bifrostReq)
if request == nil {
t.Fatal("expected non-nil request")
}
jsonBytes, err := request.MarshalJSON()
if err != nil {
t.Fatalf("failed to marshal responses request: %v", err)
}
var jsonMap map[string]interface{}
if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil {
t.Fatalf("failed to parse marshaled JSON: %v", err)
}
if got, ok := jsonMap["previous_response_id"].(string); !ok || got != "resp_previous" {
t.Fatalf("expected previous_response_id to be preserved, got %#v", jsonMap["previous_response_id"])
}
if got, ok := jsonMap["max_tool_calls"].(float64); !ok || got != 2 {
t.Fatalf("expected max_tool_calls to be preserved, got %#v", jsonMap["max_tool_calls"])
}
if got, ok := jsonMap["store"].(bool); !ok || !got {
t.Fatalf("expected store=true to be preserved, got %#v", jsonMap["store"])
}
}
func TestOpenAIResponsesRequest_MarshalJSON_FieldShadowingBehavior(t *testing.T) {
// This test verifies that the field shadowing pattern works correctly
// by ensuring that the aux struct properly shadows Input and Reasoning fields
t.Run("field shadowing preserves other fields", func(t *testing.T) {
request := &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputStr: schemas.Ptr("test input"),
},
ResponsesParameters: schemas.ResponsesParameters{
MaxOutputTokens: schemas.Ptr(100),
Temperature: schemas.Ptr(0.7),
Reasoning: &schemas.ResponsesParametersReasoning{
Effort: schemas.Ptr("high"),
MaxTokens: schemas.Ptr(500), // This should be omitted
Summary: schemas.Ptr("detailed"),
},
},
Stream: schemas.Ptr(true),
Fallbacks: []string{"fallback1", "fallback2"},
}
jsonBytes, err := request.MarshalJSON()
if err != nil {
t.Fatalf("Failed to marshal JSON: %v", err)
}
var jsonMap map[string]interface{}
if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil {
t.Fatalf("Failed to unmarshal marshaled JSON: %v", err)
}
// Verify base fields are present
if jsonMap["model"] != "gpt-4o" {
t.Errorf("Expected model to be 'gpt-4o', got %v", jsonMap["model"])
}
if jsonMap["stream"] != true {
t.Errorf("Expected stream to be true, got %v", jsonMap["stream"])
}
fallbacks, ok := jsonMap["fallbacks"].([]interface{})
if !ok || len(fallbacks) != 2 {
t.Errorf("Expected fallbacks to have 2 elements, got %v", jsonMap["fallbacks"])
}
// Verify ResponsesParameters fields are present
if jsonMap["max_output_tokens"] != float64(100) {
t.Errorf("Expected max_output_tokens to be 100, got %v", jsonMap["max_output_tokens"])
}
if jsonMap["temperature"] != 0.7 {
t.Errorf("Expected temperature to be 0.7, got %v", jsonMap["temperature"])
}
// Verify reasoning.max_tokens is absent
if reasoning, ok := jsonMap["reasoning"].(map[string]interface{}); ok {
if _, exists := reasoning["max_tokens"]; exists {
t.Error("reasoning.max_tokens should be absent from JSON output")
}
if reasoning["effort"] != "high" {
t.Errorf("Expected reasoning.effort to be 'high', got %v", reasoning["effort"])
}
if reasoning["summary"] != "detailed" {
t.Errorf("Expected reasoning.summary to be 'detailed', got %v", reasoning["summary"])
}
} else {
t.Error("reasoning field should be present in JSON")
}
// Verify input is correctly marshaled
if jsonMap["input"] != "test input" {
t.Errorf("Expected input to be 'test input', got %v", jsonMap["input"])
}
})
}
func TestOpenAIResponsesRequest_MarshalJSON_RoundTrip(t *testing.T) {
// Test that marshaling and unmarshaling preserves all fields except reasoning.max_tokens
t.Run("round trip preserves fields except reasoning.max_tokens", func(t *testing.T) {
original := &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("Test message"),
},
},
},
},
ResponsesParameters: schemas.ResponsesParameters{
MaxOutputTokens: schemas.Ptr(200),
Temperature: schemas.Ptr(0.8),
Reasoning: &schemas.ResponsesParametersReasoning{
Effort: schemas.Ptr("medium"),
MaxTokens: schemas.Ptr(1000), // Should be omitted
Summary: schemas.Ptr("auto"),
},
},
Stream: schemas.Ptr(false),
}
// Marshal
jsonBytes, err := original.MarshalJSON()
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Verify reasoning.max_tokens is absent in the JSON string
jsonStr := string(jsonBytes)
if strings.Contains(jsonStr, `"max_tokens"`) {
// Check if it's inside reasoning object
if strings.Contains(jsonStr, `"reasoning"`) {
// Parse to verify it's not in reasoning
var jsonMap map[string]interface{}
if err := json.Unmarshal(jsonBytes, &jsonMap); err == nil {
if reasoning, ok := jsonMap["reasoning"].(map[string]interface{}); ok {
if _, exists := reasoning["max_tokens"]; exists {
t.Error("reasoning.max_tokens should not be present in marshaled JSON")
}
}
}
}
}
// Unmarshal back
var unmarshaled OpenAIResponsesRequest
if err := sonic.Unmarshal(jsonBytes, &unmarshaled); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
// Verify fields are preserved
if unmarshaled.Model != original.Model {
t.Errorf("Model not preserved: expected %q, got %q", original.Model, unmarshaled.Model)
}
if unmarshaled.Stream == nil || *unmarshaled.Stream != *original.Stream {
t.Error("Stream not preserved")
}
if unmarshaled.MaxOutputTokens == nil || *unmarshaled.MaxOutputTokens != *original.MaxOutputTokens {
t.Error("MaxOutputTokens not preserved")
}
if unmarshaled.Temperature == nil || *unmarshaled.Temperature != *original.Temperature {
t.Error("Temperature not preserved")
}
// Verify reasoning fields except MaxTokens
if unmarshaled.Reasoning == nil {
t.Fatal("Reasoning should be present")
}
if unmarshaled.Reasoning.Effort == nil || *unmarshaled.Reasoning.Effort != *original.Reasoning.Effort {
t.Error("Reasoning.Effort not preserved")
}
if unmarshaled.Reasoning.Summary == nil || *unmarshaled.Reasoning.Summary != *original.Reasoning.Summary {
t.Error("Reasoning.Summary not preserved")
}
// MaxTokens should be nil after unmarshaling (since it wasn't in JSON)
if unmarshaled.Reasoning.MaxTokens != nil {
t.Error("Reasoning.MaxTokens should be nil after unmarshaling (was omitted from JSON)")
}
})
}
// Regression test for multi-turn Anthropic tool_result with array-form content.
// The OpenAI Responses API defines function_call_output.output as a string (see
// https://platform.openai.com/docs/api-reference/responses/create). When an
// Anthropic client sends a tool_result whose content is an array of text blocks,
// Bifrost's Anthropic→Responses translator populates
// ResponsesToolMessageOutputStruct.ResponsesFunctionToolCallOutputBlocks.
// Historically, that array was marshaled verbatim onto the wire, which some
// strict OpenAI-compat upstreams (e.g. Ollama Cloud) reject with an error like
//
// json: cannot unmarshal array into Go struct field ResponsesFunctionCallOutput.output of type string
//
// The outgoing OpenAI Responses request must emit `output` as a string for
// text-only tool outputs.
func TestOpenAIResponsesRequestInput_MarshalJSON_FunctionCallOutputFlattensTextBlocksToString(t *testing.T) {
outputText := "line1"
callID := "toolu_abc123"
functionName := "read_file"
input := &OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("Read /tmp/test.txt and tell me what it contains."),
},
},
{
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
Status: schemas.Ptr("completed"),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: schemas.Ptr(callID),
Name: schemas.Ptr(functionName),
Arguments: schemas.Ptr(`{"path":"/tmp/test.txt"}`),
},
},
{
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput),
Status: schemas.Ptr("completed"),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: schemas.Ptr(callID),
Output: &schemas.ResponsesToolMessageOutputStruct{
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
{
Type: schemas.ResponsesInputMessageContentBlockTypeText,
Text: schemas.Ptr(outputText),
},
},
},
},
},
},
}
jsonBytes, err := input.MarshalJSON()
if err != nil {
t.Fatalf("Failed to marshal OpenAIResponsesRequestInput: %v", err)
}
var messages []map[string]interface{}
if err := sonic.Unmarshal(jsonBytes, &messages); err != nil {
t.Fatalf("Failed to unmarshal marshaled input as array: %v\nraw=%s", err, string(jsonBytes))
}
var fcoMsg map[string]interface{}
for _, m := range messages {
if t, ok := m["type"].(string); ok && t == string(schemas.ResponsesMessageTypeFunctionCallOutput) {
fcoMsg = m
break
}
}
if fcoMsg == nil {
t.Fatalf("did not find function_call_output message in marshaled JSON: %s", string(jsonBytes))
}
outputVal, ok := fcoMsg["output"]
if !ok {
t.Fatalf("function_call_output message has no `output` field: %s", string(jsonBytes))
}
outputStr, isString := outputVal.(string)
if !isString {
t.Fatalf("function_call_output.output must be a string (OpenAI Responses API spec); got %T: %v\nraw=%s", outputVal, outputVal, string(jsonBytes))
}
if outputStr != outputText {
t.Fatalf("function_call_output.output mismatch: want %q, got %q", outputText, outputStr)
}
}
// Flattening must concatenate multiple text blocks with newline separators so
// every character from the upstream tool response reaches the model.
func TestOpenAIResponsesRequestInput_MarshalJSON_FunctionCallOutputConcatenatesMultipleTextBlocks(t *testing.T) {
callID := "toolu_multi"
input := &OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
{
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput),
Status: schemas.Ptr("completed"),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: schemas.Ptr(callID),
Output: &schemas.ResponsesToolMessageOutputStruct{
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
{Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: schemas.Ptr("line1")},
{Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: schemas.Ptr("line2")},
},
},
},
},
},
}
jsonBytes, err := input.MarshalJSON()
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
var messages []map[string]interface{}
if err := sonic.Unmarshal(jsonBytes, &messages); err != nil {
t.Fatalf("Failed to unmarshal: %v\nraw=%s", err, string(jsonBytes))
}
if len(messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(messages))
}
got, ok := messages[0]["output"].(string)
if !ok {
t.Fatalf("output must be string, got %T", messages[0]["output"])
}
if want := "line1\nline2"; got != want {
t.Fatalf("flattened output mismatch: want %q, got %q", want, got)
}
}
// When the tool result contains a non-text block (e.g. an image), flattening is
// unsafe — preserve the array form and let the upstream handle it. This keeps
// the fix scoped to the common text-only case without dropping rich content.
func TestOpenAIResponsesRequestInput_MarshalJSON_FunctionCallOutputPreservesNonTextBlocks(t *testing.T) {
callID := "toolu_with_image"
imageURL := "https://example.com/screenshot.png"
input := &OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
{
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput),
Status: schemas.Ptr("completed"),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: schemas.Ptr(callID),
Output: &schemas.ResponsesToolMessageOutputStruct{
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
{Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: schemas.Ptr("here is the screenshot:")},
{
Type: schemas.ResponsesInputMessageContentBlockTypeImage,
ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{
ImageURL: &imageURL,
},
},
},
},
},
},
},
}
jsonBytes, err := input.MarshalJSON()
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
var messages []map[string]interface{}
if err := sonic.Unmarshal(jsonBytes, &messages); err != nil {
t.Fatalf("Failed to unmarshal: %v\nraw=%s", err, string(jsonBytes))
}
if _, isString := messages[0]["output"].(string); isString {
t.Fatalf("non-text blocks must not be flattened to string; raw=%s", string(jsonBytes))
}
}
// TestOpenAIResponsesRequest_MarshalJSON_StripsAnthropicToolFlags ensures the
// Responses serializer drops the four Anthropic-native tool flags
// (defer_loading, allowed_callers, input_examples, eager_input_streaming)
// along with CacheControl before forwarding to OpenAI — mirroring the Chat
// path's behavior so Anthropic-flavored tools cannot 400 OpenAI via Responses.
func TestOpenAIResponsesRequest_MarshalJSON_StripsAnthropicToolFlags(t *testing.T) {
req := &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("hello"),
},
},
},
},
ResponsesParameters: schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{
{
Type: schemas.ResponsesToolTypeFunction,
Name: schemas.Ptr("lookup"),
Description: schemas.Ptr("lookup something"),
CacheControl: &schemas.CacheControl{Type: "ephemeral"},
DeferLoading: schemas.Ptr(true),
AllowedCallers: []string{"direct", "agent"},
EagerInputStreaming: schemas.Ptr(false),
InputExamples: []schemas.ChatToolInputExample{
{Input: json.RawMessage(`{"q":"hi"}`)},
},
ResponsesToolFunction: &schemas.ResponsesToolFunction{},
},
},
},
}
jsonBytes, err := req.MarshalJSON()
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
raw := string(jsonBytes)
// None of the five Anthropic-only tool keys must survive on the wire.
for _, key := range []string{`"cache_control"`, `"defer_loading"`, `"allowed_callers"`, `"input_examples"`, `"eager_input_streaming"`} {
if strings.Contains(raw, key) {
t.Errorf("OpenAI Responses serializer must strip %s; raw=%s", key, raw)
}
}
// Function tool identity should be preserved.
if !strings.Contains(raw, `"name":"lookup"`) {
t.Errorf("tool identity lost after strip; raw=%s", raw)
}
}
// TestOpenAIResponsesRequest_MarshalJSON_DropsAnthropicOnlyToolTypes verifies
// that Anthropic-only tool types (web_fetch, memory) are dropped entirely when
// serializing for OpenAI Responses. Per OpenAI's OpenAPI spec the Responses
// Tool discriminator union does not include web_fetch or memory, so forwarding
// them would trigger a 400 schema-validation error. Mirrors the Chat path's
// isAnthropicServerToolShape drop behavior.
func TestOpenAIResponsesRequest_MarshalJSON_DropsAnthropicOnlyToolTypes(t *testing.T) {
req := &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{
ContentStr: schemas.Ptr("hello"),
},
},
},
},
ResponsesParameters: schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{
// Kept: function (OpenAI-native).
{
Type: schemas.ResponsesToolTypeFunction,
Name: schemas.Ptr("keeper_func"),
ResponsesToolFunction: &schemas.ResponsesToolFunction{},
},
// Dropped: web_fetch (Anthropic-only).
{
Type: schemas.ResponsesToolTypeWebFetch,
Name: schemas.Ptr("anthropic_webfetch"),
ResponsesToolWebFetch: &schemas.ResponsesToolWebFetch{},
},
// Kept: web_search (both support).
{
Type: schemas.ResponsesToolTypeWebSearch,
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{},
},
// Dropped: memory (Anthropic-only).
{
Type: schemas.ResponsesToolTypeMemory,
Name: schemas.Ptr("anthropic_memory"),
},
// Kept: tool_search (both support per OpenAI OpenAPI spec).
{
Type: schemas.ResponsesToolTypeToolSearch,
},
},
},
}
jsonBytes, err := req.MarshalJSON()
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
raw := string(jsonBytes)
// Dropped types must not appear on the wire.
for _, dropped := range []string{`"web_fetch"`, `"memory"`, `"anthropic_webfetch"`, `"anthropic_memory"`} {
if strings.Contains(raw, dropped) {
t.Errorf("Anthropic-only tool must be dropped; found %s in raw=%s", dropped, raw)
}
}
// Kept types must still appear.
for _, kept := range []string{`"function"`, `"web_search"`, `"tool_search"`, `"keeper_func"`} {
if !strings.Contains(raw, kept) {
t.Errorf("supported tool %s should be preserved; raw=%s", kept, raw)
}
}
// Confirm the tools array is present and has exactly 3 entries (2 dropped of 5).
var decoded struct {
Tools []map[string]interface{} `json:"tools"`
}
if err := json.Unmarshal(jsonBytes, &decoded); err != nil {
t.Fatalf("decode failed: %v", err)
}
if len(decoded.Tools) != 3 {
t.Errorf("expected 3 tools after drop (function, web_search, tool_search), got %d; tools=%+v", len(decoded.Tools), decoded.Tools)
}
}
// TestOpenAIResponsesRequest_MarshalJSON_KeepsAllWhenAllSupported verifies the
// no-reshape fast path: if every tool is OpenAI-compatible with no
// Anthropic-only flags, the tools slice passes through unchanged (no copy,
// no drop).
func TestOpenAIResponsesRequest_MarshalJSON_KeepsAllWhenAllSupported(t *testing.T) {
req := &OpenAIResponsesRequest{
Model: "gpt-4o",
Input: OpenAIResponsesRequestInput{
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
{
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr("hi")},
},
},
},
ResponsesParameters: schemas.ResponsesParameters{
Tools: []schemas.ResponsesTool{
{Type: schemas.ResponsesToolTypeFunction, Name: schemas.Ptr("f"), ResponsesToolFunction: &schemas.ResponsesToolFunction{}},
{Type: schemas.ResponsesToolTypeWebSearch, ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{}},
{Type: schemas.ResponsesToolTypeCodeInterpreter, ResponsesToolCodeInterpreter: &schemas.ResponsesToolCodeInterpreter{}},
},
},
}
jsonBytes, err := req.MarshalJSON()
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var decoded struct {
Tools []map[string]interface{} `json:"tools"`
}
if err := json.Unmarshal(jsonBytes, &decoded); err != nil {
t.Fatalf("decode failed: %v", err)
}
if len(decoded.Tools) != 3 {
t.Errorf("expected 3 tools preserved, got %d", len(decoded.Tools))
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,43 @@
package openai
import (
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostSpeechRequest converts an OpenAI speech request to Bifrost format
func (request *OpenAISpeechRequest) ToBifrostSpeechRequest(ctx *schemas.BifrostContext) *schemas.BifrostSpeechRequest {
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
return &schemas.BifrostSpeechRequest{
Provider: provider,
Model: model,
Input: &schemas.SpeechInput{Input: request.Input},
Params: &request.SpeechParameters,
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
}
}
// ToOpenAISpeechRequest converts a Bifrost speech request to OpenAI format
func ToOpenAISpeechRequest(bifrostReq *schemas.BifrostSpeechRequest) *OpenAISpeechRequest {
if bifrostReq == nil || bifrostReq.Input.Input == "" {
return nil
}
speechInput := bifrostReq.Input
params := bifrostReq.Params
openaiReq := &OpenAISpeechRequest{
Model: bifrostReq.Model,
Input: speechInput.Input,
}
if params != nil {
openaiReq.SpeechParameters = *params
}
if bifrostReq.Params != nil {
openaiReq.ExtraParams = bifrostReq.Params.ExtraParams
}
return openaiReq
}

View File

@@ -0,0 +1,75 @@
package openai
import (
"maps"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToOpenAITextCompletionRequest converts a Bifrost text completion request to OpenAI format
func ToOpenAITextCompletionRequest(bifrostReq *schemas.BifrostTextCompletionRequest) *OpenAITextCompletionRequest {
if bifrostReq == nil {
return nil
}
params := bifrostReq.Params
openaiReq := &OpenAITextCompletionRequest{
Model: bifrostReq.Model,
Prompt: bifrostReq.Input,
}
if params != nil {
openaiReq.TextCompletionParameters = *params
// Drop user field if it exceeds OpenAI's 64 character limit
openaiReq.TextCompletionParameters.User = SanitizeUserField(openaiReq.TextCompletionParameters.User)
if bifrostReq.Params.ExtraParams != nil {
openaiReq.ExtraParams = maps.Clone(bifrostReq.Params.ExtraParams)
openaiReq.TextCompletionParameters.ExtraParams = openaiReq.ExtraParams
}
}
if bifrostReq.Provider == schemas.Fireworks {
openaiReq.applyFireworksTextCompletionCompatibility()
}
return openaiReq
}
// applyFireworksTextCompletionCompatibility maps Fireworks-specific text fields.
func (req *OpenAITextCompletionRequest) applyFireworksTextCompletionCompatibility() {
if req == nil || req.ExtraParams == nil {
return
}
// Fireworks uses prompt_cache_isolation_key for text-completion cache isolation.
if req.PromptCacheIsolationKey == nil {
if value, ok := req.ExtraParams["prompt_cache_key"]; ok {
switch typed := value.(type) {
case string:
if typed != "" {
req.PromptCacheIsolationKey = &typed
}
case *string:
if typed != nil && *typed != "" {
req.PromptCacheIsolationKey = typed
}
}
}
}
delete(req.ExtraParams, "prompt_cache_key")
req.TextCompletionParameters.ExtraParams = req.ExtraParams
}
// ToBifrostTextCompletionRequest converts an OpenAI text completion request to Bifrost format
func (req *OpenAITextCompletionRequest) ToBifrostTextCompletionRequest(ctx *schemas.BifrostContext) *schemas.BifrostTextCompletionRequest {
if req == nil {
return nil
}
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
return &schemas.BifrostTextCompletionRequest{
Provider: provider,
Model: model,
Input: req.Prompt,
Params: &req.TextCompletionParameters,
Fallbacks: schemas.ParseFallbacks(req.Fallbacks),
}
}

View File

@@ -0,0 +1,73 @@
package openai
import (
"testing"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
func TestToOpenAITextCompletionRequest_FireworksUsesCacheIsolation(t *testing.T) {
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
ctx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
cacheKey := "cache-key-1"
prompt := "A is for apple and B is for"
extraParams := map[string]interface{}{
"prompt_cache_key": cacheKey,
"top_k": 4,
}
bifrostReq := &schemas.BifrostTextCompletionRequest{
Provider: schemas.Fireworks,
Model: "accounts/fireworks/models/deepseek-v3p2",
Input: &schemas.TextCompletionInput{
PromptStr: &prompt,
},
Params: &schemas.TextCompletionParameters{
ExtraParams: extraParams,
},
}
result := ToOpenAITextCompletionRequest(bifrostReq)
if result == nil {
t.Fatal("expected non-nil result")
}
if result.PromptCacheIsolationKey == nil || *result.PromptCacheIsolationKey != cacheKey {
t.Fatalf("expected prompt_cache_isolation_key %q, got %v", cacheKey, result.PromptCacheIsolationKey)
}
if _, ok := result.ExtraParams["prompt_cache_key"]; ok {
t.Fatalf("expected prompt_cache_key to be removed from extra params, got %#v", result.ExtraParams)
}
if _, ok := bifrostReq.Params.ExtraParams["prompt_cache_key"]; !ok {
t.Fatalf("expected original extra params to remain unchanged, got %#v", bifrostReq.Params.ExtraParams)
}
wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
bifrostReq,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToOpenAITextCompletionRequest(bifrostReq), nil
},
)
if bifrostErr != nil {
t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message)
}
var jsonMap map[string]interface{}
if err := sonic.Unmarshal(wireBody, &jsonMap); err != nil {
t.Fatalf("failed to parse marshaled request body: %v", err)
}
if got, ok := jsonMap["prompt_cache_isolation_key"].(string); !ok || got != cacheKey {
t.Fatalf("expected prompt_cache_isolation_key %q in wire payload, got %#v", cacheKey, jsonMap["prompt_cache_isolation_key"])
}
if _, ok := jsonMap["prompt_cache_key"]; ok {
t.Fatalf("expected prompt_cache_key to be absent from wire payload, got %#v", jsonMap["prompt_cache_key"])
}
if got, ok := jsonMap["top_k"].(float64); !ok || got != 4 {
t.Fatalf("expected top_k extra param to be preserved, got %#v", jsonMap["top_k"])
}
}

View File

@@ -0,0 +1,117 @@
package openai
import (
"fmt"
"mime/multipart"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostTranscriptionRequest converts an OpenAI transcription request to Bifrost format
func (request *OpenAITranscriptionRequest) ToBifrostTranscriptionRequest(ctx *schemas.BifrostContext) *schemas.BifrostTranscriptionRequest {
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
return &schemas.BifrostTranscriptionRequest{
Provider: provider,
Model: model,
Input: &schemas.TranscriptionInput{
File: request.File,
},
Params: &request.TranscriptionParameters,
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
}
}
// ToOpenAITranscriptionRequest converts a Bifrost transcription request to OpenAI format
func ToOpenAITranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionRequest) *OpenAITranscriptionRequest {
if bifrostReq == nil || bifrostReq.Input.File == nil {
return nil
}
transcriptionInput := bifrostReq.Input
params := bifrostReq.Params
openaiReq := &OpenAITranscriptionRequest{
Model: bifrostReq.Model,
File: transcriptionInput.File,
Filename: transcriptionInput.Filename,
}
if params != nil {
openaiReq.TranscriptionParameters = *params
}
return openaiReq
}
// ParseTranscriptionFormDataBodyFromRequest parses the transcription request and writes it to the multipart form.
func ParseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAITranscriptionRequest, providerName schemas.ModelProvider) *schemas.BifrostError {
// Add model field before the file so upstreams can route without buffering the audio payload.
if err := writer.WriteField("model", openaiReq.Model); err != nil {
return utils.NewBifrostOperationError("failed to write model field", err)
}
// Add optional fields
if openaiReq.Language != nil {
if err := writer.WriteField("language", *openaiReq.Language); err != nil {
return utils.NewBifrostOperationError("failed to write language field", err)
}
}
if openaiReq.Prompt != nil {
if err := writer.WriteField("prompt", *openaiReq.Prompt); err != nil {
return utils.NewBifrostOperationError("failed to write prompt field", err)
}
}
if openaiReq.ResponseFormat != nil {
if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil {
return utils.NewBifrostOperationError("failed to write response_format field", err)
}
}
if openaiReq.Temperature != nil {
if err := writer.WriteField("temperature", fmt.Sprintf("%g", *openaiReq.Temperature)); err != nil {
return utils.NewBifrostOperationError("failed to write temperature field", err)
}
}
for _, granularity := range openaiReq.TimestampGranularities {
if err := writer.WriteField("timestamp_granularities[]", granularity); err != nil {
return utils.NewBifrostOperationError("failed to write timestamp_granularities field", err)
}
}
for _, include := range openaiReq.Include {
if err := writer.WriteField("include[]", include); err != nil {
return utils.NewBifrostOperationError("failed to write include field", err)
}
}
if openaiReq.Stream != nil && *openaiReq.Stream {
if err := writer.WriteField("stream", "true"); err != nil {
return utils.NewBifrostOperationError("failed to write stream field", err)
}
}
// Add file field last so large multipart uploads don't block model discovery upstream.
filename := openaiReq.Filename
if filename == "" {
filename = utils.AudioFilenameFromBytes(openaiReq.File)
}
fileWriter, err := writer.CreateFormFile("file", filename)
if err != nil {
return utils.NewBifrostOperationError("failed to create form file", err)
}
if _, err := fileWriter.Write(openaiReq.File); err != nil {
return utils.NewBifrostOperationError("failed to write file data", err)
}
// Close the multipart writer
if err := writer.Close(); err != nil {
return utils.NewBifrostOperationError("failed to close multipart writer", err)
}
return nil
}

View File

@@ -0,0 +1,79 @@
package openai
import (
"bytes"
"io"
"mime"
"mime/multipart"
"testing"
"github.com/maximhq/bifrost/core/schemas"
)
func multipartPartOrder(t *testing.T, contentType string, body []byte) []string {
t.Helper()
_, params, err := mime.ParseMediaType(contentType)
if err != nil {
t.Fatalf("ParseMediaType(%q): %v", contentType, err)
}
boundary := params["boundary"]
if boundary == "" {
t.Fatalf("missing boundary in %q", contentType)
}
reader := multipart.NewReader(bytes.NewReader(body), boundary)
var order []string
for {
part, err := reader.NextPart()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("NextPart(): %v", err)
}
order = append(order, part.FormName())
_, _ = io.Copy(io.Discard, part)
_ = part.Close()
}
return order
}
func TestParseTranscriptionFormDataBodyFromRequest_OrdersMetadataBeforeFile(t *testing.T) {
language := "en"
prompt := "transcribe this"
responseFormat := "verbose_json"
temperature := 0.2
stream := true
var body bytes.Buffer
writer := multipart.NewWriter(&body)
req := &OpenAITranscriptionRequest{
Model: "whisper-1",
File: []byte("audio-bytes"),
Filename: "sample.mp3",
TranscriptionParameters: schemas.TranscriptionParameters{
Language: &language,
Prompt: &prompt,
ResponseFormat: &responseFormat,
Temperature: &temperature,
TimestampGranularities: []string{"word"},
Include: []string{"logprobs"},
},
Stream: &stream,
}
if bifrostErr := ParseTranscriptionFormDataBodyFromRequest(writer, req, schemas.OpenAI); bifrostErr != nil {
t.Fatalf("unexpected bifrost error: %v", bifrostErr.Error.Message)
}
order := multipartPartOrder(t, writer.FormDataContentType(), body.Bytes())
if len(order) == 0 {
t.Fatal("expected multipart parts to be written")
}
if order[len(order)-1] != "file" {
t.Fatalf("expected file part last, got order %v", order)
}
if order[0] != "model" {
t.Fatalf("expected model part first, got order %v", order)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,468 @@
package openai
import (
"testing"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
)
func TestOpenAIChatRequest_UnmarshalJSON_BaseFieldsPreserved(t *testing.T) {
tests := []struct {
name string
jsonPayload string
validate func(t *testing.T, req *OpenAIChatRequest)
}{
{
name: "all base fields preserved with ChatParameters",
jsonPayload: `{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "Hello, world!"
}
],
"stream": true,
"max_tokens": 100,
"prompt_cache_isolation_key": "cache-key-1",
"fallbacks": ["gpt-3.5-turbo"],
"temperature": 0.7,
"top_p": 0.9
}`,
validate: func(t *testing.T, req *OpenAIChatRequest) {
// Assert base fields are preserved
if req.Model != "gpt-4o" {
t.Errorf("Expected Model to be 'gpt-4o', got %q", req.Model)
}
if len(req.Messages) != 1 {
t.Fatalf("Expected 1 message, got %d", len(req.Messages))
}
if req.Messages[0].Role != schemas.ChatMessageRoleUser {
t.Errorf("Expected message role to be 'user', got %q", req.Messages[0].Role)
}
if req.Messages[0].Content == nil || req.Messages[0].Content.ContentStr == nil {
t.Fatal("Expected message content to be set")
}
if *req.Messages[0].Content.ContentStr != "Hello, world!" {
t.Errorf("Expected message content to be 'Hello, world!', got %q", *req.Messages[0].Content.ContentStr)
}
if req.Stream == nil || !*req.Stream {
t.Error("Expected Stream to be true")
}
if req.MaxTokens == nil || *req.MaxTokens != 100 {
t.Errorf("Expected MaxTokens to be 100, got %v", req.MaxTokens)
}
if req.PromptCacheIsolationKey == nil || *req.PromptCacheIsolationKey != "cache-key-1" {
t.Errorf("Expected PromptCacheIsolationKey to be %q, got %v", "cache-key-1", req.PromptCacheIsolationKey)
}
if len(req.Fallbacks) != 1 || req.Fallbacks[0] != "gpt-3.5-turbo" {
t.Errorf("Expected Fallbacks to be ['gpt-3.5-turbo'], got %v", req.Fallbacks)
}
// Assert ChatParameters fields are populated
if req.Temperature == nil || *req.Temperature != 0.7 {
t.Errorf("Expected Temperature to be 0.7, got %v", req.Temperature)
}
if req.TopP == nil || *req.TopP != 0.9 {
t.Errorf("Expected TopP to be 0.9, got %v", req.TopP)
}
},
},
{
name: "base fields with multiple ChatParameters fields",
jsonPayload: `{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What is 2+2?"
}
],
"stream": false,
"max_tokens": 500,
"fallbacks": ["gpt-4o", "gpt-4"],
"temperature": 0.5,
"top_p": 0.95,
"frequency_penalty": 0.2,
"presence_penalty": 0.3,
"seed": 42,
"stop": ["STOP", "END"]
}`,
validate: func(t *testing.T, req *OpenAIChatRequest) {
// Assert base fields
if req.Model != "gpt-3.5-turbo" {
t.Errorf("Expected Model to be 'gpt-3.5-turbo', got %q", req.Model)
}
if len(req.Messages) != 2 {
t.Fatalf("Expected 2 messages, got %d", len(req.Messages))
}
if req.Stream == nil || *req.Stream {
t.Error("Expected Stream to be false")
}
if req.MaxTokens == nil || *req.MaxTokens != 500 {
t.Errorf("Expected MaxTokens to be 500, got %v", req.MaxTokens)
}
if len(req.Fallbacks) != 2 {
t.Errorf("Expected 2 fallbacks, got %d", len(req.Fallbacks))
}
// Assert multiple ChatParameters fields
if req.Temperature == nil || *req.Temperature != 0.5 {
t.Errorf("Expected Temperature to be 0.5, got %v", req.Temperature)
}
if req.TopP == nil || *req.TopP != 0.95 {
t.Errorf("Expected TopP to be 0.95, got %v", req.TopP)
}
if req.FrequencyPenalty == nil || *req.FrequencyPenalty != 0.2 {
t.Errorf("Expected FrequencyPenalty to be 0.2, got %v", req.FrequencyPenalty)
}
if req.PresencePenalty == nil || *req.PresencePenalty != 0.3 {
t.Errorf("Expected PresencePenalty to be 0.3, got %v", req.PresencePenalty)
}
if req.Seed == nil || *req.Seed != 42 {
t.Errorf("Expected Seed to be 42, got %v", req.Seed)
}
if len(req.Stop) != 2 {
t.Errorf("Expected Stop to have 2 elements, got %d", len(req.Stop))
}
},
},
{
name: "base fields with optional fields omitted",
jsonPayload: `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Test"
}
],
"temperature": 1.0,
"top_p": 1.0
}`,
validate: func(t *testing.T, req *OpenAIChatRequest) {
if req.Model != "gpt-4" {
t.Errorf("Expected Model to be 'gpt-4', got %q", req.Model)
}
if len(req.Messages) != 1 {
t.Fatalf("Expected 1 message, got %d", len(req.Messages))
}
// Optional fields should be nil/empty when omitted
if req.Stream != nil {
t.Error("Expected Stream to be nil when omitted")
}
if req.MaxTokens != nil {
t.Error("Expected MaxTokens to be nil when omitted")
}
if len(req.Fallbacks) != 0 {
t.Errorf("Expected Fallbacks to be empty when omitted, got %v", req.Fallbacks)
}
// ChatParameters fields should still be populated
if req.Temperature == nil || *req.Temperature != 1.0 {
t.Errorf("Expected Temperature to be 1.0, got %v", req.Temperature)
}
if req.TopP == nil || *req.TopP != 1.0 {
t.Errorf("Expected TopP to be 1.0, got %v", req.TopP)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var req OpenAIChatRequest
if err := sonic.Unmarshal([]byte(tt.jsonPayload), &req); err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
tt.validate(t, &req)
})
}
}
func TestOpenAIChatRequest_UnmarshalJSON_ChatParametersCustomLogic(t *testing.T) {
tests := []struct {
name string
jsonPayload string
validate func(t *testing.T, req *OpenAIChatRequest)
expectError bool
}{
{
name: "reasoning_effort converted to Reasoning.Effort",
jsonPayload: `{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "Think step by step"
}
],
"reasoning_effort": "high",
"temperature": 0.8
}`,
validate: func(t *testing.T, req *OpenAIChatRequest) {
// Assert base fields are preserved
if req.Model != "gpt-4o" {
t.Errorf("Expected Model to be 'gpt-4o', got %q", req.Model)
}
// Assert reasoning_effort was converted to Reasoning.Effort
if req.Reasoning == nil {
t.Fatal("Expected Reasoning to be set from reasoning_effort")
}
if req.Reasoning.Effort == nil {
t.Fatal("Expected Reasoning.Effort to be set")
}
if *req.Reasoning.Effort != "high" {
t.Errorf("Expected Reasoning.Effort to be 'high', got %q", *req.Reasoning.Effort)
}
// Assert other ChatParameters fields are still populated
if req.Temperature == nil || *req.Temperature != 0.8 {
t.Errorf("Expected Temperature to be 0.8, got %v", req.Temperature)
}
},
expectError: false,
},
{
name: "both reasoning and reasoning_effort should error",
jsonPayload: `{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "Test"
}
],
"reasoning": {
"effort": "medium"
},
"reasoning_effort": "high"
}`,
validate: func(t *testing.T, req *OpenAIChatRequest) {
// This should have failed during unmarshaling
},
expectError: true,
},
{
name: "reasoning_effort with multiple ChatParameters fields",
jsonPayload: `{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": "Analyze this"
}
],
"reasoning_effort": "medium",
"temperature": 0.6,
"top_p": 0.85,
"max_completion_tokens": 2000
}`,
validate: func(t *testing.T, req *OpenAIChatRequest) {
// Assert base fields
if req.Model != "gpt-4o" {
t.Errorf("Expected Model to be 'gpt-4o', got %q", req.Model)
}
// Assert reasoning_effort conversion
if req.Reasoning == nil || req.Reasoning.Effort == nil {
t.Fatal("Expected Reasoning.Effort to be set from reasoning_effort")
}
if *req.Reasoning.Effort != "medium" {
t.Errorf("Expected Reasoning.Effort to be 'medium', got %q", *req.Reasoning.Effort)
}
// Assert other ChatParameters fields
if req.Temperature == nil || *req.Temperature != 0.6 {
t.Errorf("Expected Temperature to be 0.6, got %v", req.Temperature)
}
if req.TopP == nil || *req.TopP != 0.85 {
t.Errorf("Expected TopP to be 0.85, got %v", req.TopP)
}
if req.MaxCompletionTokens == nil || *req.MaxCompletionTokens != 2000 {
t.Errorf("Expected MaxCompletionTokens to be 2000, got %v", req.MaxCompletionTokens)
}
},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var req OpenAIChatRequest
err := sonic.Unmarshal([]byte(tt.jsonPayload), &req)
if tt.expectError {
if err == nil {
t.Error("Expected error but got none")
}
return
}
if err != nil {
t.Fatalf("Unexpected error during unmarshaling: %v", err)
}
tt.validate(t, &req)
})
}
}
func TestOpenAIChatRequest_UnmarshalJSON_PresenceAssertions(t *testing.T) {
// Test that verifies presence of fields (not just values)
jsonPayload := `{
"model": "gpt-4o-mini",
"messages": [
{
"role": "assistant",
"content": "Hello!"
}
],
"stream": false,
"max_tokens": 150,
"fallbacks": ["model1", "model2"],
"temperature": 0.3,
"top_p": 0.7,
"user": "test-user-123"
}`
var req OpenAIChatRequest
if err := sonic.Unmarshal([]byte(jsonPayload), &req); err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
// Presence assertions for base fields
if req.Model == "" {
t.Error("Model field should be present")
}
if len(req.Messages) == 0 {
t.Error("Messages field should be present and non-empty")
}
if req.Stream == nil {
t.Error("Stream field should be present (even if false)")
}
if req.MaxTokens == nil {
t.Error("MaxTokens field should be present")
}
if len(req.Fallbacks) == 0 {
t.Error("Fallbacks field should be present and non-empty")
}
// Presence assertions for ChatParameters fields
if req.Temperature == nil {
t.Error("Temperature field should be present")
}
if req.TopP == nil {
t.Error("TopP field should be present")
}
if req.User == nil {
t.Error("User field should be present")
}
}
func TestOpenAIChatRequest_UnmarshalJSON_ValueAssertions(t *testing.T) {
// Test that verifies exact values match expectations
jsonPayload := `{
"model": "gpt-4-turbo",
"messages": [
{
"role": "system",
"content": "System message"
},
{
"role": "user",
"content": "User message"
}
],
"stream": true,
"max_tokens": 250,
"fallbacks": ["fallback1"],
"temperature": 0.9,
"top_p": 0.95,
"seed": 12345,
"stop": ["END", "STOP"]
}`
var req OpenAIChatRequest
if err := sonic.Unmarshal([]byte(jsonPayload), &req); err != nil {
t.Fatalf("Failed to unmarshal JSON: %v", err)
}
// Value assertions for base fields
if req.Model != "gpt-4-turbo" {
t.Errorf("Expected Model value 'gpt-4-turbo', got %q", req.Model)
}
if len(req.Messages) != 2 {
t.Fatalf("Expected 2 messages, got %d", len(req.Messages))
}
if req.Messages[0].Role != schemas.ChatMessageRoleSystem {
t.Errorf("Expected first message role 'system', got %q", req.Messages[0].Role)
}
if req.Messages[1].Role != schemas.ChatMessageRoleUser {
t.Errorf("Expected second message role 'user', got %q", req.Messages[1].Role)
}
if req.Stream == nil || !*req.Stream {
t.Error("Expected Stream value to be true")
}
if req.MaxTokens == nil || *req.MaxTokens != 250 {
t.Errorf("Expected MaxTokens value 250, got %v", req.MaxTokens)
}
if len(req.Fallbacks) != 1 || req.Fallbacks[0] != "fallback1" {
t.Errorf("Expected Fallbacks value ['fallback1'], got %v", req.Fallbacks)
}
// Value assertions for ChatParameters fields
if req.Temperature == nil || *req.Temperature != 0.9 {
t.Errorf("Expected Temperature value 0.9, got %v", req.Temperature)
}
if req.TopP == nil || *req.TopP != 0.95 {
t.Errorf("Expected TopP value 0.95, got %v", req.TopP)
}
if req.Seed == nil || *req.Seed != 12345 {
t.Errorf("Expected Seed value 12345, got %v", req.Seed)
}
if len(req.Stop) != 2 || req.Stop[0] != "END" || req.Stop[1] != "STOP" {
t.Errorf("Expected Stop value ['END', 'STOP'], got %v", req.Stop)
}
}

View File

@@ -0,0 +1,108 @@
package openai
import (
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
// CustomResponseHandler is a function that produces a Bifrost response from a Bifrost request.
// T is the concrete Bifrost response type (e.g. BifrostEmbeddingResponse, BifrostTextCompletionResponse, BifrostChatResponse, BifrostResponsesResponse, BifrostImageGenerationResponse, BifrostTranscriptionResponse).
type responseHandler[T any] func(responseBody []byte, response *T, requestBody []byte, sendBackRawRequest bool, sendBackRawResponse bool) (rawRequest interface{}, rawResponse interface{}, bifrostErr *schemas.BifrostError)
func ConvertOpenAIMessagesToBifrostMessages(messages []OpenAIMessage) []schemas.ChatMessage {
bifrostMessages := make([]schemas.ChatMessage, len(messages))
for i, message := range messages {
bifrostMessages[i] = schemas.ChatMessage{
Name: message.Name,
Role: message.Role,
Content: message.Content,
ChatToolMessage: message.ChatToolMessage,
}
if message.OpenAIChatAssistantMessage != nil {
bifrostMessages[i].ChatAssistantMessage = &schemas.ChatAssistantMessage{
Refusal: message.OpenAIChatAssistantMessage.Refusal,
Reasoning: message.OpenAIChatAssistantMessage.Reasoning,
Annotations: message.OpenAIChatAssistantMessage.Annotations,
ToolCalls: message.OpenAIChatAssistantMessage.ToolCalls,
}
}
}
return bifrostMessages
}
func ConvertBifrostMessagesToOpenAIMessages(messages []schemas.ChatMessage) []OpenAIMessage {
openaiMessages := make([]OpenAIMessage, len(messages))
for i, message := range messages {
openaiMessages[i] = OpenAIMessage{
Name: message.Name,
Role: message.Role,
Content: message.Content,
ChatToolMessage: message.ChatToolMessage,
}
if message.ChatAssistantMessage != nil {
openaiMessages[i].OpenAIChatAssistantMessage = &OpenAIChatAssistantMessage{
Refusal: message.ChatAssistantMessage.Refusal,
Reasoning: message.ChatAssistantMessage.Reasoning,
Annotations: message.ChatAssistantMessage.Annotations,
ToolCalls: message.ChatAssistantMessage.ToolCalls,
}
}
}
return openaiMessages
}
// isOpenAIReasoningModel checks if the given model is an OpenAI reasoning model
// that supports the reasoning.effort parameter.
// OpenAI reasoning models include o1, o3, o4 series and GPT-5.x variants.
// Note: -pro and -codex variants (e.g. gpt-5.2-pro, gpt-5.2-codex) are always-reasoning
// models that do NOT support effort "none" — callers must handle top_p stripping separately.
// TODO we need to find a better way to check if a model is an OpenAI reasoning model
func isOpenAIReasoningModel(model string) bool {
_, parsedModel := schemas.ParseModelString(model, schemas.OpenAI)
if parsedModel != "" {
model = parsedModel
}
modelLower := strings.ToLower(model)
// Check for o1 or o3 series models
// Match patterns like: o1, o1-mini, o1-preview, o3, o3-mini, etc.
// Also match gpt-oss models which support reasoning
if strings.Contains(modelLower, "gpt-oss") {
return true
}
// Check for o1/o3/o4 series - these are reasoning models
// The pattern matches "o1", "o3", or "o4" followed by end of string, hyphen, or underscore
for _, prefix := range []string{"o1", "o3", "o4"} {
if strings.HasPrefix(modelLower, prefix) {
// Check if it's exactly the prefix or followed by a separator
if len(modelLower) == len(prefix) ||
modelLower[len(prefix)] == '-' ||
modelLower[len(prefix)] == '_' {
return true
}
}
// Also check for models like "openai-o1-mini" where prefix is not at start
if strings.Contains(modelLower, "-"+prefix+"-") ||
strings.Contains(modelLower, "_"+prefix+"_") ||
strings.HasSuffix(modelLower, "-"+prefix) ||
strings.HasSuffix(modelLower, "_"+prefix) {
return true
}
}
// Check for GPT-5 series models which support reasoning.effort
if strings.HasPrefix(modelLower, "gpt-5") {
return true
}
return false
}
// MaxUserFieldLength for OpenAI enforces a 64 character maximum on the user field
const MaxUserFieldLength = 64
// SanitizeUserField returns nil if user exceeds MaxUserFieldLength, otherwise returns the original value
func SanitizeUserField(user *string) *string {
if user != nil && len(*user) > MaxUserFieldLength {
return nil
}
return user
}

View File

@@ -0,0 +1,212 @@
package openai
import (
"encoding/base64"
"fmt"
"mime/multipart"
"net/http"
"github.com/maximhq/bifrost/core/providers/utils"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToOpenAIVideoGenerationRequest converts a Bifrost Video Request to OpenAI format
func ToOpenAIVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRequest) (*OpenAIVideoGenerationRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil || bifrostReq.Input.Prompt == "" {
return nil, fmt.Errorf("bifrost request, input, or prompt is nil/empty")
}
req := &OpenAIVideoGenerationRequest{
Model: bifrostReq.Model,
Prompt: bifrostReq.Input.Prompt,
}
if bifrostReq.Input.InputReference != nil {
// convert base64 to bytes
sanitizedURL, err := schemas.SanitizeImageURL(*bifrostReq.Input.InputReference)
if err != nil {
return nil, fmt.Errorf("invalid input reference: %w", err)
}
urlInfo := schemas.ExtractURLTypeInfo(sanitizedURL)
if urlInfo.DataURLWithoutPrefix != nil {
bytes, err := base64.StdEncoding.DecodeString(*urlInfo.DataURLWithoutPrefix)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 input reference: %w", err)
}
req.InputReference = bytes
} else {
return nil, fmt.Errorf("input_reference must be a base64 data URL (e.g. data:image/png;base64,...)")
}
}
if bifrostReq.Params != nil {
if bifrostReq.Params.Seconds != nil {
req.Seconds = bifrostReq.Params.Seconds
}
// Validate and set size
if bifrostReq.Params.Size != "" {
// Check if the provided size is valid
if ValidOpenAIVideoSizes[bifrostReq.Params.Size] {
req.Size = bifrostReq.Params.Size
} else {
// Invalid size provided, use default
req.Size = string(DefaultOpenAIVideoSize)
}
} else {
// No size provided, use default
req.Size = string(DefaultOpenAIVideoSize)
}
req.ExtraParams = bifrostReq.Params.ExtraParams
}
return req, nil
}
func ToOpenAIVideoRemixRequest(bifrostReq *schemas.BifrostVideoRemixRequest) (*OpenAIVideoRemixRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil || bifrostReq.Input.Prompt == "" {
return nil, fmt.Errorf("bifrost request, input, or prompt is nil/empty")
}
req := &OpenAIVideoRemixRequest{
Prompt: bifrostReq.Input.Prompt,
}
return req, nil
}
func ToBifrostVideoRemixRequest(openaiReq *OpenAIVideoRemixRequest) *schemas.BifrostVideoRemixRequest {
if openaiReq == nil || openaiReq.Prompt == "" {
return nil
}
provider := openaiReq.Provider
if provider == "" {
provider = schemas.OpenAI
}
return &schemas.BifrostVideoRemixRequest{
ID: openaiReq.ID,
Provider: provider,
Input: &schemas.VideoGenerationInput{
Prompt: openaiReq.Prompt,
},
}
}
func (req *OpenAIVideoGenerationRequest) ToBifrostVideoGenerationRequest(ctx *schemas.BifrostContext) *schemas.BifrostVideoGenerationRequest {
if req == nil {
return nil
}
defaultProvider := schemas.OpenAI
// for requests coming from azure sdk without provider prefix, we need to set the default provider to azure
if ctx != nil {
if isAzureUser, ok := ctx.Value(schemas.BifrostContextKeyIsAzureUserAgent).(bool); ok && isAzureUser {
defaultProvider = schemas.Azure
}
}
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, defaultProvider))
input := &schemas.VideoGenerationInput{
Prompt: req.Prompt,
}
if req.InputReference != nil {
input.InputReference = schemas.Ptr(providerUtils.FileBytesToBase64DataURL(req.InputReference))
}
return &schemas.BifrostVideoGenerationRequest{
Provider: provider,
Model: model,
Input: input,
Params: &req.VideoGenerationParameters,
Fallbacks: schemas.ParseFallbacks(req.Fallbacks),
}
}
// parseVideoGenerationFormDataBodyFromRequest parses the video generation request and writes it to the multipart form.
func parseVideoGenerationFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIVideoGenerationRequest, providerName schemas.ModelProvider) *schemas.BifrostError {
// Add prompt field (required)
if openaiReq.Prompt == "" {
return providerUtils.NewBifrostOperationError("prompt is required", nil)
}
if err := writer.WriteField("prompt", openaiReq.Prompt); err != nil {
return providerUtils.NewBifrostOperationError("failed to write prompt field", err)
}
// Add optional model field
if openaiReq.Model != "" {
if err := writer.WriteField("model", openaiReq.Model); err != nil {
return providerUtils.NewBifrostOperationError("failed to write model field", err)
}
}
// Add optional seconds field
if openaiReq.Seconds != nil {
if err := writer.WriteField("seconds", *openaiReq.Seconds); err != nil {
return providerUtils.NewBifrostOperationError("failed to write seconds field", err)
}
}
// Add optional size field
if openaiReq.Size != "" {
if err := writer.WriteField("size", openaiReq.Size); err != nil {
return providerUtils.NewBifrostOperationError("failed to write size field", err)
}
}
// Add optional input_reference field (image or video file)
if len(openaiReq.InputReference) > 0 {
// Detect MIME type
mimeType := http.DetectContentType(openaiReq.InputReference)
// Validate and set proper MIME type
validMimeTypes := map[string]bool{
"image/jpeg": true,
"image/png": true,
"image/webp": true,
"video/mp4": true,
}
if !validMimeTypes[mimeType] {
// Default to image/png if not detected properly
mimeType = "image/png"
}
// Determine filename based on MIME type
var filename string
switch mimeType {
case "image/jpeg":
filename = "input_reference.jpg"
case "image/webp":
filename = "input_reference.webp"
case "video/mp4":
filename = "input_reference.mp4"
default:
filename = "input_reference.png"
}
// Create form part with proper Content-Type header
part, err := writer.CreatePart(map[string][]string{
"Content-Disposition": {fmt.Sprintf(`form-data; name="input_reference"; filename="%s"`, filename)},
"Content-Type": {mimeType},
})
if err != nil {
return providerUtils.NewBifrostOperationError("failed to create form part for input_reference", err)
}
if _, err := part.Write(openaiReq.InputReference); err != nil {
return providerUtils.NewBifrostOperationError("failed to write input_reference file data", err)
}
}
// Close the multipart writer
if err := writer.Close(); err != nil {
return providerUtils.NewBifrostOperationError("failed to close multipart writer", err)
}
return nil
}

View File

@@ -0,0 +1,35 @@
package openai
import (
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
// SupportsWebSocketMode returns true since OpenAI natively supports the Responses API WebSocket Mode.
func (provider *OpenAIProvider) SupportsWebSocketMode() bool {
return true
}
// WebSocketResponsesURL returns the WebSocket URL for the OpenAI Responses API.
// Converts the HTTP base URL to a WSS URL: https://api.openai.com -> wss://api.openai.com/v1/responses
func (provider *OpenAIProvider) WebSocketResponsesURL(key schemas.Key) string {
base := provider.networkConfig.BaseURL
base = strings.Replace(base, "https://", "wss://", 1)
base = strings.Replace(base, "http://", "ws://", 1)
return base + "/v1/responses"
}
// WebSocketHeaders returns the headers required for the upstream WebSocket connection to OpenAI.
func (provider *OpenAIProvider) WebSocketHeaders(key schemas.Key) map[string]string {
headers := map[string]string{
"Authorization": "Bearer " + key.Value.GetValue(),
}
for k, v := range provider.networkConfig.ExtraHeaders {
if strings.EqualFold(k, "Authorization") {
continue
}
headers[k] = v
}
return headers
}