first commit
This commit is contained in:
172
core/providers/openai/batch.go
Normal file
172
core/providers/openai/batch.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// OpenAI Batch API Types
|
||||
|
||||
// OpenAIBatchRequest represents the request body for creating a batch.
|
||||
type OpenAIBatchRequest struct {
|
||||
InputFileID string `json:"input_file_id"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
CompletionWindow string `json:"completion_window"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
OutputExpiresAfter *schemas.BatchExpiresAfter `json:"output_expires_after,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIBatchResponse represents an OpenAI batch response.
|
||||
type OpenAIBatchResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Errors *schemas.BatchErrors `json:"errors,omitempty"`
|
||||
InputFileID string `json:"input_file_id"`
|
||||
CompletionWindow string `json:"completion_window"`
|
||||
Status string `json:"status"`
|
||||
OutputFileID *string `json:"output_file_id,omitempty"`
|
||||
ErrorFileID *string `json:"error_file_id,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
InProgressAt *int64 `json:"in_progress_at,omitempty"`
|
||||
ExpiresAt *int64 `json:"expires_at,omitempty"`
|
||||
FinalizingAt *int64 `json:"finalizing_at,omitempty"`
|
||||
CompletedAt *int64 `json:"completed_at,omitempty"`
|
||||
FailedAt *int64 `json:"failed_at,omitempty"`
|
||||
ExpiredAt *int64 `json:"expired_at,omitempty"`
|
||||
CancellingAt *int64 `json:"cancelling_at,omitempty"`
|
||||
CancelledAt *int64 `json:"cancelled_at,omitempty"`
|
||||
RequestCounts *OpenAIBatchRequestCounts `json:"request_counts,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIBatchRequestCounts represents the request counts for a batch.
|
||||
type OpenAIBatchRequestCounts struct {
|
||||
Total int `json:"total"`
|
||||
Completed int `json:"completed"`
|
||||
Failed int `json:"failed"`
|
||||
}
|
||||
|
||||
// OpenAIBatchListResponse represents the response from listing batches.
|
||||
type OpenAIBatchListResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIBatchResponse `json:"data"`
|
||||
FirstID *string `json:"first_id,omitempty"`
|
||||
LastID *string `json:"last_id,omitempty"`
|
||||
HasMore bool `json:"has_more"`
|
||||
}
|
||||
|
||||
// ToBifrostBatchStatus converts OpenAI status to Bifrost status.
|
||||
func ToBifrostBatchStatus(status string) schemas.BatchStatus {
|
||||
switch status {
|
||||
case "validating":
|
||||
return schemas.BatchStatusValidating
|
||||
case "failed":
|
||||
return schemas.BatchStatusFailed
|
||||
case "in_progress":
|
||||
return schemas.BatchStatusInProgress
|
||||
case "finalizing":
|
||||
return schemas.BatchStatusFinalizing
|
||||
case "completed":
|
||||
return schemas.BatchStatusCompleted
|
||||
case "expired":
|
||||
return schemas.BatchStatusExpired
|
||||
case "cancelling":
|
||||
return schemas.BatchStatusCancelling
|
||||
case "cancelled":
|
||||
return schemas.BatchStatusCancelled
|
||||
default:
|
||||
return schemas.BatchStatus(status)
|
||||
}
|
||||
}
|
||||
|
||||
// ToBifrostBatchCreateResponse converts OpenAI batch response to Bifrost batch response.
|
||||
func (r *OpenAIBatchResponse) ToBifrostBatchCreateResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse {
|
||||
resp := &schemas.BifrostBatchCreateResponse{
|
||||
ID: r.ID,
|
||||
Object: r.Object,
|
||||
Endpoint: r.Endpoint,
|
||||
InputFileID: r.InputFileID,
|
||||
CompletionWindow: r.CompletionWindow,
|
||||
Status: ToBifrostBatchStatus(r.Status),
|
||||
Metadata: r.Metadata,
|
||||
CreatedAt: r.CreatedAt,
|
||||
OutputFileID: r.OutputFileID,
|
||||
ErrorFileID: r.ErrorFileID,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Latency: latency.Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
if sendBackRawRequest {
|
||||
resp.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
|
||||
if r.ExpiresAt != nil {
|
||||
resp.ExpiresAt = r.ExpiresAt
|
||||
}
|
||||
|
||||
if r.RequestCounts != nil {
|
||||
resp.RequestCounts = schemas.BatchRequestCounts{
|
||||
Total: r.RequestCounts.Total,
|
||||
Completed: r.RequestCounts.Completed,
|
||||
Failed: r.RequestCounts.Failed,
|
||||
}
|
||||
}
|
||||
|
||||
if sendBackRawResponse {
|
||||
resp.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// ToBifrostBatchRetrieveResponse converts OpenAI batch response to Bifrost batch retrieve response.
|
||||
func (r *OpenAIBatchResponse) ToBifrostBatchRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse {
|
||||
resp := &schemas.BifrostBatchRetrieveResponse{
|
||||
ID: r.ID,
|
||||
Object: r.Object,
|
||||
Endpoint: r.Endpoint,
|
||||
InputFileID: r.InputFileID,
|
||||
CompletionWindow: r.CompletionWindow,
|
||||
Status: ToBifrostBatchStatus(r.Status),
|
||||
Metadata: r.Metadata,
|
||||
CreatedAt: r.CreatedAt,
|
||||
InProgressAt: r.InProgressAt,
|
||||
FinalizingAt: r.FinalizingAt,
|
||||
CompletedAt: r.CompletedAt,
|
||||
FailedAt: r.FailedAt,
|
||||
ExpiredAt: r.ExpiredAt,
|
||||
CancellingAt: r.CancellingAt,
|
||||
CancelledAt: r.CancelledAt,
|
||||
OutputFileID: r.OutputFileID,
|
||||
ErrorFileID: r.ErrorFileID,
|
||||
Errors: r.Errors,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Latency: latency.Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
if sendBackRawRequest {
|
||||
resp.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
|
||||
if r.ExpiresAt != nil {
|
||||
resp.ExpiresAt = r.ExpiresAt
|
||||
}
|
||||
|
||||
if r.RequestCounts != nil {
|
||||
resp.RequestCounts = schemas.BatchRequestCounts{
|
||||
Total: r.RequestCounts.Total,
|
||||
Completed: r.RequestCounts.Completed,
|
||||
Failed: r.RequestCounts.Failed,
|
||||
}
|
||||
}
|
||||
|
||||
if sendBackRawResponse {
|
||||
resp.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
192
core/providers/openai/chat.go
Normal file
192
core/providers/openai/chat.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostChatRequest converts an OpenAI chat request to Bifrost format
|
||||
func (req *OpenAIChatRequest) ToBifrostChatRequest(ctx *schemas.BifrostContext) *schemas.BifrostChatRequest {
|
||||
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
|
||||
|
||||
return &schemas.BifrostChatRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: ConvertOpenAIMessagesToBifrostMessages(req.Messages),
|
||||
Params: &req.ChatParameters,
|
||||
Fallbacks: schemas.ParseFallbacks(req.Fallbacks),
|
||||
}
|
||||
}
|
||||
|
||||
// ToOpenAIChatRequest converts a Bifrost chat completion request to OpenAI format
|
||||
func ToOpenAIChatRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.BifrostChatRequest) *OpenAIChatRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
openaiReq := &OpenAIChatRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Messages: ConvertBifrostMessagesToOpenAIMessages(bifrostReq.Input),
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
openaiReq.ChatParameters = *bifrostReq.Params
|
||||
if openaiReq.ChatParameters.MaxCompletionTokens != nil && *openaiReq.ChatParameters.MaxCompletionTokens < MinMaxCompletionTokens {
|
||||
openaiReq.ChatParameters.MaxCompletionTokens = schemas.Ptr(MinMaxCompletionTokens)
|
||||
}
|
||||
// Drop user field if it exceeds OpenAI's 64 character limit
|
||||
openaiReq.ChatParameters.User = SanitizeUserField(openaiReq.ChatParameters.User)
|
||||
openaiReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
|
||||
// Normalize tool parameters for deterministic JSON serialization (improves prompt caching)
|
||||
if len(openaiReq.ChatParameters.Tools) > 0 {
|
||||
normalizedTools := make([]schemas.ChatTool, len(openaiReq.ChatParameters.Tools))
|
||||
for i, tool := range openaiReq.ChatParameters.Tools {
|
||||
normalizedTools[i] = tool
|
||||
if tool.Function != nil && tool.Function.Parameters != nil {
|
||||
funcCopy := *tool.Function
|
||||
funcCopy.Parameters = tool.Function.Parameters.Normalized()
|
||||
normalizedTools[i].Function = &funcCopy
|
||||
}
|
||||
}
|
||||
openaiReq.ChatParameters.Tools = normalizedTools
|
||||
}
|
||||
}
|
||||
switch bifrostReq.Provider {
|
||||
case schemas.OpenAI, schemas.Azure:
|
||||
return openaiReq
|
||||
case schemas.XAI:
|
||||
openaiReq.filterOpenAISpecificParameters()
|
||||
openaiReq.applyXAICompatibility(bifrostReq.Model)
|
||||
return openaiReq
|
||||
case schemas.Gemini:
|
||||
openaiReq.filterOpenAISpecificParameters()
|
||||
// Removing extra parameters that are not supported by Gemini
|
||||
openaiReq.ServiceTier = nil
|
||||
return openaiReq
|
||||
case schemas.Mistral:
|
||||
openaiReq.filterOpenAISpecificParameters()
|
||||
openaiReq.applyMistralCompatibility()
|
||||
return openaiReq
|
||||
case schemas.Vertex:
|
||||
openaiReq.filterOpenAISpecificParameters()
|
||||
|
||||
// Apply Mistral-specific transformations for Vertex Mistral models
|
||||
if schemas.IsMistralModel(bifrostReq.Model) {
|
||||
openaiReq.applyMistralCompatibility()
|
||||
}
|
||||
return openaiReq
|
||||
case schemas.Fireworks:
|
||||
// Fireworks uses prompt_cache_isolation_key for cache isolation on chat/completions.
|
||||
// Preserve it before the generic filter strips prompt_cache_key.
|
||||
if openaiReq.ChatParameters.PromptCacheKey != nil && openaiReq.PromptCacheIsolationKey == nil {
|
||||
openaiReq.PromptCacheIsolationKey = openaiReq.ChatParameters.PromptCacheKey
|
||||
}
|
||||
// Fireworks supports predicted outputs; save before the filter strips them.
|
||||
prediction := openaiReq.ChatParameters.Prediction
|
||||
openaiReq.filterOpenAISpecificParameters()
|
||||
openaiReq.ChatParameters.Prediction = prediction
|
||||
return openaiReq
|
||||
default:
|
||||
// Check if provider is a custom provider
|
||||
if isCustomProvider, ok := ctx.Value(schemas.BifrostContextKeyIsCustomProvider).(bool); ok && isCustomProvider {
|
||||
return openaiReq
|
||||
}
|
||||
openaiReq.filterOpenAISpecificParameters()
|
||||
return openaiReq
|
||||
}
|
||||
}
|
||||
|
||||
// Filter OpenAI Specific Parameters
|
||||
func (req *OpenAIChatRequest) filterOpenAISpecificParameters() {
|
||||
// Handle reasoning parameter: OpenAI uses effort-based reasoning
|
||||
// Priority: effort (native) > max_tokens (estimated)
|
||||
if req.ChatParameters.Reasoning != nil {
|
||||
reasoningCopy := *req.ChatParameters.Reasoning
|
||||
req.ChatParameters.Reasoning = &reasoningCopy
|
||||
if req.ChatParameters.Reasoning.Effort != nil {
|
||||
// Native field is provided, use it (and clear max_tokens)
|
||||
effort := *req.ChatParameters.Reasoning.Effort
|
||||
// Convert "minimal" to "low"; cap "xhigh"/"max" to "high" — OpenAI tops out at high.
|
||||
switch effort {
|
||||
case "minimal":
|
||||
req.ChatParameters.Reasoning.Effort = schemas.Ptr("low")
|
||||
case "xhigh", "max":
|
||||
req.ChatParameters.Reasoning.Effort = schemas.Ptr("high")
|
||||
}
|
||||
// Clear max_tokens since OpenAI doesn't use it
|
||||
req.ChatParameters.Reasoning.MaxTokens = nil
|
||||
} else if req.ChatParameters.Reasoning.MaxTokens != nil {
|
||||
// Estimate effort from max_tokens
|
||||
maxTokens := *req.ChatParameters.Reasoning.MaxTokens
|
||||
maxCompletionTokens := utils.GetMaxOutputTokensOrDefault(req.Model, DefaultCompletionMaxTokens)
|
||||
if req.ChatParameters.MaxCompletionTokens != nil {
|
||||
maxCompletionTokens = *req.ChatParameters.MaxCompletionTokens
|
||||
}
|
||||
effort := utils.GetReasoningEffortFromBudgetTokens(maxTokens, MinReasoningMaxTokens, maxCompletionTokens)
|
||||
req.ChatParameters.Reasoning.Effort = schemas.Ptr(effort)
|
||||
// Clear max_tokens since OpenAI doesn't use it
|
||||
req.ChatParameters.Reasoning.MaxTokens = nil
|
||||
}
|
||||
}
|
||||
|
||||
if req.ChatParameters.Prediction != nil {
|
||||
req.ChatParameters.Prediction = nil
|
||||
}
|
||||
if req.ChatParameters.PromptCacheKey != nil {
|
||||
req.ChatParameters.PromptCacheKey = nil
|
||||
}
|
||||
if req.ChatParameters.PromptCacheRetention != nil {
|
||||
req.ChatParameters.PromptCacheRetention = nil
|
||||
}
|
||||
if req.ChatParameters.Verbosity != nil {
|
||||
req.ChatParameters.Verbosity = nil
|
||||
}
|
||||
if req.ChatParameters.Store != nil {
|
||||
req.ChatParameters.Store = nil
|
||||
}
|
||||
if req.ChatParameters.WebSearchOptions != nil {
|
||||
req.ChatParameters.WebSearchOptions = nil
|
||||
}
|
||||
}
|
||||
|
||||
// applyMistralCompatibility applies Mistral-specific transformations to the request
|
||||
func (req *OpenAIChatRequest) applyMistralCompatibility() {
|
||||
// Mistral uses max_tokens instead of max_completion_tokens
|
||||
if req.MaxCompletionTokens != nil {
|
||||
req.MaxTokens = req.MaxCompletionTokens
|
||||
req.MaxCompletionTokens = nil
|
||||
}
|
||||
|
||||
// Mistral does not support ToolChoiceStruct, only simple tool choice strings are supported
|
||||
if req.ToolChoice != nil && req.ToolChoice.ChatToolChoiceStruct != nil {
|
||||
req.ToolChoice.ChatToolChoiceStr = schemas.Ptr("any")
|
||||
req.ToolChoice.ChatToolChoiceStruct = nil
|
||||
}
|
||||
}
|
||||
|
||||
// applyXAICompatibility applies xAI-specific transformations to the request
|
||||
func (req *OpenAIChatRequest) applyXAICompatibility(model string) {
|
||||
// Only apply filters if this is a grok reasoning model
|
||||
if !schemas.IsGrokReasoningModel(model) {
|
||||
return
|
||||
}
|
||||
|
||||
req.ChatParameters.PresencePenalty = nil
|
||||
|
||||
// Only non-mini grok-3 models support frequency_penalty and stop
|
||||
// grok-3-mini only supports reasoning_effort in reasoning mode
|
||||
if !strings.Contains(model, "grok-3") || strings.Contains(model, "grok-3-mini") {
|
||||
req.ChatParameters.FrequencyPenalty = nil
|
||||
req.ChatParameters.Stop = nil
|
||||
}
|
||||
|
||||
// Only grok-3-mini supports reasoning_effort
|
||||
if req.ChatParameters.Reasoning != nil &&
|
||||
!strings.Contains(model, "grok-3-mini") {
|
||||
// Clear reasoning_effort for non-grok-3-mini models
|
||||
req.ChatParameters.Reasoning.Effort = nil
|
||||
}
|
||||
}
|
||||
707
core/providers/openai/chat_test.go
Normal file
707
core/providers/openai/chat_test.go
Normal file
@@ -0,0 +1,707 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToOpenAIChatRequest_ToolNormalization(t *testing.T) {
|
||||
// Create tool parameters with keys in non-alphabetical order:
|
||||
// "required" before "properties" before "type" — Normalized() should reorder to
|
||||
// type → description → properties → required, then alphabetical.
|
||||
unsortedParams := &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("zebra", map[string]interface{}{"type": "string"}),
|
||||
schemas.KV("alpha", map[string]interface{}{"type": "number"}),
|
||||
),
|
||||
Required: []string{"zebra"},
|
||||
}
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.OpenAI,
|
||||
Model: "gpt-4o",
|
||||
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser}},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "test_func",
|
||||
Parameters: unsortedParams,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: &schemas.ChatToolFunction{Name: "no_params_func"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result := ToOpenAIChatRequest(ctx, bifrostReq)
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
// Verify parameters are normalized: Properties keys should preserve original order
|
||||
// (user-defined property names are kept in client order for LLM generation quality)
|
||||
normalizedParams := result.ChatParameters.Tools[0].Function.Parameters
|
||||
if normalizedParams == nil {
|
||||
t.Fatal("expected normalized parameters to be non-nil")
|
||||
}
|
||||
keys := normalizedParams.Properties.Keys()
|
||||
if len(keys) != 2 || keys[0] != "zebra" || keys[1] != "alpha" {
|
||||
t.Errorf("expected Properties keys preserved as [zebra, alpha], got %v", keys)
|
||||
}
|
||||
|
||||
// Verify tool without parameters is unaffected
|
||||
if result.ChatParameters.Tools[1].Function.Parameters != nil {
|
||||
t.Error("expected nil parameters for tool without parameters")
|
||||
}
|
||||
|
||||
// Verify original bifrostReq.Params.Tools was NOT mutated
|
||||
origKeys := bifrostReq.Params.Tools[0].Function.Parameters.Properties.Keys()
|
||||
if len(origKeys) != 2 || origKeys[0] != "zebra" || origKeys[1] != "alpha" {
|
||||
t.Errorf("original parameters were mutated: expected [zebra, alpha], got %v", origKeys)
|
||||
}
|
||||
|
||||
// Verify the Function pointer is a different object (deep copy)
|
||||
if result.ChatParameters.Tools[0].Function == bifrostReq.Params.Tools[0].Function {
|
||||
t.Error("expected Function pointer to be a copy, not the original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToOpenAIChatRequest_PreservesN(t *testing.T) {
|
||||
req := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.OpenAI,
|
||||
Model: "gpt-4.1",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: schemas.Ptr("hello"),
|
||||
},
|
||||
},
|
||||
},
|
||||
Params: &schemas.ChatParameters{
|
||||
N: schemas.Ptr(2),
|
||||
},
|
||||
}
|
||||
|
||||
out := ToOpenAIChatRequest(schemas.NewBifrostContext(nil, schemas.NoDeadline), req)
|
||||
if out == nil {
|
||||
t.Fatal("expected request")
|
||||
}
|
||||
if out.N == nil || *out.N != 2 {
|
||||
t.Fatalf("expected n=2, got %#v", out.N)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToOpenAIChatRequest_PreservesPropertyOrder(t *testing.T) {
|
||||
params := &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("reasoning", map[string]interface{}{"type": "string", "description": "Step by step"}),
|
||||
schemas.KV("answer", map[string]interface{}{"type": "string", "description": "Final answer"}),
|
||||
schemas.KV("confidence", map[string]interface{}{"type": "number", "description": "Score"}),
|
||||
),
|
||||
Required: []string{"reasoning", "answer"},
|
||||
}
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.OpenAI,
|
||||
Model: "gpt-4o",
|
||||
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser}},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{{
|
||||
Type: "function",
|
||||
Function: &schemas.ChatToolFunction{Name: "test_func", Parameters: params},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result := ToOpenAIChatRequest(ctx, bifrostReq)
|
||||
|
||||
// CoT: property order preserved
|
||||
normalizedParams := result.ChatParameters.Tools[0].Function.Parameters
|
||||
keys := normalizedParams.Properties.Keys()
|
||||
if len(keys) != 3 || keys[0] != "reasoning" || keys[1] != "answer" || keys[2] != "confidence" {
|
||||
t.Errorf("expected property order [reasoning, answer, confidence], got %v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToOpenAIChatRequest_PreservesExplicitEmptyToolParameters(t *testing.T) {
|
||||
var tool schemas.ChatTool
|
||||
err := json.Unmarshal([]byte(`{"type":"function","function":{"name":"empty_schema","parameters":{},"strict":false}}`), &tool)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal tool: %v", err)
|
||||
}
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.OpenAI,
|
||||
Model: "gpt-4o",
|
||||
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser}},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{tool},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result := ToOpenAIChatRequest(ctx, bifrostReq)
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
params := result.ChatParameters.Tools[0].Function.Parameters
|
||||
if params == nil {
|
||||
t.Fatal("expected tool parameters to be preserved")
|
||||
}
|
||||
|
||||
marshaled, err := schemas.Marshal(params)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal parameters: %v", err)
|
||||
}
|
||||
if string(marshaled) != `{}` {
|
||||
t.Fatalf("expected parameters to remain {}, got %s", marshaled)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToOpenAIChatRequest_CachingDeterminism(t *testing.T) {
|
||||
// Same properties, different structural key orders within property definitions
|
||||
makeReq := func(props *schemas.OrderedMap) *schemas.BifrostChatRequest {
|
||||
return &schemas.BifrostChatRequest{
|
||||
Provider: schemas.OpenAI,
|
||||
Model: "gpt-4o",
|
||||
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser}},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{{
|
||||
Type: "function",
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "test",
|
||||
Parameters: &schemas.ToolFunctionParameters{Type: "object", Properties: props},
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Version A: type before description
|
||||
propsA := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "string"),
|
||||
schemas.KV("description", "Step by step"),
|
||||
)),
|
||||
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "string"),
|
||||
schemas.KV("description", "Final answer"),
|
||||
)),
|
||||
)
|
||||
|
||||
// Version B: description before type (different structural order)
|
||||
propsB := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("description", "Step by step"),
|
||||
schemas.KV("type", "string"),
|
||||
)),
|
||||
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("description", "Final answer"),
|
||||
schemas.KV("type", "string"),
|
||||
)),
|
||||
)
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
resultA := ToOpenAIChatRequest(ctx, makeReq(propsA))
|
||||
resultB := ToOpenAIChatRequest(ctx, makeReq(propsB))
|
||||
|
||||
jsonA, err := schemas.Marshal(resultA.ChatParameters.Tools[0].Function.Parameters)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal params A: %v", err)
|
||||
}
|
||||
jsonB, err := schemas.Marshal(resultB.ChatParameters.Tools[0].Function.Parameters)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal params B: %v", err)
|
||||
}
|
||||
|
||||
if string(jsonA) != string(jsonB) {
|
||||
t.Errorf("caching broken: same schema produced different JSON\nA: %s\nB: %s", jsonA, jsonB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToOpenAIChatRequest_FireworksPreservesReasoningAndCacheIsolation(t *testing.T) {
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
cacheKey := "cache-key-1"
|
||||
reasoning := "step by step"
|
||||
predictionContent := "fireworks ok"
|
||||
userContent := "Reply with exactly: fireworks ok"
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: "accounts/fireworks/models/deepseek-v3p2",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &userContent,
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{
|
||||
ContentStr: &predictionContent,
|
||||
},
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
Reasoning: &reasoning,
|
||||
},
|
||||
},
|
||||
},
|
||||
Params: &schemas.ChatParameters{
|
||||
PromptCacheKey: &cacheKey,
|
||||
Prediction: &schemas.ChatPrediction{
|
||||
Type: "content",
|
||||
Content: predictionContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := ToOpenAIChatRequest(ctx, bifrostReq)
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if result.PromptCacheIsolationKey == nil || *result.PromptCacheIsolationKey != cacheKey {
|
||||
t.Fatalf("expected prompt_cache_isolation_key %q, got %v", cacheKey, result.PromptCacheIsolationKey)
|
||||
}
|
||||
if result.PromptCacheKey != nil {
|
||||
t.Fatalf("expected prompt_cache_key to be stripped, got %v", *result.PromptCacheKey)
|
||||
}
|
||||
if result.Prediction == nil || result.Prediction.Content != predictionContent {
|
||||
t.Fatalf("expected prediction to be preserved, got %#v", result.Prediction)
|
||||
}
|
||||
if len(result.Messages) != 2 || result.Messages[1].OpenAIChatAssistantMessage == nil {
|
||||
t.Fatalf("expected assistant message with OpenAI assistant payload, got %#v", result.Messages)
|
||||
}
|
||||
if result.Messages[1].OpenAIChatAssistantMessage.Reasoning == nil || *result.Messages[1].OpenAIChatAssistantMessage.Reasoning != reasoning {
|
||||
t.Fatalf("expected assistant reasoning_content %q, got %#v", reasoning, result.Messages[1].OpenAIChatAssistantMessage)
|
||||
}
|
||||
|
||||
ctx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
||||
ctx,
|
||||
bifrostReq,
|
||||
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
||||
return ToOpenAIChatRequest(ctx, bifrostReq), nil
|
||||
},
|
||||
)
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message)
|
||||
}
|
||||
|
||||
var jsonMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(wireBody, &jsonMap); err != nil {
|
||||
t.Fatalf("failed to parse marshaled request body: %v", err)
|
||||
}
|
||||
if got, ok := jsonMap["prompt_cache_isolation_key"].(string); !ok || got != cacheKey {
|
||||
t.Fatalf("expected prompt_cache_isolation_key %q in wire payload, got %#v", cacheKey, jsonMap["prompt_cache_isolation_key"])
|
||||
}
|
||||
if _, ok := jsonMap["prompt_cache_key"]; ok {
|
||||
t.Fatalf("expected prompt_cache_key to be absent from wire payload, got %#v", jsonMap["prompt_cache_key"])
|
||||
}
|
||||
|
||||
messages, ok := jsonMap["messages"].([]interface{})
|
||||
if !ok || len(messages) != 2 {
|
||||
t.Fatalf("expected 2 messages in wire payload, got %#v", jsonMap["messages"])
|
||||
}
|
||||
assistantMessage, ok := messages[1].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected assistant message object, got %#v", messages[1])
|
||||
}
|
||||
if got, ok := assistantMessage["reasoning_content"].(string); !ok || got != reasoning {
|
||||
t.Fatalf("expected reasoning_content %q in assistant payload, got %#v", reasoning, assistantMessage["reasoning_content"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestToOpenAIChatRequest_AnnotationsNotInWirePayload verifies that MCPToolAnnotations
|
||||
// (stored on ChatTool with json:"-") are never included in the JSON body sent to OpenAI.
|
||||
func TestToOpenAIChatRequest_AnnotationsNotInWirePayload(t *testing.T) {
|
||||
readOnly := true
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.OpenAI,
|
||||
Model: "gpt-4o",
|
||||
Input: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hello")}},
|
||||
},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "read_file",
|
||||
Description: schemas.Ptr("Read a file"),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("path", map[string]interface{}{"type": "string"}),
|
||||
),
|
||||
Required: []string{"path"},
|
||||
},
|
||||
},
|
||||
Annotations: &schemas.MCPToolAnnotations{
|
||||
Title: "File Reader",
|
||||
ReadOnlyHint: &readOnly,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
result := ToOpenAIChatRequest(ctx, bifrostReq)
|
||||
require.NotNil(t, result)
|
||||
|
||||
wireBody, err := json.Marshal(result)
|
||||
require.NoError(t, err)
|
||||
s := string(wireBody)
|
||||
|
||||
// Annotations must be absent from the wire payload
|
||||
if strings.Contains(s, "annotations") {
|
||||
t.Errorf("annotations field leaked into OpenAI wire payload: %s", s)
|
||||
}
|
||||
if strings.Contains(s, "readOnlyHint") {
|
||||
t.Errorf("readOnlyHint leaked into OpenAI wire payload: %s", s)
|
||||
}
|
||||
if strings.Contains(s, "File Reader") {
|
||||
t.Errorf("annotation title leaked into OpenAI wire payload: %s", s)
|
||||
}
|
||||
|
||||
// The function definition must still be intact
|
||||
if !strings.Contains(s, "read_file") {
|
||||
t.Errorf("function name missing from OpenAI wire payload: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyXAICompatibility(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
request *OpenAIChatRequest
|
||||
validate func(t *testing.T, req *OpenAIChatRequest)
|
||||
}{
|
||||
{
|
||||
name: "grok-3: preserves frequency_penalty and stop, clears presence_penalty and reasoning_effort",
|
||||
model: "grok-3",
|
||||
request: &OpenAIChatRequest{
|
||||
Model: "grok-3",
|
||||
Messages: []OpenAIMessage{},
|
||||
ChatParameters: schemas.ChatParameters{
|
||||
FrequencyPenalty: schemas.Ptr(0.5),
|
||||
PresencePenalty: schemas.Ptr(0.3),
|
||||
Stop: []string{"STOP"},
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: schemas.Ptr("high"),
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// frequency_penalty should be preserved
|
||||
if req.FrequencyPenalty == nil || *req.FrequencyPenalty != 0.5 {
|
||||
t.Errorf("Expected FrequencyPenalty to be preserved at 0.5, got %v", req.FrequencyPenalty)
|
||||
}
|
||||
|
||||
// stop should be preserved
|
||||
if len(req.Stop) != 1 || req.Stop[0] != "STOP" {
|
||||
t.Errorf("Expected Stop to be preserved as ['STOP'], got %v", req.Stop)
|
||||
}
|
||||
|
||||
// presence_penalty should be cleared
|
||||
if req.PresencePenalty != nil {
|
||||
t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty)
|
||||
}
|
||||
|
||||
// reasoning_effort should be cleared for non-mini grok-3
|
||||
if req.Reasoning == nil {
|
||||
t.Fatal("Expected Reasoning to remain non-nil")
|
||||
}
|
||||
if req.Reasoning.Effort != nil {
|
||||
t.Errorf("Expected Reasoning.Effort to be cleared (nil) for grok-3, got %v", *req.Reasoning.Effort)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "grok-3-mini: clears all penalties and stop, preserves reasoning_effort",
|
||||
model: "grok-3-mini",
|
||||
request: &OpenAIChatRequest{
|
||||
Model: "grok-3-mini",
|
||||
Messages: []OpenAIMessage{},
|
||||
ChatParameters: schemas.ChatParameters{
|
||||
FrequencyPenalty: schemas.Ptr(0.5),
|
||||
PresencePenalty: schemas.Ptr(0.3),
|
||||
Stop: []string{"STOP"},
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: schemas.Ptr("medium"),
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// presence_penalty should be cleared
|
||||
if req.PresencePenalty != nil {
|
||||
t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty)
|
||||
}
|
||||
|
||||
// frequency_penalty should be cleared for grok-3-mini
|
||||
if req.FrequencyPenalty != nil {
|
||||
t.Errorf("Expected FrequencyPenalty to be cleared (nil) for grok-3-mini, got %v", *req.FrequencyPenalty)
|
||||
}
|
||||
|
||||
// stop should be cleared for grok-3-mini
|
||||
if req.Stop != nil {
|
||||
t.Errorf("Expected Stop to be cleared (nil) for grok-3-mini, got %v", req.Stop)
|
||||
}
|
||||
|
||||
// reasoning_effort should be preserved for grok-3-mini
|
||||
if req.Reasoning == nil || req.Reasoning.Effort == nil {
|
||||
t.Fatal("Expected Reasoning.Effort to be preserved for grok-3-mini")
|
||||
}
|
||||
if *req.Reasoning.Effort != "medium" {
|
||||
t.Errorf("Expected Reasoning.Effort to be 'medium', got %v", *req.Reasoning.Effort)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "grok-4: clears all penalties, stop, and reasoning_effort",
|
||||
model: "grok-4",
|
||||
request: &OpenAIChatRequest{
|
||||
Model: "grok-4",
|
||||
Messages: []OpenAIMessage{},
|
||||
ChatParameters: schemas.ChatParameters{
|
||||
FrequencyPenalty: schemas.Ptr(0.5),
|
||||
PresencePenalty: schemas.Ptr(0.3),
|
||||
Stop: []string{"STOP"},
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: schemas.Ptr("high"),
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// presence_penalty should be cleared
|
||||
if req.PresencePenalty != nil {
|
||||
t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty)
|
||||
}
|
||||
|
||||
// frequency_penalty should be cleared for grok-4
|
||||
if req.FrequencyPenalty != nil {
|
||||
t.Errorf("Expected FrequencyPenalty to be cleared (nil) for grok-4, got %v", *req.FrequencyPenalty)
|
||||
}
|
||||
|
||||
// stop should be cleared for grok-4
|
||||
if req.Stop != nil {
|
||||
t.Errorf("Expected Stop to be cleared (nil) for grok-4, got %v", req.Stop)
|
||||
}
|
||||
|
||||
// reasoning_effort should be cleared for grok-4
|
||||
if req.Reasoning == nil {
|
||||
t.Fatal("Expected Reasoning to remain non-nil")
|
||||
}
|
||||
if req.Reasoning.Effort != nil {
|
||||
t.Errorf("Expected Reasoning.Effort to be cleared (nil) for grok-4, got %v", *req.Reasoning.Effort)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "grok-4-fast-reasoning: clears all penalties, stop, and reasoning_effort",
|
||||
model: "grok-4-fast-reasoning",
|
||||
request: &OpenAIChatRequest{
|
||||
Model: "grok-4-fast-reasoning",
|
||||
Messages: []OpenAIMessage{},
|
||||
ChatParameters: schemas.ChatParameters{
|
||||
FrequencyPenalty: schemas.Ptr(0.5),
|
||||
PresencePenalty: schemas.Ptr(0.3),
|
||||
Stop: []string{"STOP", "END"},
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: schemas.Ptr("high"),
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// presence_penalty should be cleared
|
||||
if req.PresencePenalty != nil {
|
||||
t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty)
|
||||
}
|
||||
|
||||
// frequency_penalty should be cleared
|
||||
if req.FrequencyPenalty != nil {
|
||||
t.Errorf("Expected FrequencyPenalty to be cleared (nil), got %v", *req.FrequencyPenalty)
|
||||
}
|
||||
|
||||
// stop should be cleared
|
||||
if req.Stop != nil {
|
||||
t.Errorf("Expected Stop to be cleared (nil), got %v", req.Stop)
|
||||
}
|
||||
|
||||
// reasoning_effort should be cleared
|
||||
if req.Reasoning == nil {
|
||||
t.Fatal("Expected Reasoning to remain non-nil")
|
||||
}
|
||||
if req.Reasoning.Effort != nil {
|
||||
t.Errorf("Expected Reasoning.Effort to be cleared (nil), got %v", *req.Reasoning.Effort)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "grok-code-fast-1: clears all penalties, stop, and reasoning_effort",
|
||||
model: "grok-code-fast-1",
|
||||
request: &OpenAIChatRequest{
|
||||
Model: "grok-code-fast-1",
|
||||
Messages: []OpenAIMessage{},
|
||||
ChatParameters: schemas.ChatParameters{
|
||||
FrequencyPenalty: schemas.Ptr(0.2),
|
||||
PresencePenalty: schemas.Ptr(0.1),
|
||||
Stop: []string{"END"},
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: schemas.Ptr("low"),
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// presence_penalty should be cleared
|
||||
if req.PresencePenalty != nil {
|
||||
t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty)
|
||||
}
|
||||
|
||||
// frequency_penalty should be cleared
|
||||
if req.FrequencyPenalty != nil {
|
||||
t.Errorf("Expected FrequencyPenalty to be cleared (nil), got %v", *req.FrequencyPenalty)
|
||||
}
|
||||
|
||||
// stop should be cleared
|
||||
if req.Stop != nil {
|
||||
t.Errorf("Expected Stop to be cleared (nil), got %v", req.Stop)
|
||||
}
|
||||
|
||||
// reasoning_effort should be cleared
|
||||
if req.Reasoning == nil {
|
||||
t.Fatal("Expected Reasoning to remain non-nil")
|
||||
}
|
||||
if req.Reasoning.Effort != nil {
|
||||
t.Errorf("Expected Reasoning.Effort to be cleared (nil), got %v", *req.Reasoning.Effort)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-reasoning grok model: no changes applied",
|
||||
model: "grok-2-latest",
|
||||
request: &OpenAIChatRequest{
|
||||
Model: "grok-2-latest",
|
||||
Messages: []OpenAIMessage{},
|
||||
ChatParameters: schemas.ChatParameters{
|
||||
FrequencyPenalty: schemas.Ptr(0.5),
|
||||
PresencePenalty: schemas.Ptr(0.3),
|
||||
Stop: []string{"STOP"},
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: schemas.Ptr("high"),
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// All parameters should be preserved for non-reasoning models
|
||||
if req.FrequencyPenalty == nil || *req.FrequencyPenalty != 0.5 {
|
||||
t.Errorf("Expected FrequencyPenalty to be preserved at 0.5, got %v", req.FrequencyPenalty)
|
||||
}
|
||||
|
||||
if req.PresencePenalty == nil || *req.PresencePenalty != 0.3 {
|
||||
t.Errorf("Expected PresencePenalty to be preserved at 0.3, got %v", req.PresencePenalty)
|
||||
}
|
||||
|
||||
if len(req.Stop) != 1 || req.Stop[0] != "STOP" {
|
||||
t.Errorf("Expected Stop to be preserved as ['STOP'], got %v", req.Stop)
|
||||
}
|
||||
|
||||
if req.Reasoning == nil || req.Reasoning.Effort == nil {
|
||||
t.Fatal("Expected Reasoning.Effort to be preserved")
|
||||
}
|
||||
if *req.Reasoning.Effort != "high" {
|
||||
t.Errorf("Expected Reasoning.Effort to be 'high', got %v", *req.Reasoning.Effort)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "grok-3: handles nil reasoning gracefully",
|
||||
model: "grok-3",
|
||||
request: &OpenAIChatRequest{
|
||||
Model: "grok-3",
|
||||
Messages: []OpenAIMessage{},
|
||||
ChatParameters: schemas.ChatParameters{
|
||||
FrequencyPenalty: schemas.Ptr(0.5),
|
||||
PresencePenalty: schemas.Ptr(0.3),
|
||||
Stop: []string{"STOP"},
|
||||
Reasoning: nil,
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// Should handle nil reasoning without panicking
|
||||
if req.Reasoning != nil {
|
||||
t.Errorf("Expected Reasoning to remain nil, got %v", req.Reasoning)
|
||||
}
|
||||
|
||||
// Other parameters should still be processed
|
||||
if req.PresencePenalty != nil {
|
||||
t.Errorf("Expected PresencePenalty to be cleared (nil), got %v", *req.PresencePenalty)
|
||||
}
|
||||
|
||||
if req.FrequencyPenalty == nil || *req.FrequencyPenalty != 0.5 {
|
||||
t.Errorf("Expected FrequencyPenalty to be preserved at 0.5, got %v", req.FrequencyPenalty)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "grok-3: preserves other parameters like temperature",
|
||||
model: "grok-3",
|
||||
request: &OpenAIChatRequest{
|
||||
Model: "grok-3",
|
||||
Messages: []OpenAIMessage{},
|
||||
ChatParameters: schemas.ChatParameters{
|
||||
Temperature: schemas.Ptr(0.8),
|
||||
TopP: schemas.Ptr(0.9),
|
||||
FrequencyPenalty: schemas.Ptr(0.5),
|
||||
PresencePenalty: schemas.Ptr(0.3),
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// Unrelated parameters should be preserved
|
||||
if req.Temperature == nil || *req.Temperature != 0.8 {
|
||||
t.Errorf("Expected Temperature to be preserved at 0.8, got %v", req.Temperature)
|
||||
}
|
||||
|
||||
if req.TopP == nil || *req.TopP != 0.9 {
|
||||
t.Errorf("Expected TopP to be preserved at 0.9, got %v", req.TopP)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Apply the compatibility function
|
||||
tt.request.applyXAICompatibility(tt.model)
|
||||
|
||||
// Validate the results
|
||||
tt.validate(t, tt.request)
|
||||
})
|
||||
}
|
||||
}
|
||||
40
core/providers/openai/embedding.go
Normal file
40
core/providers/openai/embedding.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostEmbeddingRequest converts an OpenAI embedding request to Bifrost format
|
||||
func (request *OpenAIEmbeddingRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest {
|
||||
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
|
||||
|
||||
return &schemas.BifrostEmbeddingRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: request.Input,
|
||||
Params: &request.EmbeddingParameters,
|
||||
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
|
||||
}
|
||||
}
|
||||
|
||||
// ToOpenAIEmbeddingRequest converts a Bifrost embedding request to OpenAI format
|
||||
func ToOpenAIEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *OpenAIEmbeddingRequest {
|
||||
if bifrostReq == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
params := bifrostReq.Params
|
||||
|
||||
openaiReq := &OpenAIEmbeddingRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Input: bifrostReq.Input,
|
||||
}
|
||||
|
||||
// Map parameters
|
||||
if params != nil {
|
||||
openaiReq.EmbeddingParameters = *params
|
||||
openaiReq.ExtraParams = params.ExtraParams
|
||||
}
|
||||
return openaiReq
|
||||
}
|
||||
54
core/providers/openai/errors.go
Normal file
54
core/providers/openai/errors.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ErrorConverter is a function that converts provider-specific error responses to BifrostError.
|
||||
type ErrorConverter func(resp *fasthttp.Response) *schemas.BifrostError
|
||||
|
||||
// ParseOpenAIError parses OpenAI error responses.
|
||||
func ParseOpenAIError(resp *fasthttp.Response) *schemas.BifrostError {
|
||||
var errorResp schemas.BifrostError
|
||||
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
|
||||
|
||||
if errorResp.EventID != nil {
|
||||
bifrostErr.EventID = errorResp.EventID
|
||||
}
|
||||
|
||||
if errorResp.Error != nil {
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
bifrostErr.Error.Type = errorResp.Error.Type
|
||||
bifrostErr.Error.Code = errorResp.Error.Code
|
||||
if errorResp.Error.Message != "" {
|
||||
bifrostErr.Error.Message = errorResp.Error.Message
|
||||
}
|
||||
bifrostErr.Error.Param = errorResp.Error.Param
|
||||
if errorResp.Error.EventID != nil {
|
||||
bifrostErr.Error.EventID = errorResp.Error.EventID
|
||||
}
|
||||
}
|
||||
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
if strings.TrimSpace(bifrostErr.Error.Message) == "" {
|
||||
if bifrostErr.StatusCode != nil {
|
||||
bifrostErr.Error.Message = fmt.Sprintf("provider API error (status %d)", *bifrostErr.StatusCode)
|
||||
} else {
|
||||
bifrostErr.Error.Message = "provider API error"
|
||||
}
|
||||
}
|
||||
|
||||
// Set ExtraFields unconditionally so provider/model/request metadata is always attached
|
||||
|
||||
return bifrostErr
|
||||
}
|
||||
83
core/providers/openai/errors_test.go
Normal file
83
core/providers/openai/errors_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestParseOpenAIError_FallbackMessageWhenProviderBodyIsNonOpenAIShape(t *testing.T) {
|
||||
var resp fasthttp.Response
|
||||
resp.SetStatusCode(fasthttp.StatusUnprocessableEntity)
|
||||
resp.SetBodyString(`{"detail":[{"loc":["body","messages",0,"role"],"msg":"value is not a valid enumeration member"}]}`)
|
||||
|
||||
errResp := ParseOpenAIError(&resp)
|
||||
if errResp == nil || errResp.Error == nil {
|
||||
t.Fatal("expected non-nil error response")
|
||||
}
|
||||
if errResp.Error.Message == "" {
|
||||
t.Fatal("expected non-empty error message")
|
||||
}
|
||||
if errResp.Error.Message != "provider API error (status 422)" {
|
||||
t.Fatalf("expected fallback message, got %q", errResp.Error.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOpenAIError_PreservesProviderMessageWhenPresent(t *testing.T) {
|
||||
var resp fasthttp.Response
|
||||
resp.SetStatusCode(fasthttp.StatusUnprocessableEntity)
|
||||
resp.SetBodyString(`{"error":{"message":"unsupported role: developer","type":"invalid_request_error","param":"messages.0.role","code":"invalid_value"}}`)
|
||||
|
||||
errResp := ParseOpenAIError(&resp)
|
||||
if errResp == nil || errResp.Error == nil {
|
||||
t.Fatal("expected non-nil error response")
|
||||
}
|
||||
if errResp.Error.Message != "unsupported role: developer" {
|
||||
t.Fatalf("expected provider message, got %q", errResp.Error.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOpenAIError_FallbackMessageWhenBodyIsEmpty(t *testing.T) {
|
||||
var resp fasthttp.Response
|
||||
resp.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
resp.SetBody(nil)
|
||||
|
||||
errResp := ParseOpenAIError(&resp)
|
||||
if errResp == nil || errResp.Error == nil {
|
||||
t.Fatal("expected non-nil error response")
|
||||
}
|
||||
// HandleProviderAPIError returns ErrProviderResponseEmpty with HTTP status for empty bodies.
|
||||
expectedMsg := schemas.ErrProviderResponseEmpty + " (HTTP 400)"
|
||||
if errResp.Error.Message != expectedMsg {
|
||||
t.Fatalf("expected %q, got %q", expectedMsg, errResp.Error.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOpenAIError_WhitespaceProviderMessageFallsBack(t *testing.T) {
|
||||
var resp fasthttp.Response
|
||||
resp.SetStatusCode(fasthttp.StatusBadRequest)
|
||||
resp.SetBodyString(`{"error":{"message":" ","type":"invalid_request_error"}}`)
|
||||
|
||||
errResp := ParseOpenAIError(&resp)
|
||||
if errResp == nil || errResp.Error == nil {
|
||||
t.Fatal("expected non-nil error response")
|
||||
}
|
||||
if errResp.Error.Message != "provider API error (status 400)" {
|
||||
t.Fatalf("expected fallback message, got %q", errResp.Error.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOpenAIError_DefaultStatusCodeFallsBackWithStatusNumber(t *testing.T) {
|
||||
var resp fasthttp.Response
|
||||
// fasthttp defaults zero-value response status code to 200.
|
||||
resp.SetBodyString(`{"error":{"message":""}}`)
|
||||
|
||||
errResp := ParseOpenAIError(&resp)
|
||||
if errResp == nil || errResp.Error == nil {
|
||||
t.Fatal("expected non-nil error response")
|
||||
}
|
||||
if errResp.Error.Message != "provider API error (status 200)" {
|
||||
t.Fatalf("expected fallback message with default status, got %q", errResp.Error.Message)
|
||||
}
|
||||
}
|
||||
124
core/providers/openai/files.go
Normal file
124
core/providers/openai/files.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"time"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// OpenAI File API Types
|
||||
|
||||
// OpenAIFileResponse represents an OpenAI file response.
|
||||
type OpenAIFileResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Bytes int64 `json:"bytes"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Filename string `json:"filename"`
|
||||
Purpose schemas.FilePurpose `json:"purpose"`
|
||||
Status string `json:"status,omitempty"`
|
||||
StatusDetails *string `json:"status_details,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIFileListResponse represents the response from listing files.
|
||||
type OpenAIFileListResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIFileResponse `json:"data"`
|
||||
HasMore bool `json:"has_more,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIFileDeleteResponse represents the response from deleting a file.
|
||||
type OpenAIFileDeleteResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Deleted bool `json:"deleted"`
|
||||
}
|
||||
|
||||
// ToBifrostFileStatus converts OpenAI status to Bifrost status.
|
||||
func ToBifrostFileStatus(status string) schemas.FileStatus {
|
||||
switch status {
|
||||
case "uploaded":
|
||||
return schemas.FileStatusUploaded
|
||||
case "processed", "completed":
|
||||
return schemas.FileStatusProcessed
|
||||
case "processing", "in_progress":
|
||||
return schemas.FileStatusProcessing
|
||||
case "error", "failed":
|
||||
return schemas.FileStatusError
|
||||
case "deleted", "cancelled":
|
||||
return schemas.FileStatusDeleted
|
||||
default:
|
||||
return schemas.FileStatus(status)
|
||||
}
|
||||
}
|
||||
|
||||
// ToBifrostFileUploadResponse converts OpenAI file response to Bifrost file upload response.
|
||||
func (r *OpenAIFileResponse) ToBifrostFileUploadResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileUploadResponse {
|
||||
resp := &schemas.BifrostFileUploadResponse{
|
||||
ID: r.ID,
|
||||
Object: r.Object,
|
||||
Bytes: r.Bytes,
|
||||
CreatedAt: r.CreatedAt,
|
||||
Filename: r.Filename,
|
||||
Purpose: r.Purpose,
|
||||
Status: ToBifrostFileStatus(r.Status),
|
||||
StatusDetails: r.StatusDetails,
|
||||
StorageBackend: schemas.FileStorageAPI,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Latency: latency.Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
if sendBackRawRequest {
|
||||
resp.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
|
||||
if sendBackRawResponse {
|
||||
resp.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// ToBifrostFileRetrieveResponse converts OpenAI file response to Bifrost file retrieve response.
|
||||
func (r *OpenAIFileResponse) ToBifrostFileRetrieveResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostFileRetrieveResponse {
|
||||
resp := &schemas.BifrostFileRetrieveResponse{
|
||||
ID: r.ID,
|
||||
Object: r.Object,
|
||||
Bytes: r.Bytes,
|
||||
CreatedAt: r.CreatedAt,
|
||||
Filename: r.Filename,
|
||||
Purpose: r.Purpose,
|
||||
Status: ToBifrostFileStatus(r.Status),
|
||||
StatusDetails: r.StatusDetails,
|
||||
StorageBackend: schemas.FileStorageAPI,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Latency: latency.Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
if sendBackRawRequest {
|
||||
resp.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
|
||||
if sendBackRawResponse {
|
||||
resp.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// ConvertRequestsToJSONL converts batch request items to JSONL format.
|
||||
func ConvertRequestsToJSONL(requests []schemas.BatchRequestItem) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
for _, req := range requests {
|
||||
line, err := providerUtils.MarshalSorted(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf.Write(line)
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
361
core/providers/openai/images.go
Normal file
361
core/providers/openai/images.go
Normal file
@@ -0,0 +1,361 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToOpenAIImageGenerationRequest converts a Bifrost Image Request to OpenAI format
|
||||
func ToOpenAIImageGenerationRequest(bifrostReq *schemas.BifrostImageGenerationRequest) *OpenAIImageGenerationRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || bifrostReq.Input.Prompt == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &OpenAIImageGenerationRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Prompt: bifrostReq.Input.Prompt,
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
req.ImageGenerationParameters = *bifrostReq.Params
|
||||
}
|
||||
|
||||
switch bifrostReq.Provider {
|
||||
case schemas.XAI:
|
||||
filterXAISpecificParameters(req)
|
||||
case schemas.OpenAI, schemas.Azure:
|
||||
filterOpenAISpecificParameters(req)
|
||||
}
|
||||
if bifrostReq.Params != nil {
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func filterXAISpecificParameters(req *OpenAIImageGenerationRequest) {
|
||||
req.ImageGenerationParameters.Quality = nil
|
||||
req.ImageGenerationParameters.Style = nil
|
||||
req.ImageGenerationParameters.Size = nil
|
||||
req.ImageGenerationParameters.OutputCompression = nil
|
||||
}
|
||||
|
||||
func filterOpenAISpecificParameters(req *OpenAIImageGenerationRequest) {
|
||||
req.ImageGenerationParameters.Seed = nil
|
||||
req.NumInferenceSteps = nil
|
||||
req.NegativePrompt = nil
|
||||
}
|
||||
|
||||
// ToBifrostImageGenerationRequest converts an OpenAI image generation request to Bifrost format
|
||||
func (request *OpenAIImageGenerationRequest) ToBifrostImageGenerationRequest(ctx *schemas.BifrostContext) *schemas.BifrostImageGenerationRequest {
|
||||
if request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(request.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
|
||||
|
||||
return &schemas.BifrostImageGenerationRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: &schemas.ImageGenerationInput{
|
||||
Prompt: request.Prompt,
|
||||
},
|
||||
Params: &request.ImageGenerationParameters,
|
||||
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
|
||||
}
|
||||
}
|
||||
|
||||
func (request *OpenAIImageEditRequest) ToBifrostImageEditRequest(ctx *schemas.BifrostContext) *schemas.BifrostImageEditRequest {
|
||||
if request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(request.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
|
||||
|
||||
return &schemas.BifrostImageEditRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: request.Input,
|
||||
Params: &request.ImageEditParameters,
|
||||
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
|
||||
}
|
||||
}
|
||||
|
||||
func (request *OpenAIImageVariationRequest) ToBifrostImageVariationRequest(ctx *schemas.BifrostContext) *schemas.BifrostImageVariationRequest {
|
||||
if request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(request.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
|
||||
|
||||
return &schemas.BifrostImageVariationRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: request.Input,
|
||||
Params: &request.ImageVariationParameters,
|
||||
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
|
||||
}
|
||||
}
|
||||
|
||||
func ToOpenAIImageEditRequest(bifrostReq *schemas.BifrostImageEditRequest) *OpenAIImageEditRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || bifrostReq.Input.Images == nil || bifrostReq.Input.Prompt == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &OpenAIImageEditRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Input: bifrostReq.Input,
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
req.ImageEditParameters = *bifrostReq.Params
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func parseImageEditFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIImageEditRequest, providerName schemas.ModelProvider) *schemas.BifrostError {
|
||||
// Add model field (required)
|
||||
if err := writer.WriteField("model", openaiReq.Model); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write model field", err)
|
||||
}
|
||||
|
||||
// Add prompt field (required)
|
||||
if err := writer.WriteField("prompt", openaiReq.Input.Prompt); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write prompt field", err)
|
||||
}
|
||||
|
||||
// Add stream field when requesting streaming
|
||||
if openaiReq.Stream != nil && *openaiReq.Stream {
|
||||
if err := writer.WriteField("stream", "true"); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write stream field", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add optional parameters before file parts so routing metadata arrives first upstream.
|
||||
if openaiReq.N != nil {
|
||||
if err := writer.WriteField("n", strconv.Itoa(*openaiReq.N)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write n field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.Size != nil {
|
||||
if err := writer.WriteField("size", *openaiReq.Size); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write size field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.ResponseFormat != nil {
|
||||
if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write response_format field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.Quality != nil {
|
||||
if err := writer.WriteField("quality", *openaiReq.Quality); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write quality field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.Background != nil {
|
||||
if err := writer.WriteField("background", *openaiReq.Background); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write background field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.InputFidelity != nil {
|
||||
if err := writer.WriteField("input_fidelity", *openaiReq.InputFidelity); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write input_fidelity field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.PartialImages != nil {
|
||||
if err := writer.WriteField("partial_images", strconv.Itoa(*openaiReq.PartialImages)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write partial_images field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.OutputFormat != nil {
|
||||
if err := writer.WriteField("output_format", *openaiReq.OutputFormat); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write output_format field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.OutputCompression != nil {
|
||||
if err := writer.WriteField("output_compression", strconv.Itoa(*openaiReq.OutputCompression)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write output_compression field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.User != nil {
|
||||
if err := writer.WriteField("user", *openaiReq.User); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write user field", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add image[] fields (one for each image)
|
||||
for i, imageInput := range openaiReq.Input.Images {
|
||||
fieldName := "image[]"
|
||||
|
||||
// Detect and validate MIME type
|
||||
mimeType := http.DetectContentType(imageInput.Image)
|
||||
// Fallback to PNG if content type is undetectable or generic
|
||||
if mimeType == "" || mimeType == "application/octet-stream" {
|
||||
mimeType = "image/png"
|
||||
}
|
||||
|
||||
// Determine filename based on MIME type
|
||||
var filename string
|
||||
switch mimeType {
|
||||
case "image/jpeg":
|
||||
filename = fmt.Sprintf("image%d.jpg", i)
|
||||
case "image/webp":
|
||||
filename = fmt.Sprintf("image%d.webp", i)
|
||||
default:
|
||||
filename = fmt.Sprintf("image%d.png", i)
|
||||
}
|
||||
|
||||
// Create form part with proper Content-Type header (not CreateFormFile which defaults to application/octet-stream)
|
||||
part, err := writer.CreatePart(map[string][]string{
|
||||
"Content-Disposition": {fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, filename)},
|
||||
"Content-Type": {mimeType},
|
||||
})
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to create form part for image %d", i), err)
|
||||
}
|
||||
if _, err := part.Write(imageInput.Image); err != nil {
|
||||
return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to write image %d data", i), err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add mask if present
|
||||
if len(openaiReq.Mask) > 0 {
|
||||
// Detect MIME type for mask
|
||||
maskMimeType := http.DetectContentType(openaiReq.Mask)
|
||||
if maskMimeType != "image/png" && maskMimeType != "image/jpeg" && maskMimeType != "image/webp" {
|
||||
maskMimeType = "image/png"
|
||||
}
|
||||
|
||||
var maskFilename string
|
||||
switch maskMimeType {
|
||||
case "image/jpeg":
|
||||
maskFilename = "mask.jpg"
|
||||
case "image/webp":
|
||||
maskFilename = "mask.webp"
|
||||
default:
|
||||
maskFilename = "mask.png"
|
||||
}
|
||||
|
||||
// Create form part with proper Content-Type header
|
||||
maskPart, err := writer.CreatePart(map[string][]string{
|
||||
"Content-Disposition": {`form-data; name="mask"; filename="` + maskFilename + `"`},
|
||||
"Content-Type": {maskMimeType},
|
||||
})
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to create mask form part", err)
|
||||
}
|
||||
if _, err := maskPart.Write(openaiReq.Mask); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write mask data", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close the multipart writer
|
||||
if err := writer.Close(); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to close multipart writer", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ToOpenAIImageVariationRequest(bifrostReq *schemas.BifrostImageVariationRequest) *OpenAIImageVariationRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || bifrostReq.Input.Image.Image == nil || len(bifrostReq.Input.Image.Image) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &OpenAIImageVariationRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Input: bifrostReq.Input,
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
req.ImageVariationParameters = *bifrostReq.Params
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func parseImageVariationFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIImageVariationRequest, providerName schemas.ModelProvider) *schemas.BifrostError {
|
||||
// Add model field (required)
|
||||
if err := writer.WriteField("model", openaiReq.Model); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write model field", err)
|
||||
}
|
||||
|
||||
// Add image file (required)
|
||||
if openaiReq.Input == nil || openaiReq.Input.Image.Image == nil || len(openaiReq.Input.Image.Image) == 0 {
|
||||
return providerUtils.NewBifrostOperationError("image is required", nil)
|
||||
}
|
||||
|
||||
// Add optional parameters before the image part so metadata arrives first upstream.
|
||||
if openaiReq.N != nil {
|
||||
if err := writer.WriteField("n", strconv.Itoa(*openaiReq.N)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write n field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.ResponseFormat != nil {
|
||||
if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write response_format field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.Size != nil {
|
||||
if err := writer.WriteField("size", *openaiReq.Size); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write size field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.User != nil {
|
||||
if err := writer.WriteField("user", *openaiReq.User); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write user field", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Detect MIME type
|
||||
mimeType := http.DetectContentType(openaiReq.Input.Image.Image)
|
||||
// If still not detected, default to PNG
|
||||
if mimeType == "application/octet-stream" || mimeType == "" {
|
||||
mimeType = "image/png"
|
||||
}
|
||||
|
||||
filename := "image"
|
||||
part, err := writer.CreatePart(map[string][]string{
|
||||
"Content-Disposition": {fmt.Sprintf(`form-data; name="image"; filename="%s"`, filename)},
|
||||
"Content-Type": {mimeType},
|
||||
})
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to create image part", err)
|
||||
}
|
||||
|
||||
if _, err := part.Write(openaiReq.Input.Image.Image); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write image data", err)
|
||||
}
|
||||
|
||||
// Close the multipart writer
|
||||
if err := writer.Close(); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to close multipart writer", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
156
core/providers/openai/large_payload.go
Normal file
156
core/providers/openai/large_payload.go
Normal file
@@ -0,0 +1,156 @@
|
||||
// Package openai provides the OpenAI provider implementation for the Bifrost framework.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// largePayloadResult holds the lightweight metadata extracted from a large payload passthrough.
|
||||
type largePayloadResult struct {
|
||||
Usage *schemas.BifrostLLMUsage
|
||||
Latency int64
|
||||
ResponseBody []byte // non-nil for request types that need the raw upstream response (transcription, speech, etc.)
|
||||
}
|
||||
|
||||
// setStreamingRequestBody sets the request body for streaming handlers.
|
||||
// In normal mode it uses the marshaled jsonBody. In large payload mode it delegates to
|
||||
// ApplyLargePayloadRequestBodyWithModelNormalization which streams the original request
|
||||
// body to upstream with model prefix rewriting.
|
||||
func setStreamingRequestBody(ctx *schemas.BifrostContext, req *fasthttp.Request, jsonBody []byte, providerName schemas.ModelProvider) {
|
||||
if !providerUtils.ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, providerName) {
|
||||
req.SetBody(jsonBody)
|
||||
}
|
||||
}
|
||||
|
||||
// handleOpenAILargePayloadPassthrough handles a complete large payload request-response cycle
|
||||
// for OpenAI-compatible providers. When large payload mode is active, it streams the request
|
||||
// body to upstream and optionally streams the response back without full materialization.
|
||||
//
|
||||
// Returns (result, nil, true) on success, (nil, err, true) on error, or (nil, nil, false) when
|
||||
// large payload mode is not active and the caller should use the normal path.
|
||||
func handleOpenAILargePayloadPassthrough(
|
||||
ctx *schemas.BifrostContext,
|
||||
client *fasthttp.Client,
|
||||
url string,
|
||||
key schemas.Key,
|
||||
extraHeaders map[string]string,
|
||||
providerName schemas.ModelProvider,
|
||||
logger schemas.Logger,
|
||||
) (*largePayloadResult, *schemas.BifrostError, bool) {
|
||||
isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool)
|
||||
if !isLargePayload {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
// resp lifecycle: managed manually when large response streaming is active
|
||||
|
||||
providerUtils.SetExtraHeaders(ctx, req, extraHeaders, nil)
|
||||
req.SetRequestURI(url)
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
||||
}
|
||||
|
||||
// Rewrite model prefix and stream request body to upstream.
|
||||
// Sets content-type from context; falls back to JSON if not set.
|
||||
if !providerUtils.ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, providerName) {
|
||||
fasthttp.ReleaseResponse(resp)
|
||||
return nil, nil, false
|
||||
}
|
||||
if len(req.Header.ContentType()) == 0 {
|
||||
req.Header.SetContentType("application/json")
|
||||
}
|
||||
|
||||
// Choose client: enable response body streaming when threshold is configured
|
||||
activeClient := providerUtils.PrepareResponseStreaming(ctx, client, resp)
|
||||
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
||||
wait()
|
||||
if bifrostErr != nil {
|
||||
fasthttp.ReleaseResponse(resp)
|
||||
return nil, bifrostErr, true
|
||||
}
|
||||
|
||||
// Extract provider response headers early so they're available on error and large-response paths
|
||||
if headers := providerUtils.ExtractProviderResponseHeaders(resp); headers != nil {
|
||||
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers)
|
||||
}
|
||||
|
||||
// Error responses are always small — materialize stream body for error parsing
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
||||
parsedErr := ParseOpenAIError(resp)
|
||||
fasthttp.ReleaseResponse(resp)
|
||||
return nil, parsedErr, true
|
||||
}
|
||||
|
||||
// Delegate response body handling (large detection + resp lifecycle) to finalizeOpenAIResponse
|
||||
body, result, respErr := finalizeOpenAIResponse(ctx, resp, latency, providerName, logger)
|
||||
if respErr != nil {
|
||||
return nil, respErr, true
|
||||
}
|
||||
if result != nil {
|
||||
return result, nil, true
|
||||
}
|
||||
// Normal path — extract usage from raw bytes (passthrough doesn't parse structured response)
|
||||
usage := extractOpenAIUsageFromBytes(body)
|
||||
return &largePayloadResult{Usage: usage, Latency: latency.Milliseconds(), ResponseBody: body}, nil, true
|
||||
}
|
||||
|
||||
// finalizeOpenAIResponse handles response body processing with optional large response detection.
|
||||
// Delegates to FinalizeResponseWithLargeDetection for the core branching logic.
|
||||
// Takes ownership of resp — caller must NOT defer ReleaseResponse and must set respOwned = false
|
||||
// after this call returns.
|
||||
//
|
||||
// Returns:
|
||||
// - (body, nil, nil) — normal path; body ready for parsing; resp released.
|
||||
// - (nil, result, nil) — large response detected; context flags set for streaming; resp
|
||||
// wrapped in reader (released on reader Close).
|
||||
// - (nil, nil, err) — error; resp released.
|
||||
func finalizeOpenAIResponse(
|
||||
ctx *schemas.BifrostContext,
|
||||
resp *fasthttp.Response,
|
||||
latency time.Duration,
|
||||
providerName schemas.ModelProvider,
|
||||
logger schemas.Logger,
|
||||
) ([]byte, *largePayloadResult, *schemas.BifrostError) {
|
||||
body, isLarge, bifrostErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, logger)
|
||||
if bifrostErr != nil {
|
||||
fasthttp.ReleaseResponse(resp)
|
||||
return nil, nil, bifrostErr
|
||||
}
|
||||
if isLarge {
|
||||
// Extract usage from the response preview stored in context by FinalizeResponseWithLargeDetection
|
||||
preview, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadResponsePreview).(string)
|
||||
usage := extractOpenAIUsageFromBytes([]byte(preview))
|
||||
// resp owned by LargeResponseReader in context — don't release
|
||||
return nil, &largePayloadResult{Usage: usage, Latency: latency.Milliseconds()}, nil
|
||||
}
|
||||
// Normal path — body already copied by shared utility, safe to release resp
|
||||
fasthttp.ReleaseResponse(resp)
|
||||
return body, nil, nil
|
||||
}
|
||||
|
||||
// extractOpenAIUsageFromBytes extracts usage metadata from OpenAI response bytes using sonic.Get.
|
||||
// OpenAI responses have "usage" at the top level with prompt_tokens, completion_tokens, total_tokens.
|
||||
func extractOpenAIUsageFromBytes(data []byte) *schemas.BifrostLLMUsage {
|
||||
node, err := sonic.Get(data, "usage")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
raw, err := node.Raw()
|
||||
if err != nil || raw == "" {
|
||||
return nil
|
||||
}
|
||||
return providerUtils.ParseOpenAIUsageFromBytes([]byte(raw))
|
||||
}
|
||||
80
core/providers/openai/models.go
Normal file
80
core/providers/openai/models.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostListModelsResponse converts an OpenAI list models response to a Bifrost list models response
|
||||
func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.Data)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: providerKey,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.Data {
|
||||
for _, result := range pipeline.FilterModel(model.ID) {
|
||||
entry := schemas.Model{
|
||||
ID: string(providerKey) + "/" + result.ResolvedID,
|
||||
Created: model.Created,
|
||||
OwnedBy: schemas.Ptr(model.OwnedBy),
|
||||
ContextLength: model.ContextWindow,
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
entry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, entry)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
// ToOpenAIListModelsResponse converts a Bifrost list models response to an OpenAI list models response
|
||||
func ToOpenAIListModelsResponse(response *schemas.BifrostListModelsResponse) *OpenAIListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
openaiResponse := &OpenAIListModelsResponse{
|
||||
Data: make([]OpenAIModel, 0, len(response.Data)),
|
||||
}
|
||||
for _, model := range response.Data {
|
||||
openaiModel := OpenAIModel{
|
||||
ID: model.ID,
|
||||
Object: "model",
|
||||
}
|
||||
if model.Created != nil {
|
||||
openaiModel.Created = model.Created
|
||||
}
|
||||
if model.OwnedBy != nil {
|
||||
openaiModel.OwnedBy = *model.OwnedBy
|
||||
}
|
||||
|
||||
openaiResponse.Data = append(openaiResponse.Data, openaiModel)
|
||||
|
||||
}
|
||||
return openaiResponse
|
||||
}
|
||||
7065
core/providers/openai/openai.go
Normal file
7065
core/providers/openai/openai.go
Normal file
File diff suppressed because it is too large
Load Diff
119
core/providers/openai/openai_test.go
Normal file
119
core/providers/openai/openai_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package openai_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestOpenAI(t *testing.T) {
|
||||
t.Parallel()
|
||||
if strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) == "" {
|
||||
t.Skip("Skipping OpenAI tests because OPENAI_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.OpenAI,
|
||||
TextModel: "gpt-3.5-turbo-instruct",
|
||||
ChatModel: "gpt-4o",
|
||||
PromptCachingModel: "gpt-4.1",
|
||||
Fallbacks: []schemas.Fallback{
|
||||
{Provider: schemas.OpenAI, Model: "gpt-4o"},
|
||||
},
|
||||
VisionModel: "gpt-4o",
|
||||
EmbeddingModel: "text-embedding-3-small",
|
||||
TranscriptionModel: "gpt-4o-transcribe",
|
||||
TranscriptionFallbacks: []schemas.Fallback{
|
||||
{Provider: schemas.OpenAI, Model: "whisper-1"},
|
||||
},
|
||||
SpeechSynthesisModel: "gpt-4o-mini-tts",
|
||||
ReasoningModel: "o4-mini", // o4-mini properly returns both reasoning items and message output
|
||||
ImageGenerationModel: "gpt-image-1",
|
||||
ImageEditModel: "gpt-image-1",
|
||||
ImageVariationModel: "", // dall-e-2 is deprecated and no other OpenAI model supports image variations
|
||||
VideoGenerationModel: "sora-2",
|
||||
ChatAudioModel: "gpt-4o-mini-audio-preview",
|
||||
PassthroughModel: "gpt-4o",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: true,
|
||||
TextCompletionStream: true,
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: true,
|
||||
MultipleToolCallsStreaming: true,
|
||||
End2EndToolCalling: true,
|
||||
AutomaticFunctionCall: true,
|
||||
WebSearchTool: true,
|
||||
ImageURL: true,
|
||||
ImageBase64: true,
|
||||
MultipleImages: true,
|
||||
FileBase64: true,
|
||||
FileURL: true,
|
||||
CompleteEnd2End: true,
|
||||
SpeechSynthesis: true,
|
||||
SpeechSynthesisStream: true,
|
||||
Transcription: true,
|
||||
TranscriptionStream: true,
|
||||
Embedding: true,
|
||||
Reasoning: true,
|
||||
ListModels: true,
|
||||
ImageGeneration: true,
|
||||
ImageGenerationStream: true,
|
||||
ImageEdit: true,
|
||||
ImageEditStream: true,
|
||||
ImageVariation: false, // dall-e-2 is deprecated and no other OpenAI model supports image variations
|
||||
VideoGeneration: false, // disabled for now because of long running operations
|
||||
VideoRetrieve: false,
|
||||
VideoRemix: false,
|
||||
VideoDownload: false,
|
||||
VideoList: false,
|
||||
VideoDelete: false,
|
||||
BatchCreate: true,
|
||||
BatchList: true,
|
||||
BatchRetrieve: true,
|
||||
BatchCancel: true,
|
||||
BatchResults: true,
|
||||
FileUpload: true,
|
||||
FileList: true,
|
||||
FileRetrieve: true,
|
||||
FileDelete: true,
|
||||
FileContent: true,
|
||||
FileBatchInput: true,
|
||||
CountTokens: true,
|
||||
ChatAudio: true,
|
||||
StructuredOutputs: true, // Structured outputs with nullable enum support
|
||||
ContainerCreate: true,
|
||||
ContainerList: true,
|
||||
ContainerRetrieve: true,
|
||||
ContainerDelete: true,
|
||||
ContainerFileCreate: true,
|
||||
ContainerFileList: true,
|
||||
ContainerFileRetrieve: true,
|
||||
ContainerFileContent: true,
|
||||
ContainerFileDelete: true,
|
||||
PromptCaching: true,
|
||||
PassthroughAPI: true,
|
||||
WebSocketResponses: true,
|
||||
Realtime: false,
|
||||
},
|
||||
RealtimeModel: "gpt-4o-realtime-preview",
|
||||
}
|
||||
|
||||
t.Run("OpenAITests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
121
core/providers/openai/payload_ordering_test.go
Normal file
121
core/providers/openai/payload_ordering_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"mime/multipart"
|
||||
"testing"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPayloadOrdering_OpenAIChatRequest(t *testing.T) {
|
||||
req := &OpenAIChatRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []OpenAIMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hello")},
|
||||
},
|
||||
},
|
||||
ChatParameters: schemas.ChatParameters{
|
||||
Temperature: schemas.Ptr(0.7),
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: schemas.Ptr("Get weather"),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("location", map[string]interface{}{"type": "string"}),
|
||||
),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Reasoning: &schemas.ChatReasoning{
|
||||
Effort: schemas.Ptr("high"),
|
||||
},
|
||||
},
|
||||
Stream: schemas.Ptr(true),
|
||||
}
|
||||
|
||||
result, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
golden := `{"model":"gpt-4o","temperature":0.7,"stream":true,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}}],"reasoning_effort":"high"}`
|
||||
|
||||
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
|
||||
|
||||
// Determinism: 100 iterations must produce identical bytes
|
||||
for i := 0; i < 100; i++ {
|
||||
iter, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseImageEditFormDataBodyFromRequest_OrdersMetadataBeforeFiles(t *testing.T) {
|
||||
req := &OpenAIImageEditRequest{
|
||||
Model: "gpt-image-1",
|
||||
Input: &schemas.ImageEditInput{
|
||||
Prompt: "edit this",
|
||||
Images: []schemas.ImageInput{{Image: []byte("image-one")}, {Image: []byte("image-two")}},
|
||||
},
|
||||
ImageEditParameters: schemas.ImageEditParameters{
|
||||
N: schemas.Ptr(2),
|
||||
Size: schemas.Ptr("1024x1024"),
|
||||
ResponseFormat: schemas.Ptr("b64_json"),
|
||||
Quality: schemas.Ptr("high"),
|
||||
Background: schemas.Ptr("transparent"),
|
||||
InputFidelity: schemas.Ptr("high"),
|
||||
PartialImages: schemas.Ptr(1),
|
||||
OutputFormat: schemas.Ptr("png"),
|
||||
OutputCompression: schemas.Ptr(80),
|
||||
User: schemas.Ptr("user-123"),
|
||||
Mask: []byte("mask-image"),
|
||||
},
|
||||
Stream: schemas.Ptr(true),
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
require.Nil(t, parseImageEditFormDataBodyFromRequest(writer, req, schemas.OpenAI))
|
||||
|
||||
order := multipartPartOrder(t, writer.FormDataContentType(), body.Bytes())
|
||||
assert.Equal(t,
|
||||
[]string{"model", "prompt", "stream", "n", "size", "response_format", "quality", "background", "input_fidelity", "partial_images", "output_format", "output_compression", "user", "image[]", "image[]", "mask"},
|
||||
order,
|
||||
)
|
||||
}
|
||||
|
||||
func TestParseImageVariationFormDataBodyFromRequest_OrdersMetadataBeforeFile(t *testing.T) {
|
||||
req := &OpenAIImageVariationRequest{
|
||||
Model: "gpt-image-1",
|
||||
Input: &schemas.ImageVariationInput{
|
||||
Image: schemas.ImageInput{Image: []byte("image-variation")},
|
||||
},
|
||||
ImageVariationParameters: schemas.ImageVariationParameters{
|
||||
N: schemas.Ptr(3),
|
||||
ResponseFormat: schemas.Ptr("url"),
|
||||
Size: schemas.Ptr("512x512"),
|
||||
User: schemas.Ptr("user-456"),
|
||||
},
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
require.Nil(t, parseImageVariationFormDataBodyFromRequest(writer, req, schemas.OpenAI))
|
||||
|
||||
order := multipartPartOrder(t, writer.FormDataContentType(), body.Bytes())
|
||||
assert.Equal(t,
|
||||
[]string{"model", "n", "response_format", "size", "user", "image"},
|
||||
order,
|
||||
)
|
||||
}
|
||||
|
||||
967
core/providers/openai/realtime.go
Normal file
967
core/providers/openai/realtime.go
Normal file
@@ -0,0 +1,967 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// SupportsRealtimeAPI returns true since OpenAI natively supports the Realtime API.
|
||||
func (provider *OpenAIProvider) SupportsRealtimeAPI() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// RealtimeWebSocketURL returns the WSS URL for the OpenAI Realtime API.
|
||||
// Format: wss://api.openai.com/v1/realtime?model=<model>
|
||||
func (provider *OpenAIProvider) RealtimeWebSocketURL(key schemas.Key, model string) string {
|
||||
base := provider.networkConfig.BaseURL
|
||||
base = strings.Replace(base, "https://", "wss://", 1)
|
||||
base = strings.Replace(base, "http://", "ws://", 1)
|
||||
return base + "/v1/realtime?model=" + url.QueryEscape(model)
|
||||
}
|
||||
|
||||
// RealtimeHeaders returns the headers required for the OpenAI Realtime WebSocket connection.
|
||||
func (provider *OpenAIProvider) RealtimeHeaders(key schemas.Key) map[string]string {
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + key.Value.GetValue(),
|
||||
}
|
||||
for k, v := range provider.networkConfig.ExtraHeaders {
|
||||
headers[k] = v
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// SupportsRealtimeWebRTC reports that OpenAI supports WebRTC SDP exchange.
|
||||
func (provider *OpenAIProvider) SupportsRealtimeWebRTC() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ExchangeRealtimeWebRTCSDP performs the GA SDP exchange via multipart POST to /v1/realtime/calls.
|
||||
func (provider *OpenAIProvider) ExchangeRealtimeWebRTCSDP(
|
||||
ctx *schemas.BifrostContext,
|
||||
key schemas.Key,
|
||||
model string,
|
||||
sdp string,
|
||||
session json.RawMessage,
|
||||
) (string, *schemas.BifrostError) {
|
||||
path := "/v1/realtime/calls"
|
||||
if session == nil && strings.TrimSpace(model) != "" {
|
||||
path += "?model=" + url.QueryEscape(model)
|
||||
}
|
||||
return provider.exchangeWebRTCSDP(ctx, key, path, sdp, session)
|
||||
}
|
||||
|
||||
// ExchangeLegacyRealtimeWebRTCSDP performs the beta SDP exchange via multipart POST to /v1/realtime.
|
||||
// Same multipart format but targets the legacy endpoint with model in the URL.
|
||||
func (provider *OpenAIProvider) ExchangeLegacyRealtimeWebRTCSDP(
|
||||
ctx *schemas.BifrostContext,
|
||||
key schemas.Key,
|
||||
sdp string,
|
||||
session json.RawMessage,
|
||||
model string,
|
||||
) (string, *schemas.BifrostError) {
|
||||
return provider.exchangeWebRTCSDP(ctx, key, "/v1/realtime?model="+url.QueryEscape(model), sdp, session)
|
||||
}
|
||||
|
||||
// exchangeWebRTCSDP is the shared multipart SDP exchange implementation.
|
||||
// Builds a multipart body with sdp + optional session, POSTs to the given path.
|
||||
func (provider *OpenAIProvider) exchangeWebRTCSDP(
|
||||
ctx *schemas.BifrostContext,
|
||||
key schemas.Key,
|
||||
path string,
|
||||
sdp string,
|
||||
session json.RawMessage,
|
||||
) (string, *schemas.BifrostError) {
|
||||
bodyBuf := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(bodyBuf)
|
||||
if err := writer.WriteField("sdp", sdp); err != nil {
|
||||
return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream SDP body", err)
|
||||
}
|
||||
if session != nil {
|
||||
if err := writer.WriteField("session", string(session)); err != nil {
|
||||
return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream session body", err)
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to finalize upstream SDP body", err)
|
||||
}
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(provider.buildRequestURL(ctx, path, schemas.RealtimeRequest))
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.Header.SetContentType(writer.FormDataContentType())
|
||||
req.Header.Set("Authorization", "Bearer "+key.Value.GetValue())
|
||||
for k, v := range provider.networkConfig.ExtraHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
if headers, _ := ctx.Value(schemas.BifrostContextKeyRequestHeaders).(map[string]string); headers != nil {
|
||||
if agentsSDK := headers["x-openai-agents-sdk"]; agentsSDK != "" {
|
||||
req.Header.Set("X-OpenAI-Agents-SDK", agentsSDK)
|
||||
}
|
||||
}
|
||||
req.SetBody(bodyBuf.Bytes())
|
||||
|
||||
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return "", bifrostErr
|
||||
}
|
||||
|
||||
answerBody := resp.Body()
|
||||
if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices {
|
||||
return "", provider.realtimeWebRTCUpstreamError(ctx, resp.StatusCode(), answerBody)
|
||||
}
|
||||
|
||||
return string(answerBody), nil
|
||||
}
|
||||
|
||||
func (provider *OpenAIProvider) realtimeWebRTCUpstreamError(ctx *schemas.BifrostContext, statusCode int, body []byte) *schemas.BifrostError {
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: schemas.Ptr(fasthttp.StatusBadGateway),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr("upstream_connection_error"),
|
||||
Message: fmt.Sprintf("upstream realtime WebRTC handshake failed for %s", provider.GetProviderKey()),
|
||||
},
|
||||
ExtraFields: schemas.BifrostErrorExtraFields{
|
||||
RequestType: schemas.RealtimeRequest,
|
||||
Provider: provider.GetProviderKey(),
|
||||
},
|
||||
}
|
||||
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
||||
bifrostErr.ExtraFields.RawResponse = map[string]any{
|
||||
"status": statusCode,
|
||||
"body": string(body),
|
||||
}
|
||||
}
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
func newRealtimeWebRTCSDPError(status int, errorType, message string, err error) *schemas.BifrostError {
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
StatusCode: schemas.Ptr(status),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr(errorType),
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
if err != nil {
|
||||
bifrostErr.Error.Error = err
|
||||
}
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
func (provider *OpenAIProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool {
|
||||
if event == nil {
|
||||
return false
|
||||
}
|
||||
switch event.Type {
|
||||
case schemas.RTEventResponseCreate, schemas.RTEventInputAudioBufferCommitted:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *OpenAIProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType {
|
||||
return schemas.RTEventResponseDone
|
||||
}
|
||||
|
||||
func (provider *OpenAIProvider) RealtimeWebRTCDataChannelLabel() string {
|
||||
return "oai-events"
|
||||
}
|
||||
|
||||
func (provider *OpenAIProvider) RealtimeWebSocketSubprotocol() string {
|
||||
return "realtime"
|
||||
}
|
||||
|
||||
func (provider *OpenAIProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (provider *OpenAIProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool {
|
||||
switch eventType {
|
||||
case schemas.RTEventResponseTextDelta,
|
||||
schemas.RTEventResponseAudioTransDelta,
|
||||
schemas.RealtimeEventType("response.output_text.delta"),
|
||||
schemas.RealtimeEventType("response.output_audio_transcript.delta"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// CreateRealtimeClientSecret mints an OpenAI Realtime client secret and returns
|
||||
// the native OpenAI response body unchanged.
|
||||
func (provider *OpenAIProvider) CreateRealtimeClientSecret(
|
||||
ctx *schemas.BifrostContext,
|
||||
key schemas.Key,
|
||||
endpointType schemas.RealtimeSessionEndpointType,
|
||||
rawRequest json.RawMessage,
|
||||
) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
|
||||
if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.RealtimeRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
normalizedBody, requestedModel, bifrostErr := normalizeRealtimeClientSecretRequest(rawRequest, provider.GetProviderKey(), endpointType)
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(provider.buildRequestURL(ctx, realtimeSessionUpstreamPath(endpointType), schemas.RealtimeRequest))
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.Header.SetContentType("application/json")
|
||||
for k, v := range provider.realtimeSessionHeaders(key, endpointType) {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
req.SetBody(normalizedBody)
|
||||
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
headers := providerUtils.ExtractProviderResponseHeaders(resp)
|
||||
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers)
|
||||
|
||||
if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices {
|
||||
return nil, ParseOpenAIError(resp)
|
||||
}
|
||||
|
||||
body, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err)
|
||||
}
|
||||
for k := range headers {
|
||||
if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") {
|
||||
delete(headers, k)
|
||||
}
|
||||
}
|
||||
|
||||
out := &schemas.BifrostPassthroughResponse{
|
||||
StatusCode: resp.StatusCode(),
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
}
|
||||
out.ExtraFields.Provider = provider.GetProviderKey()
|
||||
out.ExtraFields.OriginalModelRequested = requestedModel
|
||||
out.ExtraFields.RequestType = schemas.RealtimeRequest
|
||||
out.ExtraFields.Latency = latency.Milliseconds()
|
||||
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
||||
providerUtils.ParseAndSetRawRequestIfJSON(req, &out.ExtraFields)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func normalizeRealtimeClientSecretRequest(
|
||||
rawRequest json.RawMessage,
|
||||
defaultProvider schemas.ModelProvider,
|
||||
endpointType schemas.RealtimeSessionEndpointType,
|
||||
) ([]byte, string, *schemas.BifrostError) {
|
||||
root, bifrostErr := schemas.ParseRealtimeClientSecretBody(rawRequest)
|
||||
if bifrostErr != nil {
|
||||
return nil, "", bifrostErr
|
||||
}
|
||||
|
||||
modelValue, bifrostErr := schemas.ExtractRealtimeClientSecretModel(root)
|
||||
if bifrostErr != nil {
|
||||
return nil, "", bifrostErr
|
||||
}
|
||||
providerKey, normalizedModel := schemas.ParseModelString(modelValue, defaultProvider)
|
||||
if normalizedModel == "" {
|
||||
return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model is required", nil)
|
||||
}
|
||||
if providerKey == "" {
|
||||
providerKey = defaultProvider
|
||||
}
|
||||
if providerKey == "" {
|
||||
return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "unable to determine provider from model", nil)
|
||||
}
|
||||
|
||||
if endpointType == schemas.RealtimeSessionEndpointSessions {
|
||||
return normalizeRealtimeSessionsRequest(root, normalizedModel)
|
||||
}
|
||||
|
||||
return normalizeRealtimeClientSecretsRequest(root, normalizedModel)
|
||||
}
|
||||
|
||||
func normalizeRealtimeClientSecretsRequest(
|
||||
root map[string]json.RawMessage,
|
||||
normalizedModel string,
|
||||
) ([]byte, string, *schemas.BifrostError) {
|
||||
session := map[string]json.RawMessage{}
|
||||
if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) {
|
||||
if err := json.Unmarshal(existingSession, &session); err != nil {
|
||||
return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err)
|
||||
}
|
||||
}
|
||||
|
||||
modelJSON, marshalErr := json.Marshal(normalizedModel)
|
||||
if marshalErr != nil {
|
||||
return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr)
|
||||
}
|
||||
session["model"] = modelJSON
|
||||
if _, ok := session["type"]; !ok {
|
||||
typeJSON, marshalErr := json.Marshal("realtime")
|
||||
if marshalErr != nil {
|
||||
return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session type", marshalErr)
|
||||
}
|
||||
session["type"] = typeJSON
|
||||
}
|
||||
delete(root, "model")
|
||||
|
||||
sessionJSON, marshalErr := json.Marshal(session)
|
||||
if marshalErr != nil {
|
||||
return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session", marshalErr)
|
||||
}
|
||||
root["session"] = sessionJSON
|
||||
|
||||
normalizedBody, marshalErr := json.Marshal(root)
|
||||
if marshalErr != nil {
|
||||
return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr)
|
||||
}
|
||||
|
||||
return normalizedBody, normalizedModel, nil
|
||||
}
|
||||
|
||||
func normalizeRealtimeSessionsRequest(
|
||||
root map[string]json.RawMessage,
|
||||
normalizedModel string,
|
||||
) ([]byte, string, *schemas.BifrostError) {
|
||||
if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) {
|
||||
session := map[string]json.RawMessage{}
|
||||
if err := json.Unmarshal(existingSession, &session); err != nil {
|
||||
return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err)
|
||||
}
|
||||
for key, value := range session {
|
||||
if _, exists := root[key]; !exists {
|
||||
root[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
modelJSON, marshalErr := json.Marshal(normalizedModel)
|
||||
if marshalErr != nil {
|
||||
return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr)
|
||||
}
|
||||
root["model"] = modelJSON
|
||||
delete(root, "session")
|
||||
|
||||
normalizedBody, marshalErr := json.Marshal(root)
|
||||
if marshalErr != nil {
|
||||
return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr)
|
||||
}
|
||||
|
||||
return normalizedBody, normalizedModel, nil
|
||||
}
|
||||
|
||||
func (provider *OpenAIProvider) realtimeSessionHeaders(
|
||||
key schemas.Key,
|
||||
endpointType schemas.RealtimeSessionEndpointType,
|
||||
) map[string]string {
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + key.Value.GetValue(),
|
||||
}
|
||||
if endpointType == schemas.RealtimeSessionEndpointSessions {
|
||||
headers["OpenAI-Beta"] = "realtime=v1"
|
||||
}
|
||||
for k, v := range provider.networkConfig.ExtraHeaders {
|
||||
headers[k] = v
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func realtimeSessionUpstreamPath(endpointType schemas.RealtimeSessionEndpointType) string {
|
||||
if endpointType == schemas.RealtimeSessionEndpointSessions {
|
||||
return "/v1/realtime/sessions"
|
||||
}
|
||||
return "/v1/realtime/client_secrets"
|
||||
}
|
||||
|
||||
func newRealtimeClientSecretError(status int, errorType, message string, err error) *schemas.BifrostError {
|
||||
return &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: schemas.Ptr(status),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr(errorType),
|
||||
Message: message,
|
||||
Error: err,
|
||||
},
|
||||
ExtraFields: schemas.BifrostErrorExtraFields{
|
||||
RequestType: schemas.RealtimeRequest,
|
||||
Provider: schemas.OpenAI,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// openAIRealtimeEvent is the raw shape of an OpenAI Realtime protocol event.
|
||||
type openAIRealtimeEvent struct {
|
||||
Type string `json:"type"`
|
||||
EventID string `json:"event_id,omitempty"`
|
||||
Session json.RawMessage `json:"session,omitempty"`
|
||||
Conversation json.RawMessage `json:"conversation,omitempty"`
|
||||
Item json.RawMessage `json:"item,omitempty"`
|
||||
Response json.RawMessage `json:"response,omitempty"`
|
||||
Part json.RawMessage `json:"part,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
Transcript string `json:"transcript,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Error json.RawMessage `json:"error,omitempty"`
|
||||
ItemID string `json:"item_id,omitempty"`
|
||||
OutputIndex *int `json:"output_index,omitempty"`
|
||||
ContentIndex *int `json:"content_index,omitempty"`
|
||||
ResponseID string `json:"response_id,omitempty"`
|
||||
AudioEndMS *int `json:"audio_end_ms,omitempty"`
|
||||
|
||||
PreviousItemID string `json:"previous_item_id,omitempty"`
|
||||
}
|
||||
|
||||
// openAIRealtimeSession is the session object within an OpenAI Realtime event.
|
||||
type openAIRealtimeSession struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Modalities []string `json:"modalities,omitempty"`
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
Voice string `json:"voice,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"`
|
||||
TurnDetection json.RawMessage `json:"turn_detection,omitempty"`
|
||||
InputAudioFormat string `json:"input_audio_format,omitempty"`
|
||||
OutputAudioType string `json:"output_audio_type,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// openAIRealtimeItem is the item object within an OpenAI Realtime event.
|
||||
type openAIRealtimeItem struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
CallID string `json:"call_id,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
}
|
||||
|
||||
// openAIRealtimeError is the error object within an OpenAI Realtime event.
|
||||
type openAIRealtimeError struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Param string `json:"param,omitempty"`
|
||||
}
|
||||
|
||||
// ToBifrostRealtimeEvent converts an OpenAI Realtime event (raw JSON) to the unified Bifrost format.
|
||||
func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMessage) (*schemas.BifrostRealtimeEvent, error) {
|
||||
var raw openAIRealtimeEvent
|
||||
if err := json.Unmarshal(providerEvent, &raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal OpenAI realtime event: %w", err)
|
||||
}
|
||||
|
||||
event := &schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RealtimeEventType(raw.Type),
|
||||
EventID: raw.EventID,
|
||||
RawData: providerEvent,
|
||||
}
|
||||
setRealtimeExtraParam(event, "item_id", raw.ItemID)
|
||||
setRealtimeExtraParam(event, "previous_item_id", raw.PreviousItemID)
|
||||
setRealtimeExtraParam(event, "output_index", raw.OutputIndex)
|
||||
setRealtimeExtraParam(event, "content_index", raw.ContentIndex)
|
||||
setRealtimeExtraParam(event, "response_id", raw.ResponseID)
|
||||
setRealtimeExtraParam(event, "audio_end_ms", raw.AudioEndMS)
|
||||
setRealtimeExtraParam(event, "transcript", raw.Transcript)
|
||||
setRealtimeExtraParam(event, "text", raw.Text)
|
||||
setRealtimeExtraParam(event, "conversation", raw.Conversation)
|
||||
setRealtimeExtraParam(event, "response", raw.Response)
|
||||
setRealtimeExtraParam(event, "part", raw.Part)
|
||||
|
||||
switch {
|
||||
case raw.Session != nil:
|
||||
var sess openAIRealtimeSession
|
||||
if err := json.Unmarshal(raw.Session, &sess); err == nil {
|
||||
event.Session = &schemas.RealtimeSession{
|
||||
ID: sess.ID,
|
||||
Model: sess.Model,
|
||||
Modalities: sess.Modalities,
|
||||
Instructions: sess.Instructions,
|
||||
Voice: sess.Voice,
|
||||
Temperature: sess.Temperature,
|
||||
MaxOutputTokens: sess.MaxOutputTokens,
|
||||
TurnDetection: sess.TurnDetection,
|
||||
InputAudioFormat: sess.InputAudioFormat,
|
||||
OutputAudioType: sess.OutputAudioType,
|
||||
Tools: sess.Tools,
|
||||
}
|
||||
if extra := extractRealtimeNestedParams(raw.Session, "id", "model", "modalities", "instructions", "voice", "temperature", "max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools"); len(extra) > 0 {
|
||||
event.Session.ExtraParams = extra
|
||||
}
|
||||
}
|
||||
case raw.Item != nil:
|
||||
var item openAIRealtimeItem
|
||||
if err := json.Unmarshal(raw.Item, &item); err == nil {
|
||||
event.Item = &schemas.RealtimeItem{
|
||||
ID: item.ID,
|
||||
Type: item.Type,
|
||||
Role: item.Role,
|
||||
Status: item.Status,
|
||||
Content: item.Content,
|
||||
Name: item.Name,
|
||||
CallID: item.CallID,
|
||||
Arguments: item.Arguments,
|
||||
Output: item.Output,
|
||||
}
|
||||
if extra := extractRealtimeNestedParams(raw.Item, "id", "type", "role", "status", "content", "name", "call_id", "arguments", "output"); len(extra) > 0 {
|
||||
event.Item.ExtraParams = extra
|
||||
}
|
||||
}
|
||||
|
||||
case raw.Error != nil:
|
||||
var rtErr openAIRealtimeError
|
||||
if err := json.Unmarshal(raw.Error, &rtErr); err == nil {
|
||||
event.Error = &schemas.RealtimeError{
|
||||
Type: rtErr.Type,
|
||||
Code: rtErr.Code,
|
||||
Message: rtErr.Message,
|
||||
Param: rtErr.Param,
|
||||
}
|
||||
if extra := extractRealtimeNestedParams(raw.Error, "type", "code", "message", "param"); len(extra) > 0 {
|
||||
event.Error.ExtraParams = extra
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isRealtimeDeltaEvent(raw.Type) {
|
||||
event.Delta = &schemas.RealtimeDelta{
|
||||
Text: raw.Text,
|
||||
Audio: raw.Audio,
|
||||
Transcript: raw.Transcript,
|
||||
ItemID: raw.ItemID,
|
||||
OutputIdx: raw.OutputIndex,
|
||||
ContentIdx: raw.ContentIndex,
|
||||
ResponseID: raw.ResponseID,
|
||||
}
|
||||
if raw.Delta != "" {
|
||||
if event.Delta.Text == "" {
|
||||
event.Delta.Text = raw.Delta
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
|
||||
// ToProviderRealtimeEvent converts a unified Bifrost Realtime event back to OpenAI's native JSON.
|
||||
func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) {
|
||||
out := map[string]interface{}{
|
||||
"type": string(bifrostEvent.Type),
|
||||
}
|
||||
if bifrostEvent.EventID != "" {
|
||||
out["event_id"] = bifrostEvent.EventID
|
||||
}
|
||||
mergeRealtimeExtraParams(out, bifrostEvent.ExtraParams)
|
||||
|
||||
if bifrostEvent.Session != nil {
|
||||
sess := map[string]interface{}{}
|
||||
if bifrostEvent.Session.ID != "" && bifrostEvent.Type != schemas.RTEventSessionUpdate {
|
||||
sess["id"] = bifrostEvent.Session.ID
|
||||
}
|
||||
if bifrostEvent.Session.Model != "" {
|
||||
sess["model"] = bifrostEvent.Session.Model
|
||||
}
|
||||
if len(bifrostEvent.Session.Modalities) > 0 {
|
||||
sess["modalities"] = bifrostEvent.Session.Modalities
|
||||
}
|
||||
if bifrostEvent.Session.Instructions != "" {
|
||||
sess["instructions"] = bifrostEvent.Session.Instructions
|
||||
}
|
||||
if bifrostEvent.Session.Voice != "" {
|
||||
sess["voice"] = bifrostEvent.Session.Voice
|
||||
}
|
||||
if bifrostEvent.Session.Temperature != nil {
|
||||
sess["temperature"] = *bifrostEvent.Session.Temperature
|
||||
}
|
||||
if bifrostEvent.Session.MaxOutputTokens != nil {
|
||||
sess["max_output_tokens"] = bifrostEvent.Session.MaxOutputTokens
|
||||
}
|
||||
if bifrostEvent.Session.TurnDetection != nil {
|
||||
sess["turn_detection"] = bifrostEvent.Session.TurnDetection
|
||||
}
|
||||
if bifrostEvent.Session.InputAudioFormat != "" {
|
||||
sess["input_audio_format"] = bifrostEvent.Session.InputAudioFormat
|
||||
}
|
||||
if bifrostEvent.Session.OutputAudioType != "" {
|
||||
sess["output_audio_type"] = bifrostEvent.Session.OutputAudioType
|
||||
}
|
||||
if bifrostEvent.Session.Tools != nil {
|
||||
sess["tools"] = bifrostEvent.Session.Tools
|
||||
}
|
||||
mergeRealtimeSessionExtraParams(sess, bifrostEvent.Session.ExtraParams, bifrostEvent.Type)
|
||||
out["session"] = sess
|
||||
}
|
||||
|
||||
if bifrostEvent.Item != nil {
|
||||
item := map[string]interface{}{
|
||||
"type": bifrostEvent.Item.Type,
|
||||
}
|
||||
if bifrostEvent.Item.ID != "" {
|
||||
item["id"] = bifrostEvent.Item.ID
|
||||
}
|
||||
if bifrostEvent.Item.Role != "" {
|
||||
item["role"] = bifrostEvent.Item.Role
|
||||
}
|
||||
if bifrostEvent.Item.Status != "" {
|
||||
item["status"] = bifrostEvent.Item.Status
|
||||
}
|
||||
if bifrostEvent.Item.Content != nil {
|
||||
item["content"] = bifrostEvent.Item.Content
|
||||
}
|
||||
if bifrostEvent.Item.Name != "" {
|
||||
item["name"] = bifrostEvent.Item.Name
|
||||
}
|
||||
if bifrostEvent.Item.CallID != "" {
|
||||
item["call_id"] = bifrostEvent.Item.CallID
|
||||
}
|
||||
if bifrostEvent.Item.Arguments != "" {
|
||||
item["arguments"] = bifrostEvent.Item.Arguments
|
||||
}
|
||||
if bifrostEvent.Item.Output != "" {
|
||||
item["output"] = bifrostEvent.Item.Output
|
||||
}
|
||||
mergeRealtimeExtraParams(item, bifrostEvent.Item.ExtraParams)
|
||||
out["item"] = item
|
||||
}
|
||||
|
||||
if bifrostEvent.Error != nil {
|
||||
rtErr := map[string]interface{}{}
|
||||
if bifrostEvent.Error.Type != "" {
|
||||
rtErr["type"] = bifrostEvent.Error.Type
|
||||
}
|
||||
if bifrostEvent.Error.Code != "" {
|
||||
rtErr["code"] = bifrostEvent.Error.Code
|
||||
}
|
||||
if bifrostEvent.Error.Message != "" {
|
||||
rtErr["message"] = bifrostEvent.Error.Message
|
||||
}
|
||||
if bifrostEvent.Error.Param != "" {
|
||||
rtErr["param"] = bifrostEvent.Error.Param
|
||||
}
|
||||
mergeRealtimeExtraParams(rtErr, bifrostEvent.Error.ExtraParams)
|
||||
out["error"] = rtErr
|
||||
}
|
||||
|
||||
if bifrostEvent.Delta != nil {
|
||||
if bifrostEvent.Delta.Text != "" {
|
||||
out["delta"] = bifrostEvent.Delta.Text
|
||||
}
|
||||
if bifrostEvent.Delta.Audio != "" {
|
||||
out["audio"] = bifrostEvent.Delta.Audio
|
||||
}
|
||||
if bifrostEvent.Delta.Transcript != "" {
|
||||
out["transcript"] = bifrostEvent.Delta.Transcript
|
||||
}
|
||||
if bifrostEvent.Delta.ItemID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "item_id") {
|
||||
out["item_id"] = bifrostEvent.Delta.ItemID
|
||||
}
|
||||
if bifrostEvent.Delta.OutputIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "output_index") {
|
||||
out["output_index"] = *bifrostEvent.Delta.OutputIdx
|
||||
}
|
||||
if bifrostEvent.Delta.ContentIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "content_index") {
|
||||
out["content_index"] = *bifrostEvent.Delta.ContentIdx
|
||||
}
|
||||
if bifrostEvent.Delta.ResponseID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "response_id") {
|
||||
out["response_id"] = bifrostEvent.Delta.ResponseID
|
||||
}
|
||||
}
|
||||
|
||||
return providerUtils.MarshalSorted(out)
|
||||
}
|
||||
|
||||
func mergeRealtimeSessionExtraParams(out map[string]interface{}, params map[string]json.RawMessage, eventType schemas.RealtimeEventType) {
|
||||
filtered := params
|
||||
if eventType == schemas.RTEventSessionUpdate && len(params) > 0 {
|
||||
filtered = make(map[string]json.RawMessage, len(params))
|
||||
for key, value := range params {
|
||||
switch key {
|
||||
case "id", "object", "expires_at", "client_secret":
|
||||
continue
|
||||
default:
|
||||
filtered[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
mergeRealtimeExtraParams(out, filtered)
|
||||
}
|
||||
|
||||
func (provider *OpenAIProvider) ExtractRealtimeTurnUsage(terminalEventRaw []byte) *schemas.BifrostLLMUsage {
|
||||
if len(terminalEventRaw) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parsed openAIRealtimeResponseDoneEnvelope
|
||||
if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil || parsed.Response.Usage == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
usage := &schemas.BifrostLLMUsage{
|
||||
PromptTokens: parsed.Response.Usage.InputTokens,
|
||||
CompletionTokens: parsed.Response.Usage.OutputTokens,
|
||||
TotalTokens: parsed.Response.Usage.TotalTokens,
|
||||
}
|
||||
|
||||
if parsed.Response.Usage.InputTokenDetails != nil {
|
||||
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
|
||||
TextTokens: parsed.Response.Usage.InputTokenDetails.TextTokens,
|
||||
AudioTokens: parsed.Response.Usage.InputTokenDetails.AudioTokens,
|
||||
ImageTokens: parsed.Response.Usage.InputTokenDetails.ImageTokens,
|
||||
CachedReadTokens: parsed.Response.Usage.InputTokenDetails.CachedTokens,
|
||||
}
|
||||
}
|
||||
|
||||
if parsed.Response.Usage.OutputTokenDetails != nil {
|
||||
usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{
|
||||
TextTokens: parsed.Response.Usage.OutputTokenDetails.TextTokens,
|
||||
AudioTokens: parsed.Response.Usage.OutputTokenDetails.AudioTokens,
|
||||
ReasoningTokens: parsed.Response.Usage.OutputTokenDetails.ReasoningTokens,
|
||||
ImageTokens: parsed.Response.Usage.OutputTokenDetails.ImageTokens,
|
||||
CitationTokens: parsed.Response.Usage.OutputTokenDetails.CitationTokens,
|
||||
NumSearchQueries: parsed.Response.Usage.OutputTokenDetails.NumSearchQueries,
|
||||
AcceptedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.AcceptedPredictionTokens,
|
||||
RejectedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.RejectedPredictionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
return usage
|
||||
}
|
||||
|
||||
func (provider *OpenAIProvider) ExtractRealtimeTurnOutput(terminalEventRaw []byte) *schemas.ChatMessage {
|
||||
if len(terminalEventRaw) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var parsed openAIRealtimeResponseDoneEnvelope
|
||||
if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
content := extractOpenAIRealtimeResponseDoneAssistantText(parsed.Response.Output)
|
||||
toolCalls := extractOpenAIRealtimeResponseDoneToolCalls(parsed.Response.Output)
|
||||
if content == "" && len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant}
|
||||
if content != "" {
|
||||
message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(content)}
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ToolCalls: toolCalls}
|
||||
}
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
type openAIRealtimeResponseDoneEnvelope struct {
|
||||
Response struct {
|
||||
Output []openAIRealtimeResponseDoneOutput `json:"output"`
|
||||
Usage *openAIRealtimeResponseDoneUsage `json:"usage"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
type openAIRealtimeResponseDoneOutput struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
CallID string `json:"call_id"`
|
||||
Arguments string `json:"arguments"`
|
||||
Content []openAIRealtimeResponseDoneBlock `json:"content"`
|
||||
}
|
||||
|
||||
type openAIRealtimeResponseDoneBlock struct {
|
||||
Text string `json:"text"`
|
||||
Transcript string `json:"transcript"`
|
||||
Refusal string `json:"refusal"`
|
||||
}
|
||||
|
||||
type openAIRealtimeResponseDoneUsage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails *openAIRealtimeResponseDoneInputTokenUsage `json:"input_token_details"`
|
||||
OutputTokenDetails *openAIRealtimeResponseDoneOutputTokenUsage `json:"output_token_details"`
|
||||
}
|
||||
|
||||
type openAIRealtimeResponseDoneInputTokenUsage struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ImageTokens int `json:"image_tokens"`
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
type openAIRealtimeResponseDoneOutputTokenUsage struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
ImageTokens *int `json:"image_tokens"`
|
||||
CitationTokens *int `json:"citation_tokens"`
|
||||
NumSearchQueries *int `json:"num_search_queries"`
|
||||
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
|
||||
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
|
||||
}
|
||||
|
||||
func extractOpenAIRealtimeResponseDoneAssistantText(outputs []openAIRealtimeResponseDoneOutput) string {
|
||||
var sb strings.Builder
|
||||
for _, output := range outputs {
|
||||
if output.Type != "message" {
|
||||
continue
|
||||
}
|
||||
for _, block := range output.Content {
|
||||
switch {
|
||||
case strings.TrimSpace(block.Text) != "":
|
||||
sb.WriteString(block.Text)
|
||||
case strings.TrimSpace(block.Transcript) != "":
|
||||
sb.WriteString(block.Transcript)
|
||||
case strings.TrimSpace(block.Refusal) != "":
|
||||
sb.WriteString(block.Refusal)
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
func extractOpenAIRealtimeResponseDoneToolCalls(outputs []openAIRealtimeResponseDoneOutput) []schemas.ChatAssistantMessageToolCall {
|
||||
toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0)
|
||||
for _, output := range outputs {
|
||||
if output.Type != "function_call" {
|
||||
continue
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(output.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
toolType := "function"
|
||||
id := strings.TrimSpace(output.CallID)
|
||||
if id == "" {
|
||||
id = strings.TrimSpace(output.ID)
|
||||
}
|
||||
|
||||
toolCall := schemas.ChatAssistantMessageToolCall{
|
||||
Index: uint16(len(toolCalls)),
|
||||
Type: &toolType,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr(name),
|
||||
Arguments: output.Arguments,
|
||||
},
|
||||
}
|
||||
if id != "" {
|
||||
toolCall.ID = schemas.Ptr(id)
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func setRealtimeExtraParam(event *schemas.BifrostRealtimeEvent, key string, value any) {
|
||||
if event == nil || key == "" || value == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
case *int:
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
case json.RawMessage:
|
||||
if len(v) == 0 || string(v) == "null" {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if event.ExtraParams == nil {
|
||||
event.ExtraParams = make(map[string]json.RawMessage)
|
||||
}
|
||||
event.ExtraParams[key] = raw
|
||||
}
|
||||
|
||||
func mergeRealtimeExtraParams(out map[string]interface{}, params map[string]json.RawMessage) {
|
||||
for key, raw := range params {
|
||||
if len(raw) == 0 {
|
||||
continue
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(raw, &value); err != nil {
|
||||
continue
|
||||
}
|
||||
out[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func hasRealtimeExtraParam(params map[string]json.RawMessage, key string) bool {
|
||||
if params == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := params[key]
|
||||
return ok && len(raw) > 0
|
||||
}
|
||||
|
||||
func extractRealtimeNestedParams(raw json.RawMessage, knownKeys ...string) map[string]json.RawMessage {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
root := map[string]json.RawMessage{}
|
||||
if err := json.Unmarshal(raw, &root); err != nil {
|
||||
return nil
|
||||
}
|
||||
for _, key := range knownKeys {
|
||||
delete(root, key)
|
||||
}
|
||||
if len(root) == 0 {
|
||||
return nil
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func isRealtimeDeltaEvent(eventType string) bool {
|
||||
switch eventType {
|
||||
case "response.text.delta",
|
||||
"response.output_text.delta",
|
||||
"response.audio.delta",
|
||||
"response.output_audio.delta",
|
||||
"response.audio_transcript.delta",
|
||||
"response.output_audio_transcript.delta",
|
||||
"conversation.item.input_audio_transcription.delta":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
561
core/providers/openai/realtime_test.go
Normal file
561
core/providers/openai/realtime_test.go
Normal file
@@ -0,0 +1,561 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestNormalizeRealtimeClientSecretRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body, model, bifrostErr := normalizeRealtimeClientSecretRequest(
|
||||
json.RawMessage(`{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}`),
|
||||
schemas.OpenAI,
|
||||
schemas.RealtimeSessionEndpointClientSecrets,
|
||||
)
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr)
|
||||
}
|
||||
if model != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview")
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
t.Fatalf("failed to unmarshal normalized body: %v", err)
|
||||
}
|
||||
if _, ok := payload["model"]; ok {
|
||||
t.Fatal("top-level model should be removed after normalization")
|
||||
}
|
||||
|
||||
var session map[string]any
|
||||
if err := json.Unmarshal(payload["session"], &session); err != nil {
|
||||
t.Fatalf("failed to unmarshal session: %v", err)
|
||||
}
|
||||
if session["model"] != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview")
|
||||
}
|
||||
if session["type"] != "realtime" {
|
||||
t.Fatalf("session.type = %v, want %q", session["type"], "realtime")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeRealtimeClientSecretRequestUsesDefaultProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body, model, bifrostErr := normalizeRealtimeClientSecretRequest(
|
||||
json.RawMessage(`{"session":{"model":"gpt-4o-realtime-preview"}}`),
|
||||
schemas.OpenAI,
|
||||
schemas.RealtimeSessionEndpointClientSecrets,
|
||||
)
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr)
|
||||
}
|
||||
if model != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview")
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
t.Fatalf("failed to unmarshal normalized body: %v", err)
|
||||
}
|
||||
|
||||
var session map[string]any
|
||||
if err := json.Unmarshal(payload["session"], &session); err != nil {
|
||||
t.Fatalf("failed to unmarshal session: %v", err)
|
||||
}
|
||||
if session["model"] != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview")
|
||||
}
|
||||
if session["type"] != "realtime" {
|
||||
t.Fatalf("session.type = %v, want %q", session["type"], "realtime")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeRealtimeSessionsRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body, model, bifrostErr := normalizeRealtimeClientSecretRequest(
|
||||
json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`),
|
||||
schemas.OpenAI,
|
||||
schemas.RealtimeSessionEndpointSessions,
|
||||
)
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr)
|
||||
}
|
||||
if model != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview")
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
t.Fatalf("failed to unmarshal normalized body: %v", err)
|
||||
}
|
||||
if _, ok := payload["session"]; ok {
|
||||
t.Fatal("legacy sessions endpoint should not forward nested session object")
|
||||
}
|
||||
if payload["model"] != "gpt-4o-realtime-preview" {
|
||||
t.Fatalf("model = %v, want %q", payload["model"], "gpt-4o-realtime-preview")
|
||||
}
|
||||
if payload["voice"] != "alloy" {
|
||||
t.Fatalf("voice = %v, want %q", payload["voice"], "alloy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProviderRealtimeEventSerializesTopLevelClientFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
contentIndex, err := json.Marshal(0)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
audioEndMS, err := json.Marshal(640)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RealtimeEventType("conversation.item.truncate"),
|
||||
ExtraParams: map[string]json.RawMessage{
|
||||
"item_id": json.RawMessage(`"item_123"`),
|
||||
"content_index": contentIndex,
|
||||
"audio_end_ms": audioEndMS,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(out, &payload); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
if payload["type"] != "conversation.item.truncate" {
|
||||
t.Fatalf("type = %v, want %q", payload["type"], "conversation.item.truncate")
|
||||
}
|
||||
if payload["item_id"] != "item_123" {
|
||||
t.Fatalf("item_id = %v, want %q", payload["item_id"], "item_123")
|
||||
}
|
||||
if payload["content_index"] != float64(0) {
|
||||
t.Fatalf("content_index = %v, want 0", payload["content_index"])
|
||||
}
|
||||
if payload["audio_end_ms"] != float64(640) {
|
||||
t.Fatalf("audio_end_ms = %v, want 640", payload["audio_end_ms"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostRealtimeEventParsesTopLevelClientFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.truncate","item_id":"item_123","content_index":0,"audio_end_ms":640}`))
|
||||
if err != nil {
|
||||
t.Fatalf("ToBifrostRealtimeEvent() error = %v", err)
|
||||
}
|
||||
var itemID string
|
||||
if err := json.Unmarshal(event.ExtraParams["item_id"], &itemID); err != nil {
|
||||
t.Fatalf("json.Unmarshal(item_id) error = %v", err)
|
||||
}
|
||||
if itemID != "item_123" {
|
||||
t.Fatalf("item_id = %q, want %q", itemID, "item_123")
|
||||
}
|
||||
var contentIndex int
|
||||
if err := json.Unmarshal(event.ExtraParams["content_index"], &contentIndex); err != nil {
|
||||
t.Fatalf("json.Unmarshal(content_index) error = %v", err)
|
||||
}
|
||||
if contentIndex != 0 {
|
||||
t.Fatalf("content_index = %d, want 0", contentIndex)
|
||||
}
|
||||
var audioEndMS int
|
||||
if err := json.Unmarshal(event.ExtraParams["audio_end_ms"], &audioEndMS); err != nil {
|
||||
t.Fatalf("json.Unmarshal(audio_end_ms) error = %v", err)
|
||||
}
|
||||
if audioEndMS != 640 {
|
||||
t.Fatalf("audio_end_ms = %d, want 640", audioEndMS)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostRealtimeEventParsesCompletedInputAudioTranscript(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.input_audio_transcription.completed","event_id":"evt_123","item_id":"item_123","content_index":0,"transcript":"Who are you?"}`))
|
||||
if err != nil {
|
||||
t.Fatalf("ToBifrostRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var transcript string
|
||||
if err := json.Unmarshal(event.ExtraParams["transcript"], &transcript); err != nil {
|
||||
t.Fatalf("json.Unmarshal(transcript) error = %v", err)
|
||||
}
|
||||
if transcript != "Who are you?" {
|
||||
t.Fatalf("transcript = %q, want %q", transcript, "Who are you?")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostRealtimeEventParsesModernOutputTextDelta(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{
|
||||
"type":"response.output_text.delta",
|
||||
"event_id":"evt_123",
|
||||
"item_id":"item_123",
|
||||
"output_index":0,
|
||||
"content_index":0,
|
||||
"response_id":"resp_123",
|
||||
"delta":"hello"
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatalf("ToBifrostRealtimeEvent() error = %v", err)
|
||||
}
|
||||
if event.Delta == nil || event.Delta.Text != "hello" {
|
||||
t.Fatalf("Delta = %+v, want text=hello", event.Delta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldStartRealtimeTurn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
tests := []struct {
|
||||
name string
|
||||
event *schemas.BifrostRealtimeEvent
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "response create starts turn",
|
||||
event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseCreate},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "audio buffer committed starts turn",
|
||||
event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventInputAudioBufferCommitted},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "response done does not start turn",
|
||||
event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseDone},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil event does not start turn",
|
||||
event: nil,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := provider.ShouldStartRealtimeTurn(tt.event); got != tt.want {
|
||||
t.Fatalf("ShouldStartRealtimeTurn() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProviderRealtimeEventSerializesModernOutputTextDelta(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
outputIndex := 0
|
||||
contentIndex := 0
|
||||
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RealtimeEventType("response.output_text.delta"),
|
||||
Delta: &schemas.RealtimeDelta{
|
||||
Text: "hello",
|
||||
ItemID: "item_123",
|
||||
OutputIdx: &outputIndex,
|
||||
ContentIdx: &contentIndex,
|
||||
ResponseID: "resp_123",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(out, &payload); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
if payload["type"] != "response.output_text.delta" {
|
||||
t.Fatalf("type = %v, want response.output_text.delta", payload["type"])
|
||||
}
|
||||
if payload["delta"] != "hello" {
|
||||
t.Fatalf("delta = %v, want hello", payload["delta"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProviderRealtimeEventSerializesSessionID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventSessionCreated,
|
||||
Session: &schemas.RealtimeSession{
|
||||
ID: "sess_123",
|
||||
Model: "gpt-realtime",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(out, &payload); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
session, ok := payload["session"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("session = %T, want object", payload["session"])
|
||||
}
|
||||
if session["id"] != "sess_123" {
|
||||
t.Fatalf("session.id = %v, want sess_123", session["id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProviderRealtimeEventSerializesMessageItemStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
content := json.RawMessage(`[{"type":"input_audio","transcript":"hello"}]`)
|
||||
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RealtimeEventType("conversation.item.retrieved"),
|
||||
Item: &schemas.RealtimeItem{
|
||||
ID: "item_123",
|
||||
Type: "message",
|
||||
Role: "user",
|
||||
Status: "completed",
|
||||
Content: content,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(out, &payload); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
item, ok := payload["item"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("item = %T, want object", payload["item"])
|
||||
}
|
||||
if item["status"] != "completed" {
|
||||
t.Fatalf("item.status = %v, want completed", item["status"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostRealtimeEventPreservesTopLevelResponsePayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{
|
||||
"type":"response.done",
|
||||
"event_id":"evt_123",
|
||||
"response":{
|
||||
"id":"resp_123",
|
||||
"output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}]
|
||||
}
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatalf("ToBifrostRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var response map[string]any
|
||||
if err := json.Unmarshal(event.ExtraParams["response"], &response); err != nil {
|
||||
t.Fatalf("json.Unmarshal(response) error = %v", err)
|
||||
}
|
||||
if response["id"] != "resp_123" {
|
||||
t.Fatalf("response.id = %v, want resp_123", response["id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProviderRealtimeEventSerializesTopLevelResponsePayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventResponseDone,
|
||||
ExtraParams: map[string]json.RawMessage{
|
||||
"response": json.RawMessage(`{"id":"resp_123","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}]}`),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(out, &payload); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
response, ok := payload["response"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("response = %T, want object", payload["response"])
|
||||
}
|
||||
if response["id"] != "resp_123" {
|
||||
t.Fatalf("response.id = %v, want resp_123", response["id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostRealtimeEventPreservesTopLevelPartPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{
|
||||
"type":"response.content_part.added",
|
||||
"event_id":"evt_123",
|
||||
"item_id":"item_123",
|
||||
"output_index":0,
|
||||
"content_index":0,
|
||||
"part":{
|
||||
"type":"text",
|
||||
"text":"hello"
|
||||
}
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatalf("ToBifrostRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var part map[string]any
|
||||
if err := json.Unmarshal(event.ExtraParams["part"], &part); err != nil {
|
||||
t.Fatalf("json.Unmarshal(part) error = %v", err)
|
||||
}
|
||||
if part["type"] != "text" {
|
||||
t.Fatalf("part.type = %v, want text", part["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProviderRealtimeEventSerializesTopLevelPartPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventResponseContentPartAdded,
|
||||
ExtraParams: map[string]json.RawMessage{
|
||||
"part": json.RawMessage(`{"type":"text","text":"hello"}`),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(out, &payload); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
part, ok := payload["part"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("part = %T, want object", payload["part"])
|
||||
}
|
||||
if part["type"] != "text" {
|
||||
t.Fatalf("part.type = %v, want text", part["type"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRealtimeEventPreservesNestedSessionExtraParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
event, err := schemas.ParseRealtimeEvent([]byte(`{
|
||||
"type":"session.update",
|
||||
"session":{
|
||||
"type":"realtime",
|
||||
"model":"gpt-4o-realtime-preview",
|
||||
"output_modalities":["text"]
|
||||
}
|
||||
}`))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseRealtimeEvent() error = %v", err)
|
||||
}
|
||||
if event.Session == nil {
|
||||
t.Fatal("expected session to be parsed")
|
||||
}
|
||||
var outputModalities []string
|
||||
if err := json.Unmarshal(event.Session.ExtraParams["output_modalities"], &outputModalities); err != nil {
|
||||
t.Fatalf("json.Unmarshal(output_modalities) error = %v", err)
|
||||
}
|
||||
if len(outputModalities) != 1 || outputModalities[0] != "text" {
|
||||
t.Fatalf("output_modalities = %v, want [text]", outputModalities)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProviderRealtimeEventSerializesNestedSessionExtraParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventSessionUpdate,
|
||||
Session: &schemas.RealtimeSession{
|
||||
Model: "gpt-4o-realtime-preview",
|
||||
ExtraParams: map[string]json.RawMessage{
|
||||
"type": json.RawMessage(`"realtime"`),
|
||||
"output_modalities": json.RawMessage(`["text"]`),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Type string `json:"type"`
|
||||
Session map[string]any `json:"session"`
|
||||
}
|
||||
if err := json.Unmarshal(out, &payload); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
if payload.Type != "session.update" {
|
||||
t.Fatalf("type = %q, want %q", payload.Type, "session.update")
|
||||
}
|
||||
if payload.Session["type"] != "realtime" {
|
||||
t.Fatalf("session.type = %v, want realtime", payload.Session["type"])
|
||||
}
|
||||
outputModalities, ok := payload.Session["output_modalities"].([]any)
|
||||
if !ok || len(outputModalities) != 1 || outputModalities[0] != "text" {
|
||||
t.Fatalf("session.output_modalities = %v, want [text]", payload.Session["output_modalities"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProviderRealtimeEventOmitsReadOnlySessionFieldsOnSessionUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := &OpenAIProvider{}
|
||||
out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{
|
||||
Type: schemas.RTEventSessionUpdate,
|
||||
Session: &schemas.RealtimeSession{
|
||||
ID: "sess_123",
|
||||
Model: "gpt-realtime",
|
||||
ExtraParams: map[string]json.RawMessage{
|
||||
"type": json.RawMessage(`"realtime"`),
|
||||
"object": json.RawMessage(`"realtime.session"`),
|
||||
"expires_at": json.RawMessage(`1774614381`),
|
||||
"client_secret": json.RawMessage(`{"value":"secret"}`),
|
||||
"modalities": json.RawMessage(`["text","audio"]`),
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ToProviderRealtimeEvent() error = %v", err)
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Session map[string]any `json:"session"`
|
||||
}
|
||||
if err := json.Unmarshal(out, &payload); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
for _, key := range []string{"id", "object", "expires_at", "client_secret"} {
|
||||
if _, ok := payload.Session[key]; ok {
|
||||
t.Fatalf("session.%s unexpectedly present in session.update payload", key)
|
||||
}
|
||||
}
|
||||
if payload.Session["type"] != "realtime" {
|
||||
t.Fatalf("session.type = %v, want realtime", payload.Session["type"])
|
||||
}
|
||||
if payload.Session["model"] != "gpt-realtime" {
|
||||
t.Fatalf("session.model = %v, want gpt-realtime", payload.Session["model"])
|
||||
}
|
||||
}
|
||||
376
core/providers/openai/responses.go
Normal file
376
core/providers/openai/responses.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostResponsesRequest converts an OpenAI responses request to Bifrost format
|
||||
func (resp *OpenAIResponsesRequest) ToBifrostResponsesRequest(ctx *schemas.BifrostContext) *schemas.BifrostResponsesRequest {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
defaultProvider := schemas.OpenAI
|
||||
|
||||
// for requests coming from azure sdk without provider prefix, we need to set the default provider to azure
|
||||
if ctx != nil {
|
||||
if isAzureUser, ok := ctx.Value(schemas.BifrostContextKeyIsAzureUserAgent).(bool); ok && isAzureUser {
|
||||
defaultProvider = schemas.Azure
|
||||
}
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(resp.Model, utils.CheckAndSetDefaultProvider(ctx, defaultProvider))
|
||||
|
||||
input := resp.Input.OpenAIResponsesRequestInputArray
|
||||
if len(input) == 0 {
|
||||
input = []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: resp.Input.OpenAIResponsesRequestInputStr},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &schemas.BifrostResponsesRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: input,
|
||||
Params: &resp.ResponsesParameters,
|
||||
Fallbacks: schemas.ParseFallbacks(resp.Fallbacks),
|
||||
}
|
||||
}
|
||||
|
||||
// ToOpenAIResponsesRequest converts a Bifrost responses request to OpenAI format
|
||||
func ToOpenAIResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *OpenAIResponsesRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var messages []schemas.ResponsesMessage
|
||||
// OpenAI models (except for gpt-oss) do not support reasoning content blocks, so we need to convert them to summaries, if there are any
|
||||
// OpenAI also doesn't support compaction content blocks, so we need to convert them to text blocks
|
||||
messages = make([]schemas.ResponsesMessage, 0, len(bifrostReq.Input))
|
||||
for _, message := range bifrostReq.Input {
|
||||
// First, check if message has compaction content blocks and convert them to text
|
||||
if message.Content != nil && len(message.Content.ContentBlocks) > 0 {
|
||||
hasCompaction := false
|
||||
for _, block := range message.Content.ContentBlocks {
|
||||
if block.Type == schemas.ResponsesOutputMessageContentTypeCompaction {
|
||||
hasCompaction = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasCompaction {
|
||||
// Create a new message with converted content blocks
|
||||
newMessage := message
|
||||
newContentBlocks := make([]schemas.ResponsesMessageContentBlock, 0, len(message.Content.ContentBlocks))
|
||||
|
||||
for _, block := range message.Content.ContentBlocks {
|
||||
if block.Type == schemas.ResponsesOutputMessageContentTypeCompaction {
|
||||
// Convert compaction block to text block
|
||||
if block.ResponsesOutputMessageContentCompaction != nil && block.ResponsesOutputMessageContentCompaction.Summary != "" {
|
||||
newContentBlocks = append(newContentBlocks, schemas.ResponsesMessageContentBlock{
|
||||
Type: schemas.ResponsesOutputMessageContentTypeText,
|
||||
Text: schemas.Ptr(block.ResponsesOutputMessageContentCompaction.Summary),
|
||||
})
|
||||
}
|
||||
// If summary is empty, skip the block entirely
|
||||
} else {
|
||||
// Keep non-compaction blocks as-is
|
||||
newContentBlocks = append(newContentBlocks, block)
|
||||
}
|
||||
}
|
||||
|
||||
// Only update if we have blocks remaining after conversion
|
||||
if len(newContentBlocks) > 0 {
|
||||
newMessage.Content = &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: newContentBlocks,
|
||||
}
|
||||
message = newMessage
|
||||
} else {
|
||||
// If all blocks were compaction with empty summaries, skip message
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if message.ResponsesReasoning != nil {
|
||||
isGptOss := strings.Contains(bifrostReq.Model, "gpt-oss")
|
||||
isReasoning := isOpenAIReasoningModel(bifrostReq.Model)
|
||||
|
||||
// For non-gpt-oss models, skip reasoning-only messages that have content blocks but no summaries.
|
||||
// For non-reasoning models (e.g., gpt-4o), also skip when EncryptedContent is present since
|
||||
// these models don't produce encrypted reasoning — any encrypted content is cross-provider
|
||||
// (e.g., Gemini ThoughtSignatures) and cannot be decrypted by OpenAI.
|
||||
if len(message.ResponsesReasoning.Summary) == 0 &&
|
||||
message.Content != nil &&
|
||||
len(message.Content.ContentBlocks) > 0 &&
|
||||
!isGptOss &&
|
||||
(message.ResponsesReasoning.EncryptedContent == nil || !isReasoning) {
|
||||
continue
|
||||
}
|
||||
|
||||
// If the message has summaries but no content blocks and the model is gpt-oss, then convert the summaries to content blocks
|
||||
if len(message.ResponsesReasoning.Summary) > 0 && isGptOss &&
|
||||
(message.Content == nil || len(message.Content.ContentBlocks) == 0) {
|
||||
var newMessage schemas.ResponsesMessage
|
||||
newMessage.ID = message.ID
|
||||
newMessage.Type = message.Type
|
||||
newMessage.Status = message.Status
|
||||
newMessage.Role = message.Role
|
||||
|
||||
// Convert summaries to content blocks
|
||||
contentBlocks := make([]schemas.ResponsesMessageContentBlock, 0, len(message.ResponsesReasoning.Summary))
|
||||
for _, summary := range message.ResponsesReasoning.Summary {
|
||||
contentBlocks = append(contentBlocks, schemas.ResponsesMessageContentBlock{
|
||||
Type: schemas.ResponsesOutputMessageContentTypeReasoning,
|
||||
Text: schemas.Ptr(summary.Text),
|
||||
})
|
||||
}
|
||||
newMessage.Content = &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: contentBlocks,
|
||||
}
|
||||
messages = append(messages, newMessage)
|
||||
} else {
|
||||
// Clone the embedded pointer to avoid mutating the original input
|
||||
reasoningCopy := *message.ResponsesReasoning
|
||||
message.ResponsesReasoning = &reasoningCopy
|
||||
// OpenAI's Responses API does not accept 'role' on reasoning items
|
||||
message.Role = nil
|
||||
// Strip cross-provider encrypted content that non-reasoning models cannot decrypt.
|
||||
// Reasoning models (o1/o3/o4/GPT-5) may use EncryptedContent for multi-turn state.
|
||||
if !isReasoning {
|
||||
message.ResponsesReasoning.EncryptedContent = nil
|
||||
}
|
||||
messages = append(messages, message)
|
||||
}
|
||||
} else if message.ResponsesToolMessage != nil &&
|
||||
message.ResponsesToolMessage.Action != nil &&
|
||||
message.ResponsesToolMessage.Action.ResponsesComputerToolCallAction != nil {
|
||||
action := message.ResponsesToolMessage.Action.ResponsesComputerToolCallAction
|
||||
if action.Type == "zoom" || action.Region != nil {
|
||||
// Copy action and modify
|
||||
newAction := *action
|
||||
newAction.Region = nil
|
||||
if newAction.Type == "zoom" {
|
||||
newAction.Type = "screenshot"
|
||||
}
|
||||
|
||||
actionStructCopy := *message.ResponsesToolMessage.Action
|
||||
actionStructCopy.ResponsesComputerToolCallAction = &newAction
|
||||
|
||||
toolMsgCopy := *message.ResponsesToolMessage
|
||||
toolMsgCopy.Action = &actionStructCopy
|
||||
|
||||
message.ResponsesToolMessage = &toolMsgCopy
|
||||
}
|
||||
|
||||
messages = append(messages, message)
|
||||
} else {
|
||||
messages = append(messages, message)
|
||||
}
|
||||
}
|
||||
// Updating params
|
||||
params := bifrostReq.Params
|
||||
// Create the responses request with properly mapped parameters
|
||||
req := &OpenAIResponsesRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputArray: messages,
|
||||
},
|
||||
}
|
||||
|
||||
if params != nil {
|
||||
req.ResponsesParameters = *params
|
||||
if req.ResponsesParameters.MaxOutputTokens != nil && *req.ResponsesParameters.MaxOutputTokens < MinMaxCompletionTokens {
|
||||
req.ResponsesParameters.MaxOutputTokens = schemas.Ptr(MinMaxCompletionTokens)
|
||||
}
|
||||
// Drop user field if it exceeds OpenAI's 64 character limit
|
||||
req.ResponsesParameters.User = SanitizeUserField(req.ResponsesParameters.User)
|
||||
|
||||
// Handle reasoning parameter: OpenAI uses effort-based reasoning
|
||||
// Priority: effort (native) > max_tokens (estimated)
|
||||
if req.ResponsesParameters.Reasoning != nil {
|
||||
// Clone the Reasoning pointer to avoid mutating the original params
|
||||
reasoningCopy := *req.ResponsesParameters.Reasoning
|
||||
req.ResponsesParameters.Reasoning = &reasoningCopy
|
||||
if req.ResponsesParameters.Reasoning.Effort != nil {
|
||||
// Native field is provided, use it (and clear max_tokens)
|
||||
effort := *req.ResponsesParameters.Reasoning.Effort
|
||||
// Convert "minimal" to "low"; cap "xhigh"/"max" to "high" — OpenAI tops out at high.
|
||||
switch effort {
|
||||
case "minimal":
|
||||
req.ResponsesParameters.Reasoning.Effort = schemas.Ptr("low")
|
||||
case "xhigh", "max":
|
||||
req.ResponsesParameters.Reasoning.Effort = schemas.Ptr("high")
|
||||
}
|
||||
// Clear max_tokens since OpenAI doesn't use it
|
||||
req.ResponsesParameters.Reasoning.MaxTokens = nil
|
||||
} else if req.ResponsesParameters.Reasoning.MaxTokens != nil {
|
||||
// Estimate effort from max_tokens
|
||||
maxTokens := *req.ResponsesParameters.Reasoning.MaxTokens
|
||||
maxOutputTokens := utils.GetMaxOutputTokensOrDefault(req.Model, DefaultCompletionMaxTokens)
|
||||
if req.ResponsesParameters.MaxOutputTokens != nil {
|
||||
maxOutputTokens = *req.ResponsesParameters.MaxOutputTokens
|
||||
}
|
||||
effort := utils.GetReasoningEffortFromBudgetTokens(maxTokens, MinReasoningMaxTokens, maxOutputTokens)
|
||||
req.ResponsesParameters.Reasoning.Effort = schemas.Ptr(effort)
|
||||
// Clear max_tokens since OpenAI doesn't use it
|
||||
req.ResponsesParameters.Reasoning.MaxTokens = nil
|
||||
}
|
||||
|
||||
// summary:"none" is Anthropic-specific (maps to display:"omitted"); strip it for OpenAI.
|
||||
if req.ResponsesParameters.Reasoning.Summary != nil && *req.ResponsesParameters.Reasoning.Summary == "none" {
|
||||
req.ResponsesParameters.Reasoning.Summary = nil
|
||||
}
|
||||
|
||||
// Handle xAI-specific parameter filtering
|
||||
// Only grok-3-mini supports reasoning_effort
|
||||
if bifrostReq.Provider == schemas.XAI &&
|
||||
schemas.IsGrokReasoningModel(bifrostReq.Model) &&
|
||||
!strings.Contains(bifrostReq.Model, "grok-3-mini") {
|
||||
// Clear reasoning_effort for non-grok-3-mini xAI reasoning models
|
||||
req.ResponsesParameters.Reasoning.Effort = nil
|
||||
}
|
||||
|
||||
// Handle OpenAI-specific parameter filtering
|
||||
// Only o1/o3 series models support reasoning.effort
|
||||
// Regular models like gpt-4o, gpt-4, gpt-3.5-turbo don't support it
|
||||
if bifrostReq.Provider == schemas.OpenAI && !isOpenAIReasoningModel(bifrostReq.Model) {
|
||||
// Clear reasoning for non-reasoning OpenAI models to avoid API errors
|
||||
req.ResponsesParameters.Reasoning = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Strip top_p for OpenAI reasoning models (o1/o3 series) which reject it
|
||||
// GPT-5.x accept top_p when reasoning.effort is "none" (defaults to "none" when omitted)
|
||||
if isOpenAIReasoningModel(bifrostReq.Model) {
|
||||
stripTopP := true
|
||||
_, parsedModel := schemas.ParseModelString(bifrostReq.Model, schemas.OpenAI)
|
||||
modelLower := strings.ToLower(parsedModel)
|
||||
effort := ""
|
||||
if req.ResponsesParameters.Reasoning != nil &&
|
||||
req.ResponsesParameters.Reasoning.Effort != nil {
|
||||
effort = *req.ResponsesParameters.Reasoning.Effort
|
||||
}
|
||||
// GPT-5.x: reasoning defaults to "none" when omitted, and top_p is allowed in that case
|
||||
// Exception: -pro and -codex variants always reason (no "none" mode), so top_p must be stripped
|
||||
if strings.HasPrefix(modelLower, "gpt-5.") &&
|
||||
(effort == "" || effort == "none") &&
|
||||
!strings.Contains(modelLower, "-pro") &&
|
||||
!strings.Contains(modelLower, "-codex") {
|
||||
stripTopP = false
|
||||
}
|
||||
if stripTopP {
|
||||
req.ResponsesParameters.TopP = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize function tool parameters for deterministic JSON serialization.
|
||||
// We must copy the Tools slice since it shares the backing array with bifrostReq.Params.Tools.
|
||||
if len(req.Tools) > 0 {
|
||||
normalizedTools := make([]schemas.ResponsesTool, len(req.Tools))
|
||||
copy(normalizedTools, req.Tools)
|
||||
for i, tool := range normalizedTools {
|
||||
if tool.Type == schemas.ResponsesToolTypeFunction &&
|
||||
tool.ResponsesToolFunction != nil &&
|
||||
tool.ResponsesToolFunction.Parameters != nil {
|
||||
funcCopy := *tool.ResponsesToolFunction
|
||||
funcCopy.Parameters = tool.ResponsesToolFunction.Parameters.Normalized()
|
||||
normalizedTools[i].ResponsesToolFunction = &funcCopy
|
||||
}
|
||||
}
|
||||
req.Tools = normalizedTools
|
||||
}
|
||||
|
||||
// Filter out tools that OpenAI doesn't support
|
||||
req.filterUnsupportedTools()
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// filterUnsupportedTools removes tool types that OpenAI doesn't support
|
||||
func (resp *OpenAIResponsesRequest) filterUnsupportedTools() {
|
||||
if len(resp.Tools) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Define OpenAI-supported tool types
|
||||
supportedTypes := map[schemas.ResponsesToolType]bool{
|
||||
schemas.ResponsesToolTypeFunction: true,
|
||||
schemas.ResponsesToolTypeFileSearch: true,
|
||||
schemas.ResponsesToolTypeComputerUsePreview: true,
|
||||
schemas.ResponsesToolTypeWebSearch: true,
|
||||
schemas.ResponsesToolTypeWebFetch: true,
|
||||
schemas.ResponsesToolTypeMCP: true,
|
||||
schemas.ResponsesToolTypeCodeInterpreter: true,
|
||||
schemas.ResponsesToolTypeImageGeneration: true,
|
||||
schemas.ResponsesToolTypeLocalShell: true,
|
||||
schemas.ResponsesToolTypeCustom: true,
|
||||
schemas.ResponsesToolTypeWebSearchPreview: true,
|
||||
schemas.ResponsesToolTypeMemory: true,
|
||||
schemas.ResponsesToolTypeToolSearch: true,
|
||||
}
|
||||
|
||||
// Filter tools to only include supported types
|
||||
filteredTools := make([]schemas.ResponsesTool, 0, len(resp.Tools))
|
||||
for _, tool := range resp.Tools {
|
||||
if supportedTypes[tool.Type] {
|
||||
// check for computer use preview
|
||||
if tool.Type == schemas.ResponsesToolTypeComputerUsePreview && tool.ResponsesToolComputerUsePreview != nil && tool.ResponsesToolComputerUsePreview.EnableZoom != nil {
|
||||
newTool := tool
|
||||
newComputerUse := &schemas.ResponsesToolComputerUsePreview{
|
||||
DisplayHeight: tool.ResponsesToolComputerUsePreview.DisplayHeight,
|
||||
DisplayWidth: tool.ResponsesToolComputerUsePreview.DisplayWidth,
|
||||
Environment: tool.ResponsesToolComputerUsePreview.Environment,
|
||||
// EnableZoom is intentionally omitted (nil) - OpenAI doesn't support it
|
||||
}
|
||||
newTool.ResponsesToolComputerUsePreview = newComputerUse
|
||||
filteredTools = append(filteredTools, newTool)
|
||||
} else if tool.Type == schemas.ResponsesToolTypeWebSearch && tool.ResponsesToolWebSearch != nil {
|
||||
// Create a proper deep copy with new nested pointers to avoid mutating the original
|
||||
newTool := tool
|
||||
newWebSearch := &schemas.ResponsesToolWebSearch{}
|
||||
|
||||
// MaxUses is intentionally omitted (nil) - OpenAI doesn't support it
|
||||
|
||||
// Handle Filters: OpenAI doesn't support BlockedDomains or TimeRangeFilter
|
||||
if tool.ResponsesToolWebSearch.Filters != nil {
|
||||
hasAllowedDomains := len(tool.ResponsesToolWebSearch.Filters.AllowedDomains) > 0
|
||||
|
||||
if hasAllowedDomains {
|
||||
// Keep only AllowedDomains (copy the slice to avoid sharing)
|
||||
newWebSearch.Filters = &schemas.ResponsesToolWebSearchFilters{
|
||||
AllowedDomains: append([]string(nil), tool.ResponsesToolWebSearch.Filters.AllowedDomains...),
|
||||
// BlockedDomains and TimeRangeFilter are intentionally omitted - OpenAI doesn't support it
|
||||
}
|
||||
}
|
||||
// If only blocked domains or both empty, Filters stays nil
|
||||
}
|
||||
|
||||
// Copy other fields if they exist
|
||||
if tool.ResponsesToolWebSearch.UserLocation != nil {
|
||||
newWebSearch.UserLocation = tool.ResponsesToolWebSearch.UserLocation
|
||||
}
|
||||
if tool.ResponsesToolWebSearch.SearchContextSize != nil {
|
||||
newWebSearch.SearchContextSize = tool.ResponsesToolWebSearch.SearchContextSize
|
||||
}
|
||||
|
||||
newTool.ResponsesToolWebSearch = newWebSearch
|
||||
filteredTools = append(filteredTools, newTool)
|
||||
} else {
|
||||
filteredTools = append(filteredTools, tool)
|
||||
}
|
||||
}
|
||||
}
|
||||
resp.Tools = filteredTools
|
||||
}
|
||||
|
||||
872
core/providers/openai/responses_marshal_test.go
Normal file
872
core/providers/openai/responses_marshal_test.go
Normal file
@@ -0,0 +1,872 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestOpenAIResponsesRequest_MarshalJSON_ReasoningMaxTokensAbsent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request *OpenAIResponsesRequest
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "reasoning with MaxTokens set should omit max_tokens from output",
|
||||
request: &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputStr: schemas.Ptr("test input"),
|
||||
},
|
||||
ResponsesParameters: schemas.ResponsesParameters{
|
||||
Reasoning: &schemas.ResponsesParametersReasoning{
|
||||
Effort: schemas.Ptr("high"),
|
||||
MaxTokens: schemas.Ptr(1000),
|
||||
Summary: schemas.Ptr("detailed"),
|
||||
},
|
||||
},
|
||||
},
|
||||
description: "When Reasoning.MaxTokens is set, it should be absent from JSON output",
|
||||
},
|
||||
{
|
||||
name: "reasoning with all fields set should omit only max_tokens",
|
||||
request: &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputStr: schemas.Ptr("test"),
|
||||
},
|
||||
ResponsesParameters: schemas.ResponsesParameters{
|
||||
Reasoning: &schemas.ResponsesParametersReasoning{
|
||||
Effort: schemas.Ptr("medium"),
|
||||
GenerateSummary: schemas.Ptr("auto"),
|
||||
Summary: schemas.Ptr("concise"),
|
||||
MaxTokens: schemas.Ptr(500),
|
||||
},
|
||||
},
|
||||
},
|
||||
description: "All reasoning fields except MaxTokens should be present in output",
|
||||
},
|
||||
{
|
||||
name: "reasoning with nil MaxTokens should not include max_tokens",
|
||||
request: &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputStr: schemas.Ptr("test"),
|
||||
},
|
||||
ResponsesParameters: schemas.ResponsesParameters{
|
||||
Reasoning: &schemas.ResponsesParametersReasoning{
|
||||
Effort: schemas.Ptr("low"),
|
||||
MaxTokens: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
description: "When Reasoning.MaxTokens is nil, max_tokens should not appear in output",
|
||||
},
|
||||
{
|
||||
name: "nil reasoning should not include reasoning field",
|
||||
request: &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputStr: schemas.Ptr("test"),
|
||||
},
|
||||
ResponsesParameters: schemas.ResponsesParameters{
|
||||
Reasoning: nil,
|
||||
},
|
||||
},
|
||||
description: "When Reasoning is nil, reasoning field should not appear in output",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
jsonBytes, err := tt.request.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal JSON: %v", err)
|
||||
}
|
||||
|
||||
// Parse the JSON to check structure
|
||||
var jsonMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil {
|
||||
t.Fatalf("Failed to unmarshal marshaled JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check that reasoning.max_tokens is absent
|
||||
if reasoning, ok := jsonMap["reasoning"].(map[string]interface{}); ok {
|
||||
if maxTokens, exists := reasoning["max_tokens"]; exists {
|
||||
t.Errorf("%s: reasoning.max_tokens should be absent from JSON output, but found: %v", tt.description, maxTokens)
|
||||
}
|
||||
|
||||
// Verify other reasoning fields are present when they should be
|
||||
if tt.request.Reasoning != nil {
|
||||
if tt.request.Reasoning.Effort != nil {
|
||||
if _, exists := reasoning["effort"]; !exists {
|
||||
t.Error("reasoning.effort should be present in output")
|
||||
}
|
||||
}
|
||||
if tt.request.Reasoning.Summary != nil {
|
||||
if _, exists := reasoning["summary"]; !exists {
|
||||
t.Error("reasoning.summary should be present in output")
|
||||
}
|
||||
}
|
||||
if tt.request.Reasoning.GenerateSummary != nil {
|
||||
if _, exists := reasoning["generate_summary"]; !exists {
|
||||
t.Error("reasoning.generate_summary should be present in output")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if tt.request.Reasoning != nil {
|
||||
// If reasoning is set, it should appear in JSON (unless all fields are nil/omitted)
|
||||
if tt.request.Reasoning.Effort != nil || tt.request.Reasoning.Summary != nil || tt.request.Reasoning.GenerateSummary != nil {
|
||||
t.Error("reasoning field should be present in JSON when Reasoning is set with non-nil fields")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesRequest_MarshalJSON_InputStringForm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request *OpenAIResponsesRequest
|
||||
expected string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "input as string is correctly marshaled",
|
||||
request: &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputStr: schemas.Ptr("Hello, world!"),
|
||||
},
|
||||
},
|
||||
expected: "Hello, world!",
|
||||
description: "Input field should be marshaled as a string when OpenAIResponsesRequestInputStr is set",
|
||||
},
|
||||
{
|
||||
name: "input as empty string is correctly marshaled",
|
||||
request: &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputStr: schemas.Ptr(""),
|
||||
},
|
||||
},
|
||||
expected: "",
|
||||
description: "Input field should be marshaled as empty string when set to empty string",
|
||||
},
|
||||
{
|
||||
name: "input as string with special characters",
|
||||
request: &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputStr: schemas.Ptr(`{"key": "value"}`),
|
||||
},
|
||||
},
|
||||
expected: `{"key": "value"}`,
|
||||
description: "Input field should correctly marshal strings with special characters",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
jsonBytes, err := tt.request.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal JSON: %v", err)
|
||||
}
|
||||
|
||||
// Parse the JSON to check input field
|
||||
var jsonMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil {
|
||||
t.Fatalf("Failed to unmarshal marshaled JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check that input is a string
|
||||
inputValue, exists := jsonMap["input"]
|
||||
if !exists {
|
||||
t.Fatalf("%s: input field should be present in JSON", tt.description)
|
||||
}
|
||||
|
||||
inputStr, ok := inputValue.(string)
|
||||
if !ok {
|
||||
t.Errorf("%s: input field should be a string, got type %T", tt.description, inputValue)
|
||||
}
|
||||
|
||||
if inputStr != tt.expected {
|
||||
t.Errorf("%s: expected input to be %q, got %q", tt.description, tt.expected, inputStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesRequest_MarshalJSON_InputArrayForm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request *OpenAIResponsesRequest
|
||||
validate func(t *testing.T, inputValue interface{})
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "input as array is correctly marshaled",
|
||||
request: &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("Hello"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, inputValue interface{}) {
|
||||
inputArray, ok := inputValue.([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected input to be an array, got type %T", inputValue)
|
||||
}
|
||||
if len(inputArray) != 1 {
|
||||
t.Errorf("Expected 1 message in array, got %d", len(inputArray))
|
||||
}
|
||||
},
|
||||
description: "Input field should be marshaled as an array when OpenAIResponsesRequestInputArray is set",
|
||||
},
|
||||
{
|
||||
name: "input as empty array is correctly marshaled",
|
||||
request: &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, inputValue interface{}) {
|
||||
inputArray, ok := inputValue.([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected input to be an array, got type %T", inputValue)
|
||||
}
|
||||
if len(inputArray) != 0 {
|
||||
t.Errorf("Expected empty array, got %d elements", len(inputArray))
|
||||
}
|
||||
},
|
||||
description: "Input field should be marshaled as empty array when set to empty array",
|
||||
},
|
||||
{
|
||||
name: "input as array with multiple messages",
|
||||
request: &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("You are a helpful assistant."),
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("What is 2+2?"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
validate: func(t *testing.T, inputValue interface{}) {
|
||||
inputArray, ok := inputValue.([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected input to be an array, got type %T", inputValue)
|
||||
}
|
||||
if len(inputArray) != 2 {
|
||||
t.Errorf("Expected 2 messages in array, got %d", len(inputArray))
|
||||
}
|
||||
},
|
||||
description: "Input field should correctly marshal arrays with multiple messages",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
jsonBytes, err := tt.request.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal JSON: %v", err)
|
||||
}
|
||||
|
||||
// Parse the JSON to check input field
|
||||
var jsonMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil {
|
||||
t.Fatalf("Failed to unmarshal marshaled JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check that input is present
|
||||
inputValue, exists := jsonMap["input"]
|
||||
if !exists {
|
||||
t.Fatalf("%s: input field should be present in JSON", tt.description)
|
||||
}
|
||||
|
||||
// Validate using the provided function
|
||||
tt.validate(t, inputValue)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToOpenAIResponsesRequest_FireworksPreservesNativeFields(t *testing.T) {
|
||||
bifrostReq := &schemas.BifrostResponsesRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: "accounts/fireworks/models/deepseek-v3p2",
|
||||
Input: []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("hello"),
|
||||
},
|
||||
},
|
||||
},
|
||||
Params: &schemas.ResponsesParameters{
|
||||
PreviousResponseID: schemas.Ptr("resp_previous"),
|
||||
MaxToolCalls: schemas.Ptr(2),
|
||||
Store: schemas.Ptr(true),
|
||||
},
|
||||
}
|
||||
|
||||
request := ToOpenAIResponsesRequest(bifrostReq)
|
||||
if request == nil {
|
||||
t.Fatal("expected non-nil request")
|
||||
}
|
||||
|
||||
jsonBytes, err := request.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal responses request: %v", err)
|
||||
}
|
||||
|
||||
var jsonMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil {
|
||||
t.Fatalf("failed to parse marshaled JSON: %v", err)
|
||||
}
|
||||
|
||||
if got, ok := jsonMap["previous_response_id"].(string); !ok || got != "resp_previous" {
|
||||
t.Fatalf("expected previous_response_id to be preserved, got %#v", jsonMap["previous_response_id"])
|
||||
}
|
||||
if got, ok := jsonMap["max_tool_calls"].(float64); !ok || got != 2 {
|
||||
t.Fatalf("expected max_tool_calls to be preserved, got %#v", jsonMap["max_tool_calls"])
|
||||
}
|
||||
if got, ok := jsonMap["store"].(bool); !ok || !got {
|
||||
t.Fatalf("expected store=true to be preserved, got %#v", jsonMap["store"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesRequest_MarshalJSON_FieldShadowingBehavior(t *testing.T) {
|
||||
// This test verifies that the field shadowing pattern works correctly
|
||||
// by ensuring that the aux struct properly shadows Input and Reasoning fields
|
||||
t.Run("field shadowing preserves other fields", func(t *testing.T) {
|
||||
request := &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputStr: schemas.Ptr("test input"),
|
||||
},
|
||||
ResponsesParameters: schemas.ResponsesParameters{
|
||||
MaxOutputTokens: schemas.Ptr(100),
|
||||
Temperature: schemas.Ptr(0.7),
|
||||
Reasoning: &schemas.ResponsesParametersReasoning{
|
||||
Effort: schemas.Ptr("high"),
|
||||
MaxTokens: schemas.Ptr(500), // This should be omitted
|
||||
Summary: schemas.Ptr("detailed"),
|
||||
},
|
||||
},
|
||||
Stream: schemas.Ptr(true),
|
||||
Fallbacks: []string{"fallback1", "fallback2"},
|
||||
}
|
||||
|
||||
jsonBytes, err := request.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal JSON: %v", err)
|
||||
}
|
||||
|
||||
var jsonMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(jsonBytes, &jsonMap); err != nil {
|
||||
t.Fatalf("Failed to unmarshal marshaled JSON: %v", err)
|
||||
}
|
||||
|
||||
// Verify base fields are present
|
||||
if jsonMap["model"] != "gpt-4o" {
|
||||
t.Errorf("Expected model to be 'gpt-4o', got %v", jsonMap["model"])
|
||||
}
|
||||
|
||||
if jsonMap["stream"] != true {
|
||||
t.Errorf("Expected stream to be true, got %v", jsonMap["stream"])
|
||||
}
|
||||
|
||||
fallbacks, ok := jsonMap["fallbacks"].([]interface{})
|
||||
if !ok || len(fallbacks) != 2 {
|
||||
t.Errorf("Expected fallbacks to have 2 elements, got %v", jsonMap["fallbacks"])
|
||||
}
|
||||
|
||||
// Verify ResponsesParameters fields are present
|
||||
if jsonMap["max_output_tokens"] != float64(100) {
|
||||
t.Errorf("Expected max_output_tokens to be 100, got %v", jsonMap["max_output_tokens"])
|
||||
}
|
||||
|
||||
if jsonMap["temperature"] != 0.7 {
|
||||
t.Errorf("Expected temperature to be 0.7, got %v", jsonMap["temperature"])
|
||||
}
|
||||
|
||||
// Verify reasoning.max_tokens is absent
|
||||
if reasoning, ok := jsonMap["reasoning"].(map[string]interface{}); ok {
|
||||
if _, exists := reasoning["max_tokens"]; exists {
|
||||
t.Error("reasoning.max_tokens should be absent from JSON output")
|
||||
}
|
||||
if reasoning["effort"] != "high" {
|
||||
t.Errorf("Expected reasoning.effort to be 'high', got %v", reasoning["effort"])
|
||||
}
|
||||
if reasoning["summary"] != "detailed" {
|
||||
t.Errorf("Expected reasoning.summary to be 'detailed', got %v", reasoning["summary"])
|
||||
}
|
||||
} else {
|
||||
t.Error("reasoning field should be present in JSON")
|
||||
}
|
||||
|
||||
// Verify input is correctly marshaled
|
||||
if jsonMap["input"] != "test input" {
|
||||
t.Errorf("Expected input to be 'test input', got %v", jsonMap["input"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesRequest_MarshalJSON_RoundTrip(t *testing.T) {
|
||||
// Test that marshaling and unmarshaling preserves all fields except reasoning.max_tokens
|
||||
t.Run("round trip preserves fields except reasoning.max_tokens", func(t *testing.T) {
|
||||
original := &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("Test message"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ResponsesParameters: schemas.ResponsesParameters{
|
||||
MaxOutputTokens: schemas.Ptr(200),
|
||||
Temperature: schemas.Ptr(0.8),
|
||||
Reasoning: &schemas.ResponsesParametersReasoning{
|
||||
Effort: schemas.Ptr("medium"),
|
||||
MaxTokens: schemas.Ptr(1000), // Should be omitted
|
||||
Summary: schemas.Ptr("auto"),
|
||||
},
|
||||
},
|
||||
Stream: schemas.Ptr(false),
|
||||
}
|
||||
|
||||
// Marshal
|
||||
jsonBytes, err := original.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify reasoning.max_tokens is absent in the JSON string
|
||||
jsonStr := string(jsonBytes)
|
||||
if strings.Contains(jsonStr, `"max_tokens"`) {
|
||||
// Check if it's inside reasoning object
|
||||
if strings.Contains(jsonStr, `"reasoning"`) {
|
||||
// Parse to verify it's not in reasoning
|
||||
var jsonMap map[string]interface{}
|
||||
if err := json.Unmarshal(jsonBytes, &jsonMap); err == nil {
|
||||
if reasoning, ok := jsonMap["reasoning"].(map[string]interface{}); ok {
|
||||
if _, exists := reasoning["max_tokens"]; exists {
|
||||
t.Error("reasoning.max_tokens should not be present in marshaled JSON")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var unmarshaled OpenAIResponsesRequest
|
||||
if err := sonic.Unmarshal(jsonBytes, &unmarshaled); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields are preserved
|
||||
if unmarshaled.Model != original.Model {
|
||||
t.Errorf("Model not preserved: expected %q, got %q", original.Model, unmarshaled.Model)
|
||||
}
|
||||
|
||||
if unmarshaled.Stream == nil || *unmarshaled.Stream != *original.Stream {
|
||||
t.Error("Stream not preserved")
|
||||
}
|
||||
|
||||
if unmarshaled.MaxOutputTokens == nil || *unmarshaled.MaxOutputTokens != *original.MaxOutputTokens {
|
||||
t.Error("MaxOutputTokens not preserved")
|
||||
}
|
||||
|
||||
if unmarshaled.Temperature == nil || *unmarshaled.Temperature != *original.Temperature {
|
||||
t.Error("Temperature not preserved")
|
||||
}
|
||||
|
||||
// Verify reasoning fields except MaxTokens
|
||||
if unmarshaled.Reasoning == nil {
|
||||
t.Fatal("Reasoning should be present")
|
||||
}
|
||||
if unmarshaled.Reasoning.Effort == nil || *unmarshaled.Reasoning.Effort != *original.Reasoning.Effort {
|
||||
t.Error("Reasoning.Effort not preserved")
|
||||
}
|
||||
if unmarshaled.Reasoning.Summary == nil || *unmarshaled.Reasoning.Summary != *original.Reasoning.Summary {
|
||||
t.Error("Reasoning.Summary not preserved")
|
||||
}
|
||||
// MaxTokens should be nil after unmarshaling (since it wasn't in JSON)
|
||||
if unmarshaled.Reasoning.MaxTokens != nil {
|
||||
t.Error("Reasoning.MaxTokens should be nil after unmarshaling (was omitted from JSON)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Regression test for multi-turn Anthropic tool_result with array-form content.
|
||||
// The OpenAI Responses API defines function_call_output.output as a string (see
|
||||
// https://platform.openai.com/docs/api-reference/responses/create). When an
|
||||
// Anthropic client sends a tool_result whose content is an array of text blocks,
|
||||
// Bifrost's Anthropic→Responses translator populates
|
||||
// ResponsesToolMessageOutputStruct.ResponsesFunctionToolCallOutputBlocks.
|
||||
// Historically, that array was marshaled verbatim onto the wire, which some
|
||||
// strict OpenAI-compat upstreams (e.g. Ollama Cloud) reject with an error like
|
||||
//
|
||||
// json: cannot unmarshal array into Go struct field ResponsesFunctionCallOutput.output of type string
|
||||
//
|
||||
// The outgoing OpenAI Responses request must emit `output` as a string for
|
||||
// text-only tool outputs.
|
||||
func TestOpenAIResponsesRequestInput_MarshalJSON_FunctionCallOutputFlattensTextBlocksToString(t *testing.T) {
|
||||
outputText := "line1"
|
||||
callID := "toolu_abc123"
|
||||
functionName := "read_file"
|
||||
|
||||
input := &OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("Read /tmp/test.txt and tell me what it contains."),
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
|
||||
Status: schemas.Ptr("completed"),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr(callID),
|
||||
Name: schemas.Ptr(functionName),
|
||||
Arguments: schemas.Ptr(`{"path":"/tmp/test.txt"}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput),
|
||||
Status: schemas.Ptr("completed"),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr(callID),
|
||||
Output: &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{
|
||||
Type: schemas.ResponsesInputMessageContentBlockTypeText,
|
||||
Text: schemas.Ptr(outputText),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonBytes, err := input.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal OpenAIResponsesRequestInput: %v", err)
|
||||
}
|
||||
|
||||
var messages []map[string]interface{}
|
||||
if err := sonic.Unmarshal(jsonBytes, &messages); err != nil {
|
||||
t.Fatalf("Failed to unmarshal marshaled input as array: %v\nraw=%s", err, string(jsonBytes))
|
||||
}
|
||||
|
||||
var fcoMsg map[string]interface{}
|
||||
for _, m := range messages {
|
||||
if t, ok := m["type"].(string); ok && t == string(schemas.ResponsesMessageTypeFunctionCallOutput) {
|
||||
fcoMsg = m
|
||||
break
|
||||
}
|
||||
}
|
||||
if fcoMsg == nil {
|
||||
t.Fatalf("did not find function_call_output message in marshaled JSON: %s", string(jsonBytes))
|
||||
}
|
||||
|
||||
outputVal, ok := fcoMsg["output"]
|
||||
if !ok {
|
||||
t.Fatalf("function_call_output message has no `output` field: %s", string(jsonBytes))
|
||||
}
|
||||
|
||||
outputStr, isString := outputVal.(string)
|
||||
if !isString {
|
||||
t.Fatalf("function_call_output.output must be a string (OpenAI Responses API spec); got %T: %v\nraw=%s", outputVal, outputVal, string(jsonBytes))
|
||||
}
|
||||
if outputStr != outputText {
|
||||
t.Fatalf("function_call_output.output mismatch: want %q, got %q", outputText, outputStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Flattening must concatenate multiple text blocks with newline separators so
|
||||
// every character from the upstream tool response reaches the model.
|
||||
func TestOpenAIResponsesRequestInput_MarshalJSON_FunctionCallOutputConcatenatesMultipleTextBlocks(t *testing.T) {
|
||||
callID := "toolu_multi"
|
||||
input := &OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
|
||||
{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput),
|
||||
Status: schemas.Ptr("completed"),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr(callID),
|
||||
Output: &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: schemas.Ptr("line1")},
|
||||
{Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: schemas.Ptr("line2")},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonBytes, err := input.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
var messages []map[string]interface{}
|
||||
if err := sonic.Unmarshal(jsonBytes, &messages); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v\nraw=%s", err, string(jsonBytes))
|
||||
}
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(messages))
|
||||
}
|
||||
got, ok := messages[0]["output"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("output must be string, got %T", messages[0]["output"])
|
||||
}
|
||||
if want := "line1\nline2"; got != want {
|
||||
t.Fatalf("flattened output mismatch: want %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// When the tool result contains a non-text block (e.g. an image), flattening is
|
||||
// unsafe — preserve the array form and let the upstream handle it. This keeps
|
||||
// the fix scoped to the common text-only case without dropping rich content.
|
||||
func TestOpenAIResponsesRequestInput_MarshalJSON_FunctionCallOutputPreservesNonTextBlocks(t *testing.T) {
|
||||
callID := "toolu_with_image"
|
||||
imageURL := "https://example.com/screenshot.png"
|
||||
input := &OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
|
||||
{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput),
|
||||
Status: schemas.Ptr("completed"),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr(callID),
|
||||
Output: &schemas.ResponsesToolMessageOutputStruct{
|
||||
ResponsesFunctionToolCallOutputBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{Type: schemas.ResponsesInputMessageContentBlockTypeText, Text: schemas.Ptr("here is the screenshot:")},
|
||||
{
|
||||
Type: schemas.ResponsesInputMessageContentBlockTypeImage,
|
||||
ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{
|
||||
ImageURL: &imageURL,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
jsonBytes, err := input.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
var messages []map[string]interface{}
|
||||
if err := sonic.Unmarshal(jsonBytes, &messages); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v\nraw=%s", err, string(jsonBytes))
|
||||
}
|
||||
if _, isString := messages[0]["output"].(string); isString {
|
||||
t.Fatalf("non-text blocks must not be flattened to string; raw=%s", string(jsonBytes))
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIResponsesRequest_MarshalJSON_StripsAnthropicToolFlags ensures the
|
||||
// Responses serializer drops the four Anthropic-native tool flags
|
||||
// (defer_loading, allowed_callers, input_examples, eager_input_streaming)
|
||||
// along with CacheControl before forwarding to OpenAI — mirroring the Chat
|
||||
// path's behavior so Anthropic-flavored tools cannot 400 OpenAI via Responses.
|
||||
func TestOpenAIResponsesRequest_MarshalJSON_StripsAnthropicToolFlags(t *testing.T) {
|
||||
req := &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("hello"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ResponsesParameters: schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{
|
||||
{
|
||||
Type: schemas.ResponsesToolTypeFunction,
|
||||
Name: schemas.Ptr("lookup"),
|
||||
Description: schemas.Ptr("lookup something"),
|
||||
CacheControl: &schemas.CacheControl{Type: "ephemeral"},
|
||||
DeferLoading: schemas.Ptr(true),
|
||||
AllowedCallers: []string{"direct", "agent"},
|
||||
EagerInputStreaming: schemas.Ptr(false),
|
||||
InputExamples: []schemas.ChatToolInputExample{
|
||||
{Input: json.RawMessage(`{"q":"hi"}`)},
|
||||
},
|
||||
ResponsesToolFunction: &schemas.ResponsesToolFunction{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonBytes, err := req.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
raw := string(jsonBytes)
|
||||
|
||||
// None of the five Anthropic-only tool keys must survive on the wire.
|
||||
for _, key := range []string{`"cache_control"`, `"defer_loading"`, `"allowed_callers"`, `"input_examples"`, `"eager_input_streaming"`} {
|
||||
if strings.Contains(raw, key) {
|
||||
t.Errorf("OpenAI Responses serializer must strip %s; raw=%s", key, raw)
|
||||
}
|
||||
}
|
||||
// Function tool identity should be preserved.
|
||||
if !strings.Contains(raw, `"name":"lookup"`) {
|
||||
t.Errorf("tool identity lost after strip; raw=%s", raw)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIResponsesRequest_MarshalJSON_DropsAnthropicOnlyToolTypes verifies
|
||||
// that Anthropic-only tool types (web_fetch, memory) are dropped entirely when
|
||||
// serializing for OpenAI Responses. Per OpenAI's OpenAPI spec the Responses
|
||||
// Tool discriminator union does not include web_fetch or memory, so forwarding
|
||||
// them would trigger a 400 schema-validation error. Mirrors the Chat path's
|
||||
// isAnthropicServerToolShape drop behavior.
|
||||
func TestOpenAIResponsesRequest_MarshalJSON_DropsAnthropicOnlyToolTypes(t *testing.T) {
|
||||
req := &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: schemas.Ptr("hello"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ResponsesParameters: schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{
|
||||
// Kept: function (OpenAI-native).
|
||||
{
|
||||
Type: schemas.ResponsesToolTypeFunction,
|
||||
Name: schemas.Ptr("keeper_func"),
|
||||
ResponsesToolFunction: &schemas.ResponsesToolFunction{},
|
||||
},
|
||||
// Dropped: web_fetch (Anthropic-only).
|
||||
{
|
||||
Type: schemas.ResponsesToolTypeWebFetch,
|
||||
Name: schemas.Ptr("anthropic_webfetch"),
|
||||
ResponsesToolWebFetch: &schemas.ResponsesToolWebFetch{},
|
||||
},
|
||||
// Kept: web_search (both support).
|
||||
{
|
||||
Type: schemas.ResponsesToolTypeWebSearch,
|
||||
ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{},
|
||||
},
|
||||
// Dropped: memory (Anthropic-only).
|
||||
{
|
||||
Type: schemas.ResponsesToolTypeMemory,
|
||||
Name: schemas.Ptr("anthropic_memory"),
|
||||
},
|
||||
// Kept: tool_search (both support per OpenAI OpenAPI spec).
|
||||
{
|
||||
Type: schemas.ResponsesToolTypeToolSearch,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonBytes, err := req.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
raw := string(jsonBytes)
|
||||
|
||||
// Dropped types must not appear on the wire.
|
||||
for _, dropped := range []string{`"web_fetch"`, `"memory"`, `"anthropic_webfetch"`, `"anthropic_memory"`} {
|
||||
if strings.Contains(raw, dropped) {
|
||||
t.Errorf("Anthropic-only tool must be dropped; found %s in raw=%s", dropped, raw)
|
||||
}
|
||||
}
|
||||
// Kept types must still appear.
|
||||
for _, kept := range []string{`"function"`, `"web_search"`, `"tool_search"`, `"keeper_func"`} {
|
||||
if !strings.Contains(raw, kept) {
|
||||
t.Errorf("supported tool %s should be preserved; raw=%s", kept, raw)
|
||||
}
|
||||
}
|
||||
|
||||
// Confirm the tools array is present and has exactly 3 entries (2 dropped of 5).
|
||||
var decoded struct {
|
||||
Tools []map[string]interface{} `json:"tools"`
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &decoded); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
if len(decoded.Tools) != 3 {
|
||||
t.Errorf("expected 3 tools after drop (function, web_search, tool_search), got %d; tools=%+v", len(decoded.Tools), decoded.Tools)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpenAIResponsesRequest_MarshalJSON_KeepsAllWhenAllSupported verifies the
|
||||
// no-reshape fast path: if every tool is OpenAI-compatible with no
|
||||
// Anthropic-only flags, the tools slice passes through unchanged (no copy,
|
||||
// no drop).
|
||||
func TestOpenAIResponsesRequest_MarshalJSON_KeepsAllWhenAllSupported(t *testing.T) {
|
||||
req := &OpenAIResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Input: OpenAIResponsesRequestInput{
|
||||
OpenAIResponsesRequestInputArray: []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser),
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr("hi")},
|
||||
},
|
||||
},
|
||||
},
|
||||
ResponsesParameters: schemas.ResponsesParameters{
|
||||
Tools: []schemas.ResponsesTool{
|
||||
{Type: schemas.ResponsesToolTypeFunction, Name: schemas.Ptr("f"), ResponsesToolFunction: &schemas.ResponsesToolFunction{}},
|
||||
{Type: schemas.ResponsesToolTypeWebSearch, ResponsesToolWebSearch: &schemas.ResponsesToolWebSearch{}},
|
||||
{Type: schemas.ResponsesToolTypeCodeInterpreter, ResponsesToolCodeInterpreter: &schemas.ResponsesToolCodeInterpreter{}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonBytes, err := req.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var decoded struct {
|
||||
Tools []map[string]interface{} `json:"tools"`
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &decoded); err != nil {
|
||||
t.Fatalf("decode failed: %v", err)
|
||||
}
|
||||
if len(decoded.Tools) != 3 {
|
||||
t.Errorf("expected 3 tools preserved, got %d", len(decoded.Tools))
|
||||
}
|
||||
}
|
||||
1598
core/providers/openai/responses_test.go
Normal file
1598
core/providers/openai/responses_test.go
Normal file
File diff suppressed because it is too large
Load Diff
43
core/providers/openai/speech.go
Normal file
43
core/providers/openai/speech.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostSpeechRequest converts an OpenAI speech request to Bifrost format
|
||||
func (request *OpenAISpeechRequest) ToBifrostSpeechRequest(ctx *schemas.BifrostContext) *schemas.BifrostSpeechRequest {
|
||||
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
|
||||
|
||||
return &schemas.BifrostSpeechRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: &schemas.SpeechInput{Input: request.Input},
|
||||
Params: &request.SpeechParameters,
|
||||
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
|
||||
}
|
||||
}
|
||||
|
||||
// ToOpenAISpeechRequest converts a Bifrost speech request to OpenAI format
|
||||
func ToOpenAISpeechRequest(bifrostReq *schemas.BifrostSpeechRequest) *OpenAISpeechRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input.Input == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
speechInput := bifrostReq.Input
|
||||
params := bifrostReq.Params
|
||||
|
||||
openaiReq := &OpenAISpeechRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Input: speechInput.Input,
|
||||
}
|
||||
|
||||
if params != nil {
|
||||
openaiReq.SpeechParameters = *params
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
openaiReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
return openaiReq
|
||||
}
|
||||
75
core/providers/openai/text.go
Normal file
75
core/providers/openai/text.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"maps"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToOpenAITextCompletionRequest converts a Bifrost text completion request to OpenAI format
|
||||
func ToOpenAITextCompletionRequest(bifrostReq *schemas.BifrostTextCompletionRequest) *OpenAITextCompletionRequest {
|
||||
if bifrostReq == nil {
|
||||
return nil
|
||||
}
|
||||
params := bifrostReq.Params
|
||||
openaiReq := &OpenAITextCompletionRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Prompt: bifrostReq.Input,
|
||||
}
|
||||
if params != nil {
|
||||
openaiReq.TextCompletionParameters = *params
|
||||
// Drop user field if it exceeds OpenAI's 64 character limit
|
||||
openaiReq.TextCompletionParameters.User = SanitizeUserField(openaiReq.TextCompletionParameters.User)
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
openaiReq.ExtraParams = maps.Clone(bifrostReq.Params.ExtraParams)
|
||||
openaiReq.TextCompletionParameters.ExtraParams = openaiReq.ExtraParams
|
||||
}
|
||||
}
|
||||
if bifrostReq.Provider == schemas.Fireworks {
|
||||
openaiReq.applyFireworksTextCompletionCompatibility()
|
||||
}
|
||||
return openaiReq
|
||||
}
|
||||
|
||||
// applyFireworksTextCompletionCompatibility maps Fireworks-specific text fields.
|
||||
func (req *OpenAITextCompletionRequest) applyFireworksTextCompletionCompatibility() {
|
||||
if req == nil || req.ExtraParams == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Fireworks uses prompt_cache_isolation_key for text-completion cache isolation.
|
||||
if req.PromptCacheIsolationKey == nil {
|
||||
if value, ok := req.ExtraParams["prompt_cache_key"]; ok {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
if typed != "" {
|
||||
req.PromptCacheIsolationKey = &typed
|
||||
}
|
||||
case *string:
|
||||
if typed != nil && *typed != "" {
|
||||
req.PromptCacheIsolationKey = typed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(req.ExtraParams, "prompt_cache_key")
|
||||
req.TextCompletionParameters.ExtraParams = req.ExtraParams
|
||||
}
|
||||
|
||||
// ToBifrostTextCompletionRequest converts an OpenAI text completion request to Bifrost format
|
||||
func (req *OpenAITextCompletionRequest) ToBifrostTextCompletionRequest(ctx *schemas.BifrostContext) *schemas.BifrostTextCompletionRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
|
||||
|
||||
return &schemas.BifrostTextCompletionRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: req.Prompt,
|
||||
Params: &req.TextCompletionParameters,
|
||||
Fallbacks: schemas.ParseFallbacks(req.Fallbacks),
|
||||
}
|
||||
}
|
||||
73
core/providers/openai/text_test.go
Normal file
73
core/providers/openai/text_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestToOpenAITextCompletionRequest_FireworksUsesCacheIsolation(t *testing.T) {
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
ctx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
|
||||
cacheKey := "cache-key-1"
|
||||
prompt := "A is for apple and B is for"
|
||||
extraParams := map[string]interface{}{
|
||||
"prompt_cache_key": cacheKey,
|
||||
"top_k": 4,
|
||||
}
|
||||
|
||||
bifrostReq := &schemas.BifrostTextCompletionRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: "accounts/fireworks/models/deepseek-v3p2",
|
||||
Input: &schemas.TextCompletionInput{
|
||||
PromptStr: &prompt,
|
||||
},
|
||||
Params: &schemas.TextCompletionParameters{
|
||||
ExtraParams: extraParams,
|
||||
},
|
||||
}
|
||||
|
||||
result := ToOpenAITextCompletionRequest(bifrostReq)
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if result.PromptCacheIsolationKey == nil || *result.PromptCacheIsolationKey != cacheKey {
|
||||
t.Fatalf("expected prompt_cache_isolation_key %q, got %v", cacheKey, result.PromptCacheIsolationKey)
|
||||
}
|
||||
if _, ok := result.ExtraParams["prompt_cache_key"]; ok {
|
||||
t.Fatalf("expected prompt_cache_key to be removed from extra params, got %#v", result.ExtraParams)
|
||||
}
|
||||
if _, ok := bifrostReq.Params.ExtraParams["prompt_cache_key"]; !ok {
|
||||
t.Fatalf("expected original extra params to remain unchanged, got %#v", bifrostReq.Params.ExtraParams)
|
||||
}
|
||||
|
||||
wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
||||
ctx,
|
||||
bifrostReq,
|
||||
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
||||
return ToOpenAITextCompletionRequest(bifrostReq), nil
|
||||
},
|
||||
)
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message)
|
||||
}
|
||||
|
||||
var jsonMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(wireBody, &jsonMap); err != nil {
|
||||
t.Fatalf("failed to parse marshaled request body: %v", err)
|
||||
}
|
||||
|
||||
if got, ok := jsonMap["prompt_cache_isolation_key"].(string); !ok || got != cacheKey {
|
||||
t.Fatalf("expected prompt_cache_isolation_key %q in wire payload, got %#v", cacheKey, jsonMap["prompt_cache_isolation_key"])
|
||||
}
|
||||
if _, ok := jsonMap["prompt_cache_key"]; ok {
|
||||
t.Fatalf("expected prompt_cache_key to be absent from wire payload, got %#v", jsonMap["prompt_cache_key"])
|
||||
}
|
||||
if got, ok := jsonMap["top_k"].(float64); !ok || got != 4 {
|
||||
t.Fatalf("expected top_k extra param to be preserved, got %#v", jsonMap["top_k"])
|
||||
}
|
||||
}
|
||||
117
core/providers/openai/transcription.go
Normal file
117
core/providers/openai/transcription.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostTranscriptionRequest converts an OpenAI transcription request to Bifrost format
|
||||
func (request *OpenAITranscriptionRequest) ToBifrostTranscriptionRequest(ctx *schemas.BifrostContext) *schemas.BifrostTranscriptionRequest {
|
||||
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.OpenAI))
|
||||
|
||||
return &schemas.BifrostTranscriptionRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: &schemas.TranscriptionInput{
|
||||
File: request.File,
|
||||
},
|
||||
Params: &request.TranscriptionParameters,
|
||||
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
|
||||
}
|
||||
}
|
||||
|
||||
// ToOpenAITranscriptionRequest converts a Bifrost transcription request to OpenAI format
|
||||
func ToOpenAITranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionRequest) *OpenAITranscriptionRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input.File == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
transcriptionInput := bifrostReq.Input
|
||||
params := bifrostReq.Params
|
||||
|
||||
openaiReq := &OpenAITranscriptionRequest{
|
||||
Model: bifrostReq.Model,
|
||||
File: transcriptionInput.File,
|
||||
Filename: transcriptionInput.Filename,
|
||||
}
|
||||
|
||||
if params != nil {
|
||||
openaiReq.TranscriptionParameters = *params
|
||||
}
|
||||
|
||||
return openaiReq
|
||||
}
|
||||
|
||||
// ParseTranscriptionFormDataBodyFromRequest parses the transcription request and writes it to the multipart form.
|
||||
func ParseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAITranscriptionRequest, providerName schemas.ModelProvider) *schemas.BifrostError {
|
||||
// Add model field before the file so upstreams can route without buffering the audio payload.
|
||||
if err := writer.WriteField("model", openaiReq.Model); err != nil {
|
||||
return utils.NewBifrostOperationError("failed to write model field", err)
|
||||
}
|
||||
|
||||
// Add optional fields
|
||||
if openaiReq.Language != nil {
|
||||
if err := writer.WriteField("language", *openaiReq.Language); err != nil {
|
||||
return utils.NewBifrostOperationError("failed to write language field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.Prompt != nil {
|
||||
if err := writer.WriteField("prompt", *openaiReq.Prompt); err != nil {
|
||||
return utils.NewBifrostOperationError("failed to write prompt field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.ResponseFormat != nil {
|
||||
if err := writer.WriteField("response_format", *openaiReq.ResponseFormat); err != nil {
|
||||
return utils.NewBifrostOperationError("failed to write response_format field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.Temperature != nil {
|
||||
if err := writer.WriteField("temperature", fmt.Sprintf("%g", *openaiReq.Temperature)); err != nil {
|
||||
return utils.NewBifrostOperationError("failed to write temperature field", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, granularity := range openaiReq.TimestampGranularities {
|
||||
if err := writer.WriteField("timestamp_granularities[]", granularity); err != nil {
|
||||
return utils.NewBifrostOperationError("failed to write timestamp_granularities field", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, include := range openaiReq.Include {
|
||||
if err := writer.WriteField("include[]", include); err != nil {
|
||||
return utils.NewBifrostOperationError("failed to write include field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if openaiReq.Stream != nil && *openaiReq.Stream {
|
||||
if err := writer.WriteField("stream", "true"); err != nil {
|
||||
return utils.NewBifrostOperationError("failed to write stream field", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add file field last so large multipart uploads don't block model discovery upstream.
|
||||
filename := openaiReq.Filename
|
||||
if filename == "" {
|
||||
filename = utils.AudioFilenameFromBytes(openaiReq.File)
|
||||
}
|
||||
fileWriter, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return utils.NewBifrostOperationError("failed to create form file", err)
|
||||
}
|
||||
if _, err := fileWriter.Write(openaiReq.File); err != nil {
|
||||
return utils.NewBifrostOperationError("failed to write file data", err)
|
||||
}
|
||||
|
||||
// Close the multipart writer
|
||||
if err := writer.Close(); err != nil {
|
||||
return utils.NewBifrostOperationError("failed to close multipart writer", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
79
core/providers/openai/transcription_test.go
Normal file
79
core/providers/openai/transcription_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func multipartPartOrder(t *testing.T, contentType string, body []byte) []string {
|
||||
t.Helper()
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseMediaType(%q): %v", contentType, err)
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
t.Fatalf("missing boundary in %q", contentType)
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
var order []string
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NextPart(): %v", err)
|
||||
}
|
||||
order = append(order, part.FormName())
|
||||
_, _ = io.Copy(io.Discard, part)
|
||||
_ = part.Close()
|
||||
}
|
||||
return order
|
||||
}
|
||||
|
||||
func TestParseTranscriptionFormDataBodyFromRequest_OrdersMetadataBeforeFile(t *testing.T) {
|
||||
language := "en"
|
||||
prompt := "transcribe this"
|
||||
responseFormat := "verbose_json"
|
||||
temperature := 0.2
|
||||
stream := true
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
req := &OpenAITranscriptionRequest{
|
||||
Model: "whisper-1",
|
||||
File: []byte("audio-bytes"),
|
||||
Filename: "sample.mp3",
|
||||
TranscriptionParameters: schemas.TranscriptionParameters{
|
||||
Language: &language,
|
||||
Prompt: &prompt,
|
||||
ResponseFormat: &responseFormat,
|
||||
Temperature: &temperature,
|
||||
TimestampGranularities: []string{"word"},
|
||||
Include: []string{"logprobs"},
|
||||
},
|
||||
Stream: &stream,
|
||||
}
|
||||
|
||||
if bifrostErr := ParseTranscriptionFormDataBodyFromRequest(writer, req, schemas.OpenAI); bifrostErr != nil {
|
||||
t.Fatalf("unexpected bifrost error: %v", bifrostErr.Error.Message)
|
||||
}
|
||||
|
||||
order := multipartPartOrder(t, writer.FormDataContentType(), body.Bytes())
|
||||
if len(order) == 0 {
|
||||
t.Fatal("expected multipart parts to be written")
|
||||
}
|
||||
if order[len(order)-1] != "file" {
|
||||
t.Fatalf("expected file part last, got order %v", order)
|
||||
}
|
||||
if order[0] != "model" {
|
||||
t.Fatalf("expected model part first, got order %v", order)
|
||||
}
|
||||
}
|
||||
1025
core/providers/openai/types.go
Normal file
1025
core/providers/openai/types.go
Normal file
File diff suppressed because it is too large
Load Diff
468
core/providers/openai/types_test.go
Normal file
468
core/providers/openai/types_test.go
Normal file
@@ -0,0 +1,468 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestOpenAIChatRequest_UnmarshalJSON_BaseFieldsPreserved(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonPayload string
|
||||
validate func(t *testing.T, req *OpenAIChatRequest)
|
||||
}{
|
||||
{
|
||||
name: "all base fields preserved with ChatParameters",
|
||||
jsonPayload: `{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, world!"
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 100,
|
||||
"prompt_cache_isolation_key": "cache-key-1",
|
||||
"fallbacks": ["gpt-3.5-turbo"],
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9
|
||||
}`,
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// Assert base fields are preserved
|
||||
if req.Model != "gpt-4o" {
|
||||
t.Errorf("Expected Model to be 'gpt-4o', got %q", req.Model)
|
||||
}
|
||||
|
||||
if len(req.Messages) != 1 {
|
||||
t.Fatalf("Expected 1 message, got %d", len(req.Messages))
|
||||
}
|
||||
if req.Messages[0].Role != schemas.ChatMessageRoleUser {
|
||||
t.Errorf("Expected message role to be 'user', got %q", req.Messages[0].Role)
|
||||
}
|
||||
if req.Messages[0].Content == nil || req.Messages[0].Content.ContentStr == nil {
|
||||
t.Fatal("Expected message content to be set")
|
||||
}
|
||||
if *req.Messages[0].Content.ContentStr != "Hello, world!" {
|
||||
t.Errorf("Expected message content to be 'Hello, world!', got %q", *req.Messages[0].Content.ContentStr)
|
||||
}
|
||||
|
||||
if req.Stream == nil || !*req.Stream {
|
||||
t.Error("Expected Stream to be true")
|
||||
}
|
||||
|
||||
if req.MaxTokens == nil || *req.MaxTokens != 100 {
|
||||
t.Errorf("Expected MaxTokens to be 100, got %v", req.MaxTokens)
|
||||
}
|
||||
|
||||
if req.PromptCacheIsolationKey == nil || *req.PromptCacheIsolationKey != "cache-key-1" {
|
||||
t.Errorf("Expected PromptCacheIsolationKey to be %q, got %v", "cache-key-1", req.PromptCacheIsolationKey)
|
||||
}
|
||||
|
||||
if len(req.Fallbacks) != 1 || req.Fallbacks[0] != "gpt-3.5-turbo" {
|
||||
t.Errorf("Expected Fallbacks to be ['gpt-3.5-turbo'], got %v", req.Fallbacks)
|
||||
}
|
||||
|
||||
// Assert ChatParameters fields are populated
|
||||
if req.Temperature == nil || *req.Temperature != 0.7 {
|
||||
t.Errorf("Expected Temperature to be 0.7, got %v", req.Temperature)
|
||||
}
|
||||
|
||||
if req.TopP == nil || *req.TopP != 0.9 {
|
||||
t.Errorf("Expected TopP to be 0.9, got %v", req.TopP)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "base fields with multiple ChatParameters fields",
|
||||
jsonPayload: `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 2+2?"
|
||||
}
|
||||
],
|
||||
"stream": false,
|
||||
"max_tokens": 500,
|
||||
"fallbacks": ["gpt-4o", "gpt-4"],
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.95,
|
||||
"frequency_penalty": 0.2,
|
||||
"presence_penalty": 0.3,
|
||||
"seed": 42,
|
||||
"stop": ["STOP", "END"]
|
||||
}`,
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// Assert base fields
|
||||
if req.Model != "gpt-3.5-turbo" {
|
||||
t.Errorf("Expected Model to be 'gpt-3.5-turbo', got %q", req.Model)
|
||||
}
|
||||
|
||||
if len(req.Messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d", len(req.Messages))
|
||||
}
|
||||
|
||||
if req.Stream == nil || *req.Stream {
|
||||
t.Error("Expected Stream to be false")
|
||||
}
|
||||
|
||||
if req.MaxTokens == nil || *req.MaxTokens != 500 {
|
||||
t.Errorf("Expected MaxTokens to be 500, got %v", req.MaxTokens)
|
||||
}
|
||||
|
||||
if len(req.Fallbacks) != 2 {
|
||||
t.Errorf("Expected 2 fallbacks, got %d", len(req.Fallbacks))
|
||||
}
|
||||
|
||||
// Assert multiple ChatParameters fields
|
||||
if req.Temperature == nil || *req.Temperature != 0.5 {
|
||||
t.Errorf("Expected Temperature to be 0.5, got %v", req.Temperature)
|
||||
}
|
||||
|
||||
if req.TopP == nil || *req.TopP != 0.95 {
|
||||
t.Errorf("Expected TopP to be 0.95, got %v", req.TopP)
|
||||
}
|
||||
|
||||
if req.FrequencyPenalty == nil || *req.FrequencyPenalty != 0.2 {
|
||||
t.Errorf("Expected FrequencyPenalty to be 0.2, got %v", req.FrequencyPenalty)
|
||||
}
|
||||
|
||||
if req.PresencePenalty == nil || *req.PresencePenalty != 0.3 {
|
||||
t.Errorf("Expected PresencePenalty to be 0.3, got %v", req.PresencePenalty)
|
||||
}
|
||||
|
||||
if req.Seed == nil || *req.Seed != 42 {
|
||||
t.Errorf("Expected Seed to be 42, got %v", req.Seed)
|
||||
}
|
||||
|
||||
if len(req.Stop) != 2 {
|
||||
t.Errorf("Expected Stop to have 2 elements, got %d", len(req.Stop))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "base fields with optional fields omitted",
|
||||
jsonPayload: `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Test"
|
||||
}
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0
|
||||
}`,
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
if req.Model != "gpt-4" {
|
||||
t.Errorf("Expected Model to be 'gpt-4', got %q", req.Model)
|
||||
}
|
||||
|
||||
if len(req.Messages) != 1 {
|
||||
t.Fatalf("Expected 1 message, got %d", len(req.Messages))
|
||||
}
|
||||
|
||||
// Optional fields should be nil/empty when omitted
|
||||
if req.Stream != nil {
|
||||
t.Error("Expected Stream to be nil when omitted")
|
||||
}
|
||||
|
||||
if req.MaxTokens != nil {
|
||||
t.Error("Expected MaxTokens to be nil when omitted")
|
||||
}
|
||||
|
||||
if len(req.Fallbacks) != 0 {
|
||||
t.Errorf("Expected Fallbacks to be empty when omitted, got %v", req.Fallbacks)
|
||||
}
|
||||
|
||||
// ChatParameters fields should still be populated
|
||||
if req.Temperature == nil || *req.Temperature != 1.0 {
|
||||
t.Errorf("Expected Temperature to be 1.0, got %v", req.Temperature)
|
||||
}
|
||||
|
||||
if req.TopP == nil || *req.TopP != 1.0 {
|
||||
t.Errorf("Expected TopP to be 1.0, got %v", req.TopP)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var req OpenAIChatRequest
|
||||
if err := sonic.Unmarshal([]byte(tt.jsonPayload), &req); err != nil {
|
||||
t.Fatalf("Failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
|
||||
tt.validate(t, &req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIChatRequest_UnmarshalJSON_ChatParametersCustomLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
jsonPayload string
|
||||
validate func(t *testing.T, req *OpenAIChatRequest)
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "reasoning_effort converted to Reasoning.Effort",
|
||||
jsonPayload: `{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Think step by step"
|
||||
}
|
||||
],
|
||||
"reasoning_effort": "high",
|
||||
"temperature": 0.8
|
||||
}`,
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// Assert base fields are preserved
|
||||
if req.Model != "gpt-4o" {
|
||||
t.Errorf("Expected Model to be 'gpt-4o', got %q", req.Model)
|
||||
}
|
||||
|
||||
// Assert reasoning_effort was converted to Reasoning.Effort
|
||||
if req.Reasoning == nil {
|
||||
t.Fatal("Expected Reasoning to be set from reasoning_effort")
|
||||
}
|
||||
if req.Reasoning.Effort == nil {
|
||||
t.Fatal("Expected Reasoning.Effort to be set")
|
||||
}
|
||||
if *req.Reasoning.Effort != "high" {
|
||||
t.Errorf("Expected Reasoning.Effort to be 'high', got %q", *req.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// Assert other ChatParameters fields are still populated
|
||||
if req.Temperature == nil || *req.Temperature != 0.8 {
|
||||
t.Errorf("Expected Temperature to be 0.8, got %v", req.Temperature)
|
||||
}
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "both reasoning and reasoning_effort should error",
|
||||
jsonPayload: `{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Test"
|
||||
}
|
||||
],
|
||||
"reasoning": {
|
||||
"effort": "medium"
|
||||
},
|
||||
"reasoning_effort": "high"
|
||||
}`,
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// This should have failed during unmarshaling
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "reasoning_effort with multiple ChatParameters fields",
|
||||
jsonPayload: `{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Analyze this"
|
||||
}
|
||||
],
|
||||
"reasoning_effort": "medium",
|
||||
"temperature": 0.6,
|
||||
"top_p": 0.85,
|
||||
"max_completion_tokens": 2000
|
||||
}`,
|
||||
validate: func(t *testing.T, req *OpenAIChatRequest) {
|
||||
// Assert base fields
|
||||
if req.Model != "gpt-4o" {
|
||||
t.Errorf("Expected Model to be 'gpt-4o', got %q", req.Model)
|
||||
}
|
||||
|
||||
// Assert reasoning_effort conversion
|
||||
if req.Reasoning == nil || req.Reasoning.Effort == nil {
|
||||
t.Fatal("Expected Reasoning.Effort to be set from reasoning_effort")
|
||||
}
|
||||
if *req.Reasoning.Effort != "medium" {
|
||||
t.Errorf("Expected Reasoning.Effort to be 'medium', got %q", *req.Reasoning.Effort)
|
||||
}
|
||||
|
||||
// Assert other ChatParameters fields
|
||||
if req.Temperature == nil || *req.Temperature != 0.6 {
|
||||
t.Errorf("Expected Temperature to be 0.6, got %v", req.Temperature)
|
||||
}
|
||||
if req.TopP == nil || *req.TopP != 0.85 {
|
||||
t.Errorf("Expected TopP to be 0.85, got %v", req.TopP)
|
||||
}
|
||||
if req.MaxCompletionTokens == nil || *req.MaxCompletionTokens != 2000 {
|
||||
t.Errorf("Expected MaxCompletionTokens to be 2000, got %v", req.MaxCompletionTokens)
|
||||
}
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var req OpenAIChatRequest
|
||||
err := sonic.Unmarshal([]byte(tt.jsonPayload), &req)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error during unmarshaling: %v", err)
|
||||
}
|
||||
|
||||
tt.validate(t, &req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIChatRequest_UnmarshalJSON_PresenceAssertions(t *testing.T) {
|
||||
// Test that verifies presence of fields (not just values)
|
||||
jsonPayload := `{
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello!"
|
||||
}
|
||||
],
|
||||
"stream": false,
|
||||
"max_tokens": 150,
|
||||
"fallbacks": ["model1", "model2"],
|
||||
"temperature": 0.3,
|
||||
"top_p": 0.7,
|
||||
"user": "test-user-123"
|
||||
}`
|
||||
|
||||
var req OpenAIChatRequest
|
||||
if err := sonic.Unmarshal([]byte(jsonPayload), &req); err != nil {
|
||||
t.Fatalf("Failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
|
||||
// Presence assertions for base fields
|
||||
if req.Model == "" {
|
||||
t.Error("Model field should be present")
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
t.Error("Messages field should be present and non-empty")
|
||||
}
|
||||
|
||||
if req.Stream == nil {
|
||||
t.Error("Stream field should be present (even if false)")
|
||||
}
|
||||
|
||||
if req.MaxTokens == nil {
|
||||
t.Error("MaxTokens field should be present")
|
||||
}
|
||||
|
||||
if len(req.Fallbacks) == 0 {
|
||||
t.Error("Fallbacks field should be present and non-empty")
|
||||
}
|
||||
|
||||
// Presence assertions for ChatParameters fields
|
||||
if req.Temperature == nil {
|
||||
t.Error("Temperature field should be present")
|
||||
}
|
||||
|
||||
if req.TopP == nil {
|
||||
t.Error("TopP field should be present")
|
||||
}
|
||||
|
||||
if req.User == nil {
|
||||
t.Error("User field should be present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIChatRequest_UnmarshalJSON_ValueAssertions(t *testing.T) {
|
||||
// Test that verifies exact values match expectations
|
||||
jsonPayload := `{
|
||||
"model": "gpt-4-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "System message"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "User message"
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 250,
|
||||
"fallbacks": ["fallback1"],
|
||||
"temperature": 0.9,
|
||||
"top_p": 0.95,
|
||||
"seed": 12345,
|
||||
"stop": ["END", "STOP"]
|
||||
}`
|
||||
|
||||
var req OpenAIChatRequest
|
||||
if err := sonic.Unmarshal([]byte(jsonPayload), &req); err != nil {
|
||||
t.Fatalf("Failed to unmarshal JSON: %v", err)
|
||||
}
|
||||
|
||||
// Value assertions for base fields
|
||||
if req.Model != "gpt-4-turbo" {
|
||||
t.Errorf("Expected Model value 'gpt-4-turbo', got %q", req.Model)
|
||||
}
|
||||
|
||||
if len(req.Messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %d", len(req.Messages))
|
||||
}
|
||||
if req.Messages[0].Role != schemas.ChatMessageRoleSystem {
|
||||
t.Errorf("Expected first message role 'system', got %q", req.Messages[0].Role)
|
||||
}
|
||||
if req.Messages[1].Role != schemas.ChatMessageRoleUser {
|
||||
t.Errorf("Expected second message role 'user', got %q", req.Messages[1].Role)
|
||||
}
|
||||
|
||||
if req.Stream == nil || !*req.Stream {
|
||||
t.Error("Expected Stream value to be true")
|
||||
}
|
||||
|
||||
if req.MaxTokens == nil || *req.MaxTokens != 250 {
|
||||
t.Errorf("Expected MaxTokens value 250, got %v", req.MaxTokens)
|
||||
}
|
||||
|
||||
if len(req.Fallbacks) != 1 || req.Fallbacks[0] != "fallback1" {
|
||||
t.Errorf("Expected Fallbacks value ['fallback1'], got %v", req.Fallbacks)
|
||||
}
|
||||
|
||||
// Value assertions for ChatParameters fields
|
||||
if req.Temperature == nil || *req.Temperature != 0.9 {
|
||||
t.Errorf("Expected Temperature value 0.9, got %v", req.Temperature)
|
||||
}
|
||||
|
||||
if req.TopP == nil || *req.TopP != 0.95 {
|
||||
t.Errorf("Expected TopP value 0.95, got %v", req.TopP)
|
||||
}
|
||||
|
||||
if req.Seed == nil || *req.Seed != 12345 {
|
||||
t.Errorf("Expected Seed value 12345, got %v", req.Seed)
|
||||
}
|
||||
|
||||
if len(req.Stop) != 2 || req.Stop[0] != "END" || req.Stop[1] != "STOP" {
|
||||
t.Errorf("Expected Stop value ['END', 'STOP'], got %v", req.Stop)
|
||||
}
|
||||
}
|
||||
|
||||
108
core/providers/openai/utils.go
Normal file
108
core/providers/openai/utils.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// CustomResponseHandler is a function that produces a Bifrost response from a Bifrost request.
|
||||
// T is the concrete Bifrost response type (e.g. BifrostEmbeddingResponse, BifrostTextCompletionResponse, BifrostChatResponse, BifrostResponsesResponse, BifrostImageGenerationResponse, BifrostTranscriptionResponse).
|
||||
type responseHandler[T any] func(responseBody []byte, response *T, requestBody []byte, sendBackRawRequest bool, sendBackRawResponse bool) (rawRequest interface{}, rawResponse interface{}, bifrostErr *schemas.BifrostError)
|
||||
|
||||
func ConvertOpenAIMessagesToBifrostMessages(messages []OpenAIMessage) []schemas.ChatMessage {
|
||||
bifrostMessages := make([]schemas.ChatMessage, len(messages))
|
||||
for i, message := range messages {
|
||||
bifrostMessages[i] = schemas.ChatMessage{
|
||||
Name: message.Name,
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
ChatToolMessage: message.ChatToolMessage,
|
||||
}
|
||||
if message.OpenAIChatAssistantMessage != nil {
|
||||
bifrostMessages[i].ChatAssistantMessage = &schemas.ChatAssistantMessage{
|
||||
Refusal: message.OpenAIChatAssistantMessage.Refusal,
|
||||
Reasoning: message.OpenAIChatAssistantMessage.Reasoning,
|
||||
Annotations: message.OpenAIChatAssistantMessage.Annotations,
|
||||
ToolCalls: message.OpenAIChatAssistantMessage.ToolCalls,
|
||||
}
|
||||
}
|
||||
}
|
||||
return bifrostMessages
|
||||
}
|
||||
|
||||
func ConvertBifrostMessagesToOpenAIMessages(messages []schemas.ChatMessage) []OpenAIMessage {
|
||||
openaiMessages := make([]OpenAIMessage, len(messages))
|
||||
for i, message := range messages {
|
||||
openaiMessages[i] = OpenAIMessage{
|
||||
Name: message.Name,
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
ChatToolMessage: message.ChatToolMessage,
|
||||
}
|
||||
if message.ChatAssistantMessage != nil {
|
||||
openaiMessages[i].OpenAIChatAssistantMessage = &OpenAIChatAssistantMessage{
|
||||
Refusal: message.ChatAssistantMessage.Refusal,
|
||||
Reasoning: message.ChatAssistantMessage.Reasoning,
|
||||
Annotations: message.ChatAssistantMessage.Annotations,
|
||||
ToolCalls: message.ChatAssistantMessage.ToolCalls,
|
||||
}
|
||||
}
|
||||
}
|
||||
return openaiMessages
|
||||
}
|
||||
|
||||
// isOpenAIReasoningModel checks if the given model is an OpenAI reasoning model
|
||||
// that supports the reasoning.effort parameter.
|
||||
// OpenAI reasoning models include o1, o3, o4 series and GPT-5.x variants.
|
||||
// Note: -pro and -codex variants (e.g. gpt-5.2-pro, gpt-5.2-codex) are always-reasoning
|
||||
// models that do NOT support effort "none" — callers must handle top_p stripping separately.
|
||||
// TODO we need to find a better way to check if a model is an OpenAI reasoning model
|
||||
func isOpenAIReasoningModel(model string) bool {
|
||||
_, parsedModel := schemas.ParseModelString(model, schemas.OpenAI)
|
||||
if parsedModel != "" {
|
||||
model = parsedModel
|
||||
}
|
||||
modelLower := strings.ToLower(model)
|
||||
// Check for o1 or o3 series models
|
||||
// Match patterns like: o1, o1-mini, o1-preview, o3, o3-mini, etc.
|
||||
// Also match gpt-oss models which support reasoning
|
||||
if strings.Contains(modelLower, "gpt-oss") {
|
||||
return true
|
||||
}
|
||||
// Check for o1/o3/o4 series - these are reasoning models
|
||||
// The pattern matches "o1", "o3", or "o4" followed by end of string, hyphen, or underscore
|
||||
for _, prefix := range []string{"o1", "o3", "o4"} {
|
||||
if strings.HasPrefix(modelLower, prefix) {
|
||||
// Check if it's exactly the prefix or followed by a separator
|
||||
if len(modelLower) == len(prefix) ||
|
||||
modelLower[len(prefix)] == '-' ||
|
||||
modelLower[len(prefix)] == '_' {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Also check for models like "openai-o1-mini" where prefix is not at start
|
||||
if strings.Contains(modelLower, "-"+prefix+"-") ||
|
||||
strings.Contains(modelLower, "_"+prefix+"_") ||
|
||||
strings.HasSuffix(modelLower, "-"+prefix) ||
|
||||
strings.HasSuffix(modelLower, "_"+prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Check for GPT-5 series models which support reasoning.effort
|
||||
if strings.HasPrefix(modelLower, "gpt-5") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MaxUserFieldLength for OpenAI enforces a 64 character maximum on the user field
|
||||
const MaxUserFieldLength = 64
|
||||
|
||||
// SanitizeUserField returns nil if user exceeds MaxUserFieldLength, otherwise returns the original value
|
||||
func SanitizeUserField(user *string) *string {
|
||||
if user != nil && len(*user) > MaxUserFieldLength {
|
||||
return nil
|
||||
}
|
||||
return user
|
||||
}
|
||||
212
core/providers/openai/videos.go
Normal file
212
core/providers/openai/videos.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToOpenAIVideoGenerationRequest converts a Bifrost Video Request to OpenAI format
|
||||
func ToOpenAIVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRequest) (*OpenAIVideoGenerationRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || bifrostReq.Input.Prompt == "" {
|
||||
return nil, fmt.Errorf("bifrost request, input, or prompt is nil/empty")
|
||||
}
|
||||
|
||||
req := &OpenAIVideoGenerationRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Prompt: bifrostReq.Input.Prompt,
|
||||
}
|
||||
|
||||
if bifrostReq.Input.InputReference != nil {
|
||||
// convert base64 to bytes
|
||||
sanitizedURL, err := schemas.SanitizeImageURL(*bifrostReq.Input.InputReference)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid input reference: %w", err)
|
||||
}
|
||||
urlInfo := schemas.ExtractURLTypeInfo(sanitizedURL)
|
||||
if urlInfo.DataURLWithoutPrefix != nil {
|
||||
bytes, err := base64.StdEncoding.DecodeString(*urlInfo.DataURLWithoutPrefix)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode base64 input reference: %w", err)
|
||||
}
|
||||
req.InputReference = bytes
|
||||
} else {
|
||||
return nil, fmt.Errorf("input_reference must be a base64 data URL (e.g. data:image/png;base64,...)")
|
||||
}
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
if bifrostReq.Params.Seconds != nil {
|
||||
req.Seconds = bifrostReq.Params.Seconds
|
||||
}
|
||||
|
||||
// Validate and set size
|
||||
if bifrostReq.Params.Size != "" {
|
||||
// Check if the provided size is valid
|
||||
if ValidOpenAIVideoSizes[bifrostReq.Params.Size] {
|
||||
req.Size = bifrostReq.Params.Size
|
||||
} else {
|
||||
// Invalid size provided, use default
|
||||
req.Size = string(DefaultOpenAIVideoSize)
|
||||
}
|
||||
} else {
|
||||
// No size provided, use default
|
||||
req.Size = string(DefaultOpenAIVideoSize)
|
||||
}
|
||||
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func ToOpenAIVideoRemixRequest(bifrostReq *schemas.BifrostVideoRemixRequest) (*OpenAIVideoRemixRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || bifrostReq.Input.Prompt == "" {
|
||||
return nil, fmt.Errorf("bifrost request, input, or prompt is nil/empty")
|
||||
}
|
||||
|
||||
req := &OpenAIVideoRemixRequest{
|
||||
Prompt: bifrostReq.Input.Prompt,
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func ToBifrostVideoRemixRequest(openaiReq *OpenAIVideoRemixRequest) *schemas.BifrostVideoRemixRequest {
|
||||
if openaiReq == nil || openaiReq.Prompt == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider := openaiReq.Provider
|
||||
if provider == "" {
|
||||
provider = schemas.OpenAI
|
||||
}
|
||||
|
||||
return &schemas.BifrostVideoRemixRequest{
|
||||
ID: openaiReq.ID,
|
||||
Provider: provider,
|
||||
Input: &schemas.VideoGenerationInput{
|
||||
Prompt: openaiReq.Prompt,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (req *OpenAIVideoGenerationRequest) ToBifrostVideoGenerationRequest(ctx *schemas.BifrostContext) *schemas.BifrostVideoGenerationRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
defaultProvider := schemas.OpenAI
|
||||
|
||||
// for requests coming from azure sdk without provider prefix, we need to set the default provider to azure
|
||||
if ctx != nil {
|
||||
if isAzureUser, ok := ctx.Value(schemas.BifrostContextKeyIsAzureUserAgent).(bool); ok && isAzureUser {
|
||||
defaultProvider = schemas.Azure
|
||||
}
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, defaultProvider))
|
||||
|
||||
input := &schemas.VideoGenerationInput{
|
||||
Prompt: req.Prompt,
|
||||
}
|
||||
if req.InputReference != nil {
|
||||
input.InputReference = schemas.Ptr(providerUtils.FileBytesToBase64DataURL(req.InputReference))
|
||||
}
|
||||
|
||||
return &schemas.BifrostVideoGenerationRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: input,
|
||||
Params: &req.VideoGenerationParameters,
|
||||
Fallbacks: schemas.ParseFallbacks(req.Fallbacks),
|
||||
}
|
||||
}
|
||||
|
||||
// parseVideoGenerationFormDataBodyFromRequest parses the video generation request and writes it to the multipart form.
|
||||
func parseVideoGenerationFormDataBodyFromRequest(writer *multipart.Writer, openaiReq *OpenAIVideoGenerationRequest, providerName schemas.ModelProvider) *schemas.BifrostError {
|
||||
// Add prompt field (required)
|
||||
if openaiReq.Prompt == "" {
|
||||
return providerUtils.NewBifrostOperationError("prompt is required", nil)
|
||||
}
|
||||
if err := writer.WriteField("prompt", openaiReq.Prompt); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write prompt field", err)
|
||||
}
|
||||
|
||||
// Add optional model field
|
||||
if openaiReq.Model != "" {
|
||||
if err := writer.WriteField("model", openaiReq.Model); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write model field", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add optional seconds field
|
||||
if openaiReq.Seconds != nil {
|
||||
if err := writer.WriteField("seconds", *openaiReq.Seconds); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write seconds field", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add optional size field
|
||||
if openaiReq.Size != "" {
|
||||
if err := writer.WriteField("size", openaiReq.Size); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write size field", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add optional input_reference field (image or video file)
|
||||
if len(openaiReq.InputReference) > 0 {
|
||||
// Detect MIME type
|
||||
mimeType := http.DetectContentType(openaiReq.InputReference)
|
||||
|
||||
// Validate and set proper MIME type
|
||||
validMimeTypes := map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/png": true,
|
||||
"image/webp": true,
|
||||
"video/mp4": true,
|
||||
}
|
||||
|
||||
if !validMimeTypes[mimeType] {
|
||||
// Default to image/png if not detected properly
|
||||
mimeType = "image/png"
|
||||
}
|
||||
|
||||
// Determine filename based on MIME type
|
||||
var filename string
|
||||
switch mimeType {
|
||||
case "image/jpeg":
|
||||
filename = "input_reference.jpg"
|
||||
case "image/webp":
|
||||
filename = "input_reference.webp"
|
||||
case "video/mp4":
|
||||
filename = "input_reference.mp4"
|
||||
default:
|
||||
filename = "input_reference.png"
|
||||
}
|
||||
|
||||
// Create form part with proper Content-Type header
|
||||
part, err := writer.CreatePart(map[string][]string{
|
||||
"Content-Disposition": {fmt.Sprintf(`form-data; name="input_reference"; filename="%s"`, filename)},
|
||||
"Content-Type": {mimeType},
|
||||
})
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to create form part for input_reference", err)
|
||||
}
|
||||
if _, err := part.Write(openaiReq.InputReference); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write input_reference file data", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close the multipart writer
|
||||
if err := writer.Close(); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to close multipart writer", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
35
core/providers/openai/websocket.go
Normal file
35
core/providers/openai/websocket.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// SupportsWebSocketMode returns true since OpenAI natively supports the Responses API WebSocket Mode.
|
||||
func (provider *OpenAIProvider) SupportsWebSocketMode() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// WebSocketResponsesURL returns the WebSocket URL for the OpenAI Responses API.
|
||||
// Converts the HTTP base URL to a WSS URL: https://api.openai.com -> wss://api.openai.com/v1/responses
|
||||
func (provider *OpenAIProvider) WebSocketResponsesURL(key schemas.Key) string {
|
||||
base := provider.networkConfig.BaseURL
|
||||
base = strings.Replace(base, "https://", "wss://", 1)
|
||||
base = strings.Replace(base, "http://", "ws://", 1)
|
||||
return base + "/v1/responses"
|
||||
}
|
||||
|
||||
// WebSocketHeaders returns the headers required for the upstream WebSocket connection to OpenAI.
|
||||
func (provider *OpenAIProvider) WebSocketHeaders(key schemas.Key) map[string]string {
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + key.Value.GetValue(),
|
||||
}
|
||||
for k, v := range provider.networkConfig.ExtraHeaders {
|
||||
if strings.EqualFold(k, "Authorization") {
|
||||
continue
|
||||
}
|
||||
headers[k] = v
|
||||
}
|
||||
return headers
|
||||
}
|
||||
Reference in New Issue
Block a user