first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,84 @@
package anthropic_test
import (
"os"
"strings"
"testing"
"github.com/maximhq/bifrost/core/internal/llmtests"
"github.com/maximhq/bifrost/core/schemas"
)
func TestAnthropic(t *testing.T) {
t.Parallel()
if strings.TrimSpace(os.Getenv("ANTHROPIC_API_KEY")) == "" {
t.Skip("Skipping Anthropic tests because ANTHROPIC_API_KEY is not set")
}
client, ctx, cancel, err := llmtests.SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
testConfig := llmtests.ComprehensiveTestConfig{
Provider: schemas.Anthropic,
ChatModel: "claude-sonnet-4-5",
Fallbacks: []schemas.Fallback{
{Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"},
{Provider: schemas.Anthropic, Model: "claude-sonnet-4-20250514"},
},
VisionModel: "claude-sonnet-4-5", // Same model supports vision
ReasoningModel: "claude-opus-4-5",
PromptCachingModel: "claude-sonnet-4-20250514",
PassthroughModel: "claude-sonnet-4-5",
Scenarios: llmtests.TestScenarios{
TextCompletion: false, // Not supported
SimpleChat: true,
CompletionStream: true,
MultiTurnConversation: true,
ToolCalls: true,
ToolCallsStreaming: true,
MultipleToolCalls: true,
MultipleToolCallsStreaming: true,
End2EndToolCalling: true,
AutomaticFunctionCall: true,
WebSearchTool: true,
ImageURL: true,
ImageBase64: true,
MultipleImages: true,
FileBase64: true,
FileURL: true,
CompleteEnd2End: true,
Embedding: false,
Reasoning: true,
PromptCaching: true,
ListModels: true,
BatchCreate: true,
BatchList: true,
BatchRetrieve: true,
BatchCancel: true,
BatchResults: true,
FileUpload: true,
FileList: true,
FileRetrieve: true,
FileDelete: true,
FileContent: false,
FileBatchInput: false, // Anthropic batch API only supports inline requests, not file-based input
CountTokens: true,
StructuredOutputs: true, // Structured outputs with nullable enum support
PassthroughAPI: true,
Compaction: true,
InterleavedThinking: true,
FastMode: false, // Enable when test API key has Opus 4.6 access
EagerInputStreaming: true, // fine-grained-tool-streaming-2025-05-14 (GA on Anthropic)
ServerToolsViaOpenAIEndpoint: true, // web_search / web_fetch / code_execution via /v1/chat/completions
},
}
t.Run("AnthropicTests", func(t *testing.T) {
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
})
}

View File

@@ -0,0 +1,359 @@
package anthropic
import (
"time"
"github.com/maximhq/bifrost/core/schemas"
)
// Anthropic Batch API Types
// AnthropicBatchRequestItem represents a single request in a batch.
type AnthropicBatchRequestItem struct {
CustomID string `json:"custom_id"`
Params map[string]any `json:"params"`
}
// AnthropicBatchCreateRequest represents the request body for creating a batch.
type AnthropicBatchCreateRequest struct {
Requests []AnthropicBatchRequestItem `json:"requests"`
}
// AnthropicBatchCancelRequest represents the request body for canceling a batch.
type AnthropicBatchCancelRequest struct {
BatchID string `json:"batch_id"`
}
// AnthropicBatchRetrieveRequest represents the request body for retrieving a batch.
type AnthropicBatchRetrieveRequest struct {
BatchID string `json:"batch_id"`
}
// AnthropicBatchListRequest represents the request body for listing batches.
type AnthropicBatchListRequest struct {
PageToken *string `json:"page_token"`
PageSize int `json:"page_size"`
}
// AnthropicBatchResultsRequest represents the request body for retrieving batch results.
type AnthropicBatchResultsRequest struct {
BatchID string `json:"batch_id"`
}
// AnthropicBatchResponse represents an Anthropic batch response.
type AnthropicBatchResponse struct {
ID string `json:"id"`
Type string `json:"type"`
ProcessingStatus string `json:"processing_status"`
RequestCounts *AnthropicBatchRequestCounts `json:"request_counts,omitempty"`
EndedAt *string `json:"ended_at,omitempty"`
CreatedAt string `json:"created_at"`
ExpiresAt string `json:"expires_at"`
ArchivedAt *string `json:"archived_at,omitempty"`
CancelInitiatedAt *string `json:"cancel_initiated_at,omitempty"`
ResultsURL *string `json:"results_url,omitempty"`
}
// AnthropicBatchRequestCounts represents the request counts for a batch.
type AnthropicBatchRequestCounts struct {
Processing int `json:"processing"`
Succeeded int `json:"succeeded"`
Errored int `json:"errored"`
Canceled int `json:"canceled"`
Expired int `json:"expired"`
}
// AnthropicBatchListResponse represents the response from listing batches.
type AnthropicBatchListResponse struct {
Data []AnthropicBatchResponse `json:"data"`
HasMore bool `json:"has_more"`
FirstID *string `json:"first_id,omitempty"`
LastID *string `json:"last_id,omitempty"`
}
// AnthropicBatchResultItem represents a single result from a batch.
type AnthropicBatchResultItem struct {
CustomID string `json:"custom_id"`
Result AnthropicBatchResultData `json:"result"`
}
// AnthropicBatchResultData represents the result data.
type AnthropicBatchResultData struct {
Type string `json:"type"` // "succeeded", "errored", "expired", "canceled"
Message map[string]interface{} `json:"message,omitempty"`
Error *AnthropicBatchError `json:"error,omitempty"`
}
// AnthropicBatchError represents an error in batch results.
type AnthropicBatchError struct {
Type string `json:"type"`
Message string `json:"message"`
}
// ToBifrostBatchStatus converts Anthropic processing_status to Bifrost status.
func ToBifrostBatchStatus(status string) schemas.BatchStatus {
switch status {
case "in_progress":
return schemas.BatchStatusInProgress
case "canceling":
return schemas.BatchStatusCancelling
case "ended":
return schemas.BatchStatusEnded
default:
return schemas.BatchStatus(status)
}
}
// parseAnthropicTimestamp converts Anthropic ISO timestamp to Unix timestamp.
func parseAnthropicTimestamp(timestamp string) int64 {
if timestamp == "" {
return 0
}
t, err := time.Parse(time.RFC3339Nano, timestamp)
if err != nil {
return 0
}
return t.Unix()
}
// ToBifrostObjectType converts Anthropic type to Bifrost object type.
func ToBifrostObjectType(anthropicType string) string {
switch anthropicType {
case "message_batch":
return "batch"
default:
return anthropicType
}
}
// ToBifrostBatchCreateResponse converts Anthropic batch response to Bifrost batch create response.
func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse {
expiresAt := parseAnthropicTimestamp(r.ExpiresAt)
resp := &schemas.BifrostBatchCreateResponse{
ID: r.ID,
Object: ToBifrostObjectType(r.Type),
Status: ToBifrostBatchStatus(r.ProcessingStatus),
ProcessingStatus: &r.ProcessingStatus,
ResultsURL: r.ResultsURL,
CreatedAt: parseAnthropicTimestamp(r.CreatedAt),
ExpiresAt: &expiresAt,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
if r.RequestCounts != nil {
resp.RequestCounts = schemas.BatchRequestCounts{
Total: r.RequestCounts.Processing + r.RequestCounts.Succeeded + r.RequestCounts.Errored + r.RequestCounts.Canceled + r.RequestCounts.Expired,
Completed: r.RequestCounts.Succeeded,
Failed: r.RequestCounts.Errored,
Succeeded: r.RequestCounts.Succeeded,
Expired: r.RequestCounts.Expired,
Canceled: r.RequestCounts.Canceled,
Pending: r.RequestCounts.Processing,
}
}
if sendBackRawRequest {
resp.ExtraFields.RawRequest = rawRequest
}
if sendBackRawResponse {
resp.ExtraFields.RawResponse = rawResponse
}
return resp
}
// ToBifrostBatchRetrieveResponse converts Anthropic batch response to Bifrost batch retrieve response.
func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse {
resp := &schemas.BifrostBatchRetrieveResponse{
ID: r.ID,
Object: ToBifrostObjectType(r.Type),
Status: ToBifrostBatchStatus(r.ProcessingStatus),
ProcessingStatus: &r.ProcessingStatus,
ResultsURL: r.ResultsURL,
CreatedAt: parseAnthropicTimestamp(r.CreatedAt),
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
if sendBackRawRequest {
resp.ExtraFields.RawRequest = rawRequest
}
expiresAt := parseAnthropicTimestamp(r.ExpiresAt)
if expiresAt > 0 {
resp.ExpiresAt = &expiresAt
}
if r.EndedAt != nil {
endedAt := parseAnthropicTimestamp(*r.EndedAt)
resp.CompletedAt = &endedAt
}
if r.ArchivedAt != nil {
archivedAt := parseAnthropicTimestamp(*r.ArchivedAt)
resp.ArchivedAt = &archivedAt
}
if r.CancelInitiatedAt != nil {
cancellingAt := parseAnthropicTimestamp(*r.CancelInitiatedAt)
resp.CancellingAt = &cancellingAt
}
if r.RequestCounts != nil {
resp.RequestCounts = schemas.BatchRequestCounts{
Total: r.RequestCounts.Processing + r.RequestCounts.Succeeded + r.RequestCounts.Errored + r.RequestCounts.Canceled + r.RequestCounts.Expired,
Completed: r.RequestCounts.Succeeded,
Failed: r.RequestCounts.Errored,
Succeeded: r.RequestCounts.Succeeded,
Expired: r.RequestCounts.Expired,
Canceled: r.RequestCounts.Canceled,
Pending: r.RequestCounts.Processing,
}
}
if sendBackRawResponse {
resp.ExtraFields.RawResponse = rawResponse
}
return resp
}
// ToAnthropicBatchCreateResponse converts a Bifrost batch create response to Anthropic format.
func ToAnthropicBatchCreateResponse(resp *schemas.BifrostBatchCreateResponse) *AnthropicBatchResponse {
result := &AnthropicBatchResponse{
ID: resp.ID,
Type: "message_batch",
ProcessingStatus: toAnthropicProcessingStatus(resp.Status),
CreatedAt: formatAnthropicTimestamp(resp.CreatedAt),
ResultsURL: resp.ResultsURL,
}
if resp.ExpiresAt != nil {
result.ExpiresAt = formatAnthropicTimestamp(*resp.ExpiresAt)
} else {
// This is a fallback for worst case scenario where expires_at is not available
// Which is never expected to happen, but just in case.
result.ExpiresAt = formatAnthropicTimestamp(time.Now().Add(24 * time.Hour).Unix())
}
if resp.RequestCounts.Total > 0 {
result.RequestCounts = &AnthropicBatchRequestCounts{
Processing: resp.RequestCounts.Pending,
Succeeded: resp.RequestCounts.Succeeded,
Errored: resp.RequestCounts.Failed,
Canceled: resp.RequestCounts.Canceled,
Expired: resp.RequestCounts.Expired,
}
}
return result
}
// ToAnthropicBatchListResponse converts a Bifrost batch list response to Anthropic format.
func ToAnthropicBatchListResponse(resp *schemas.BifrostBatchListResponse) *AnthropicBatchListResponse {
result := &AnthropicBatchListResponse{
Data: make([]AnthropicBatchResponse, len(resp.Data)),
HasMore: resp.HasMore,
FirstID: resp.FirstID,
LastID: resp.LastID,
}
for i, batch := range resp.Data {
result.Data[i] = *ToAnthropicBatchRetrieveResponse(&batch)
}
return result
}
// ToAnthropicBatchRetrieveResponse converts a Bifrost batch retrieve response to Anthropic format.
func ToAnthropicBatchRetrieveResponse(resp *schemas.BifrostBatchRetrieveResponse) *AnthropicBatchResponse {
result := &AnthropicBatchResponse{
ID: resp.ID,
Type: "message_batch",
ProcessingStatus: toAnthropicProcessingStatus(resp.Status),
CreatedAt: formatAnthropicTimestamp(resp.CreatedAt),
ResultsURL: resp.ResultsURL,
}
if resp.ExpiresAt != nil {
result.ExpiresAt = formatAnthropicTimestamp(*resp.ExpiresAt)
}
if resp.CompletedAt != nil {
endedAt := formatAnthropicTimestamp(*resp.CompletedAt)
result.EndedAt = &endedAt
}
if resp.ArchivedAt != nil {
archivedAt := formatAnthropicTimestamp(*resp.ArchivedAt)
result.ArchivedAt = &archivedAt
}
if resp.CancellingAt != nil {
cancelInitiatedAt := formatAnthropicTimestamp(*resp.CancellingAt)
result.CancelInitiatedAt = &cancelInitiatedAt
}
if resp.RequestCounts.Total > 0 {
result.RequestCounts = &AnthropicBatchRequestCounts{
Processing: resp.RequestCounts.Pending,
Succeeded: resp.RequestCounts.Succeeded,
Errored: resp.RequestCounts.Failed,
Canceled: resp.RequestCounts.Canceled,
Expired: resp.RequestCounts.Expired,
}
}
return result
}
// ToAnthropicBatchCancelResponse converts a Bifrost batch cancel response to Anthropic format.
func ToAnthropicBatchCancelResponse(resp *schemas.BifrostBatchCancelResponse) *AnthropicBatchResponse {
result := &AnthropicBatchResponse{
ID: resp.ID,
Type: "message_batch",
ProcessingStatus: toAnthropicProcessingStatus(resp.Status),
}
if resp.CancellingAt != nil {
cancelInitiatedAt := formatAnthropicTimestamp(*resp.CancellingAt)
result.CancelInitiatedAt = &cancelInitiatedAt
}
if resp.RequestCounts.Total > 0 {
result.RequestCounts = &AnthropicBatchRequestCounts{
Processing: resp.RequestCounts.Pending,
Succeeded: resp.RequestCounts.Succeeded,
Canceled: resp.RequestCounts.Canceled,
Expired: resp.RequestCounts.Expired,
Errored: resp.RequestCounts.Failed,
}
}
return result
}
// toAnthropicProcessingStatus converts Bifrost batch status to Anthropic processing_status.
func toAnthropicProcessingStatus(status schemas.BatchStatus) string {
switch status {
case schemas.BatchStatusInProgress:
fallthrough
case schemas.BatchStatusValidating:
return "in_progress"
case schemas.BatchStatusCancelling:
return "canceling"
case schemas.BatchStatusEnded, schemas.BatchStatusCompleted, schemas.BatchStatusCancelled:
return "ended"
default:
return string(status)
}
}
// formatAnthropicTimestamp converts Unix timestamp to Anthropic ISO timestamp format.
func formatAnthropicTimestamp(unixTime int64) string {
if unixTime == 0 {
return ""
}
return time.Unix(unixTime, 0).UTC().Format(time.RFC3339)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,366 @@
package anthropic
import (
"encoding/json"
"testing"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
)
// TestChatTool_ServerToolRoundTrip verifies that every Anthropic server-tool
// variant survives Marshal/Unmarshal through the neutral ChatTool schema.
// This locks in the fix for the user-reported bug where a raw JSON tool like
// {"type":"web_search_20260209","name":"web_search","max_uses":5} was being
// dropped at the neutral-schema layer because ChatTool had no slots for the
// server-tool metadata.
func TestChatTool_ServerToolRoundTrip(t *testing.T) {
five := 5
ptrTrue := true
w, h := 1280, 800
maxChars := 16000
maxContent := 32000
cases := []struct {
name string
raw string
}{
{
name: "web_search_20260209",
raw: `{"type":"web_search_20260209","name":"web_search","max_uses":5,"allowed_callers":["direct"]}`,
},
{
name: "web_search_with_domains",
raw: `{"type":"web_search_20250305","name":"web_search","allowed_domains":["example.com","docs.example.com"]}`,
},
{
name: "web_search_with_user_location",
raw: `{"type":"web_search_20250305","name":"web_search","user_location":{"type":"approximate","city":"San Francisco","country":"US","timezone":"America/Los_Angeles"}}`,
},
{
name: "web_fetch_20260309",
raw: `{"type":"web_fetch_20260309","name":"web_fetch","max_uses":5,"max_content_tokens":32000,"citations":{"enabled":true},"use_cache":true}`,
},
{
name: "computer_20251124",
raw: `{"type":"computer_20251124","name":"computer","display_width_px":1280,"display_height_px":800,"display_number":1,"enable_zoom":true}`,
},
{
name: "text_editor_20250728",
raw: `{"type":"text_editor_20250728","name":"str_replace_based_edit_tool","max_characters":16000}`,
},
{
name: "bash_20250124",
raw: `{"type":"bash_20250124","name":"bash"}`,
},
{
name: "memory_20250818",
raw: `{"type":"memory_20250818","name":"memory"}`,
},
{
name: "code_execution_20250825",
raw: `{"type":"code_execution_20250825","name":"code_execution"}`,
},
{
name: "tool_search_tool_bm25",
raw: `{"type":"tool_search_tool_bm25","name":"tool_search_tool_bm25"}`,
},
{
name: "mcp_toolset",
raw: `{"type":"mcp_toolset","name":"my_mcp","mcp_server_name":"notion","configs":{"search":{"enabled":true}}}`,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// Variant-specific field assertions. Invoked twice — once after
// initial decode, once after round-trip — so that a regression in
// MarshalSorted that silently drops any variant-specific field
// fails this test instead of sneaking through.
assertVariantFields := func(label string, tl schemas.ChatTool) {
t.Helper()
switch tc.name {
case "web_search_20260209":
if tl.MaxUses == nil || *tl.MaxUses != five {
t.Errorf("%s: MaxUses not preserved, got %v", label, tl.MaxUses)
}
if len(tl.AllowedCallers) != 1 || tl.AllowedCallers[0] != "direct" {
t.Errorf("%s: AllowedCallers not preserved, got %v", label, tl.AllowedCallers)
}
case "web_fetch_20260309":
if tl.MaxContentTokens == nil || *tl.MaxContentTokens != maxContent {
t.Errorf("%s: MaxContentTokens not preserved, got %v", label, tl.MaxContentTokens)
}
if tl.Citations == nil || tl.Citations.Enabled == nil || !*tl.Citations.Enabled {
t.Errorf("%s: Citations not preserved, got %v", label, tl.Citations)
}
if tl.UseCache == nil || !*tl.UseCache {
t.Errorf("%s: UseCache not preserved", label)
}
_ = ptrTrue
case "computer_20251124":
if tl.DisplayWidthPx == nil || *tl.DisplayWidthPx != w {
t.Errorf("%s: DisplayWidthPx not preserved, got %v", label, tl.DisplayWidthPx)
}
if tl.DisplayHeightPx == nil || *tl.DisplayHeightPx != h {
t.Errorf("%s: DisplayHeightPx not preserved, got %v", label, tl.DisplayHeightPx)
}
case "text_editor_20250728":
if tl.MaxCharacters == nil || *tl.MaxCharacters != maxChars {
t.Errorf("%s: MaxCharacters not preserved, got %v", label, tl.MaxCharacters)
}
case "mcp_toolset":
if tl.MCPServerName != "notion" {
t.Errorf("%s: MCPServerName not preserved, got %q", label, tl.MCPServerName)
}
if len(tl.Configs) != 1 {
t.Errorf("%s: Configs not preserved, got %v", label, tl.Configs)
}
}
}
var tool schemas.ChatTool
if err := sonic.Unmarshal([]byte(tc.raw), &tool); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
if string(tool.Type) == "" {
t.Errorf("Type should be preserved, got empty")
}
if tool.Name == "" {
t.Errorf("Name should be preserved, got empty")
}
assertVariantFields("first decode", tool)
// Re-marshal and re-decode — all preserved fields should survive round trip.
out, err := schemas.MarshalSorted(tool)
if err != nil {
t.Fatalf("marshal failed: %v", err)
}
var tool2 schemas.ChatTool
if err := sonic.Unmarshal(out, &tool2); err != nil {
t.Fatalf("second unmarshal failed: %v\njson: %s", err, string(out))
}
if tool.Name != tool2.Name || tool.Type != tool2.Type {
t.Errorf("round-trip mismatch\n in: %s\n out: %s", tc.raw, string(out))
}
assertVariantFields("round trip", tool2)
})
}
}
// TestToAnthropicChatRequest_ServerTools verifies every ChatTool server-tool
// shape converts correctly through ToAnthropicChatRequest.
func TestToAnthropicChatRequest_ServerTools(t *testing.T) {
mk := func(rawTool string) *schemas.BifrostChatRequest {
var tool schemas.ChatTool
if err := sonic.Unmarshal([]byte(rawTool), &tool); err != nil {
t.Fatalf("test setup: %v", err)
}
return &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: "claude-sonnet-4-6",
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hi")}}},
Params: &schemas.ChatParameters{Tools: []schemas.ChatTool{tool}},
}
}
type check struct {
expectName string
expectType AnthropicToolType
expectWebSearch bool
expectWebFetch bool
expectComputer bool
expectTextEditor bool
expectMCPToolset bool
}
cases := []struct {
name string
raw string
want check
}{
{
name: "web_search",
raw: `{"type":"web_search_20260209","name":"web_search","max_uses":5}`,
want: check{expectName: "web_search", expectType: "web_search_20260209", expectWebSearch: true},
},
{
name: "web_fetch",
raw: `{"type":"web_fetch_20260309","name":"web_fetch","max_uses":3,"use_cache":true}`,
want: check{expectName: "web_fetch", expectType: "web_fetch_20260309", expectWebFetch: true},
},
{
name: "computer_20251124",
raw: `{"type":"computer_20251124","name":"computer","display_width_px":1280,"display_height_px":800}`,
want: check{expectName: "computer", expectType: "computer_20251124", expectComputer: true},
},
{
name: "text_editor_20250728",
raw: `{"type":"text_editor_20250728","name":"str_replace_based_edit_tool","max_characters":16000}`,
want: check{expectName: "str_replace_based_edit_tool", expectType: "text_editor_20250728", expectTextEditor: true},
},
{
name: "bash_20250124",
raw: `{"type":"bash_20250124","name":"bash"}`,
want: check{expectName: "bash", expectType: "bash_20250124"},
},
{
name: "mcp_toolset",
raw: `{"type":"mcp_toolset","name":"notion","mcp_server_name":"notion"}`,
want: check{expectMCPToolset: true},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
req := mk(tc.raw)
out, err := ToAnthropicChatRequest(nil, req)
if err != nil {
t.Fatalf("conversion failed: %v", err)
}
if len(out.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d (raw: %s)", len(out.Tools), tc.raw)
}
at := out.Tools[0]
if tc.want.expectMCPToolset {
if at.MCPToolset == nil {
t.Errorf("expected MCPToolset to be set")
}
return
}
if at.Name != tc.want.expectName {
t.Errorf("Name: got %q want %q", at.Name, tc.want.expectName)
}
if at.Type == nil || *at.Type != tc.want.expectType {
t.Errorf("Type: got %v want %q", at.Type, tc.want.expectType)
}
if tc.want.expectWebSearch && at.AnthropicToolWebSearch == nil {
t.Errorf("expected AnthropicToolWebSearch populated")
}
if tc.want.expectWebFetch && at.AnthropicToolWebFetch == nil {
t.Errorf("expected AnthropicToolWebFetch populated")
}
if tc.want.expectComputer && at.AnthropicToolComputerUse == nil {
t.Errorf("expected AnthropicToolComputerUse populated")
}
if tc.want.expectTextEditor && at.AnthropicToolTextEditor == nil {
t.Errorf("expected AnthropicToolTextEditor populated")
}
})
}
}
// TestToBifrostResponsesRequest_MCPToolsetPreservesAnthropicFlags verifies
// that when an Anthropic request carries an mcp_toolset tool with the four
// Anthropic-native flags (DeferLoading, AllowedCallers, InputExamples,
// EagerInputStreaming), those flags survive the inbound conversion into the
// neutral ResponsesTool on the mcp_servers merge path. Before the fix, the
// merge path only applied MCP configs (allowlist/cache-control) and dropped
// the flags because convertAnthropicToolToBifrost skips mcp_toolset entries.
func TestToBifrostResponsesRequest_MCPToolsetPreservesAnthropicFlags(t *testing.T) {
toolsetType := "mcp_toolset"
_ = toolsetType // shape documentation only; AnthropicTool.Type is pointer-to-enum and left nil for mcp_toolset
req := &AnthropicMessageRequest{
Model: "claude-sonnet-4-6",
Tools: []AnthropicTool{
{
Name: "notion",
DeferLoading: schemas.Ptr(true),
AllowedCallers: []string{"direct", "agent"},
EagerInputStreaming: schemas.Ptr(false),
InputExamples: []AnthropicToolInputExample{
{Input: json.RawMessage(`{"q":"hello"}`), Description: schemas.Ptr("basic")},
},
MCPToolset: &AnthropicMCPToolsetTool{
Type: "mcp_toolset",
MCPServerName: "notion",
DefaultConfig: &AnthropicMCPToolsetConfig{Enabled: schemas.Ptr(true)},
},
},
},
MCPServers: []AnthropicMCPServerV2{
{Type: "url", URL: "https://mcp.example.com", Name: "notion"},
},
}
got := req.ToBifrostResponsesRequest(nil)
if got == nil || got.Params == nil {
t.Fatalf("ToBifrostResponsesRequest returned nil params")
}
// The mcp_toolset tool should have been dropped by convertAnthropicToolToBifrost
// and re-created on the mcp_servers merge path — end result: exactly one tool,
// of type mcp, carrying the Anthropic flags we set.
if len(got.Params.Tools) != 1 {
t.Fatalf("expected 1 mcp tool after merge, got %d", len(got.Params.Tools))
}
mcp := got.Params.Tools[0]
if mcp.Type != schemas.ResponsesToolTypeMCP {
t.Errorf("expected MCP tool, got type=%q", mcp.Type)
}
if mcp.DeferLoading == nil || !*mcp.DeferLoading {
t.Errorf("DeferLoading dropped on mcp_toolset merge path")
}
if len(mcp.AllowedCallers) != 2 || mcp.AllowedCallers[0] != "direct" {
t.Errorf("AllowedCallers dropped on mcp_toolset merge path, got %v", mcp.AllowedCallers)
}
if len(mcp.InputExamples) != 1 {
t.Errorf("InputExamples dropped on mcp_toolset merge path, got len=%d", len(mcp.InputExamples))
}
if mcp.EagerInputStreaming == nil || *mcp.EagerInputStreaming {
t.Errorf("EagerInputStreaming dropped on mcp_toolset merge path, got %v", mcp.EagerInputStreaming)
}
}
// TestToAnthropicChatRequest_ServerTools_ReproUserBug is the exact shape
// from the reported curl — web_search_20260209 with max_uses + allowed_callers.
// Verifies the request reaches ToAnthropicChatRequest output with a populated
// tools array (previously it was silently dropped).
func TestToAnthropicChatRequest_ServerTools_ReproUserBug(t *testing.T) {
raw := []byte(`{
"model":"claude-sonnet-4-6",
"messages":[{"role":"user","content":"What is the weather in SF?"}],
"tools":[{"name":"web_search","type":"web_search_20260209","max_uses":5,"allowed_callers":["direct"]}]
}`)
// Unmarshal through the neutral schema the way the OpenAI endpoint does.
var inner struct {
Model string `json:"model"`
Messages []json.RawMessage `json:"messages"`
Tools []schemas.ChatTool `json:"tools"`
}
if err := sonic.Unmarshal(raw, &inner); err != nil {
t.Fatalf("outer unmarshal: %v", err)
}
if len(inner.Tools) != 1 {
t.Fatalf("setup: expected 1 tool in raw JSON, got %d", len(inner.Tools))
}
if inner.Tools[0].Name == "" {
t.Errorf("Name lost at neutral-schema decode (was the bug). Got: %+v", inner.Tools[0])
}
if inner.Tools[0].MaxUses == nil {
t.Errorf("MaxUses lost at neutral-schema decode (was the bug)")
}
req := &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: inner.Model,
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hi")}}},
Params: &schemas.ChatParameters{Tools: inner.Tools},
}
out, err := ToAnthropicChatRequest(nil, req)
if err != nil {
t.Fatalf("conversion failed: %v", err)
}
if len(out.Tools) != 1 {
t.Fatalf("repro bug: expected 1 tool after conversion, got %d (tools array was empty — this was the bug)", len(out.Tools))
}
if out.Tools[0].Name != "web_search" {
t.Errorf("tool Name: got %q, want %q", out.Tools[0].Name, "web_search")
}
if out.Tools[0].AnthropicToolWebSearch == nil ||
out.Tools[0].AnthropicToolWebSearch.MaxUses == nil ||
*out.Tools[0].AnthropicToolWebSearch.MaxUses != 5 {
t.Errorf("tool max_uses lost: %+v", out.Tools[0])
}
}

View File

@@ -0,0 +1,752 @@
package anthropic
import (
"encoding/json"
"strings"
"testing"
"github.com/maximhq/bifrost/core/schemas"
)
func TestToAnthropicChatRequest_PreservesPropertyOrder(t *testing.T) {
params := &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("chain_of_thought", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "string"),
schemas.KV("description", "Reasoning steps"),
)),
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "string"),
schemas.KV("description", "The answer"),
)),
schemas.KV("citations", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "array"),
)),
schemas.KV("is_unanswered", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "boolean"),
)),
),
Required: []string{"answer", "is_unanswered"},
}
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: "claude-sonnet-4-20250514",
Input: []schemas.ChatMessage{{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
}},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "AnswerResponseModel",
Description: schemas.Ptr("Extract answer"),
Parameters: params,
},
}},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) == 0 {
t.Fatal("expected at least one tool")
}
inputSchema := result.Tools[0].InputSchema
if inputSchema == nil {
t.Fatal("expected InputSchema to be non-nil")
}
// CoT: property order preserved
keys := inputSchema.Properties.Keys()
expected := []string{"chain_of_thought", "answer", "citations", "is_unanswered"}
if len(keys) != len(expected) {
t.Fatalf("expected %d properties, got %d: %v", len(expected), len(keys), keys)
}
for i, k := range expected {
if keys[i] != k {
t.Errorf("property %d: expected %q, got %q (full order: %v)", i, k, keys[i], keys)
}
}
}
func TestToAnthropicChatRequest_CachingDeterminism(t *testing.T) {
makeReq := func(props *schemas.OrderedMap) *schemas.BifrostChatRequest {
return &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: "claude-sonnet-4-20250514",
Input: []schemas.ChatMessage{{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: new("test")},
}},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "test",
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
},
}
}
// Version A: type before description
propsA := schemas.NewOrderedMapFromPairs(
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "string"),
schemas.KV("description", "Step by step"),
)),
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "string"),
schemas.KV("description", "Final answer"),
)),
)
// Version B: description before type
propsB := schemas.NewOrderedMapFromPairs(
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(
schemas.KV("description", "Step by step"),
schemas.KV("type", "string"),
)),
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
schemas.KV("description", "Final answer"),
schemas.KV("type", "string"),
)),
)
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
resultA, err := ToAnthropicChatRequest(ctx, makeReq(propsA))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
resultB, err := ToAnthropicChatRequest(ctx, makeReq(propsB))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
jsonA, err := schemas.Marshal(resultA.Tools[0].InputSchema)
if err != nil {
t.Fatalf("failed to marshal params A: %v", err)
}
jsonB, err := schemas.Marshal(resultB.Tools[0].InputSchema)
if err != nil {
t.Fatalf("failed to marshal params B: %v", err)
}
if string(jsonA) != string(jsonB) {
t.Errorf("caching broken: same schema produced different JSON\nA: %s\nB: %s", jsonA, jsonB)
}
}
func TestToAnthropicChatRequest_NestedProperties_Preserved(t *testing.T) {
params := &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("output", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "object"),
schemas.KV("properties", schemas.NewOrderedMapFromPairs(
schemas.KV("verdict", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
schemas.KV("score", schemas.NewOrderedMapFromPairs(schemas.KV("type", "number"))),
schemas.KV("explanation", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
)),
)),
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
),
}
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: "claude-sonnet-4-20250514",
Input: []schemas.ChatMessage{{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
}},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "nested_tool",
Parameters: params,
},
}},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) == 0 {
t.Fatal("expected at least one tool")
}
inputSchema := result.Tools[0].InputSchema
// CoT: top-level property order preserved
keys := inputSchema.Properties.Keys()
if len(keys) != 2 || keys[0] != "output" || keys[1] != "reasoning" {
t.Errorf("expected top-level property order [output, reasoning], got %v", keys)
}
// CoT: nested property order preserved
output, ok := inputSchema.Properties.Get("output")
if !ok {
t.Fatal("expected output property")
}
outputOM, ok := output.(*schemas.OrderedMap)
if !ok {
t.Fatalf("expected output to be *schemas.OrderedMap, got %T", output)
}
nestedProps, ok := outputOM.Get("properties")
if !ok {
t.Fatal("expected nested properties in output")
}
nestedPropsOM, ok := nestedProps.(*schemas.OrderedMap)
if !ok {
t.Fatalf("expected nested properties to be *schemas.OrderedMap, got %T", nestedProps)
}
nestedKeys := nestedPropsOM.Keys()
if len(nestedKeys) != 3 || nestedKeys[0] != "verdict" || nestedKeys[1] != "score" || nestedKeys[2] != "explanation" {
t.Errorf("expected nested property order [verdict, score, explanation], got %v", nestedKeys)
}
}
// TestToAnthropicChatRequest_ToolInputKeyOrderPreservation verifies that tool_use input
// arguments preserve the client's original key ordering after conversion to Anthropic format.
// This is critical for prompt caching, which relies on exact byte-for-byte prefix matching.
// The test uses multiple parallel tool calls in a single assistant message — each with
// a different key ordering — matching real-world Claude Code usage patterns.
func TestToAnthropicChatRequest_ToolInputKeyOrderPreservation(t *testing.T) {
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: "claude-sonnet-4-20250514",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
},
{
// Multiple parallel tool calls with different key orderings per block
Role: schemas.ChatMessageRoleAssistant,
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ToolCalls: []schemas.ChatAssistantMessageToolCall{
{
Index: 0,
Type: schemas.Ptr("function"),
ID: schemas.Ptr("toolu_vrtx_013t7gabfKz98BKpdwrnS6LP"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("bash"),
Arguments: `{"description":"Find references to auth_injector quickly","timeout":30000,"command":"grep -r \"auth_injector\" . --include=\"Makefile\" -l 2>/dev/null"}`,
},
},
{
Index: 1,
Type: schemas.Ptr("function"),
ID: schemas.Ptr("toolu_vrtx_01K2kr3wi7M4RriLgE7Kq3vJ"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("bash"),
Arguments: `{"command":"git diff main...HEAD --stat","description":"Show diff of commits in branch"}`,
},
},
{
Index: 2,
Type: schemas.Ptr("function"),
ID: schemas.Ptr("toolu_vrtx_01D1mMkcvpfqGrEhkcxUQpGc"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("bash"),
Arguments: `{"command":"git log main..HEAD --format=\"%H %s\" | head -20","description":"Show detailed commits in branch"}`,
},
},
},
},
},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Collect all tool_use content blocks
var toolUseBlocks []AnthropicContentBlock
for _, msg := range result.Messages {
for _, block := range msg.Content.ContentBlocks {
if block.Type == AnthropicContentBlockTypeToolUse {
toolUseBlocks = append(toolUseBlocks, block)
}
}
}
if len(toolUseBlocks) != 3 {
t.Fatalf("expected 3 tool_use blocks, got %d", len(toolUseBlocks))
}
// Block 0: keys should be description, timeout, command (NOT alphabetical)
json0, _ := json.Marshal(toolUseBlocks[0].Input)
s0 := string(json0)
descIdx0 := strings.Index(s0, `"description"`)
timeIdx0 := strings.Index(s0, `"timeout"`)
cmdIdx0 := strings.Index(s0, `"command"`)
if descIdx0 < 0 || timeIdx0 < 0 || cmdIdx0 < 0 {
t.Fatalf("block 0: missing expected key(s) in: %s", s0)
}
if !(descIdx0 < timeIdx0 && timeIdx0 < cmdIdx0) {
t.Errorf("block 0: key order not preserved, expected description < timeout < command in: %s", s0)
}
// Block 1: keys should be command, description (NOT alphabetical)
json1, _ := json.Marshal(toolUseBlocks[1].Input)
s1 := string(json1)
cmdIdx1 := strings.Index(s1, `"command"`)
descIdx1 := strings.Index(s1, `"description"`)
if cmdIdx1 < 0 || descIdx1 < 0 {
t.Fatalf("block 1: missing expected key(s) in: %s", s1)
}
if !(cmdIdx1 < descIdx1) {
t.Errorf("block 1: key order not preserved, expected command < description in: %s", s1)
}
// Block 2: keys should be command, description (same as block 1)
json2, _ := json.Marshal(toolUseBlocks[2].Input)
s2 := string(json2)
cmdIdx2 := strings.Index(s2, `"command"`)
descIdx2 := strings.Index(s2, `"description"`)
if cmdIdx2 < 0 || descIdx2 < 0 {
t.Fatalf("block 2: missing expected key(s) in: %s", s2)
}
if !(cmdIdx2 < descIdx2) {
t.Errorf("block 2: key order not preserved, expected command < description in: %s", s2)
}
}
func TestToBifrostChatResponse_MultipleTextBlocksWithThinking(t *testing.T) {
thinkingText := "Let me reason step by step about this problem."
textBlock1 := "The answer is 42."
textBlock2 := "Here is why that is the case."
signature := "sig_abc123"
response := &AnthropicMessageResponse{
ID: "msg_test123",
Type: "message",
Role: "assistant",
Model: "claude-opus-4-6-20250514",
Content: []AnthropicContentBlock{
{
Type: AnthropicContentBlockTypeThinking,
Thinking: &thinkingText,
Signature: &signature,
},
{
Type: AnthropicContentBlockTypeText,
Text: &textBlock1,
},
{
Type: AnthropicContentBlockTypeText,
Text: &textBlock2,
},
},
StopReason: "end_turn",
Usage: &AnthropicUsage{
InputTokens: 100,
OutputTokens: 50,
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result := response.ToBifrostChatResponse(ctx)
if result == nil {
t.Fatal("expected non-nil result")
}
// With multiple text blocks, ToBifrostChatResponse preserves them as ContentBlocks
// (only a single text block collapses to ContentStr — see chat.go:812-815).
// Thinking flows through ReasoningDetails below, not ContentStr.
choice := result.Choices[0]
msg := choice.ChatNonStreamResponseChoice.Message
if msg.Content.ContentStr != nil {
t.Errorf("expected ContentStr to be nil with multiple text blocks, got %q", *msg.Content.ContentStr)
}
if len(msg.Content.ContentBlocks) != 2 {
t.Fatalf("expected 2 content blocks (one per text block), got %d", len(msg.Content.ContentBlocks))
}
if msg.Content.ContentBlocks[0].Text == nil || *msg.Content.ContentBlocks[0].Text != textBlock1 {
t.Errorf("block 0 text mismatch: got %v, want %q", msg.Content.ContentBlocks[0].Text, textBlock1)
}
if msg.Content.ContentBlocks[1].Text == nil || *msg.Content.ContentBlocks[1].Text != textBlock2 {
t.Errorf("block 1 text mismatch: got %v, want %q", msg.Content.ContentBlocks[1].Text, textBlock2)
}
// Thinking is surfaced via ReasoningDetails with the signature preserved
// (see chat.go:798-807).
if msg.ChatAssistantMessage == nil {
t.Fatal("expected ChatAssistantMessage to be non-nil")
}
rd := msg.ChatAssistantMessage.ReasoningDetails
if len(rd) != 1 {
t.Fatalf("expected 1 reasoning details entry (the thinking block), got %d", len(rd))
}
if rd[0].Type != schemas.BifrostReasoningDetailsTypeText {
t.Errorf("expected reasoning detail type %s, got %s", schemas.BifrostReasoningDetailsTypeText, rd[0].Type)
}
if rd[0].Signature == nil || *rd[0].Signature != signature {
t.Error("expected thinking signature to be preserved on reasoning detail")
}
if rd[0].Text == nil || *rd[0].Text != thinkingText {
t.Errorf("expected reasoning text to match thinking text")
}
}
func TestToBifrostChatResponse_SingleTextBlockNoThinking(t *testing.T) {
// Verify existing behavior: single text block without thinking collapses to string
text := "Simple response"
response := &AnthropicMessageResponse{
ID: "msg_simple",
Type: "message",
Role: "assistant",
Model: "claude-sonnet-4-6-20250514",
Content: []AnthropicContentBlock{
{Type: AnthropicContentBlockTypeText, Text: &text},
},
StopReason: "end_turn",
Usage: &AnthropicUsage{InputTokens: 10, OutputTokens: 5},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result := response.ToBifrostChatResponse(ctx)
msg := result.Choices[0].ChatNonStreamResponseChoice.Message
if msg.Content.ContentStr == nil || *msg.Content.ContentStr != text {
t.Error("expected ContentStr to be the text")
}
if msg.Content.ContentBlocks != nil {
t.Error("expected ContentBlocks to be nil")
}
// No reasoning details for plain text
if msg.ChatAssistantMessage != nil && len(msg.ChatAssistantMessage.ReasoningDetails) > 0 {
t.Error("expected no reasoning details for single text block without thinking")
}
}
func TestToAnthropicChatRequest_BoundaryMismatchFallback(t *testing.T) {
// If content was modified by the client, boundaries won't match — fall back to single text block
signature := "sig_fallback"
modifiedContent := "The user edited this content"
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: "claude-opus-4-6-20250514",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("Hi")},
},
{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{ContentStr: &modifiedContent},
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ReasoningDetails: []schemas.ChatReasoningDetails{
{Index: 0, Type: schemas.BifrostReasoningDetailsTypeText, Text: &modifiedContent, Signature: &signature},
},
},
},
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("Continue")},
},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var assistantMsg *AnthropicMessage
for i := range result.Messages {
if result.Messages[i].Role == "assistant" {
assistantMsg = &result.Messages[i]
break
}
}
if assistantMsg == nil {
t.Fatal("expected assistant message")
}
// Should have thinking block (from reasoning_details with signature) + single text fallback
blocks := assistantMsg.Content.ContentBlocks
// First block: thinking (from reasoning_details, text is nil since it was cleared)
// Plus: fallback single text block with the full modified content
foundText := false
for _, block := range blocks {
if block.Type == AnthropicContentBlockTypeText {
if block.Text != nil && *block.Text == modifiedContent {
foundText = true
}
}
}
if !foundText {
t.Error("expected fallback to single text block with full content")
}
}
func TestToAnthropicChatRequest_NormalFlowUnchanged(t *testing.T) {
// Verify that the normal multi-turn flow (reasoning_details with text + signature,
// no bifrost.content_blocks) produces the same output as before.
thinkingText := "I need to think about this carefully"
signature := "sig_normal"
responseText := "Here is my answer"
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: "claude-opus-4-6-20250514",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("What is 2+2?")},
},
{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{ContentStr: &responseText},
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ReasoningDetails: []schemas.ChatReasoningDetails{
{
Index: 0,
Type: schemas.BifrostReasoningDetailsTypeText,
Text: &thinkingText,
Signature: &signature,
},
},
},
},
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("Are you sure?")},
},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var assistantMsg *AnthropicMessage
for i := range result.Messages {
if result.Messages[i].Role == "assistant" {
assistantMsg = &result.Messages[i]
break
}
}
if assistantMsg == nil {
t.Fatal("expected assistant message")
}
blocks := assistantMsg.Content.ContentBlocks
if len(blocks) != 2 {
t.Fatalf("expected 2 content blocks (thinking + text), got %d", len(blocks))
}
// Block 0: thinking with original text and signature
if blocks[0].Type != AnthropicContentBlockTypeThinking {
t.Errorf("block 0: expected thinking, got %s", blocks[0].Type)
}
if blocks[0].Thinking == nil || *blocks[0].Thinking != thinkingText {
t.Errorf("block 0: expected thinking text %q, got %v", thinkingText, blocks[0].Thinking)
}
if blocks[0].Signature == nil || *blocks[0].Signature != signature {
t.Errorf("block 0: expected signature %q, got %v", signature, blocks[0].Signature)
}
// Block 1: text with response
if blocks[1].Type != AnthropicContentBlockTypeText {
t.Errorf("block 1: expected text, got %s", blocks[1].Type)
}
if blocks[1].Text == nil || *blocks[1].Text != responseText {
t.Errorf("block 1: expected text %q, got %v", responseText, blocks[1].Text)
}
}
func TestToAnthropicChatRequest_Opus47_StripsTemperatureTopPTopK(t *testing.T) {
temp := 0.7
topP := 0.9
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: "claude-opus-4-7-20260401",
Input: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hi")}},
},
Params: &schemas.ChatParameters{
Temperature: &temp,
TopP: &topP,
ExtraParams: map[string]interface{}{"top_k": 40},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Temperature != nil {
t.Errorf("expected Temperature to be nil for Opus 4.7, got %v", result.Temperature)
}
if result.TopP != nil {
t.Errorf("expected TopP to be nil for Opus 4.7, got %v", result.TopP)
}
if result.TopK != nil {
t.Errorf("expected TopK to be nil for Opus 4.7, got %v", result.TopK)
}
}
func TestToAnthropicChatRequest_NonOpus47_PreservesTemperature(t *testing.T) {
temp := 0.7
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: "claude-opus-4-6-20250514",
Input: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hi")}},
},
Params: &schemas.ChatParameters{
Temperature: &temp,
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Temperature == nil || *result.Temperature != temp {
t.Errorf("expected Temperature %v, got %v", temp, result.Temperature)
}
}
func TestToAnthropicChatRequest_Opus47_ReasoningMaxTokens_AdaptiveOnly(t *testing.T) {
maxTok := 2048
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: "claude-opus-4-7-20260401",
Input: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("think")}},
},
Params: &schemas.ChatParameters{
MaxCompletionTokens: schemas.Ptr(8192),
Reasoning: &schemas.ChatReasoning{MaxTokens: &maxTok},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Thinking == nil {
t.Fatal("expected Thinking to be set")
}
if result.Thinking.Type != "adaptive" {
t.Errorf("expected thinking type 'adaptive' for Opus 4.7, got %q", result.Thinking.Type)
}
if result.Thinking.BudgetTokens != nil {
t.Errorf("expected BudgetTokens to be nil for Opus 4.7, got %v", result.Thinking.BudgetTokens)
}
}
func TestToAnthropicChatRequest_NonOpus47_ReasoningMaxTokens_EnabledWithBudget(t *testing.T) {
maxTok := 2048
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: "claude-opus-4-6-20250514",
Input: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("think")}},
},
Params: &schemas.ChatParameters{
MaxCompletionTokens: schemas.Ptr(8192),
Reasoning: &schemas.ChatReasoning{MaxTokens: &maxTok},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Thinking == nil {
t.Fatal("expected Thinking to be set")
}
if result.Thinking.Type != "enabled" {
t.Errorf("expected thinking type 'enabled' for Opus 4.6, got %q", result.Thinking.Type)
}
if result.Thinking.BudgetTokens == nil || *result.Thinking.BudgetTokens != maxTok {
t.Errorf("expected BudgetTokens %d, got %v", maxTok, result.Thinking.BudgetTokens)
}
}
func TestToAnthropicChatRequest_Opus47_ReasoningEffort_AdaptiveWithEffort(t *testing.T) {
effort := "high"
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Anthropic,
Model: "claude-opus-4-7-20260401",
Input: []schemas.ChatMessage{
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("think")}},
},
Params: &schemas.ChatParameters{
MaxCompletionTokens: schemas.Ptr(8192),
Reasoning: &schemas.ChatReasoning{Effort: &effort},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result.Thinking == nil {
t.Fatal("expected Thinking to be set")
}
if result.Thinking.Type != "adaptive" {
t.Errorf("expected thinking type 'adaptive' for Opus 4.7 effort-based, got %q", result.Thinking.Type)
}
if result.OutputConfig == nil || result.OutputConfig.Effort == nil {
t.Error("expected OutputConfig.Effort to be set for Opus 4.7 effort-based reasoning")
}
}

View File

@@ -0,0 +1,703 @@
package anthropic
import (
"context"
"testing"
"github.com/maximhq/bifrost/core/schemas"
)
// --- isCompactionItem tests ---
func TestIsCompactionItem(t *testing.T) {
t.Parallel()
tests := []struct {
name string
item *schemas.ResponsesMessage
expected bool
}{
{
name: "nil item",
item: nil,
expected: false,
},
{
name: "nil type",
item: &schemas.ResponsesMessage{
Content: &schemas.ResponsesMessageContent{
ContentBlocks: []schemas.ResponsesMessageContentBlock{
{Type: schemas.ResponsesOutputMessageContentTypeCompaction},
},
},
},
expected: false,
},
{
name: "message type with compaction content block",
item: &schemas.ResponsesMessage{
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
Content: &schemas.ResponsesMessageContent{
ContentBlocks: []schemas.ResponsesMessageContentBlock{
{
Type: schemas.ResponsesOutputMessageContentTypeCompaction,
ResponsesOutputMessageContentCompaction: &schemas.ResponsesOutputMessageContentCompaction{
Summary: "Summary of conversation",
},
},
},
},
},
expected: true,
},
{
name: "message type with text content block",
item: &schemas.ResponsesMessage{
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
Content: &schemas.ResponsesMessageContent{
ContentBlocks: []schemas.ResponsesMessageContentBlock{
{
Type: schemas.ResponsesOutputMessageContentTypeText,
Text: schemas.Ptr("Hello"),
},
},
},
},
expected: false,
},
{
name: "function call type",
item: &schemas.ResponsesMessage{
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
},
expected: false,
},
{
name: "message type with nil content",
item: &schemas.ResponsesMessage{
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
Content: nil,
},
expected: false,
},
{
name: "message type with empty content blocks",
item: &schemas.ResponsesMessage{
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
Content: &schemas.ResponsesMessageContent{
ContentBlocks: []schemas.ResponsesMessageContentBlock{},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isCompactionItem(tt.item)
if result != tt.expected {
t.Errorf("isCompactionItem() = %v, want %v", result, tt.expected)
}
})
}
}
// --- Streaming: Anthropic → Bifrost (inbound) ---
func TestToBifrostResponsesStream_CompactionContentBlockStart(t *testing.T) {
t.Parallel()
state := &AnthropicResponsesStreamState{
ContentIndexToOutputIndex: make(map[int]int),
ContentIndexToBlockType: make(map[int]AnthropicContentBlockType),
ToolArgumentBuffers: make(map[int]string),
MCPCallOutputIndices: make(map[int]bool),
ItemIDs: make(map[int]string),
OutputItems: make(map[int]*schemas.ResponsesMessage),
ReasoningSignatures: make(map[int]string),
TextContentIndices: make(map[int]bool),
ReasoningContentIndices: make(map[int]bool),
CompactionContentIndices: make(map[int]*schemas.CacheControl),
CurrentOutputIndex: 0,
CreatedAt: 1234567890,
HasEmittedCreated: true,
HasEmittedInProgress: true,
}
// content_block_start with compaction type should return nil (defers to delta)
chunk := &AnthropicStreamEvent{
Type: AnthropicStreamEventTypeContentBlockStart,
Index: schemas.Ptr(0),
ContentBlock: &AnthropicContentBlock{
Type: AnthropicContentBlockTypeCompaction,
CacheControl: &schemas.CacheControl{
Type: "ephemeral",
},
},
}
responses, err, isLast := chunk.ToBifrostResponsesStream(context.Background(), 0, state)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if isLast {
t.Error("should not be last chunk")
}
if len(responses) != 0 {
t.Errorf("expected 0 responses for compaction content_block_start, got %d", len(responses))
}
// Verify state was tracked
if _, exists := state.CompactionContentIndices[0]; !exists {
t.Error("expected compaction to be tracked in CompactionContentIndices")
}
if blockType, exists := state.ContentIndexToBlockType[0]; !exists || blockType != AnthropicContentBlockTypeCompaction {
t.Error("expected compaction block type tracked in ContentIndexToBlockType")
}
}
func TestToBifrostResponsesStream_CompactionDelta(t *testing.T) {
t.Parallel()
state := &AnthropicResponsesStreamState{
ContentIndexToOutputIndex: map[int]int{0: 0},
ContentIndexToBlockType: map[int]AnthropicContentBlockType{0: AnthropicContentBlockTypeCompaction},
ToolArgumentBuffers: make(map[int]string),
MCPCallOutputIndices: make(map[int]bool),
ItemIDs: map[int]string{0: "cmp_0"},
OutputItems: make(map[int]*schemas.ResponsesMessage),
ReasoningSignatures: make(map[int]string),
TextContentIndices: make(map[int]bool),
ReasoningContentIndices: make(map[int]bool),
CompactionContentIndices: map[int]*schemas.CacheControl{0: {Type: "ephemeral"}},
CurrentOutputIndex: 1,
CreatedAt: 1234567890,
HasEmittedCreated: true,
HasEmittedInProgress: true,
}
summary := "The user asked about building a website. We discussed HTML, CSS, and JavaScript."
chunk := &AnthropicStreamEvent{
Type: AnthropicStreamEventTypeContentBlockDelta,
Index: schemas.Ptr(0),
Delta: &AnthropicStreamDelta{
Type: AnthropicStreamDeltaTypeCompaction,
Content: &summary,
},
}
responses, err, isLast := chunk.ToBifrostResponsesStream(context.Background(), 0, state)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if isLast {
t.Error("should not be last chunk")
}
// Should emit output_item.added and output_item.done
if len(responses) != 2 {
t.Fatalf("expected 2 responses for compaction delta, got %d", len(responses))
}
// First: output_item.added
added := responses[0]
if added.Type != schemas.ResponsesStreamResponseTypeOutputItemAdded {
t.Errorf("first response type = %v, want %v", added.Type, schemas.ResponsesStreamResponseTypeOutputItemAdded)
}
if added.Item == nil || added.Item.Content == nil || len(added.Item.Content.ContentBlocks) == 0 {
t.Fatal("output_item.added should have content blocks")
}
block := added.Item.Content.ContentBlocks[0]
if block.Type != schemas.ResponsesOutputMessageContentTypeCompaction {
t.Errorf("content block type = %v, want compaction", block.Type)
}
if block.ResponsesOutputMessageContentCompaction == nil {
t.Fatal("expected compaction content to be non-nil")
}
if block.ResponsesOutputMessageContentCompaction.Summary != summary {
t.Errorf("summary = %q, want %q", block.ResponsesOutputMessageContentCompaction.Summary, summary)
}
// Cache control should be preserved from content_block_start
if block.CacheControl == nil || block.CacheControl.Type != "ephemeral" {
t.Error("expected cache control to be preserved")
}
// Second: output_item.done
done := responses[1]
if done.Type != schemas.ResponsesStreamResponseTypeOutputItemDone {
t.Errorf("second response type = %v, want %v", done.Type, schemas.ResponsesStreamResponseTypeOutputItemDone)
}
}
func TestToBifrostResponsesStream_CompactionContentBlockStop(t *testing.T) {
t.Parallel()
state := &AnthropicResponsesStreamState{
ContentIndexToOutputIndex: map[int]int{0: 0},
ContentIndexToBlockType: map[int]AnthropicContentBlockType{0: AnthropicContentBlockTypeCompaction},
ToolArgumentBuffers: make(map[int]string),
MCPCallOutputIndices: make(map[int]bool),
ItemIDs: map[int]string{0: "cmp_0"},
OutputItems: make(map[int]*schemas.ResponsesMessage),
ReasoningSignatures: make(map[int]string),
TextContentIndices: make(map[int]bool),
ReasoningContentIndices: make(map[int]bool),
CompactionContentIndices: make(map[int]*schemas.CacheControl),
CurrentOutputIndex: 1,
CreatedAt: 1234567890,
HasEmittedCreated: true,
HasEmittedInProgress: true,
}
chunk := &AnthropicStreamEvent{
Type: AnthropicStreamEventTypeContentBlockStop,
Index: schemas.Ptr(0),
}
responses, err, isLast := chunk.ToBifrostResponsesStream(context.Background(), 0, state)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if isLast {
t.Error("should not be last chunk")
}
// content_block_stop for compaction should return nil (done was already emitted with delta)
if len(responses) != 0 {
t.Errorf("expected 0 responses for compaction content_block_stop, got %d", len(responses))
}
}
// --- Streaming: Bifrost → Anthropic (outbound, non-passthrough) ---
func TestToAnthropicResponsesStreamResponse_CompactionOutputItemAdded(t *testing.T) {
t.Parallel()
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
summary := "Summary of the conversation about building a website"
bifrostResp := &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
OutputIndex: schemas.Ptr(0),
Item: &schemas.ResponsesMessage{
ID: schemas.Ptr("cmp_test123"),
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
Status: schemas.Ptr("completed"),
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
Content: &schemas.ResponsesMessageContent{
ContentBlocks: []schemas.ResponsesMessageContentBlock{
{
Type: schemas.ResponsesOutputMessageContentTypeCompaction,
ResponsesOutputMessageContentCompaction: &schemas.ResponsesOutputMessageContentCompaction{
Summary: summary,
},
CacheControl: &schemas.CacheControl{Type: "ephemeral"},
},
},
},
},
}
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
// Should emit: content_block_start (compaction) + content_block_delta (compaction_delta)
if len(events) < 2 {
t.Fatalf("expected at least 2 events, got %d", len(events))
}
// Event 1: content_block_start
start := events[0]
if start.Type != AnthropicStreamEventTypeContentBlockStart {
t.Errorf("event[0] type = %v, want content_block_start", start.Type)
}
if start.ContentBlock == nil {
t.Fatal("content_block_start should have ContentBlock")
}
if start.ContentBlock.Type != AnthropicContentBlockTypeCompaction {
t.Errorf("ContentBlock.Type = %v, want compaction", start.ContentBlock.Type)
}
if start.ContentBlock.CacheControl == nil || start.ContentBlock.CacheControl.Type != "ephemeral" {
t.Error("expected cache control to be preserved on content_block_start")
}
// Event 2: content_block_delta with compaction_delta
delta := events[1]
if delta.Type != AnthropicStreamEventTypeContentBlockDelta {
t.Errorf("event[1] type = %v, want content_block_delta", delta.Type)
}
if delta.Delta == nil {
t.Fatal("content_block_delta should have Delta")
}
if delta.Delta.Type != AnthropicStreamDeltaTypeCompaction {
t.Errorf("Delta.Type = %v, want compaction_delta", delta.Delta.Type)
}
if delta.Delta.Content == nil || *delta.Delta.Content != summary {
t.Errorf("Delta.Content = %v, want %q", delta.Delta.Content, summary)
}
}
func TestToAnthropicResponsesStreamResponse_CompactionOutputItemDone(t *testing.T) {
t.Parallel()
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
bifrostResp := &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeOutputItemDone,
OutputIndex: schemas.Ptr(0),
ItemID: schemas.Ptr("cmp_test123"),
Item: &schemas.ResponsesMessage{
ID: schemas.Ptr("cmp_test123"),
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
Status: schemas.Ptr("completed"),
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
Content: &schemas.ResponsesMessageContent{
ContentBlocks: []schemas.ResponsesMessageContentBlock{
{
Type: schemas.ResponsesOutputMessageContentTypeCompaction,
ResponsesOutputMessageContentCompaction: &schemas.ResponsesOutputMessageContentCompaction{
Summary: "Summary text",
},
},
},
},
},
}
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
// Should emit content_block_stop
if len(events) != 1 {
t.Fatalf("expected 1 event for output_item.done, got %d", len(events))
}
stop := events[0]
if stop.Type != AnthropicStreamEventTypeContentBlockStop {
t.Errorf("event type = %v, want content_block_stop", stop.Type)
}
}
func TestToAnthropicResponsesStreamResponse_TextOutputItemAdded_NotAffectedByCompactionCheck(t *testing.T) {
t.Parallel()
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
// Regular text message should still emit content_block_start with type=text
bifrostResp := &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
OutputIndex: schemas.Ptr(0),
Item: &schemas.ResponsesMessage{
ID: schemas.Ptr("msg_test123"),
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
Status: schemas.Ptr("in_progress"),
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
Content: &schemas.ResponsesMessageContent{
ContentBlocks: []schemas.ResponsesMessageContentBlock{
{
Type: schemas.ResponsesOutputMessageContentTypeText,
Text: schemas.Ptr(""),
},
},
},
},
}
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
if len(events) == 0 {
t.Fatal("expected at least 1 event")
}
start := events[0]
if start.Type != AnthropicStreamEventTypeContentBlockStart {
t.Errorf("event type = %v, want content_block_start", start.Type)
}
if start.ContentBlock == nil {
t.Fatal("expected ContentBlock to be non-nil")
}
if start.ContentBlock.Type != AnthropicContentBlockTypeText {
t.Errorf("ContentBlock.Type = %v, want text", start.ContentBlock.Type)
}
}
// --- Non-Streaming: stop_reason mapping ---
func TestToBifrostResponsesResponse_PreservesStopReason(t *testing.T) {
t.Parallel()
tests := []struct {
name string
stopReason AnthropicStopReason
expectedStopReason string
}{
{
name: "compaction stop reason",
stopReason: AnthropicStopReasonCompaction,
expectedStopReason: "compaction",
},
{
name: "end_turn stop reason",
stopReason: AnthropicStopReasonEndTurn,
expectedStopReason: "end_turn",
},
{
name: "tool_use stop reason",
stopReason: AnthropicStopReasonToolUse,
expectedStopReason: "tool_use",
},
{
name: "max_tokens stop reason",
stopReason: AnthropicStopReasonMaxTokens,
expectedStopReason: "max_tokens",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
resp := &AnthropicMessageResponse{
ID: "msg_test",
Type: "message",
Role: "assistant",
Model: "claude-sonnet-4-6",
StopReason: tt.stopReason,
Content: []AnthropicContentBlock{
{Type: AnthropicContentBlockTypeText, Text: schemas.Ptr("Hello")},
},
}
bifrostResp := resp.ToBifrostResponsesResponse(ctx)
if bifrostResp.StopReason == nil {
t.Fatal("expected StopReason to be non-nil")
}
if *bifrostResp.StopReason != tt.expectedStopReason {
t.Errorf("StopReason = %q, want %q", *bifrostResp.StopReason, tt.expectedStopReason)
}
})
}
}
func TestToBifrostResponsesResponse_EmptyStopReason(t *testing.T) {
t.Parallel()
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
resp := &AnthropicMessageResponse{
ID: "msg_test",
Type: "message",
Role: "assistant",
Model: "claude-sonnet-4-6",
Content: []AnthropicContentBlock{},
}
bifrostResp := resp.ToBifrostResponsesResponse(ctx)
if bifrostResp.StopReason != nil {
t.Errorf("expected nil StopReason for empty stop_reason, got %q", *bifrostResp.StopReason)
}
}
func TestToAnthropicResponsesResponse_StopReasonFromBifrost(t *testing.T) {
t.Parallel()
tests := []struct {
name string
stopReason *string
contentBlocks []schemas.ResponsesMessage
expectedReason AnthropicStopReason
}{
{
name: "compaction stop reason from bifrost",
stopReason: schemas.Ptr("compaction"),
expectedReason: AnthropicStopReasonCompaction,
},
{
name: "end_turn mapped from stop",
stopReason: schemas.Ptr("stop"),
expectedReason: AnthropicStopReasonEndTurn,
},
{
name: "tool_use mapped from tool_calls",
stopReason: schemas.Ptr("tool_calls"),
expectedReason: AnthropicStopReasonToolUse,
},
{
name: "nil stop_reason defaults to end_turn",
stopReason: nil,
expectedReason: AnthropicStopReasonEndTurn,
},
{
name: "nil stop_reason with tool_use content defaults to tool_use",
stopReason: nil,
contentBlocks: []schemas.ResponsesMessage{
{
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: schemas.Ptr("call_123"),
Name: schemas.Ptr("my_tool"),
},
},
},
expectedReason: AnthropicStopReasonToolUse,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
bifrostResp := &schemas.BifrostResponsesResponse{
ID: schemas.Ptr("resp_test"),
Model: "claude-sonnet-4-6",
StopReason: tt.stopReason,
Output: tt.contentBlocks,
}
result := ToAnthropicResponsesResponse(ctx, bifrostResp)
if result.StopReason != tt.expectedReason {
t.Errorf("StopReason = %v, want %v", result.StopReason, tt.expectedReason)
}
})
}
}
// --- Non-Streaming: compaction content block round-trip ---
func TestCompactionContentBlock_NonStreamingRoundTrip(t *testing.T) {
t.Parallel()
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
summary := "The user requested help building a web scraper using Python with BeautifulSoup."
// Simulate Anthropic response with compaction block
anthropicResp := &AnthropicMessageResponse{
ID: "msg_compaction_test",
Type: "message",
Role: "assistant",
Model: "claude-opus-4-6",
StopReason: AnthropicStopReasonCompaction,
Content: []AnthropicContentBlock{
{
Type: AnthropicContentBlockTypeCompaction,
Content: &AnthropicContent{
ContentStr: &summary,
},
CacheControl: &schemas.CacheControl{Type: "ephemeral"},
},
},
}
// Step 1: Anthropic → Bifrost
bifrostResp := anthropicResp.ToBifrostResponsesResponse(ctx)
if bifrostResp.StopReason == nil || *bifrostResp.StopReason != "compaction" {
t.Fatalf("expected stop_reason='compaction', got %v", bifrostResp.StopReason)
}
if len(bifrostResp.Output) == 0 {
t.Fatal("expected at least one output message")
}
// Find the compaction block
var foundCompaction bool
for _, msg := range bifrostResp.Output {
if msg.Content != nil {
for _, block := range msg.Content.ContentBlocks {
if block.Type == schemas.ResponsesOutputMessageContentTypeCompaction {
foundCompaction = true
if block.ResponsesOutputMessageContentCompaction == nil {
t.Fatal("expected compaction content to be non-nil")
}
if block.ResponsesOutputMessageContentCompaction.Summary != summary {
t.Errorf("summary = %q, want %q", block.ResponsesOutputMessageContentCompaction.Summary, summary)
}
}
}
}
}
if !foundCompaction {
t.Error("compaction block not found in Bifrost output")
}
// Step 2: Bifrost → Anthropic
result := ToAnthropicResponsesResponse(ctx, bifrostResp)
if result.StopReason != AnthropicStopReasonCompaction {
t.Errorf("result StopReason = %v, want compaction", result.StopReason)
}
// Find compaction content block in result
var foundResultCompaction bool
for _, block := range result.Content {
if block.Type == AnthropicContentBlockTypeCompaction {
foundResultCompaction = true
if block.Content == nil || block.Content.ContentStr == nil {
t.Fatal("expected compaction content string")
}
if *block.Content.ContentStr != summary {
t.Errorf("result summary = %q, want %q", *block.Content.ContentStr, summary)
}
if block.CacheControl == nil || block.CacheControl.Type != "ephemeral" {
t.Error("expected cache control to be preserved")
}
}
}
if !foundResultCompaction {
t.Error("compaction block not found in Anthropic result")
}
}
// --- Streaming: compaction stop_reason in response.completed ---
func TestToAnthropicResponsesStreamResponse_CompletedWithCompactionStopReason(t *testing.T) {
t.Parallel()
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
bifrostResp := &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeCompleted,
Response: &schemas.BifrostResponsesResponse{
ID: schemas.Ptr("resp_test"),
Model: "claude-opus-4-6",
StopReason: schemas.Ptr("compaction"),
Usage: &schemas.ResponsesResponseUsage{
InputTokens: 1000,
OutputTokens: 500,
TotalTokens: 1500,
},
},
}
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
// Should emit message_delta + message_stop
if len(events) != 2 {
t.Fatalf("expected 2 events for response.completed, got %d", len(events))
}
// message_delta should have stop_reason=compaction
messageDelta := events[0]
if messageDelta.Type != AnthropicStreamEventTypeMessageDelta {
t.Errorf("event[0] type = %v, want message_delta", messageDelta.Type)
}
if messageDelta.Delta == nil || messageDelta.Delta.StopReason == nil {
t.Fatal("expected Delta.StopReason in message_delta")
}
if *messageDelta.Delta.StopReason != AnthropicStopReasonCompaction {
t.Errorf("StopReason = %v, want compaction", *messageDelta.Delta.StopReason)
}
// message_stop
messageStop := events[1]
if messageStop.Type != AnthropicStreamEventTypeMessageStop {
t.Errorf("event[1] type = %v, want message_stop", messageStop.Type)
}
}

View File

@@ -0,0 +1,34 @@
package anthropic
import (
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostCountTokensResponse converts an Anthropic count tokens response to Bifrost format
func (resp *AnthropicCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
if resp == nil {
return nil
}
totalTokens := resp.InputTokens
bifrostResp := &schemas.BifrostCountTokensResponse{
Model: model,
InputTokens: resp.InputTokens,
TotalTokens: &totalTokens,
Object: "response.input_tokens",
}
return bifrostResp
}
// ToAnthropicCountTokensResponse converts a Bifrost count tokens response to Anthropic format.
func ToAnthropicCountTokensResponse(bifrostResp *schemas.BifrostCountTokensResponse) *AnthropicCountTokensResponse {
if bifrostResp == nil {
return nil
}
return &AnthropicCountTokensResponse{
InputTokens: bifrostResp.InputTokens,
}
}

View File

@@ -0,0 +1,68 @@
package anthropic
import (
"fmt"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// ToAnthropicChatCompletionError converts a BifrostError to AnthropicMessageError
func ToAnthropicChatCompletionError(bifrostErr *schemas.BifrostError) *AnthropicMessageError {
if bifrostErr == nil {
return nil
}
// Safely extract type and message from nested error
errorType := "api_error"
message := ""
if bifrostErr.Error != nil {
if bifrostErr.Error.Type != nil && *bifrostErr.Error.Type != "" {
errorType = *bifrostErr.Error.Type
}
message = bifrostErr.Error.Message
}
// Handle nested error fields with nil checks
errorStruct := AnthropicMessageErrorStruct{
Type: errorType,
Message: message,
}
return &AnthropicMessageError{
Type: "error", // always "error" for Anthropic
Error: errorStruct,
}
}
// ToAnthropicResponsesStreamError converts a BifrostError to Anthropic responses streaming error in SSE format
func ToAnthropicResponsesStreamError(bifrostErr *schemas.BifrostError) string {
if bifrostErr == nil {
return ""
}
anthropicErr := ToAnthropicChatCompletionError(bifrostErr)
// Marshal to JSON
jsonData, err := providerUtils.MarshalSorted(anthropicErr)
if err != nil {
return ""
}
// Format as Anthropic SSE error event
return fmt.Sprintf("event: error\ndata: %s\n\n", jsonData)
}
func parseAnthropicError(resp *fasthttp.Response) *schemas.BifrostError {
var errorResp AnthropicError
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
if errorResp.Error != nil {
if bifrostErr.Error == nil {
bifrostErr.Error = &schemas.ErrorField{}
}
bifrostErr.Error.Type = &errorResp.Error.Type
bifrostErr.Error.Message = errorResp.Error.Message
}
return bifrostErr
}

View File

@@ -0,0 +1,96 @@
package anthropic
import (
"testing"
schemas "github.com/maximhq/bifrost/core/schemas"
)
func TestToAnthropicChatCompletionError(t *testing.T) {
strPtr := func(s string) *string { return &s }
tests := []struct {
name string
input *schemas.BifrostError
expectNil bool
expectedType string
}{
{
name: "nil BifrostError returns nil",
input: nil,
expectNil: true,
},
{
name: "nil ErrorField.Type defaults to api_error",
input: &schemas.BifrostError{
Error: &schemas.ErrorField{
Type: nil,
Message: "connection failed",
},
},
expectedType: "api_error",
},
{
name: "empty string Type defaults to api_error",
input: &schemas.BifrostError{
Error: &schemas.ErrorField{
Type: strPtr(""),
Message: "rate limited",
},
},
expectedType: "api_error",
},
{
name: "valid Type is preserved",
input: &schemas.BifrostError{
Error: &schemas.ErrorField{
Type: strPtr("rate_limit_error"),
Message: "rate limited",
},
},
expectedType: "rate_limit_error",
},
{
name: "internal Type is preserved",
input: &schemas.BifrostError{
Error: &schemas.ErrorField{
Type: strPtr("request_cancelled"),
Message: "cancelled",
},
},
expectedType: "request_cancelled",
},
{
name: "nil Error field defaults to api_error",
input: &schemas.BifrostError{
Error: nil,
},
expectedType: "api_error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ToAnthropicChatCompletionError(tt.input)
if tt.expectNil {
if result != nil {
t.Fatalf("expected nil, got %+v", result)
}
return
}
if result == nil {
t.Fatal("expected non-nil result")
}
if result.Type != "error" {
t.Errorf("expected top-level Type %q, got %q", "error", result.Type)
}
if result.Error.Type != tt.expectedType {
t.Errorf("expected error Type %q, got %q", tt.expectedType, result.Error.Type)
}
})
}
}

View File

@@ -0,0 +1,71 @@
package anthropic
import (
"time"
"github.com/maximhq/bifrost/core/schemas"
)
// ToAnthropicFileUploadResponse converts a Bifrost file upload response to Anthropic format.
func ToAnthropicFileUploadResponse(resp *schemas.BifrostFileUploadResponse) *AnthropicFileResponse {
return &AnthropicFileResponse{
ID: resp.ID,
Type: resp.Object,
Filename: resp.Filename,
MimeType: "",
SizeBytes: resp.Bytes,
CreatedAt: formatAnthropicFileTimestamp(resp.CreatedAt),
}
}
// ToAnthropicFileListResponse converts a Bifrost file list response to Anthropic format.
func ToAnthropicFileListResponse(resp *schemas.BifrostFileListResponse) *AnthropicFileListResponse {
data := make([]AnthropicFileResponse, len(resp.Data))
for i, file := range resp.Data {
data[i] = AnthropicFileResponse{
ID: file.ID,
Type: file.Object,
Filename: file.Filename,
MimeType: "",
SizeBytes: file.Bytes,
CreatedAt: formatAnthropicFileTimestamp(file.CreatedAt),
}
}
return &AnthropicFileListResponse{
Data: data,
HasMore: resp.HasMore,
}
}
// ToAnthropicFileRetrieveResponse converts a Bifrost file retrieve response to Anthropic format.
func ToAnthropicFileRetrieveResponse(resp *schemas.BifrostFileRetrieveResponse) *AnthropicFileResponse {
return &AnthropicFileResponse{
ID: resp.ID,
Type: resp.Object,
Filename: resp.Filename,
MimeType: "", // Not supported in Bifrost responses
SizeBytes: resp.Bytes,
CreatedAt: formatAnthropicFileTimestamp(resp.CreatedAt),
}
}
// ToAnthropicFileDeleteResponse converts a Bifrost file delete response to Anthropic format.
func ToAnthropicFileDeleteResponse(resp *schemas.BifrostFileDeleteResponse) *AnthropicFileDeleteResponse {
respType := "file"
if resp.Deleted {
respType = "file_deleted"
}
return &AnthropicFileDeleteResponse{
ID: resp.ID,
Type: respType,
}
}
// formatAnthropicFileTimestamp converts Unix timestamp to Anthropic ISO timestamp format.
func formatAnthropicFileTimestamp(unixTime int64) string {
if unixTime == 0 {
return ""
}
return time.Unix(unixTime, 0).UTC().Format(time.RFC3339)
}

View File

@@ -0,0 +1,108 @@
package anthropic
import (
"strings"
"time"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(response.Data)),
FirstID: response.FirstID,
LastID: response.LastID,
HasMore: schemas.Ptr(response.HasMore),
}
// Map Anthropic's cursor-based pagination to Bifrost's token-based pagination.
// If there are more results, set next_page_token to last_id for the next request.
if response.HasMore && response.LastID != nil {
bifrostResponse.NextPageToken = *response.LastID
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: providerKey,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
for _, model := range response.Data {
for _, result := range pipeline.FilterModel(model.ID) {
resolvedKey := strings.ToLower(result.ResolvedID)
if included[resolvedKey] {
continue
}
entry := schemas.Model{
ID: string(providerKey) + "/" + result.ResolvedID,
Name: schemas.Ptr(model.DisplayName),
Created: schemas.Ptr(model.CreatedAt.Unix()),
MaxInputTokens: model.MaxInputTokens,
MaxOutputTokens: model.MaxTokens,
ProviderExtra: model.Capabilities,
}
if result.AliasValue != "" {
entry.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, entry)
included[resolvedKey] = true
}
}
bifrostResponse.Data = append(bifrostResponse.Data,
pipeline.BackfillModels(included)...)
return bifrostResponse
}
func ToAnthropicListModelsResponse(response *schemas.BifrostListModelsResponse) *AnthropicListModelsResponse {
if response == nil {
return nil
}
anthropicResponse := &AnthropicListModelsResponse{
Data: make([]AnthropicModel, 0, len(response.Data)),
}
if response.FirstID != nil {
anthropicResponse.FirstID = response.FirstID
}
if response.LastID != nil {
anthropicResponse.LastID = response.LastID
}
if response.HasMore != nil {
anthropicResponse.HasMore = *response.HasMore
}
for _, model := range response.Data {
_, modelID := schemas.ParseModelString(model.ID, schemas.Anthropic)
anthropicModel := AnthropicModel{
ID: modelID,
Type: "model",
MaxInputTokens: model.MaxInputTokens,
MaxTokens: model.MaxOutputTokens,
Capabilities: model.ProviderExtra,
}
if model.Name != nil {
anthropicModel.DisplayName = *model.Name
}
if model.Created != nil {
anthropicModel.CreatedAt = time.Unix(*model.Created, 0)
}
anthropicResponse.Data = append(anthropicResponse.Data, anthropicModel)
}
return anthropicResponse
}

View File

@@ -0,0 +1,52 @@
package anthropic
import (
"testing"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPayloadOrdering_AnthropicMessageRequest(t *testing.T) {
req := &AnthropicMessageRequest{
Model: "claude-sonnet-4-20250514",
MaxTokens: 1024,
Messages: []AnthropicMessage{
{
Role: "user",
Content: AnthropicContent{ContentStr: schemas.Ptr("hello")},
},
},
Temperature: schemas.Ptr(0.7),
Stream: schemas.Ptr(true),
Tools: []AnthropicTool{
{
Name: "get_weather",
Description: schemas.Ptr("Get weather"),
InputSchema: &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("location", map[string]interface{}{"type": "string"}),
),
Required: []string{"location"},
},
},
},
}
result, err := providerUtils.MarshalSorted(req)
require.NoError(t, err)
golden := `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"temperature":0.7,"stream":true,"tools":[{"name":"get_weather","description":"Get weather","input_schema":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}]}`
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
// Determinism: 100 iterations must produce identical bytes
for i := 0; i < 100; i++ {
iter, err := providerUtils.MarshalSorted(req)
require.NoError(t, err)
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,137 @@
package anthropic
import (
"fmt"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToAnthropicTextCompletionRequest converts a Bifrost text completion request to Anthropic format
func ToAnthropicTextCompletionRequest(bifrostReq *schemas.BifrostTextCompletionRequest) *AnthropicTextRequest {
if bifrostReq == nil {
return nil
}
prompt := ""
if bifrostReq.Input.PromptStr != nil {
prompt = *bifrostReq.Input.PromptStr
} else if len(bifrostReq.Input.PromptArray) > 0 {
prompt = strings.Join(bifrostReq.Input.PromptArray, "\n\n")
}
anthropicReq := &AnthropicTextRequest{
Model: bifrostReq.Model,
Prompt: fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", prompt),
MaxTokensToSample: providerUtils.GetMaxOutputTokensOrDefault(bifrostReq.Model, AnthropicDefaultMaxTokens),
}
// Convert parameters
if bifrostReq.Params != nil {
if bifrostReq.Params.MaxTokens != nil {
anthropicReq.MaxTokensToSample = *bifrostReq.Params.MaxTokens
}
anthropicReq.Temperature = bifrostReq.Params.Temperature
anthropicReq.TopP = bifrostReq.Params.TopP
anthropicReq.StopSequences = bifrostReq.Params.Stop
if bifrostReq.Params.ExtraParams != nil {
anthropicReq.ExtraParams = bifrostReq.Params.ExtraParams
if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok {
delete(anthropicReq.ExtraParams, "top_k")
anthropicReq.TopK = topK
}
}
}
return anthropicReq
}
// ToBifrostTextCompletionRequest converts an Anthropic text request back to Bifrost format
func (req *AnthropicTextRequest) ToBifrostTextCompletionRequest(ctx *schemas.BifrostContext) *schemas.BifrostTextCompletionRequest {
if req == nil {
return nil
}
provider, model := schemas.ParseModelString(req.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Anthropic))
bifrostReq := &schemas.BifrostTextCompletionRequest{
Provider: provider,
Model: model,
Input: &schemas.TextCompletionInput{
PromptStr: &req.Prompt,
},
Params: &schemas.TextCompletionParameters{
MaxTokens: &req.MaxTokensToSample,
Temperature: req.Temperature,
TopP: req.TopP,
Stop: req.StopSequences,
},
Fallbacks: schemas.ParseFallbacks(req.Fallbacks),
}
// Add extra params if present
if req.TopK != nil {
bifrostReq.Params.ExtraParams = map[string]interface{}{
"top_k": *req.TopK,
}
}
return bifrostReq
}
// ToBifrostTextCompletionResponse converts an Anthropic text response back to Bifrost format
func (response *AnthropicTextResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
if response == nil {
return nil
}
return &schemas.BifrostTextCompletionResponse{
ID: response.ID,
Object: "text_completion",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
Text: &response.Completion,
},
},
},
Usage: &schemas.BifrostLLMUsage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
Model: response.Model,
}
}
// ToAnthropicTextCompletionResponse converts a BifrostResponse back to Anthropic text completion format
func ToAnthropicTextCompletionResponse(bifrostResp *schemas.BifrostTextCompletionResponse) *AnthropicTextResponse {
if bifrostResp == nil {
return nil
}
anthropicResp := &AnthropicTextResponse{
ID: bifrostResp.ID,
Type: "completion",
Model: bifrostResp.Model,
}
// Convert choices to completion text
if len(bifrostResp.Choices) > 0 {
choice := bifrostResp.Choices[0] // Anthropic text API typically returns one choice
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
anthropicResp.Completion = *choice.TextCompletionResponseChoice.Text
}
}
// Convert usage information
if bifrostResp.Usage != nil {
anthropicResp.Usage.InputTokens = bifrostResp.Usage.PromptTokens
anthropicResp.Usage.OutputTokens = bifrostResp.Usage.CompletionTokens
}
return anthropicResp
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,138 @@
package anthropic
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
)
// TestValidateChatToolsForProvider locks in the partition:
// function/custom tools always survive; server tools survive only when the
// target provider's ProviderFeatures flag is true for that tool type.
func TestValidateChatToolsForProvider(t *testing.T) {
fnTool := schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{Name: "get_weather"},
}
serverTool := func(tpe, name string) schemas.ChatTool {
return schemas.ChatTool{Type: schemas.ChatToolType(tpe), Name: name}
}
cases := []struct {
name string
provider schemas.ModelProvider
input []schemas.ChatTool
wantKeep int
wantDropped []string
assertNotes string
}{
{
name: "function tools always survive on any provider",
provider: schemas.Bedrock,
input: []schemas.ChatTool{fnTool, fnTool},
wantKeep: 2,
},
{
name: "bedrock drops web_search",
provider: schemas.Bedrock,
input: []schemas.ChatTool{serverTool("web_search_20260209", "web_search")},
wantKeep: 0,
wantDropped: []string{"web_search_20260209"},
assertNotes: "Bedrock has WebSearch=false per Table 20 (AWS user guide beta-header list + Anthropic overview)",
},
{
name: "bedrock drops web_fetch + code_execution + mcp_toolset",
provider: schemas.Bedrock,
input: []schemas.ChatTool{
serverTool("web_fetch_20260309", "web_fetch"),
serverTool("code_execution_20250825", "code_execution"),
serverTool("mcp_toolset", "notion"),
},
wantKeep: 0,
wantDropped: []string{"web_fetch_20260309", "code_execution_20250825", "mcp_toolset"},
},
{
name: "bedrock keeps computer/bash/memory/text_editor/tool_search",
provider: schemas.Bedrock,
input: []schemas.ChatTool{
serverTool("computer_20251124", "computer"),
serverTool("bash_20250124", "bash"),
serverTool("memory_20250818", "memory"),
serverTool("text_editor_20250728", "str_replace_based_edit_tool"),
serverTool("tool_search_tool_bm25", "tool_search_tool_bm25"),
},
wantKeep: 5,
},
{
name: "bedrock partial drop mixes function + server tools",
provider: schemas.Bedrock,
input: []schemas.ChatTool{
fnTool,
serverTool("web_search_20260209", "web_search"),
serverTool("bash_20250124", "bash"),
},
wantKeep: 2, // fnTool + bash
wantDropped: []string{"web_search_20260209"},
},
{
name: "vertex drops web_fetch",
provider: schemas.Vertex,
input: []schemas.ChatTool{serverTool("web_fetch_20260309", "web_fetch")},
wantKeep: 0,
wantDropped: []string{"web_fetch_20260309"},
assertNotes: "Vertex has WebFetch=false per Table 20",
},
{
name: "vertex drops mcp_toolset",
provider: schemas.Vertex,
input: []schemas.ChatTool{serverTool("mcp_toolset", "notion")},
wantKeep: 0,
wantDropped: []string{"mcp_toolset"},
assertNotes: "Vertex has MCP=false per MCP-excl (explicit exclusion in Anthropic docs)",
},
{
name: "anthropic keeps everything",
provider: schemas.Anthropic,
input: []schemas.ChatTool{
serverTool("web_search_20260209", "web_search"),
serverTool("web_fetch_20260309", "web_fetch"),
serverTool("code_execution_20250825", "code_execution"),
serverTool("mcp_toolset", "x"),
serverTool("computer_20251124", "computer"),
},
wantKeep: 5,
},
{
name: "unknown provider keeps everything (forward-compat)",
provider: schemas.ModelProvider("custom-new-provider"),
input: []schemas.ChatTool{serverTool("web_search_20260209", "web_search")},
wantKeep: 1,
},
{
name: "unknown tool type on known provider is kept (forward-compat)",
provider: schemas.Bedrock,
input: []schemas.ChatTool{serverTool("future_tool_20270101", "future")},
wantKeep: 1,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
keep, dropped := ValidateChatToolsForProvider(tc.input, tc.provider)
if len(keep) != tc.wantKeep {
t.Errorf("keep count: got %d, want %d (%s)", len(keep), tc.wantKeep, tc.assertNotes)
}
if len(dropped) != len(tc.wantDropped) {
t.Errorf("dropped count: got %v, want %v", dropped, tc.wantDropped)
}
for i, d := range tc.wantDropped {
if i >= len(dropped) {
break
}
if dropped[i] != d {
t.Errorf("dropped[%d]: got %q, want %q", i, dropped[i], d)
}
}
})
}
}

View File

@@ -0,0 +1,270 @@
package anthropic
import (
"encoding/json"
"testing"
"github.com/maximhq/bifrost/core/schemas"
)
// TestWebSearch_OutputItemAdded_StoresID verifies that a WebSearch function_call
// output_item.added event stores the item ID in the per-request stream state so that
// subsequent argument deltas can be skipped.
func TestWebSearch_OutputItemAdded_StoresID(t *testing.T) {
t.Parallel()
const itemID = "toolu_ws_storesid_test"
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
bifrostResp := &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
OutputIndex: schemas.Ptr(0),
Item: &schemas.ResponsesMessage{
ID: schemas.Ptr(itemID),
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: schemas.Ptr(itemID),
Name: schemas.Ptr("WebSearch"),
Arguments: schemas.Ptr(""),
},
},
}
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
// Should emit content_block_start
if len(events) == 0 {
t.Fatal("expected at least one event")
}
if events[0].Type != AnthropicStreamEventTypeContentBlockStart {
t.Errorf("event[0].Type = %v, want content_block_start", events[0].Type)
}
if events[0].ContentBlock == nil || events[0].ContentBlock.Input == nil {
t.Fatal("expected ContentBlock with Input")
}
if string(events[0].ContentBlock.Input) != "{}" {
t.Errorf("ContentBlock.Input = %s, want {}", events[0].ContentBlock.Input)
}
// ID must now be tracked in per-request state
state := getOrCreateAnthropicToResponsesStreamState(ctx)
if !state.webSearchItemIDs[itemID] {
t.Error("expected item ID to be stored in per-request stream state after output_item.added")
}
}
// TestWebSearch_FunctionCallArgumentsDelta_Skipped verifies that argument deltas
// for a tracked WebSearch item are skipped (returning nil) regardless of the
// user agent — the fix for the original bug where non-Claude Code clients lost
// the query.
func TestWebSearch_FunctionCallArgumentsDelta_Skipped(t *testing.T) {
t.Parallel()
const itemID = "toolu_ws_skip_test"
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
// Pre-seed per-request state as if output_item.added already fired
state := getOrCreateAnthropicToResponsesStreamState(ctx)
state.webSearchItemIDs = map[string]bool{itemID: true}
partial := `{"query": "world news"`
bifrostResp := &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
OutputIndex: schemas.Ptr(0),
ItemID: schemas.Ptr(itemID),
Delta: &partial,
}
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
if len(events) != 0 {
t.Errorf("expected deltas to be skipped (0 events), got %d", len(events))
}
}
// TestWebSearch_OutputItemDone_GeneratesSyntheticDeltas verifies that when
// output_item.done fires for a tracked WebSearch item, synthetic input_json_delta
// events carrying the full query are emitted, followed by content_block_stop.
// This applies for ALL clients regardless of user agent.
func TestWebSearch_OutputItemDone_GeneratesSyntheticDeltas(t *testing.T) {
t.Parallel()
const itemID = "toolu_ws_synth_test"
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
// Pre-seed per-request state as if output_item.added already fired
state := getOrCreateAnthropicToResponsesStreamState(ctx)
state.webSearchItemIDs = map[string]bool{itemID: true}
query := `{"query":"world news today"}`
bifrostResp := &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeOutputItemDone,
OutputIndex: schemas.Ptr(1),
Item: &schemas.ResponsesMessage{
ID: schemas.Ptr(itemID),
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: schemas.Ptr(itemID),
Name: schemas.Ptr("WebSearch"),
Arguments: &query,
},
},
}
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
// Must have at least one input_json_delta and a final content_block_stop
if len(events) < 2 {
t.Fatalf("expected at least 2 events (deltas + stop), got %d", len(events))
}
// All events except last must be input_json_delta
for i, ev := range events[:len(events)-1] {
if ev.Type != AnthropicStreamEventTypeContentBlockDelta {
t.Errorf("event[%d].Type = %v, want content_block_delta", i, ev.Type)
continue
}
if ev.Delta == nil || ev.Delta.Type != AnthropicStreamDeltaTypeInputJSON {
t.Errorf("event[%d].Delta.Type = %v, want input_json", i, ev.Delta)
}
}
// Last event must be content_block_stop
last := events[len(events)-1]
if last.Type != AnthropicStreamEventTypeContentBlockStop {
t.Errorf("last event.Type = %v, want content_block_stop", last.Type)
}
// Reconstruct the accumulated JSON from the deltas
var accumulated string
for _, ev := range events[:len(events)-1] {
if ev.Delta != nil && ev.Delta.PartialJSON != nil {
accumulated += *ev.Delta.PartialJSON
}
}
var got map[string]interface{}
if err := json.Unmarshal([]byte(accumulated), &got); err != nil {
t.Fatalf("accumulated JSON invalid: %v — got %q", err, accumulated)
}
if got["query"] != "world news today" {
t.Errorf("query = %v, want %q", got["query"], "world news today")
}
// ID must have been cleaned up from per-request state
if state.webSearchItemIDs[itemID] {
t.Error("expected item ID to be removed from per-request stream state after output_item.done")
}
}
// TestWebSearch_FullFlow_AnyUserAgent is the regression test for the original bug.
// It simulates the complete streaming sequence:
//
// output_item.added → FunctionCallArgumentsDelta (×N) → output_item.done
//
// and verifies that the client-facing Anthropic stream contains proper
// input_json_delta events with the query, regardless of user agent.
func TestWebSearch_FullFlow_AnyUserAgent(t *testing.T) {
t.Parallel()
const itemID = "toolu_ws_fullflow_test"
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
var allEvents []*AnthropicStreamEvent
// Step 1: output_item.added
addedResp := &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
OutputIndex: schemas.Ptr(0),
Item: &schemas.ResponsesMessage{
ID: schemas.Ptr(itemID),
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: schemas.Ptr(itemID),
Name: schemas.Ptr("WebSearch"),
Arguments: schemas.Ptr(""),
},
},
}
allEvents = append(allEvents, ToAnthropicResponsesStreamResponse(ctx, addedResp)...)
// Step 2: FunctionCallArgumentsDelta events (should be skipped)
for _, partial := range []string{`{"query": "`, `latest AI`, `news"}`} {
p := partial
deltaResp := &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
OutputIndex: schemas.Ptr(0),
ItemID: schemas.Ptr(itemID),
Delta: &p,
}
allEvents = append(allEvents, ToAnthropicResponsesStreamResponse(ctx, deltaResp)...)
}
// Step 3: output_item.done with full accumulated arguments
fullArgs := `{"query":"latest AI news"}`
doneResp := &schemas.BifrostResponsesStreamResponse{
Type: schemas.ResponsesStreamResponseTypeOutputItemDone,
OutputIndex: schemas.Ptr(0),
Item: &schemas.ResponsesMessage{
ID: schemas.Ptr(itemID),
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
ResponsesToolMessage: &schemas.ResponsesToolMessage{
CallID: schemas.Ptr(itemID),
Name: schemas.Ptr("WebSearch"),
Arguments: &fullArgs,
},
},
}
allEvents = append(allEvents, ToAnthropicResponsesStreamResponse(ctx, doneResp)...)
// Verify the sequence:
// [0] content_block_start (input:{})
// [1..N-1] input_json_delta events
// [N] content_block_stop
if len(allEvents) < 3 {
t.Fatalf("expected at least 3 events, got %d: %v", len(allEvents), allEvents)
}
// First event: content_block_start with empty input
if allEvents[0].Type != AnthropicStreamEventTypeContentBlockStart {
t.Errorf("allEvents[0].Type = %v, want content_block_start", allEvents[0].Type)
}
// Last event: content_block_stop
last := allEvents[len(allEvents)-1]
if last.Type != AnthropicStreamEventTypeContentBlockStop {
t.Errorf("last event.Type = %v, want content_block_stop", last.Type)
}
// Middle events: all input_json_delta
for i, ev := range allEvents[1 : len(allEvents)-1] {
if ev.Type != AnthropicStreamEventTypeContentBlockDelta {
t.Errorf("allEvents[%d].Type = %v, want content_block_delta", i+1, ev.Type)
}
if ev.Delta == nil || ev.Delta.Type != AnthropicStreamDeltaTypeInputJSON {
t.Errorf("allEvents[%d].Delta.Type = %v, want input_json", i+1, ev.Delta)
}
}
// Reconstruct query from synthetic deltas
var accumulated string
for _, ev := range allEvents[1 : len(allEvents)-1] {
if ev.Delta != nil && ev.Delta.PartialJSON != nil {
accumulated += *ev.Delta.PartialJSON
}
}
var got map[string]interface{}
if err := json.Unmarshal([]byte(accumulated), &got); err != nil {
t.Fatalf("reconstructed JSON is invalid: %v — got %q", err, accumulated)
}
if got["query"] != "latest AI news" {
t.Errorf("reconstructed query = %v, want %q", got["query"], "latest AI news")
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,198 @@
package azure_test
import (
"strings"
"testing"
"github.com/maximhq/bifrost/core/providers/openai"
"github.com/maximhq/bifrost/core/schemas"
)
// TestAzure_OpenAIModel_CachingDeterminism verifies that Azure's delegation to
// openai.ToOpenAIChatRequest() produces deterministic JSON for prompt caching.
// Two schemas with the same properties but different structural key order within
// property definitions must produce byte-identical JSON after normalization.
func TestAzure_OpenAIModel_CachingDeterminism(t *testing.T) {
makeReq := func(props *schemas.OrderedMap) *schemas.BifrostChatRequest {
return &schemas.BifrostChatRequest{
Provider: schemas.Azure,
Model: "gpt-4o",
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser}},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "test",
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
},
}
}
// Version A: type before description
propsA := schemas.NewOrderedMapFromPairs(
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "string"),
schemas.KV("description", "Step by step"),
)),
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "string"),
schemas.KV("description", "Final answer"),
)),
)
// Version B: description before type (different structural order)
propsB := schemas.NewOrderedMapFromPairs(
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(
schemas.KV("description", "Step by step"),
schemas.KV("type", "string"),
)),
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
schemas.KV("description", "Final answer"),
schemas.KV("type", "string"),
)),
)
// Azure delegates OpenAI models to openai.ToOpenAIChatRequest()
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
resultA := openai.ToOpenAIChatRequest(ctx, makeReq(propsA))
resultB := openai.ToOpenAIChatRequest(ctx, makeReq(propsB))
jsonA, err := schemas.Marshal(resultA.ChatParameters.Tools[0].Function.Parameters)
if err != nil {
t.Fatalf("failed to marshal params A: %v", err)
}
jsonB, err := schemas.Marshal(resultB.ChatParameters.Tools[0].Function.Parameters)
if err != nil {
t.Fatalf("failed to marshal params B: %v", err)
}
// Caching: byte-identical JSON
if string(jsonA) != string(jsonB) {
t.Errorf("caching broken via Azure→OpenAI path: same schema produced different JSON\nA: %s\nB: %s", jsonA, jsonB)
}
// CoT: property order preserved
keys := resultA.ChatParameters.Tools[0].Function.Parameters.Properties.Keys()
if len(keys) != 2 || keys[0] != "reasoning" || keys[1] != "answer" {
t.Errorf("expected property order [reasoning, answer], got %v", keys)
}
}
// TestAzure_OpenAIModel_PreservesPropertyOrder verifies that the Azure→OpenAI
// delegation path preserves user-defined property ordering.
func TestAzure_OpenAIModel_PreservesPropertyOrder(t *testing.T) {
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Azure,
Model: "gpt-4o",
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser}},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "AnswerResponseModel",
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("chain_of_thought", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
schemas.KV("answer", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
schemas.KV("citations", schemas.NewOrderedMapFromPairs(schemas.KV("type", "array"))),
),
},
},
}},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result := openai.ToOpenAIChatRequest(ctx, bifrostReq)
keys := result.ChatParameters.Tools[0].Function.Parameters.Properties.Keys()
if len(keys) != 3 || keys[0] != "chain_of_thought" || keys[1] != "answer" || keys[2] != "citations" {
t.Errorf("expected property order [chain_of_thought, answer, citations], got %v", keys)
}
}
// TestAzure_ToolInputKeyOrderPreservation verifies that tool call arguments
// preserve their original key ordering through the Azure→OpenAI delegation path.
// TestAzure_ToolInputKeyOrderPreservation verifies that Azure→OpenAI delegation
// preserves the original key ordering of tool call arguments for prompt caching.
// Tests multiple parallel tool calls with different key orderings per block.
func TestAzure_ToolInputKeyOrderPreservation(t *testing.T) {
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Azure,
Model: "gpt-4o",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
},
{
Role: schemas.ChatMessageRoleAssistant,
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ToolCalls: []schemas.ChatAssistantMessageToolCall{
{
Index: 0,
Type: schemas.Ptr("function"),
ID: schemas.Ptr("toolu_001"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("bash"),
Arguments: `{"description":"Find references quickly","timeout":30000,"command":"grep -r auth_injector ."}`,
},
},
{
Index: 1,
Type: schemas.Ptr("function"),
ID: schemas.Ptr("toolu_002"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("bash"),
Arguments: `{"command":"git diff main...HEAD --stat","description":"Show diff"}`,
},
},
},
},
},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result := openai.ToOpenAIChatRequest(ctx, bifrostReq)
if result == nil {
t.Fatal("expected non-nil result")
}
// Collect tool call arguments from assistant message
var argsList []string
for _, msg := range result.Messages {
if msg.OpenAIChatAssistantMessage != nil {
for _, tc := range msg.OpenAIChatAssistantMessage.ToolCalls {
argsList = append(argsList, tc.Function.Arguments)
}
}
}
if len(argsList) != 2 {
t.Fatalf("expected 2 tool call arguments, got %d", len(argsList))
}
// OpenAI path passes Arguments through as strings — verify key order is preserved
// Block 0: keys should be description, timeout, command
s0 := argsList[0]
if !(strings.Index(s0, "description") < strings.Index(s0, "timeout") &&
strings.Index(s0, "timeout") < strings.Index(s0, "command")) {
t.Errorf("block 0: key order not preserved, expected description < timeout < command in: %s", s0)
}
// Block 1: keys should be command, description
s1 := argsList[1]
if !(strings.Index(s1, "command") < strings.Index(s1, "description")) {
t.Errorf("block 1: key order not preserved, expected command < description in: %s", s1)
}
}

View File

@@ -0,0 +1,92 @@
package azure_test
import (
"os"
"strings"
"testing"
"github.com/maximhq/bifrost/core/internal/llmtests"
"github.com/maximhq/bifrost/core/schemas"
)
func TestAzure(t *testing.T) {
t.Parallel()
if strings.TrimSpace(os.Getenv("AZURE_API_KEY")) == "" {
t.Skip("Skipping Azure tests because AZURE_API_KEY is not set")
}
client, ctx, cancel, err := llmtests.SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
testConfig := llmtests.ComprehensiveTestConfig{
Provider: schemas.Azure,
ChatModel: "gpt-4o",
PromptCachingModel: "gpt-4o",
VisionModel: "gpt-4o",
ChatAudioModel: "gpt-4o-mini-audio-preview",
Fallbacks: []schemas.Fallback{
{Provider: schemas.Azure, Model: "gpt-4o"},
},
TextModel: "", // Azure doesn't support text completion in newer models
EmbeddingModel: "text-embedding-ada-002",
ReasoningModel: "claude-opus-4-5",
SpeechSynthesisModel: "gpt-4o-mini-tts",
TranscriptionModel: "whisper",
ImageGenerationModel: "gpt-image-1",
ImageEditModel: "gpt-image-1",
VideoGenerationModel: "sora-2",
PassthroughModel: "gpt-4o",
Scenarios: llmtests.TestScenarios{
TextCompletion: false, // Not supported
SimpleChat: true,
CompletionStream: true,
MultiTurnConversation: true,
ToolCalls: true,
ToolCallsStreaming: true,
MultipleToolCalls: true,
MultipleToolCallsStreaming: true,
End2EndToolCalling: true,
AutomaticFunctionCall: true,
ImageURL: true,
ImageBase64: true,
MultipleImages: true,
CompleteEnd2End: true,
Embedding: true,
ListModels: true,
Reasoning: true,
ChatAudio: false,
Transcription: false, // Disabled for azure because of 3 calls/minute quota
TranscriptionStream: false, // Not properly supported yet by Azure
SpeechSynthesis: false, // Disabled for azure because of 3 calls/minute quota
SpeechSynthesisStream: false, // Disabled for azure because of 3 calls/minute quota
StructuredOutputs: true, // Structured outputs with nullable enum support
PromptCaching: true,
ImageGeneration: false, // Skipped for Azure
ImageGenerationStream: false, // Skipped for Azure
ImageEdit: false, // Model not deployed on Azure endpoint
ImageEditStream: false, // Model not deployed on Azure endpoint
ImageVariation: false, // Not supported by Azure
VideoGeneration: false, // disabled for now because of long running operations
VideoDownload: false,
VideoRetrieve: false,
VideoRemix: false,
VideoList: false,
VideoDelete: false,
InterleavedThinking: true,
PassthroughAPI: true,
EagerInputStreaming: true, // fine-grained-tool-streaming-2025-05-14 (Beta on Azure Foundry)
ServerToolsViaOpenAIEndpoint: true, // web_search / web_fetch / code_execution on Azure per Table 20
},
DisableParallelFor: []string{"Transcription"}, // Azure Whisper has 3 calls/minute quota
}
t.Run("AzureTests", func(t *testing.T) {
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
})
}

View File

@@ -0,0 +1,122 @@
package azure
import (
"context"
"fmt"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/maximhq/bifrost/core/providers/openai"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// setAzureAuth sets the Azure authentication header on the request for OpenAI models.
// It handles authentication in order of priority:
// 1. Service Principal (client ID/secret/tenant ID) - uses Bearer token
// 2. Context token - uses Bearer token
// 3. API key - uses api-key header
// 4. DefaultAzureCredential auto-detection (managed identity, workload identity, env vars, CLI)
func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.Request, key schemas.Key) *schemas.BifrostError {
// Service Principal authentication
if key.AzureKeyConfig != nil && key.AzureKeyConfig.ClientID != nil &&
key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil && key.AzureKeyConfig.ClientID.GetValue() != "" && key.AzureKeyConfig.ClientSecret.GetValue() != "" && key.AzureKeyConfig.TenantID.GetValue() != "" {
cred, err := provider.getOrCreateAuth(key.AzureKeyConfig.TenantID.GetValue(), key.AzureKeyConfig.ClientID.GetValue(), key.AzureKeyConfig.ClientSecret.GetValue())
if err != nil {
return providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err)
}
scopes := getAzureScopes(key.AzureKeyConfig.Scopes)
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{
Scopes: scopes,
})
if err != nil {
return providerUtils.NewBifrostOperationError("failed to get Azure access token", err)
}
if token.Token == "" {
return providerUtils.NewBifrostOperationError("azure access token is empty", fmt.Errorf("token is empty"))
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token))
req.Header.Del("api-key")
return nil
}
// Context token authentication
if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok && authToken != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken))
req.Header.Del("api-key")
return nil
}
// API key authentication
value := key.Value.GetValue()
if value != "" {
req.Header.Del("Authorization")
req.Header.Set("api-key", value)
return nil
}
// No explicit credentials - attempt DefaultAzureCredential auto-detection.
scopes := getAzureScopes(nil)
if key.AzureKeyConfig != nil {
scopes = getAzureScopes(key.AzureKeyConfig.Scopes)
}
cred, err := provider.getOrCreateDefaultAzureCredential()
if err != nil {
return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err)
}
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes})
if err != nil {
return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err)
}
if token.Token == "" {
return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", fmt.Errorf("token is empty"))
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token))
req.Header.Del("api-key")
return nil
}
// AzureFileResponse represents an Azure file response (same as OpenAI).
type AzureFileResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Bytes int64 `json:"bytes"`
CreatedAt int64 `json:"created_at"`
Filename string `json:"filename"`
Purpose schemas.FilePurpose `json:"purpose"`
Status string `json:"status,omitempty"`
StatusDetails *string `json:"status_details,omitempty"`
}
// ToBifrostFileUploadResponse converts Azure file response to Bifrost response.
func (r *AzureFileResponse) ToBifrostFileUploadResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawResponse bool, rawResponse interface{}) *schemas.BifrostFileUploadResponse {
resp := &schemas.BifrostFileUploadResponse{
ID: r.ID,
Object: r.Object,
Bytes: r.Bytes,
CreatedAt: r.CreatedAt,
Filename: r.Filename,
Purpose: r.Purpose,
Status: openai.ToBifrostFileStatus(r.Status),
StatusDetails: r.StatusDetails,
StorageBackend: schemas.FileStorageAPI,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
},
}
if sendBackRawResponse {
resp.ExtraFields.RawResponse = rawResponse
}
return resp
}

View File

@@ -0,0 +1,51 @@
package azure
import (
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(response.Data)),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: schemas.Azure,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
for _, model := range response.Data {
for _, result := range pipeline.FilterModel(model.ID) {
entry := schemas.Model{
ID: string(schemas.Azure) + "/" + result.ResolvedID,
Created: schemas.Ptr(model.CreatedAt),
}
if result.AliasValue != "" {
entry.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, entry)
included[strings.ToLower(result.ResolvedID)] = true
}
}
bifrostResponse.Data = append(bifrostResponse.Data,
pipeline.BackfillModels(included)...)
return bifrostResponse
}

View File

@@ -0,0 +1,36 @@
package azure
// AzureAPIVersionDefault is the default Azure API version to use when not specified.
const AzureAPIVersionDefault = "2024-10-21"
const AzureAPIVersionPreview = "preview"
const AzureAPIVersionImageEditDefault = "2025-04-01-preview"
const AzureAnthropicAPIVersionDefault = "2023-06-01"
type AzureModelCapabilities struct {
FineTune bool `json:"fine_tune"`
Inference bool `json:"inference"`
Completion bool `json:"completion"`
ChatCompletion bool `json:"chat_completion"`
Embeddings bool `json:"embeddings"`
}
type AzureModelDeprecation struct {
FineTune int64 `json:"fine_tune,omitempty"`
Inference int64 `json:"inference,omitempty"`
}
type AzureModel struct {
ID string `json:"id"`
Status string `json:"status"`
FineTune string `json:"fine_tune,omitempty"`
Capabilities AzureModelCapabilities `json:"capabilities,omitempty"`
LifecycleStatus string `json:"lifecycle_status"`
Deprecation *AzureModelDeprecation `json:"deprecation,omitempty"`
CreatedAt int64 `json:"created_at"`
Object string `json:"object"`
}
type AzureListModelsResponse struct {
Object string `json:"object"`
Data []AzureModel `json:"data"`
}

View File

@@ -0,0 +1,94 @@
package azure
import (
"fmt"
"strings"
"github.com/maximhq/bifrost/core/providers/anthropic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool) ([]byte, *schemas.BifrostError) {
// Large payload mode: body streams directly from the LP reader — skip all body building
// (matches CheckContextAndGetRequestBody guard).
if providerUtils.IsLargePayloadPassthroughEnabled(ctx) {
return nil, nil
}
var jsonBody []byte
var err error
// Check if raw request body should be used
if useRawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && useRawBody {
jsonBody = request.GetRawRequestBody()
// Add max_tokens if not present (using sjson to preserve key order for prompt caching)
if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") {
jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens))
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
// Replace model with deployment
jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
// Delete fallbacks field
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
// Add stream if streaming
if isStreaming {
jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
} else {
// Convert request to Anthropic format
request.Model = deployment
reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request)
if convErr != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr)
}
if reqBody == nil {
return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil)
}
if isStreaming {
reqBody.Stream = schemas.Ptr(true)
}
// Add provider-aware beta headers for Azure
anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure)
// Marshal struct to JSON bytes, preserving field order
jsonBody, err = providerUtils.MarshalSorted(reqBody)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err))
}
}
return jsonBody, nil
}
// getCleanedScopes returns cleaned scopes or default scope if none are valid.
// It filters out empty/whitespace-only strings and returns the default scope if no valid scopes remain.
func getAzureScopes(configuredScopes []string) []string {
scopes := []string{DefaultAzureScope}
if len(configuredScopes) > 0 {
cleaned := make([]string, 0, len(configuredScopes))
for _, s := range configuredScopes {
if strings.TrimSpace(s) != "" {
cleaned = append(cleaned, strings.TrimSpace(s))
}
}
if len(cleaned) > 0 {
scopes = cleaned
}
}
return scopes
}

View File

@@ -0,0 +1,417 @@
package bedrock
import (
"encoding/json"
"fmt"
"time"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// BedrockBatchJobRequest represents a request to create a batch inference job.
type BedrockBatchJobRequest struct {
JobName string `json:"jobName"`
ModelID *string `json:"modelId"`
RoleArn string `json:"roleArn"`
InputDataConfig BedrockInputDataConfig `json:"inputDataConfig"`
OutputDataConfig BedrockOutputDataConfig `json:"outputDataConfig"`
TimeoutDurationInHours int `json:"timeoutDurationInHours,omitempty"`
Tags []BedrockTag `json:"tags,omitempty"`
}
// BedrockInputDataConfig represents the input configuration for a batch job.
type BedrockInputDataConfig struct {
S3InputDataConfig BedrockS3InputDataConfig `json:"s3InputDataConfig"`
}
// BedrockS3InputDataConfig represents S3 input configuration.
type BedrockS3InputDataConfig struct {
S3Uri string `json:"s3Uri"`
S3InputFormat string `json:"s3InputFormat,omitempty"` // "JSONL"
Endpoint *string `json:"endpoint,omitempty"`
FileID *string `json:"file_id,omitempty"`
}
// BedrockOutputDataConfig represents the output configuration for a batch job.
type BedrockOutputDataConfig struct {
S3OutputDataConfig BedrockS3OutputDataConfig `json:"s3OutputDataConfig"`
}
// BedrockS3OutputDataConfig represents S3 output configuration.
type BedrockS3OutputDataConfig struct {
S3Uri string `json:"s3Uri"`
}
// BedrockTag represents a tag for a batch job.
type BedrockTag struct {
Key string `json:"key"`
Value string `json:"value"`
}
// BedrockBatchJobResponse represents a batch job response.
type BedrockBatchJobResponse struct {
JobArn string `json:"jobArn"`
Status string `json:"status"`
JobName string `json:"jobName,omitempty"`
ModelID string `json:"modelId,omitempty"`
RoleArn string `json:"roleArn,omitempty"`
InputDataConfig *BedrockInputDataConfig `json:"inputDataConfig,omitempty"`
OutputDataConfig *BedrockOutputDataConfig `json:"outputDataConfig,omitempty"`
VpcConfig *BedrockVpcConfig `json:"vpcConfig,omitempty"`
SubmitTime *time.Time `json:"submitTime,omitempty"`
LastModifiedTime *time.Time `json:"lastModifiedTime,omitempty"`
EndTime *time.Time `json:"endTime,omitempty"`
Message string `json:"message,omitempty"`
ClientRequestToken string `json:"clientRequestToken,omitempty"`
JobExpirationTime *time.Time `json:"jobExpirationTime,omitempty"`
TimeoutDurationInHours int `json:"timeoutDurationInHours,omitempty"`
}
// BedrockBatchJobListResponse represents a list of batch jobs.
type BedrockBatchJobListResponse struct {
InvocationJobSummaries []BedrockBatchJobSummary `json:"invocationJobSummaries"`
NextToken *string `json:"nextToken,omitempty"`
}
// BedrockBatchJobSummary represents a summary of a batch job.
type BedrockBatchJobSummary struct {
JobArn string `json:"jobArn"`
JobName string `json:"jobName"`
ModelID string `json:"modelId"`
Status string `json:"status"`
SubmitTime *time.Time `json:"submitTime,omitempty"`
LastModifiedTime *time.Time `json:"lastModifiedTime,omitempty"`
EndTime *time.Time `json:"endTime,omitempty"`
Message string `json:"message,omitempty"`
}
// BedrockBatchResultRecord represents a single result record in Bedrock batch output JSONL.
type BedrockBatchResultRecord struct {
RecordID string `json:"recordId"`
ModelOutput json.RawMessage `json:"modelOutput,omitempty"`
Error *BedrockBatchError `json:"error,omitempty"`
}
// BedrockBatchError represents an error in batch processing.
type BedrockBatchError struct {
ErrorCode int `json:"errorCode,omitempty"`
ErrorMessage string `json:"errorMessage,omitempty"`
}
// BedrockBatchListRequest represents a request to list batch jobs.
type BedrockBatchListRequest struct {
MaxResults int `json:"maxResults,omitempty"`
NextToken *string `json:"nextToken,omitempty"`
StatusEquals string `json:"statusEquals,omitempty"`
NameContains string `json:"nameContains,omitempty"`
}
// BedrockBatchRetrieveRequest represents a request to retrieve a batch job.
type BedrockBatchRetrieveRequest struct {
JobIdentifier string `json:"jobIdentifier"`
}
// BedrockBatchCancelRequest represents a request to cancel/stop a batch job.
type BedrockBatchCancelRequest struct {
JobIdentifier string `json:"jobIdentifier"`
}
// BedrockBatchCancelResponse represents the response from stopping a batch job.
type BedrockBatchCancelResponse struct {
JobArn string `json:"jobArn"`
Status string `json:"status"`
}
// ToBifrostBatchStatus converts Bedrock status to Bifrost status.
func ToBifrostBatchStatus(status string) schemas.BatchStatus {
switch status {
case "Submitted", "Validating":
return schemas.BatchStatusValidating
case "InProgress":
return schemas.BatchStatusInProgress
case "Completed":
return schemas.BatchStatusCompleted
case "Failed", "PartiallyCompleted":
return schemas.BatchStatusFailed
case "Stopping":
return schemas.BatchStatusCancelling
case "Stopped":
return schemas.BatchStatusCancelled
case "Expired":
return schemas.BatchStatusExpired
case "Scheduled":
return schemas.BatchStatusValidating
default:
return schemas.BatchStatus(status)
}
}
// parseBatchResultsJSONL parses JSONL content from Bedrock batch output into Bifrost format.
// Returns the parsed results and any parse errors encountered.
func parseBatchResultsJSONL(content []byte, provider *BedrockProvider) ([]schemas.BatchResultItem, []schemas.BatchError) {
var results []schemas.BatchResultItem
parseResult := providerUtils.ParseJSONL(content, func(line []byte) error {
var bedrockResult BedrockBatchResultRecord
if err := sonic.Unmarshal(line, &bedrockResult); err != nil {
provider.logger.Warn(fmt.Sprintf("failed to parse batch result line: %v", err))
return err
}
// Convert Bedrock format to Bifrost format
resultItem := schemas.BatchResultItem{
CustomID: bedrockResult.RecordID,
}
if bedrockResult.ModelOutput != nil {
var bodyMap map[string]interface{}
if err := sonic.Unmarshal(bedrockResult.ModelOutput, &bodyMap); err == nil {
resultItem.Response = &schemas.BatchResultResponse{
StatusCode: 200,
Body: bodyMap,
}
} else {
resultItem.Error = &schemas.BatchResultError{
Code: "parse_error",
Message: fmt.Sprintf("failed to parse model output: %v", err),
}
}
}
if bedrockResult.Error != nil {
resultItem.Error = &schemas.BatchResultError{
Code: fmt.Sprintf("%d", bedrockResult.Error.ErrorCode),
Message: bedrockResult.Error.ErrorMessage,
}
// Set status code to indicate error if there's an error
if resultItem.Response == nil {
resultItem.Response = &schemas.BatchResultResponse{
StatusCode: bedrockResult.Error.ErrorCode,
}
}
}
results = append(results, resultItem)
return nil
})
return results, parseResult.Errors
}
// ToBedrockBatchJobResponse converts a Bifrost batch create response to Bedrock format.
func ToBedrockBatchJobResponse(resp *schemas.BifrostBatchCreateResponse) *BedrockBatchJobResponse {
// Here if the provider is not Bedrock - then we create a dummy arn and string using the batch ID
if resp.ExtraFields.Provider != schemas.Bedrock {
return &BedrockBatchJobResponse{
JobArn: fmt.Sprintf("arn:aws:bedrock:us-east-1:444444444444:batch:%s", resp.ID),
Status: toBedrockBatchStatus(resp.Status),
}
}
// For bedrock, we go as is
result := &BedrockBatchJobResponse{
JobArn: resp.ID,
Status: toBedrockBatchStatus(resp.Status),
}
if resp.Metadata != nil {
if jobName, ok := resp.Metadata["job_name"]; ok {
result.JobName = jobName
}
}
if resp.CreatedAt > 0 {
t := time.Unix(resp.CreatedAt, 0)
result.SubmitTime = &t
}
return result
}
// ToBedrockBatchJobListResponse converts a Bifrost batch list response to Bedrock format.
func ToBedrockBatchJobListResponse(resp *schemas.BifrostBatchListResponse) *BedrockBatchJobListResponse {
result := &BedrockBatchJobListResponse{
InvocationJobSummaries: make([]BedrockBatchJobSummary, len(resp.Data)),
}
for i, batch := range resp.Data {
summary := BedrockBatchJobSummary{
JobArn: batch.ID,
Status: toBedrockBatchStatus(batch.Status),
}
if batch.Metadata != nil {
if jobName, ok := batch.Metadata["job_name"]; ok {
summary.JobName = jobName
}
if modelId, ok := batch.Metadata["model_id"]; ok {
summary.ModelID = modelId
}
}
if batch.CreatedAt > 0 {
t := time.Unix(batch.CreatedAt, 0)
summary.SubmitTime = &t
}
if batch.CompletedAt != nil && *batch.CompletedAt > 0 {
t := time.Unix(*batch.CompletedAt, 0)
summary.EndTime = &t
}
result.InvocationJobSummaries[i] = summary
}
if resp.LastID != nil {
result.NextToken = resp.LastID
}
return result
}
// ToBedrockBatchJobRetrieveResponse converts a Bifrost batch retrieve response to Bedrock format.
func ToBedrockBatchJobRetrieveResponse(resp *schemas.BifrostBatchRetrieveResponse) *BedrockBatchJobResponse {
result := &BedrockBatchJobResponse{
JobArn: resp.ID,
Status: toBedrockBatchStatus(resp.Status),
}
if resp.Metadata != nil {
if jobName, ok := resp.Metadata["job_name"]; ok {
result.JobName = jobName
}
}
if resp.CreatedAt > 0 {
t := time.Unix(resp.CreatedAt, 0)
result.SubmitTime = &t
}
if resp.CompletedAt != nil && *resp.CompletedAt > 0 {
t := time.Unix(*resp.CompletedAt, 0)
result.EndTime = &t
}
if resp.InputFileID != "" {
result.InputDataConfig = &BedrockInputDataConfig{
S3InputDataConfig: BedrockS3InputDataConfig{
S3Uri: resp.InputFileID,
S3InputFormat: "JSONL",
},
}
}
if resp.OutputFileID != nil && *resp.OutputFileID != "" {
result.OutputDataConfig = &BedrockOutputDataConfig{
S3OutputDataConfig: BedrockS3OutputDataConfig{
S3Uri: *resp.OutputFileID,
},
}
}
return result
}
// toBedrockBatchStatus converts Bifrost batch status to Bedrock status.
func toBedrockBatchStatus(status schemas.BatchStatus) string {
switch status {
case schemas.BatchStatusValidating:
return "Validating"
case schemas.BatchStatusInProgress:
return "InProgress"
case schemas.BatchStatusCompleted:
fallthrough
case schemas.BatchStatusEnded:
return "Completed"
case schemas.BatchStatusFailed:
return "Failed"
case schemas.BatchStatusCancelling:
return "Stopping"
case schemas.BatchStatusCancelled:
return "Stopped"
case schemas.BatchStatusExpired:
return "Expired"
default:
return string(status)
}
}
// ToBifrostBatchListRequest converts a Bedrock batch list request to Bifrost format.
func ToBifrostBatchListRequest(req *BedrockBatchListRequest, provider schemas.ModelProvider) *schemas.BifrostBatchListRequest {
result := &schemas.BifrostBatchListRequest{
Provider: provider,
Limit: req.MaxResults,
}
if req.NextToken != nil {
result.PageToken = req.NextToken
}
if req.StatusEquals != "" || req.NameContains != "" {
result.ExtraParams = make(map[string]interface{})
if req.StatusEquals != "" {
result.ExtraParams["statusEquals"] = req.StatusEquals
}
if req.NameContains != "" {
result.ExtraParams["nameContains"] = req.NameContains
}
}
return result
}
// ToBifrostBatchRetrieveRequest converts a Bedrock batch retrieve request to Bifrost format.
func ToBifrostBatchRetrieveRequest(req *BedrockBatchRetrieveRequest, provider schemas.ModelProvider) *schemas.BifrostBatchRetrieveRequest {
return &schemas.BifrostBatchRetrieveRequest{
Provider: provider,
BatchID: req.JobIdentifier,
}
}
// ToBifrostBatchCancelRequest converts a Bedrock batch cancel request to Bifrost format.
func ToBifrostBatchCancelRequest(req *BedrockBatchCancelRequest, provider schemas.ModelProvider) *schemas.BifrostBatchCancelRequest {
return &schemas.BifrostBatchCancelRequest{
Provider: provider,
BatchID: req.JobIdentifier,
}
}
// ToBedrockBatchCancelResponse converts a Bifrost batch cancel response to Bedrock format.
func ToBedrockBatchCancelResponse(resp *schemas.BifrostBatchCancelResponse) *BedrockBatchCancelResponse {
return &BedrockBatchCancelResponse{
JobArn: resp.ID,
Status: toBedrockBatchStatus(resp.Status),
}
}
// splitJSONL splits JSONL content into individual lines.
func splitJSONL(data []byte) [][]byte {
var lines [][]byte
start := 0
for i, b := range data {
if b == '\n' {
if i > start {
lines = append(lines, data[start:i])
}
start = i + 1
}
}
if start < len(data) {
lines = append(lines, data[start:])
}
return lines
}
// BedrockVpcConfig represents VPC configuration for a batch job.
type BedrockVpcConfig struct {
SecurityGroupIds []string `json:"securityGroupIds,omitempty"`
SubnetIds []string `json:"subnetIds,omitempty"`
}
// BedrockBatchManifest represents the manifest.json.out file structure from S3.
type BedrockBatchManifest struct {
TotalRecordCount int `json:"totalRecordCount"`
ProcessedRecordCount int `json:"processedRecordCount"`
ErrorRecordCount int `json:"errorRecordCount"`
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,449 @@
package bedrock
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBedrockChatCompletionRequest converts a Bifrost request to Bedrock Converse API format
func ToBedrockChatCompletionRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.BifrostChatRequest) (*BedrockConverseRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrost request is nil")
}
if bifrostReq.Input == nil {
return nil, fmt.Errorf("only chat completion requests are supported for Bedrock Converse API")
}
bedrockReq := &BedrockConverseRequest{
ModelID: bifrostReq.Model,
}
// Convert messages and system messages
messages, systemMessages, err := convertMessages(bifrostReq.Input)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
bedrockReq.Messages = messages
if len(systemMessages) > 0 {
bedrockReq.System = systemMessages
}
// Convert parameters and configurations
if err := convertChatParameters(ctx, bifrostReq, bedrockReq); err != nil {
return nil, fmt.Errorf("failed to convert chat parameters: %w", err)
}
// Ensure tool config is present when needed
ensureChatToolConfigForConversation(bifrostReq, bedrockReq)
return bedrockReq, nil
}
// ToBifrostChatResponse converts a Bedrock Converse API response to Bifrost format
func (response *BedrockConverseResponse) ToBifrostChatResponse(ctx context.Context, model string) (*schemas.BifrostChatResponse, error) {
if response == nil {
return nil, fmt.Errorf("bedrock response is nil")
}
// Convert content blocks and tool calls
var contentStr *string
var contentBlocks []schemas.ChatContentBlock
var toolCalls []schemas.ChatAssistantMessageToolCall
var reasoningDetails []schemas.ChatReasoningDetails
var reasoningText string
if response.Output.Message != nil {
for _, contentBlock := range response.Output.Message.Content {
// Handle text content
if contentBlock.Text != nil && *contentBlock.Text != "" {
chatContentBlock := schemas.ChatContentBlock{
Type: schemas.ChatContentBlockTypeText,
Text: contentBlock.Text,
}
contentBlocks = append(contentBlocks, chatContentBlock)
}
if contentBlock.ToolUse != nil {
// Check if this is the structured output tool
if structuredOutputToolName, ok := ctx.Value(schemas.BifrostContextKeyStructuredOutputToolName).(string); ok && contentBlock.ToolUse.Name == structuredOutputToolName {
// This is structured output - set contentStr and skip adding to toolCalls
if contentBlock.ToolUse.Input != nil {
jsonStr := string(contentBlock.ToolUse.Input)
contentStr = &jsonStr
}
continue // Skip adding to toolCalls
}
// Regular tool call processing
var arguments string
if contentBlock.ToolUse.Input != nil {
arguments = string(contentBlock.ToolUse.Input)
} else {
arguments = "{}"
}
toolUseID := contentBlock.ToolUse.ToolUseID
toolUseName := contentBlock.ToolUse.Name
toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{
Index: uint16(len(toolCalls)),
Type: schemas.Ptr("function"),
ID: &toolUseID,
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: &toolUseName,
Arguments: arguments,
},
})
}
// Handle reasoning content
if contentBlock.ReasoningContent != nil {
if contentBlock.ReasoningContent.ReasoningText == nil {
continue
}
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
Index: len(reasoningDetails),
Type: schemas.BifrostReasoningDetailsTypeText,
Text: contentBlock.ReasoningContent.ReasoningText.Text,
Signature: contentBlock.ReasoningContent.ReasoningText.Signature,
})
if contentBlock.ReasoningContent.ReasoningText.Text != nil {
reasoningText += *contentBlock.ReasoningContent.ReasoningText.Text + "\n"
}
}
// Handle document content
if contentBlock.Document != nil {
fileBlock := schemas.ChatContentBlock{
Type: schemas.ChatContentBlockTypeFile,
File: &schemas.ChatInputFile{},
}
// Set filename from document name
if contentBlock.Document.Name != "" {
fileBlock.File.Filename = &contentBlock.Document.Name
}
// Set file type based on format
if contentBlock.Document.Format != "" {
var fileType string
switch contentBlock.Document.Format {
case "pdf":
fileType = "application/pdf"
case "txt":
fileType = "text/plain"
case "md":
fileType = "text/markdown"
case "html":
fileType = "text/html"
case "csv":
fileType = "text/csv"
case "doc":
fileType = "application/msword"
case "docx":
fileType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
case "xls":
fileType = "application/vnd.ms-excel"
case "xlsx":
fileType = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
default:
fileType = "application/pdf"
}
fileBlock.File.FileType = &fileType
}
// Convert document source data
if contentBlock.Document.Source != nil {
if contentBlock.Document.Source.Bytes != nil {
fileBlock.File.FileData = contentBlock.Document.Source.Bytes
} else if contentBlock.Document.Source.Text != nil {
fileBlock.File.FileData = contentBlock.Document.Source.Text
}
}
contentBlocks = append(contentBlocks, fileBlock)
}
}
}
if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText {
contentStr = contentBlocks[0].Text
contentBlocks = nil
}
// Create the message content
messageContent := schemas.ChatMessageContent{
ContentStr: contentStr,
ContentBlocks: contentBlocks,
}
// Create assistant message if we have tool calls
var assistantMessage *schemas.ChatAssistantMessage
if len(toolCalls) > 0 {
assistantMessage = &schemas.ChatAssistantMessage{
ToolCalls: toolCalls,
}
}
if len(reasoningDetails) > 0 {
if assistantMessage == nil {
assistantMessage = &schemas.ChatAssistantMessage{}
}
assistantMessage.ReasoningDetails = reasoningDetails
if reasoningText != "" {
assistantMessage.Reasoning = new(reasoningText)
}
}
// Create the response choice
choices := []schemas.BifrostResponseChoice{
{
Index: 0,
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &messageContent,
ChatAssistantMessage: assistantMessage,
},
},
FinishReason: schemas.Ptr(convertBedrockStopReason(response.StopReason)),
},
}
var usage *schemas.BifrostLLMUsage
if response.Usage != nil {
// Convert usage information
usage = &schemas.BifrostLLMUsage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.TotalTokens,
}
// Handle cached tokens if present
if response.Usage.CacheReadInputTokens > 0 {
if usage.PromptTokensDetails == nil {
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{}
}
usage.PromptTokensDetails.CachedReadTokens = response.Usage.CacheReadInputTokens
usage.PromptTokens = usage.PromptTokens + response.Usage.CacheReadInputTokens
}
if response.Usage.CacheWriteInputTokens > 0 {
if usage.PromptTokensDetails == nil {
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{}
}
usage.PromptTokensDetails.CachedWriteTokens = response.Usage.CacheWriteInputTokens
usage.PromptTokens = usage.PromptTokens + response.Usage.CacheWriteInputTokens
}
}
// Create the final Bifrost response
bifrostResponse := &schemas.BifrostChatResponse{
ID: uuid.New().String(),
Model: model,
Object: "chat.completion",
Choices: choices,
Usage: usage,
Created: int(time.Now().Unix()),
ExtraFields: schemas.BifrostResponseExtraFields{
},
}
if response.ServiceTier != nil && response.ServiceTier.Type != "" {
bifrostResponse.ServiceTier = &response.ServiceTier.Type
}
return bifrostResponse, nil
}
// BedrockStreamState tracks per-stream tool call index state.
type BedrockStreamState struct {
nextToolCallIndex int
contentBlockToToolCallIdx map[int]int
}
// NewBedrockStreamState returns initialised stream state for one streaming response.
func NewBedrockStreamState() *BedrockStreamState {
return &BedrockStreamState{
contentBlockToToolCallIdx: make(map[int]int),
}
}
func (chunk *BedrockStreamEvent) ToBifrostChatCompletionStream(state *BedrockStreamState) (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) {
if state == nil {
state = NewBedrockStreamState()
} else if state.contentBlockToToolCallIdx == nil {
state.contentBlockToToolCallIdx = make(map[int]int)
}
// event with metrics/usage is the last and with stop reason is the second last
switch {
case chunk.Role != nil:
// Send empty response to signal start
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Role: chunk.Role,
},
},
},
},
}
return streamResponse, nil, false
case chunk.Start != nil && chunk.Start.ToolUse != nil:
toolUseStart := chunk.Start.ToolUse
toolCallIdx := 0
if chunk.ContentBlockIndex != nil {
toolCallIdx = state.nextToolCallIndex
state.contentBlockToToolCallIdx[*chunk.ContentBlockIndex] = toolCallIdx
state.nextToolCallIndex++
}
// Create tool call structure for start event
var toolCall schemas.ChatAssistantMessageToolCall
toolCall.Index = uint16(toolCallIdx)
toolCall.ID = schemas.Ptr(toolUseStart.ToolUseID)
toolCall.Type = schemas.Ptr("function")
toolCall.Function.Name = schemas.Ptr(toolUseStart.Name)
toolCall.Function.Arguments = "" // Start with empty arguments
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
},
},
},
},
}
return streamResponse, nil, false
case chunk.Delta != nil:
switch {
case chunk.Delta.Text != nil:
// Handle text delta
text := *chunk.Delta.Text
if text != "" {
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Content: &text,
},
},
},
},
}
return streamResponse, nil, false
}
case chunk.Delta.ToolUse != nil:
// Handle tool use delta
toolUseDelta := chunk.Delta.ToolUse
toolCallIdx := 0
if chunk.ContentBlockIndex != nil {
toolCallIdx = state.contentBlockToToolCallIdx[*chunk.ContentBlockIndex]
}
// Create tool call structure
var toolCall schemas.ChatAssistantMessageToolCall
toolCall.Index = uint16(toolCallIdx)
toolCall.Type = schemas.Ptr("function")
// For streaming, we need to accumulate tool use data
// This is a simplified approach - in practice, you'd need to track tool calls across chunks
toolCall.Function.Arguments = toolUseDelta.Input
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
},
},
},
},
}
return streamResponse, nil, false
case chunk.Delta.ReasoningContent != nil:
// Handle reasoning content delta
reasoningContentDelta := chunk.Delta.ReasoningContent
// Only construct and return a response when either Text or Signature is set
if (reasoningContentDelta.Text == nil || *reasoningContentDelta.Text == "") && reasoningContentDelta.Signature == nil {
return nil, nil, false
}
var streamResponse *schemas.BifrostChatResponse
if reasoningContentDelta.Text != nil && *reasoningContentDelta.Text != "" {
streamResponse = &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Reasoning: reasoningContentDelta.Text,
ReasoningDetails: []schemas.ChatReasoningDetails{
{
Index: 0,
Type: schemas.BifrostReasoningDetailsTypeText,
Text: reasoningContentDelta.Text,
},
},
},
},
},
},
}
} else if reasoningContentDelta.Signature != nil {
streamResponse = &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
ReasoningDetails: []schemas.ChatReasoningDetails{
{
Index: 0,
Type: schemas.BifrostReasoningDetailsTypeText,
Signature: reasoningContentDelta.Signature,
},
},
},
},
},
},
}
}
return streamResponse, nil, false
}
}
return nil, nil, false
}

View File

@@ -0,0 +1,477 @@
package bedrock
import (
"context"
"encoding/json"
"strings"
"testing"
"github.com/maximhq/bifrost/core/schemas"
)
// TestConvertToolConfig_DropsServerToolsOnBedrock locks in the bug fix from
// the user-reported repro: sending `web_search_20260209` via the OpenAI-
// compatible /v1/chat/completions endpoint to Bedrock was producing a
// malformed ToolConfig that Bedrock rejected with 400 "The provided request
// is not valid". The fix strips unsupported server tools before the
// conversion loop so the outbound request is valid.
func TestConvertToolConfig_DropsServerToolsOnBedrock(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "get_weather",
Description: schemas.Ptr("Get weather by city"),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
},
},
},
{
// Server tool — Bedrock doesn't support web_search per Table 20.
// Should be stripped silently.
Type: schemas.ChatToolType("web_search_20260209"),
Name: "web_search",
},
},
}
cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params)
if cfg == nil {
t.Fatalf("expected ToolConfig, got nil (function tool should have survived)")
}
if len(cfg.Tools) != 1 {
t.Fatalf("expected exactly 1 tool (function), got %d: %+v", len(cfg.Tools), cfg.Tools)
}
if cfg.Tools[0].ToolSpec == nil || cfg.Tools[0].ToolSpec.Name != "get_weather" {
t.Errorf("expected function tool 'get_weather' to survive, got %+v", cfg.Tools[0])
}
}
// TestConvertToolConfig_ReturnsNilWhenAllDropped locks in the empty-slice
// guard. Bedrock's Converse API rejects `"toolConfig": {"tools": []}` with a
// 400; when every tool is unsupported and gets stripped, convertToolConfig
// must return nil so no ToolConfig ships at all.
func TestConvertToolConfig_ReturnsNilWhenAllDropped(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("web_search_20260209"),
Name: "web_search",
},
{
Type: schemas.ChatToolType("web_fetch_20260309"),
Name: "web_fetch",
},
{
Type: schemas.ChatToolType("code_execution_20250825"),
Name: "code_execution",
},
},
}
cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params)
if cfg != nil {
t.Fatalf("expected nil ToolConfig (all tools unsupported on Bedrock), got %+v", cfg)
}
}
// TestConvertToolConfig_KeepsBedrockSupportedServerTools — locks in that
// Bedrock-supported server tools (bash, memory, text_editor, computer,
// tool_search) do NOT appear in Converse's typed toolConfig.tools slot —
// they must be tunneled via additionalModelRequestFields (exercised in
// TestCollectBedrockServerTools_*). If the only tool is a server tool,
// toolConfig is nil so we don't ship {"toolConfig": {"tools": []}}.
func TestConvertToolConfig_KeepsBedrockSupportedServerTools(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("bash_20250124"),
Name: "bash",
},
},
}
cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params)
if cfg != nil {
t.Fatalf("expected nil toolConfig (server tools flow via additionalModelRequestFields, not toolSpec), got %+v", cfg)
}
}
// TestCollectBedrockServerTools_BashOnly — bash is Bedrock-supported per the
// B-header list; the helper must emit it as a native-JSON tool entry with no
// derived beta header (bash has no high-confidence 1:1 beta-header mapping;
// callers rely on extra-headers for that).
func TestCollectBedrockServerTools_BashOnly(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("bash_20250124"),
Name: "bash",
},
},
}
tools, betas := collectBedrockServerTools(params)
if len(tools) != 1 {
t.Fatalf("expected 1 server tool, got %d", len(tools))
}
got := string(tools[0])
if !strings.Contains(got, `"type":"bash_20250124"`) || !strings.Contains(got, `"name":"bash"`) {
t.Errorf("expected native Anthropic bash shape, got %s", got)
}
if len(betas) != 0 {
t.Errorf("expected no derived beta headers for bash (no 1:1 mapping), got %v", betas)
}
}
// TestCollectBedrockServerTools_ComputerDerivesBeta — computer_YYYYMMDD must
// derive computer-use-YYYY-MM-DD as the beta header, gated through
// FilterBetaHeadersForProvider(Bedrock) which keeps computer-use-* headers.
func TestCollectBedrockServerTools_ComputerDerivesBeta(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("computer_20251124"),
Name: "computer",
DisplayWidthPx: schemas.Ptr(1280),
DisplayHeightPx: schemas.Ptr(800),
},
},
}
tools, betas := collectBedrockServerTools(params)
if len(tools) != 1 {
t.Fatalf("expected 1 server tool, got %d", len(tools))
}
if !strings.Contains(string(tools[0]), `"display_width_px":1280`) {
t.Errorf("expected computer variant fields to flow through, got %s", string(tools[0]))
}
if len(betas) != 1 || betas[0] != "computer-use-2025-11-24" {
t.Errorf("expected [computer-use-2025-11-24], got %v", betas)
}
}
// TestCollectBedrockServerTools_MemoryDerivesContextManagement — memory
// activates via the context-management-2025-06-27 bundle on Bedrock (cite:
// anthropic/types.go:179).
func TestCollectBedrockServerTools_MemoryDerivesContextManagement(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("memory_20250818"),
Name: "memory",
},
},
}
_, betas := collectBedrockServerTools(params)
if len(betas) != 1 || betas[0] != "context-management-2025-06-27" {
t.Errorf("expected [context-management-2025-06-27], got %v", betas)
}
}
// TestCollectBedrockServerTools_StripsUnsupported — web_search isn't in
// Bedrock's ProviderFeatures (WebSearch=false), so ValidateChatToolsForProvider
// drops it and the helper must emit nothing.
func TestCollectBedrockServerTools_StripsUnsupported(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("web_search_20260209"),
Name: "web_search",
},
},
}
tools, betas := collectBedrockServerTools(params)
if len(tools) != 0 {
t.Errorf("expected no server tools (web_search unsupported on Bedrock), got %d", len(tools))
}
if len(betas) != 0 {
t.Errorf("expected no betas when all tools filtered, got %v", betas)
}
}
// TestCollectBedrockServerTools_FunctionToolsIgnored — function/custom tools
// go through convertToolConfig, not this helper.
func TestCollectBedrockServerTools_FunctionToolsIgnored(t *testing.T) {
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "get_weather",
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
},
},
},
},
}
tools, betas := collectBedrockServerTools(params)
if len(tools) != 0 || len(betas) != 0 {
t.Errorf("function tools should not flow through server-tool helper, got tools=%d betas=%v", len(tools), betas)
}
}
// TestBuildBedrockServerToolChoice_PinnedServerTool — caller pins a kept
// server tool (computer) by name. Converse's typed toolConfig.toolChoice path
// can't carry this because toolConfig.tools doesn't include server tools; the
// existing reconciliation silently drops the pin. The tunneled path must
// emit {"type":"tool","name":"computer"} into additionalModelRequestFields.
func TestBuildBedrockServerToolChoice_PinnedServerTool(t *testing.T) {
computer := schemas.ChatTool{
Type: schemas.ChatToolType("computer_20251124"),
Name: "computer",
DisplayWidthPx: schemas.Ptr(1280),
}
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{computer},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
Type: schemas.ChatToolChoiceTypeFunction,
Function: &schemas.ChatToolChoiceFunction{Name: "computer"},
},
},
}
choice, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{computer})
if !ok {
t.Fatalf("expected tunneled tool_choice for pinned server tool, got (nil, false)")
}
got := string(choice)
if !strings.Contains(got, `"type":"tool"`) || !strings.Contains(got, `"name":"computer"`) {
t.Errorf("expected Anthropic-native {type:tool,name:computer}, got %s", got)
}
}
// TestBuildBedrockServerToolChoice_PinnedFunctionTool_NotTunneled — function
// tool pins stay on Converse's typed path (toolConfig.toolChoice.tool). The
// helper must not double-emit.
func TestBuildBedrockServerToolChoice_PinnedFunctionTool_NotTunneled(t *testing.T) {
fn := schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "get_weather",
Parameters: &schemas.ToolFunctionParameters{Type: "object"},
},
}
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{fn},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
Type: schemas.ChatToolChoiceTypeFunction,
Function: &schemas.ChatToolChoiceFunction{Name: "get_weather"},
},
},
}
if _, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{fn}); ok {
t.Errorf("expected no tunneling for function-tool pin (typed Converse path handles it)")
}
}
// TestBuildBedrockServerToolChoice_AnyWithOnlyServerTools — tool_choice:any
// with only server tools: convertToolConfig returns nil (bedrockTools empty),
// so the typed any-contract is lost. The tunneled path must emit
// {"type":"any"} to preserve the forcing semantics.
func TestBuildBedrockServerToolChoice_AnyWithOnlyServerTools(t *testing.T) {
bash := schemas.ChatTool{
Type: schemas.ChatToolType("bash_20250124"),
Name: "bash",
}
anyStr := string(schemas.ChatToolChoiceTypeAny)
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{bash},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStr: &anyStr,
},
}
choice, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{bash})
if !ok {
t.Fatalf("expected tunneled any-contract when only server tools are present, got (nil, false)")
}
got := string(choice)
if !strings.Contains(got, `"type":"any"`) {
t.Errorf("expected {type:any}, got %s", got)
}
}
// TestBuildBedrockServerToolChoice_AnyWithFunctionTool_NotTunneled — when at
// least one function/custom tool is present, Converse's typed
// toolConfig.toolChoice.any carries the any-contract. Don't double-emit.
func TestBuildBedrockServerToolChoice_AnyWithFunctionTool_NotTunneled(t *testing.T) {
fn := schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "get_weather",
Parameters: &schemas.ToolFunctionParameters{Type: "object"},
},
}
bash := schemas.ChatTool{
Type: schemas.ChatToolType("bash_20250124"),
Name: "bash",
}
anyStr := string(schemas.ChatToolChoiceTypeAny)
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{fn, bash},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStr: &anyStr,
},
}
if _, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{fn, bash}); ok {
t.Errorf("expected no tunneling when function/custom tool is present (typed Converse path handles any)")
}
}
// TestBuildBedrockServerToolChoice_UnsupportedServerToolPin_NotTunneled — the
// caller pins web_search, which ValidateChatToolsForProvider strips on
// Bedrock. The pin name is absent from the filtered set; the helper must not
// fabricate a tunneled tool_choice for a tool that isn't in the request.
func TestBuildBedrockServerToolChoice_UnsupportedServerToolPin_NotTunneled(t *testing.T) {
// The caller's original request had web_search, but it's been stripped.
// We pass the filtered slice (empty for the server-tool axis) to mimic
// the convertChatParameters call path.
params := &schemas.ChatParameters{
Tools: []schemas.ChatTool{{Type: schemas.ChatToolType("web_search_20260209"), Name: "web_search"}},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
Type: schemas.ChatToolChoiceTypeFunction,
Function: &schemas.ChatToolChoiceFunction{Name: "web_search"},
},
},
}
// Filtered (post-ValidateChatToolsForProvider(Bedrock)) — web_search is dropped.
filtered := []schemas.ChatTool{}
if _, ok := buildBedrockServerToolChoice(params, filtered); ok {
t.Errorf("expected no tunneling when pinned name was stripped by provider validation")
}
}
// TestConvertChatParameters_PinnedServerToolE2E — end-to-end verification
// that convertChatParameters composes convertToolConfig +
// collectBedrockServerTools + buildBedrockServerToolChoice such that a
// request pinning a kept server tool produces:
// - AdditionalModelRequestFields.tools containing the server tool
// - AdditionalModelRequestFields.tool_choice with Anthropic-native shape
// - ToolConfig nil (no function tools → Converse's typed path is inactive)
func TestConvertChatParameters_PinnedServerToolE2E(t *testing.T) {
bifrostReq := &schemas.BifrostChatRequest{
Model: "global.anthropic.claude-sonnet-4-6",
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("computer_20251124"),
Name: "computer",
DisplayWidthPx: schemas.Ptr(1280),
},
},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
Type: schemas.ChatToolChoiceTypeFunction,
Function: &schemas.ChatToolChoiceFunction{Name: "computer"},
},
},
},
}
bedrockReq := &BedrockConverseRequest{}
if err := convertChatParameters(nil, bifrostReq, bedrockReq); err != nil {
t.Fatalf("convertChatParameters failed: %v", err)
}
if bedrockReq.ToolConfig != nil {
t.Errorf("expected nil ToolConfig (no function/custom tools), got %+v", bedrockReq.ToolConfig)
}
if bedrockReq.AdditionalModelRequestFields == nil {
t.Fatalf("expected AdditionalModelRequestFields to carry server-tool payload, got nil")
}
tools, ok := bedrockReq.AdditionalModelRequestFields.Get("tools")
if !ok {
t.Errorf("expected additionalModelRequestFields.tools to be set for server tool")
} else if toolsSlice, castOK := tools.([]json.RawMessage); !castOK || len(toolsSlice) != 1 {
t.Errorf("expected 1 server tool in additionalModelRequestFields.tools, got %+v", tools)
}
choice, ok := bedrockReq.AdditionalModelRequestFields.Get("tool_choice")
if !ok {
t.Fatalf("expected additionalModelRequestFields.tool_choice to carry pinned server-tool contract")
}
choiceRaw, castOK := choice.(json.RawMessage)
if !castOK {
t.Fatalf("expected tool_choice value to be json.RawMessage, got %T", choice)
}
got := string(choiceRaw)
if !strings.Contains(got, `"type":"tool"`) || !strings.Contains(got, `"name":"computer"`) {
t.Errorf("expected {type:tool,name:computer}, got %s", got)
}
}
// TestConvertChatParameters_ResponseFormatWithPinnedServerTool_NoConflictingChoice
// locks in the fix for the "two conflicting tool-choice directives" hazard:
// when response_format forces the synthetic bf_so_* tool via
// ToolConfig.ToolChoice, the tunneled additionalModelRequestFields.tool_choice
// (which would pin a server tool) must be suppressed so Bedrock doesn't
// receive both pins in the same Converse call. Uses a Nova model since
// Anthropic models route response_format through native output_config.format
// (no synthetic tool), so the conflict only surfaces on non-Anthropic
// Bedrock targets.
func TestConvertChatParameters_ResponseFormatWithPinnedServerTool_NoConflictingChoice(t *testing.T) {
responseFormat := any(map[string]any{
"type": "json_schema",
"json_schema": map[string]any{
"name": "classification",
"schema": map[string]any{
"type": "object",
"properties": map[string]any{
"topic": map[string]any{"type": "string"},
},
"required": []any{"topic"},
},
},
})
bifrostReq := &schemas.BifrostChatRequest{
Model: "amazon.nova-pro-v1:0",
Params: &schemas.ChatParameters{
ResponseFormat: &responseFormat,
Tools: []schemas.ChatTool{
{
Type: schemas.ChatToolType("bash_20250124"),
Name: "bash",
},
},
ToolChoice: &schemas.ChatToolChoice{
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
Type: schemas.ChatToolChoiceTypeFunction,
Function: &schemas.ChatToolChoiceFunction{Name: "bash"},
},
},
},
}
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bedrockReq := &BedrockConverseRequest{}
if err := convertChatParameters(ctx, bifrostReq, bedrockReq); err != nil {
t.Fatalf("convertChatParameters failed: %v", err)
}
// Synthetic bf_so_* tool must be injected and pinned via Converse's typed path.
if bedrockReq.ToolConfig == nil {
t.Fatalf("expected ToolConfig with synthetic bf_so_* tool, got nil")
}
if bedrockReq.ToolConfig.ToolChoice == nil || bedrockReq.ToolConfig.ToolChoice.Tool == nil {
t.Fatalf("expected ToolConfig.ToolChoice.Tool to pin synthetic structured-output tool, got %+v", bedrockReq.ToolConfig.ToolChoice)
}
if !strings.HasPrefix(bedrockReq.ToolConfig.ToolChoice.Tool.Name, "bf_so_") {
t.Errorf("expected ToolConfig.ToolChoice.Tool.Name to start with bf_so_, got %q", bedrockReq.ToolConfig.ToolChoice.Tool.Name)
}
// Server tool must still be tunneled so the model has it available.
if bedrockReq.AdditionalModelRequestFields == nil {
t.Fatalf("expected AdditionalModelRequestFields to carry tunneled server-tool payload, got nil")
}
if _, ok := bedrockReq.AdditionalModelRequestFields.Get("tools"); !ok {
t.Errorf("expected additionalModelRequestFields.tools to still carry bash server tool")
}
// Guarded field: tunneled tool_choice MUST be absent because response_format
// forces the synthetic tool. Two tool-choice directives in the same request
// would let Bedrock pick one and silently violate the structured-output contract.
if _, ok := bedrockReq.AdditionalModelRequestFields.Get("tool_choice"); ok {
t.Errorf("expected NO additionalModelRequestFields.tool_choice when response_format pins bf_so_* (conflict hazard)")
}
}

View File

@@ -0,0 +1,57 @@
package bedrock
import (
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
const estimatedBytesPerToken = 4
// ToBifrostCountTokensResponse converts a Bedrock count tokens response to Bifrost format
func (resp *BedrockCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
if resp == nil {
return nil
}
totalTokens := resp.InputTokens
return &schemas.BifrostCountTokensResponse{
Model: model,
InputTokens: resp.InputTokens,
TotalTokens: &totalTokens,
Object: "response.input_tokens",
}
}
// ToBedrockCountTokensResponse converts a Bifrost count tokens response to Bedrock native format
func ToBedrockCountTokensResponse(resp *schemas.BifrostCountTokensResponse) *BedrockCountTokensResponse {
if resp == nil {
return nil
}
return &BedrockCountTokensResponse{
InputTokens: resp.InputTokens,
}
}
// isCountTokensUnsupported checks whether a BifrostError indicates that the
// Bedrock model does not support the count-tokens operation.
func isCountTokensUnsupported(err *schemas.BifrostError) bool {
if err == nil || err.Error == nil {
return false
}
return strings.Contains(strings.ToLower(err.Error.Message), "doesn't support counting tokens")
}
// estimateTokenCount returns a rough token count derived from the byte length
// of the serialized request body. Claude's tokenizer averages ~4 bytes per
// token on mixed content; this intentionally rounds up so that context-window
// management decisions stay on the conservative side.
func estimateTokenCount(requestBody []byte) int {
n := len(requestBody)
if n == 0 {
return 0
}
return (n + estimatedBytesPerToken - 1) / estimatedBytesPerToken
}

View File

@@ -0,0 +1,105 @@
package bedrock
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
)
func TestIsCountTokensUnsupported(t *testing.T) {
tests := []struct {
name string
err *schemas.BifrostError
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "nil error field",
err: &schemas.BifrostError{},
expected: false,
},
{
name: "matching bedrock error message",
err: &schemas.BifrostError{
Error: &schemas.ErrorField{
Message: "The provided model doesn't support counting tokens.",
},
},
expected: true,
},
{
name: "matching message with different casing",
err: &schemas.BifrostError{
Error: &schemas.ErrorField{
Message: "the provided model DOESN'T SUPPORT COUNTING TOKENS.",
},
},
expected: true,
},
{
name: "unrelated error message",
err: &schemas.BifrostError{
Error: &schemas.ErrorField{
Message: "access denied",
},
},
expected: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, isCountTokensUnsupported(tc.err))
})
}
}
func TestEstimateTokenCount(t *testing.T) {
tests := []struct {
name string
input []byte
expected int
}{
{
name: "empty input",
input: []byte{},
expected: 0,
},
{
name: "nil input",
input: nil,
expected: 0,
},
{
name: "exact multiple of 4",
input: make([]byte, 100),
expected: 25,
},
{
name: "rounds up",
input: make([]byte, 101),
expected: 26,
},
{
name: "single byte",
input: []byte("x"),
expected: 1,
},
{
name: "realistic json body",
input: []byte(`{"messages":[{"role":"user","content":"Hello, how are you today?"}],"model":"us.anthropic.claude-sonnet-4-6"}`),
expected: 28,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, estimateTokenCount(tc.input))
})
}
}

View File

@@ -0,0 +1,271 @@
package bedrock
import (
"encoding/json"
"fmt"
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBedrockTitanEmbeddingRequest converts a Bifrost embedding request to Bedrock Titan format
func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockTitanEmbeddingRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrost embedding request is nil")
}
// Validate that only single text input is provided for Titan models
if bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0 {
return nil, fmt.Errorf("no input text provided for embedding")
}
titanReq := &BedrockTitanEmbeddingRequest{}
// Set input text
if bifrostReq.Input.Text != nil {
titanReq.InputText = *bifrostReq.Input.Text
} else if len(bifrostReq.Input.Texts) > 0 {
var embeddingText string
for _, text := range bifrostReq.Input.Texts {
embeddingText += text + " \n"
}
titanReq.InputText = embeddingText
}
if bifrostReq.Params != nil {
titanReq.Dimensions = bifrostReq.Params.Dimensions
if normalize, ok := bifrostReq.Params.ExtraParams["normalize"]; ok {
if b, ok := normalize.(bool); ok {
titanReq.Normalize = &b
}
}
// Forward remaining extra params (excluding normalize which is now a first-class field)
if len(bifrostReq.Params.ExtraParams) > 0 {
extra := make(map[string]interface{})
for k, v := range bifrostReq.Params.ExtraParams {
if k != "normalize" {
extra[k] = v
}
}
if len(extra) > 0 {
titanReq.ExtraParams = extra
}
}
}
return titanReq, nil
}
// ToBifrostEmbeddingResponse converts a Bedrock Titan embedding response to Bifrost format
func (response *BedrockTitanEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostEmbeddingResponse{
Object: "list",
Data: []schemas.EmbeddingData{
{
Index: 0,
Object: "embedding",
Embedding: schemas.EmbeddingStruct{
EmbeddingArray: response.Embedding,
},
},
},
Usage: &schemas.BifrostLLMUsage{
PromptTokens: response.InputTextTokenCount,
TotalTokens: response.InputTextTokenCount,
},
}
return bifrostResponse
}
// ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format.
// Unlike the direct Cohere API, Bedrock does not accept a "model" field in the request body.
func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockCohereEmbeddingRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrost embedding request is nil")
}
if bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0) {
return nil, fmt.Errorf("no input provided for embedding")
}
req := &BedrockCohereEmbeddingRequest{}
// Map texts
if bifrostReq.Input.Text != nil {
req.Texts = []string{*bifrostReq.Input.Text}
} else if len(bifrostReq.Input.Texts) > 0 {
req.Texts = bifrostReq.Input.Texts
}
if bifrostReq.Params != nil {
extra := make(map[string]interface{}, len(bifrostReq.Params.ExtraParams))
for k, v := range bifrostReq.Params.ExtraParams {
extra[k] = v
}
if v, ok := extra["input_type"]; ok {
if s, ok := v.(string); ok {
req.InputType = s
delete(extra, "input_type")
}
}
if v, ok := extra["truncate"]; ok {
if s, ok := v.(string); ok {
req.Truncate = &s
delete(extra, "truncate")
}
}
if v, ok := extra["embedding_types"]; ok {
if ss, ok := v.([]string); ok {
req.EmbeddingTypes = ss
delete(extra, "embedding_types")
}
}
if v, ok := extra["images"]; ok {
if ss, ok := v.([]string); ok {
req.Images = ss
delete(extra, "images")
}
}
if v, ok := extra["inputs"]; ok {
if inputs, ok := v.([]BedrockCohereEmbeddingInput); ok {
req.Inputs = inputs
delete(extra, "inputs")
}
}
if v, ok := extra["max_tokens"]; ok {
switch n := v.(type) {
case int:
req.MaxTokens = &n
delete(extra, "max_tokens")
case float64:
i := int(n)
req.MaxTokens = &i
delete(extra, "max_tokens")
}
}
if bifrostReq.Params.Dimensions != nil {
req.OutputDimension = bifrostReq.Params.Dimensions
}
if len(extra) > 0 {
req.ExtraParams = extra
}
}
return req, nil
}
// DetermineEmbeddingModelType determines the embedding model type from the model name
func DetermineEmbeddingModelType(model string) (string, error) {
switch {
case strings.Contains(model, "amazon.titan-embed-text"):
return "titan", nil
case strings.Contains(model, "cohere.embed"):
return "cohere", nil
default:
return "", fmt.Errorf("unsupported embedding model: %s", model)
}
}
// ToBifrostEmbeddingResponse converts a BedrockCohereEmbeddingResponse to Bifrost format.
// Bedrock returns embeddings as a raw [][]float32 when response_type is "embeddings_floats"
// (the default, when no embedding_types are requested), and as a typed object when
// response_type is "embeddings_by_type".
func (r *BedrockCohereEmbeddingResponse) ToBifrostEmbeddingResponse() (*schemas.BifrostEmbeddingResponse, error) {
if r == nil {
return nil, fmt.Errorf("nil Bedrock Cohere embedding response")
}
bifrostResponse := &schemas.BifrostEmbeddingResponse{Object: "list"}
switch r.ResponseType {
case "embeddings_by_type":
// Object form: {"float": [[...]], "int8": [[...]], "uint8": [[...]], "binary": [[...]], "ubinary": [[...]], "base64": [...]}
var typed struct {
Float [][]float32 `json:"float"`
Base64 []string `json:"base64"`
Int8 [][]int8 `json:"int8"`
Uint8 [][]int32 `json:"uint8"` // int32 avoids []byte→base64 JSON issue
Binary [][]int8 `json:"binary"`
Ubinary [][]int32 `json:"ubinary"` // int32 avoids []byte→base64 JSON issue
}
if err := json.Unmarshal(r.Embeddings, &typed); err != nil {
return nil, fmt.Errorf("error parsing embeddings_by_type: %w", err)
}
if typed.Float != nil {
for i, emb := range typed.Float {
float64Emb := make([]float64, len(emb))
for j, v := range emb {
float64Emb[j] = float64(v)
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb},
})
}
}
if typed.Base64 != nil {
for i, emb := range typed.Base64 {
e := emb
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingStr: &e},
})
}
}
for i, emb := range typed.Int8 {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb},
})
}
for i, emb := range typed.Binary {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb},
})
}
for i, emb := range typed.Uint8 {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb},
})
}
for i, emb := range typed.Ubinary {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb},
})
}
default:
// Default / "embeddings_floats": raw array form [[...], [...]]
var floats [][]float32
if err := json.Unmarshal(r.Embeddings, &floats); err != nil {
return nil, fmt.Errorf("error parsing embeddings_floats: %w", err)
}
for i, emb := range floats {
float64Emb := make([]float64, len(emb))
for j, v := range emb {
float64Emb[j] = float64(v)
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb},
})
}
}
return bifrostResponse, nil
}

View File

@@ -0,0 +1,114 @@
package bedrock
import (
"context"
"testing"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestToBedrockCohereEmbeddingRequest(t *testing.T) {
t.Run("returns error for nil request", func(t *testing.T) {
req, err := ToBedrockCohereEmbeddingRequest(nil)
require.Error(t, err)
assert.Nil(t, req)
assert.Contains(t, err.Error(), "nil")
})
t.Run("returns error for missing input", func(t *testing.T) {
req, err := ToBedrockCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{})
require.Error(t, err)
assert.Nil(t, req)
assert.Contains(t, err.Error(), "no input")
})
t.Run("returns error for non-nil but empty input", func(t *testing.T) {
req, err := ToBedrockCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{
Input: &schemas.EmbeddingInput{},
})
require.Error(t, err)
assert.Nil(t, req)
assert.Contains(t, err.Error(), "no input")
})
t.Run("single text strips model and extracts typed params", func(t *testing.T) {
text := "hello"
truncate := "RIGHT"
dimensions := 512
bifrostReq := &schemas.BifrostEmbeddingRequest{
Model: "cohere.embed-english-v3",
Input: &schemas.EmbeddingInput{Text: &text},
Params: &schemas.EmbeddingParameters{
Dimensions: &dimensions,
ExtraParams: map[string]interface{}{
"input_type": "search_query",
"embedding_types": []string{"float"},
"truncate": truncate,
"max_tokens": float64(128),
"trace_id": "req-123",
},
},
}
req, err := ToBedrockCohereEmbeddingRequest(bifrostReq)
require.NoError(t, err)
require.NotNil(t, req)
assert.Equal(t, "search_query", req.InputType)
assert.Equal(t, []string{"hello"}, req.Texts)
assert.Equal(t, []string{"float"}, req.EmbeddingTypes)
assert.Equal(t, &dimensions, req.OutputDimension)
assert.Equal(t, 128, *req.MaxTokens)
require.NotNil(t, req.Truncate)
assert.Equal(t, truncate, *req.Truncate)
assert.Equal(t, map[string]interface{}{"trace_id": "req-123"}, req.ExtraParams)
})
t.Run("multiple texts preserve bedrock body shape", func(t *testing.T) {
bifrostReq := &schemas.BifrostEmbeddingRequest{
Model: "cohere.embed-multilingual-v3",
Input: &schemas.EmbeddingInput{Texts: []string{"hello", "world"}},
Params: &schemas.EmbeddingParameters{
ExtraParams: map[string]interface{}{
"input_type": "search_document",
},
},
}
req, err := ToBedrockCohereEmbeddingRequest(bifrostReq)
require.NoError(t, err)
assert.Equal(t, []string{"hello", "world"}, req.Texts)
assert.Equal(t, "search_document", req.InputType)
})
}
func TestToBedrockCohereEmbeddingRequestBodyOmitsModel(t *testing.T) {
text := "hello"
bifrostReq := &schemas.BifrostEmbeddingRequest{
Model: "cohere.embed-english-v3",
Input: &schemas.EmbeddingInput{Text: &text},
Params: &schemas.EmbeddingParameters{
ExtraParams: map[string]interface{}{
"input_type": "search_document",
"embedding_types": []string{"float"},
},
},
}
wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
context.Background(),
bifrostReq,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToBedrockCohereEmbeddingRequest(bifrostReq)
},
)
require.Nil(t, bifrostErr)
assert.NotContains(t, string(wireBody), `"model"`)
assert.JSONEq(t, `{
"input_type": "search_document",
"texts": ["hello"],
"embedding_types": ["float"]
}`, string(wireBody))
}

View File

@@ -0,0 +1,34 @@
package bedrock
import (
"net/http"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
func parseBedrockHTTPError(statusCode int, headers http.Header, body []byte) *schemas.BifrostError {
fastResp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(fastResp)
fastResp.SetStatusCode(statusCode)
for k, values := range headers {
for _, value := range values {
fastResp.Header.Add(k, value)
}
}
fastResp.SetBody(body)
var errorResp BedrockError
bifrostErr := providerUtils.HandleProviderAPIError(fastResp, &errorResp)
if errorResp.Message != "" {
if bifrostErr.Error == nil {
bifrostErr.Error = &schemas.ErrorField{}
}
bifrostErr.Error.Message = errorResp.Message
bifrostErr.Error.Code = errorResp.Code
}
return bifrostErr
}

View File

@@ -0,0 +1,276 @@
package bedrock
import (
"fmt"
"html"
"net/url"
"strings"
"time"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
)
// escapeS3KeyForURL escapes each segment of an S3 key path individually.
// This prevents signature and URL parsing failures with special characters.
// We can't use url.PathEscape on the full key as it escapes "/" to "%2F",
// but we need each segment properly escaped per RFC 3986 for AWS SigV4 signing.
func escapeS3KeyForURL(key string) string {
if key == "" {
return ""
}
parts := strings.Split(key, "/")
for i, p := range parts {
parts[i] = url.PathEscape(p)
}
return strings.Join(parts, "/")
}
// parseS3URI parses an S3 URI (s3://bucket/key or bucket-name) and returns bucket name and key.
func parseS3URI(uri string) (bucket, key string) {
if strings.HasPrefix(uri, "s3://") {
uri = strings.TrimPrefix(uri, "s3://")
parts := strings.SplitN(uri, "/", 2)
bucket = parts[0]
if len(parts) > 1 {
key = parts[1]
}
} else {
// Assume it's just a bucket name
bucket = uri
}
return
}
// S3ListObjectsResponse represents S3 ListObjectsV2 response.
type S3ListObjectsResponse struct {
Contents []S3Object `json:"contents"`
IsTruncated bool `json:"isTruncated"`
NextContinuationToken string `json:"nextContinuationToken,omitempty"`
}
// S3Object represents an S3 object in list response.
type S3Object struct {
Key string `json:"key"`
Size int64 `json:"size"`
LastModified time.Time `json:"lastModified"`
}
// parseS3ListResponse parses S3 ListObjectsV2 XML response.
func parseS3ListResponse(body []byte, resp *S3ListObjectsResponse) error {
// S3 returns XML, so we need to parse it
// Try JSON first (some S3-compatible services return JSON)
if err := sonic.Unmarshal(body, resp); err == nil && len(resp.Contents) > 0 {
return nil
}
// Parse XML using simple string matching for key fields
// This is a lightweight approach that doesn't require encoding/xml
bodyStr := string(body)
// Parse IsTruncated
if strings.Contains(bodyStr, "<IsTruncated>true</IsTruncated>") {
resp.IsTruncated = true
}
// Parse NextContinuationToken
if start := strings.Index(bodyStr, "<NextContinuationToken>"); start >= 0 {
start += len("<NextContinuationToken>")
if end := strings.Index(bodyStr[start:], "</NextContinuationToken>"); end >= 0 {
resp.NextContinuationToken = bodyStr[start : start+end]
}
}
// Parse Contents
contents := bodyStr
for {
start := strings.Index(contents, "<Contents>")
if start < 0 {
break
}
end := strings.Index(contents[start:], "</Contents>")
if end < 0 {
break
}
contentBlock := contents[start : start+end+len("</Contents>")]
contents = contents[start+end+len("</Contents>"):]
obj := S3Object{}
// Parse Key
if keyStart := strings.Index(contentBlock, "<Key>"); keyStart >= 0 {
keyStart += len("<Key>")
if keyEnd := strings.Index(contentBlock[keyStart:], "</Key>"); keyEnd >= 0 {
obj.Key = html.UnescapeString(contentBlock[keyStart : keyStart+keyEnd])
}
}
// Parse Size
if sizeStart := strings.Index(contentBlock, "<Size>"); sizeStart >= 0 {
sizeStart += len("<Size>")
if sizeEnd := strings.Index(contentBlock[sizeStart:], "</Size>"); sizeEnd >= 0 {
sizeStr := contentBlock[sizeStart : sizeStart+sizeEnd]
fmt.Sscanf(sizeStr, "%d", &obj.Size)
}
}
// Parse LastModified
if lmStart := strings.Index(contentBlock, "<LastModified>"); lmStart >= 0 {
lmStart += len("<LastModified>")
if lmEnd := strings.Index(contentBlock[lmStart:], "</LastModified>"); lmEnd >= 0 {
lmStr := contentBlock[lmStart : lmStart+lmEnd]
if t, err := time.Parse(time.RFC3339Nano, lmStr); err == nil {
obj.LastModified = t
}
}
}
if obj.Key != "" {
resp.Contents = append(resp.Contents, obj)
}
}
return nil
}
// ==================== BEDROCK FILE TYPE CONVERTERS ====================
// ToBedrockFileUploadResponse converts a Bifrost file upload response to Bedrock format.
func ToBedrockFileUploadResponse(resp *schemas.BifrostFileUploadResponse) *BedrockFileUploadResponse {
if resp == nil {
return nil
}
// Parse S3 URI to get bucket and key
bucket, key := parseS3URI(resp.ID)
return &BedrockFileUploadResponse{
S3Uri: resp.ID,
Bucket: bucket,
Key: key,
SizeBytes: resp.Bytes,
ContentType: "application/jsonl",
CreatedAt: resp.CreatedAt,
}
}
// ToBedrockFileListResponse converts a Bifrost file list response to Bedrock format.
func ToBedrockFileListResponse(resp *schemas.BifrostFileListResponse) *BedrockFileListResponse {
if resp == nil {
return nil
}
files := make([]BedrockFileInfo, len(resp.Data))
for i, f := range resp.Data {
_, key := parseS3URI(f.ID)
files[i] = BedrockFileInfo{
S3Uri: f.ID,
Key: key,
SizeBytes: f.Bytes,
LastModified: f.CreatedAt,
}
}
return &BedrockFileListResponse{
Files: files,
IsTruncated: resp.HasMore,
}
}
// ToBedrockFileRetrieveResponse converts a Bifrost file retrieve response to Bedrock format.
func ToBedrockFileRetrieveResponse(resp *schemas.BifrostFileRetrieveResponse) *BedrockFileRetrieveResponse {
if resp == nil {
return nil
}
_, key := parseS3URI(resp.ID)
return &BedrockFileRetrieveResponse{
S3Uri: resp.ID,
Key: key,
SizeBytes: resp.Bytes,
LastModified: resp.CreatedAt,
ContentType: "application/jsonl",
}
}
// ToBedrockFileDeleteResponse converts a Bifrost file delete response to Bedrock format.
func ToBedrockFileDeleteResponse(resp *schemas.BifrostFileDeleteResponse) *BedrockFileDeleteResponse {
if resp == nil {
return nil
}
return &BedrockFileDeleteResponse{
S3Uri: resp.ID,
Deleted: resp.Deleted,
}
}
// ToBedrockFileContentResponse converts a Bifrost file content response to Bedrock format.
func ToBedrockFileContentResponse(resp *schemas.BifrostFileContentResponse) *BedrockFileContentResponse {
if resp == nil {
return nil
}
return &BedrockFileContentResponse{
S3Uri: resp.FileID,
Content: resp.Content,
ContentType: resp.ContentType,
SizeBytes: int64(len(resp.Content)),
}
}
// ==================== S3 API XML FORMATTERS ====================
// ToS3ListObjectsV2XML converts a Bifrost file list response to S3 ListObjectsV2 XML format.
func ToS3ListObjectsV2XML(resp *schemas.BifrostFileListResponse, bucket, prefix string, maxKeys int) []byte {
if resp == nil {
return []byte(`<?xml version="1.0" encoding="UTF-8"?><ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></ListBucketResult>`)
}
var sb strings.Builder
sb.WriteString(`<?xml version="1.0" encoding="UTF-8"?>`)
sb.WriteString(`<ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">`)
sb.WriteString(fmt.Sprintf("<Name>%s</Name>", bucket))
sb.WriteString(fmt.Sprintf("<Prefix>%s</Prefix>", prefix))
sb.WriteString(fmt.Sprintf("<KeyCount>%d</KeyCount>", len(resp.Data)))
sb.WriteString(fmt.Sprintf("<MaxKeys>%d</MaxKeys>", maxKeys))
if resp.HasMore {
sb.WriteString("<IsTruncated>true</IsTruncated>")
if resp.After != nil && *resp.After != "" {
sb.WriteString(fmt.Sprintf("<NextContinuationToken>%s</NextContinuationToken>", *resp.After))
}
} else {
sb.WriteString("<IsTruncated>false</IsTruncated>")
}
for _, f := range resp.Data {
// Extract key from S3 URI
_, key := parseS3URI(f.ID)
sb.WriteString("<Contents>")
sb.WriteString(fmt.Sprintf("<Key>%s</Key>", key))
sb.WriteString(fmt.Sprintf("<Size>%d</Size>", f.Bytes))
if f.CreatedAt > 0 {
sb.WriteString(fmt.Sprintf("<LastModified>%s</LastModified>", time.Unix(f.CreatedAt, 0).UTC().Format(time.RFC3339)))
}
sb.WriteString("<StorageClass>STANDARD</StorageClass>")
sb.WriteString("</Contents>")
}
sb.WriteString("</ListBucketResult>")
return []byte(sb.String())
}
// ToS3ErrorXML converts an error to S3 error XML format.
func ToS3ErrorXML(code, message, resource, requestID string) []byte {
var sb strings.Builder
sb.WriteString(`<?xml version="1.0" encoding="UTF-8"?>`)
sb.WriteString("<Error>")
sb.WriteString(fmt.Sprintf("<Code>%s</Code>", code))
sb.WriteString(fmt.Sprintf("<Message>%s</Message>", message))
sb.WriteString(fmt.Sprintf("<Resource>%s</Resource>", resource))
sb.WriteString(fmt.Sprintf("<RequestId>%s</RequestId>", requestID))
sb.WriteString("</Error>")
return []byte(sb.String())
}

View File

@@ -0,0 +1,742 @@
package bedrock
import (
"encoding/base64"
"fmt"
"strconv"
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
// mapQualityToBedrock maps quality values to Bedrock format:
// - "low" and "medium" -> "standard"
// - "high" -> "premium"
// - "standard" and "premium" (case-insensitive) -> pass through as lowercase ("standard"/"premium")
func mapQualityToBedrock(quality *string) *string {
if quality == nil {
return nil
}
qualityLower := strings.ToLower(strings.TrimSpace(*quality))
switch qualityLower {
case "low", "medium":
return schemas.Ptr("standard")
case "high":
return schemas.Ptr("premium")
case "standard":
return schemas.Ptr("standard")
case "premium":
return schemas.Ptr("premium")
default:
return quality
}
}
// isStabilityAIModel returns true if the model is a Stability AI model (contains "stability.")
func isStabilityAIModel(model string) bool {
return strings.Contains(strings.ToLower(model), "stability.")
}
// isPromptOnlyImageGenerationModel returns true for image generation models that use a flat
// {"prompt": "..."} payload (no taskType field). Covers Vertex Imagen and similar models.
// Stability AI is excluded here — it's handled separately because it also supports image edit.
func isPromptOnlyImageGenerationModel(model string) bool {
m := strings.ToLower(model)
return strings.Contains(m, "image")
}
// ToStabilityAIImageGenerationRequest converts a Bifrost image generation request to the Stability AI
// flat request format used by Bedrock (stability.stable-image-* models).
func ToStabilityAIImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*StabilityAIImageGenerationRequest, error) {
if request == nil {
return nil, fmt.Errorf("request is nil")
}
if request.Input == nil {
return nil, fmt.Errorf("request input is required")
}
req := &StabilityAIImageGenerationRequest{
Prompt: request.Input.Prompt,
}
if request.Params != nil {
if request.Params.AspectRatio != nil {
req.AspectRatio = request.Params.AspectRatio
}
if request.Params.OutputFormat != nil {
req.OutputFormat = request.Params.OutputFormat
}
if request.Params.Seed != nil {
req.Seed = request.Params.Seed
}
if request.Params.NegativePrompt != nil {
req.NegativePrompt = request.Params.NegativePrompt
}
if request.Params.ExtraParams != nil {
// aspect_ratio may also arrive via ExtraParams if not in knownFields; skip if already set
if req.AspectRatio == nil {
if ar, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["aspect_ratio"]); ok {
delete(request.Params.ExtraParams, "aspect_ratio")
req.AspectRatio = ar
}
}
req.ExtraParams = request.Params.ExtraParams
}
}
return req, nil
}
// ToBedrockImageGenerationRequest converts a Bifrost image generation request to a Bedrock image generation request
func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*BedrockImageGenerationRequest, error) {
if request == nil {
return nil, fmt.Errorf("request is nil")
}
if request.Input == nil {
return nil, fmt.Errorf("request input is required")
}
bedrockReq := &BedrockImageGenerationRequest{
TaskType: schemas.Ptr(TaskTypeTextImage),
TextToImageParams: &BedrockTextToImageParams{
Text: request.Input.Prompt,
},
ImageGenerationConfig: &ImageGenerationConfig{},
}
if request.Params != nil {
if request.Params.N != nil {
bedrockReq.ImageGenerationConfig.NumberOfImages = request.Params.N
}
if request.Params.NegativePrompt != nil {
bedrockReq.TextToImageParams.NegativeText = request.Params.NegativePrompt
}
if request.Params.Seed != nil {
bedrockReq.ImageGenerationConfig.Seed = request.Params.Seed
}
if request.Params.Quality != nil {
bedrockReq.ImageGenerationConfig.Quality = mapQualityToBedrock(request.Params.Quality)
}
if request.Params.Style != nil {
bedrockReq.TextToImageParams.Style = request.Params.Style
}
if request.Params.Size != nil && strings.TrimSpace(strings.ToLower(*request.Params.Size)) != "auto" {
size := strings.Split(strings.TrimSpace(strings.ToLower(*request.Params.Size)), "x")
if len(size) != 2 {
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *request.Params.Size)
}
width, err := strconv.Atoi(size[0])
if err != nil {
return nil, fmt.Errorf("invalid width in size %q: %w", *request.Params.Size, err)
}
height, err := strconv.Atoi(size[1])
if err != nil {
return nil, fmt.Errorf("invalid height in size %q: %w", *request.Params.Size, err)
}
bedrockReq.ImageGenerationConfig.Width = schemas.Ptr(width)
bedrockReq.ImageGenerationConfig.Height = schemas.Ptr(height)
}
if request.Params.ExtraParams != nil {
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["cfgScale"]); ok {
delete(request.Params.ExtraParams, "cfgScale")
bedrockReq.ImageGenerationConfig.CfgScale = cfgScale
}
bedrockReq.ExtraParams = request.Params.ExtraParams
}
}
return bedrockReq, nil
}
// ToStabilityAIImageGenerationResponse converts a BifrostImageGenerationResponse back to
// the native Bedrock invoke API response format used by Stability AI models.
// Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas.
func ToStabilityAIImageGenerationResponse(response *schemas.BifrostImageGenerationResponse) (*BedrockImageGenerationResponse, error) {
if response == nil {
return nil, fmt.Errorf("response is nil")
}
result := &BedrockImageGenerationResponse{}
for _, d := range response.Data {
result.Images = append(result.Images, d.B64JSON)
}
if response.ImageGenerationResponseParameters != nil {
result.FinishReasons = response.ImageGenerationResponseParameters.FinishReasons
result.Seeds = response.ImageGenerationResponseParameters.Seeds
}
return result, nil
}
// ToBedrockImageVariationRequest converts a Bifrost image variation request to a Bedrock image variation request
func ToBedrockImageVariationRequest(request *schemas.BifrostImageVariationRequest) (*BedrockImageVariationRequest, error) {
if request == nil {
return nil, fmt.Errorf("request is nil")
}
if request.Input == nil || request.Input.Image.Image == nil || len(request.Input.Image.Image) == 0 {
return nil, fmt.Errorf("request.Input.Image is required")
}
bedrockReq := &BedrockImageVariationRequest{
TaskType: schemas.Ptr(TaskTypeImageVariation),
ImageVariationParams: &BedrockImageVariationParams{
Images: []string{},
},
ImageGenerationConfig: &ImageGenerationConfig{},
}
// Convert all images to base64 strings
// Primary image from Input.Image
imageBase64 := base64.StdEncoding.EncodeToString(request.Input.Image.Image)
bedrockReq.ImageVariationParams.Images = append(bedrockReq.ImageVariationParams.Images, imageBase64)
// Additional images from ExtraParams (stored as [][]byte)
if request.Params != nil && request.Params.ExtraParams != nil {
if additionalImages, ok := request.Params.ExtraParams["images"]; ok {
delete(request.Params.ExtraParams, "images")
// Handle array of byte arrays (stored by HTTP handler)
if imagesArray, ok := additionalImages.([][]byte); ok {
for _, imgBytes := range imagesArray {
if len(imgBytes) > 0 {
additionalBase64 := base64.StdEncoding.EncodeToString(imgBytes)
bedrockReq.ImageVariationParams.Images = append(bedrockReq.ImageVariationParams.Images, additionalBase64)
}
}
}
}
// Extract optional fields from ExtraParams
if prompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["prompt"]); ok {
delete(request.Params.ExtraParams, "prompt")
bedrockReq.ImageVariationParams.Text = prompt
}
if negativeText, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["negativeText"]); ok {
delete(request.Params.ExtraParams, "negativeText")
bedrockReq.ImageVariationParams.NegativeText = negativeText
}
if similarityStrength, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["similarityStrength"]); ok {
delete(request.Params.ExtraParams, "similarityStrength")
// Validate similarityStrength range (0.2 to 1.0)
if *similarityStrength < 0.2 || *similarityStrength > 1.0 {
return nil, fmt.Errorf("similarityStrength must be between 0.2 and 1.0, got %f", *similarityStrength)
}
bedrockReq.ImageVariationParams.SimilarityStrength = similarityStrength
}
bedrockReq.ExtraParams = request.Params.ExtraParams
}
// Map standard params to ImageGenerationConfig
if request.Params != nil {
if request.Params.N != nil {
bedrockReq.ImageGenerationConfig.NumberOfImages = request.Params.N
}
if request.Params.Size != nil && strings.TrimSpace(strings.ToLower(*request.Params.Size)) != "auto" {
size := strings.Split(strings.TrimSpace(strings.ToLower(*request.Params.Size)), "x")
if len(size) != 2 {
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *request.Params.Size)
}
width, err := strconv.Atoi(size[0])
if err != nil {
return nil, fmt.Errorf("invalid width in size %q: %w", *request.Params.Size, err)
}
height, err := strconv.Atoi(size[1])
if err != nil {
return nil, fmt.Errorf("invalid height in size %q: %w", *request.Params.Size, err)
}
bedrockReq.ImageGenerationConfig.Width = schemas.Ptr(width)
bedrockReq.ImageGenerationConfig.Height = schemas.Ptr(height)
}
// Extract quality and cfgScale from ExtraParams
if request.Params.ExtraParams != nil {
if quality, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["quality"]); ok {
bedrockReq.ImageGenerationConfig.Quality = mapQualityToBedrock(quality)
}
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["cfgScale"]); ok {
bedrockReq.ImageGenerationConfig.CfgScale = cfgScale
}
}
}
return bedrockReq, nil
}
// ToBedrockImageEditRequest converts a Bifrost image edit request to a Bedrock image edit request
func ToBedrockImageEditRequest(request *schemas.BifrostImageEditRequest) (*BedrockImageEditRequest, error) {
// Validate request
if request == nil || request.Input == nil {
return nil, fmt.Errorf("request or input is nil")
}
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
return nil, fmt.Errorf("at least one image is required")
}
// Validate and extract type (required)
if request.Params == nil || request.Params.Type == nil {
return nil, fmt.Errorf("type field is required (must be inpainting, outpainting, or background_removal)")
}
editType := strings.ToLower(*request.Params.Type)
// Convert first image to base64
imageBase64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
bedrockReq := &BedrockImageEditRequest{}
switch editType {
case "inpainting":
bedrockReq.TaskType = schemas.Ptr(TaskTypeInpainting)
bedrockReq.InPaintingParams = buildInPaintingParams(imageBase64, request)
bedrockReq.ImageGenerationConfig = buildImageGenerationConfig(request.Params)
case "outpainting":
bedrockReq.TaskType = schemas.Ptr(TaskTypeOutpainting)
bedrockReq.OutPaintingParams = buildOutPaintingParams(imageBase64, request)
bedrockReq.ImageGenerationConfig = buildImageGenerationConfig(request.Params)
case "background_removal":
bedrockReq.TaskType = schemas.Ptr(TaskTypeBackgroundRemoval)
bedrockReq.BackgroundRemovalParams = &BedrockBackgroundRemovalParams{
Image: imageBase64,
}
default:
return nil, fmt.Errorf("unsupported type for Bedrock: %s (must be inpainting, outpainting, or background_removal)", editType)
}
bedrockReq.ExtraParams = request.Params.ExtraParams
return bedrockReq, nil
}
// Helper functions
func buildInPaintingParams(imageBase64 string, request *schemas.BifrostImageEditRequest) *BedrockInPaintingParams {
params := &BedrockInPaintingParams{
Image: imageBase64,
Text: request.Input.Prompt,
}
if request.Params.NegativePrompt != nil {
params.NegativeText = request.Params.NegativePrompt
}
if request.Params.ExtraParams != nil {
if maskPrompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["mask_prompt"]); ok {
delete(request.Params.ExtraParams, "mask_prompt")
params.MaskPrompt = maskPrompt
}
if returnMask, ok := schemas.SafeExtractBoolPointer(request.Params.ExtraParams["return_mask"]); ok {
delete(request.Params.ExtraParams, "return_mask")
params.ReturnMask = returnMask
}
}
// Convert mask to base64 if present
if len(request.Params.Mask) > 0 {
maskBase64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
params.MaskImage = &maskBase64
}
return params
}
func buildOutPaintingParams(imageBase64 string, request *schemas.BifrostImageEditRequest) *BedrockOutPaintingParams {
params := &BedrockOutPaintingParams{
Text: request.Input.Prompt,
Image: imageBase64,
}
if request.Params.NegativePrompt != nil {
params.NegativeText = request.Params.NegativePrompt
}
if request.Params.ExtraParams != nil {
if maskPrompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["mask_prompt"]); ok {
delete(request.Params.ExtraParams, "mask_prompt")
params.MaskPrompt = maskPrompt
}
if returnMask, ok := schemas.SafeExtractBoolPointer(request.Params.ExtraParams["return_mask"]); ok {
delete(request.Params.ExtraParams, "return_mask")
params.ReturnMask = returnMask
}
if outPaintingMode, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["outpainting_mode"]); ok {
// Validate mode
mode := strings.ToUpper(*outPaintingMode)
if mode == "DEFAULT" || mode == "PRECISE" {
delete(request.Params.ExtraParams, "outpainting_mode")
params.OutPaintingMode = &mode
}
}
}
// Convert mask to base64 if present
if len(request.Params.Mask) > 0 {
maskBase64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
params.MaskImage = &maskBase64
}
return params
}
func buildImageGenerationConfig(params *schemas.ImageEditParameters) *ImageGenerationConfig {
config := &ImageGenerationConfig{}
if params.N != nil {
config.NumberOfImages = params.N
}
// Parse size (reuse logic from image generation)
if params.Size != nil && strings.TrimSpace(strings.ToLower(*params.Size)) != "auto" {
size := strings.Split(strings.TrimSpace(strings.ToLower(*params.Size)), "x")
if len(size) == 2 {
width, err := strconv.Atoi(size[0])
if err == nil {
height, err := strconv.Atoi(size[1])
if err == nil {
config.Width = schemas.Ptr(width)
config.Height = schemas.Ptr(height)
}
}
}
}
if params.Quality != nil {
config.Quality = mapQualityToBedrock(params.Quality)
}
if params.Seed != nil {
config.Seed = params.Seed
}
if params.ExtraParams != nil {
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(params.ExtraParams["cfgScale"]); ok {
delete(params.ExtraParams, "cfgScale")
config.CfgScale = cfgScale
}
}
return config
}
// getStabilityAITaskTypeFromParams maps the generic BifrostImageEditParameters.Type value
// to a Stability AI task type string. Returns "" if the value is not a recognized Stability AI task type.
func getStabilityAITaskTypeFromParams(t string) string {
switch strings.ToLower(t) {
case "inpainting", "inpaint":
return "inpaint"
case "outpainting", "outpaint":
return "outpaint"
case "background_removal", "remove_background":
return "remove-bg"
case "erase_object":
return "erase-object"
case "upscale_fast":
return "upscale-fast"
case "upscale_creative":
return "upscale-creative"
case "upscale_conservative":
return "upscale-conservative"
case "recolor":
return "recolor"
case "search_replace":
return "search-replace"
case "control_sketch":
return "control-sketch"
case "control_structure":
return "control-structure"
case "style_guide":
return "style-guide"
case "style_transfer":
return "style-transfer"
default:
return ""
}
}
// getStabilityAIEditTaskType infers the Stability AI edit task from the model name.
// Returns an error if the model name does not match any known pattern.
func getStabilityAIEditTaskType(model string) (string, error) {
m := strings.ToLower(model)
switch {
case strings.Contains(m, "stable-creative-upscale"):
return "upscale-creative", nil
case strings.Contains(m, "stable-conservative-upscale"):
return "upscale-conservative", nil
case strings.Contains(m, "stable-fast-upscale"):
return "upscale-fast", nil
case strings.Contains(m, "stable-image-inpaint"):
return "inpaint", nil
case strings.Contains(m, "stable-outpaint"):
return "outpaint", nil
case strings.Contains(m, "stable-image-search-recolor"):
return "recolor", nil
case strings.Contains(m, "stable-image-search-replace"):
return "search-replace", nil
case strings.Contains(m, "stable-image-erase-object"):
return "erase-object", nil
case strings.Contains(m, "stable-image-remove-background"):
return "remove-bg", nil
case strings.Contains(m, "stable-image-control-sketch"):
return "control-sketch", nil
case strings.Contains(m, "stable-image-control-structure"):
return "control-structure", nil
case strings.Contains(m, "stable-image-style-guide"):
return "style-guide", nil
case strings.Contains(m, "stable-style-transfer"):
return "style-transfer", nil
default:
return "", fmt.Errorf("cannot determine task type from stability ai model name %q", model)
}
}
// ToStabilityAIImageEditRequest converts a Bifrost image edit request to the Stability AI flat request
// format used by Bedrock edit models. Only fields valid for the detected task type are populated.
// deployment is the resolved model identifier (after applying any deployment alias mapping); it is
// used for task-type inference so that alias-mapped models route correctly.
func ToStabilityAIImageEditRequest(request *schemas.BifrostImageEditRequest, deployment string) (*StabilityAIImageEditRequest, error) {
if request == nil || request.Input == nil {
return nil, fmt.Errorf("request or input is nil")
}
var taskType string
if request.Params != nil && request.Params.Type != nil {
taskType = getStabilityAITaskTypeFromParams(*request.Params.Type)
}
if taskType == "" {
var err error
taskType, err = getStabilityAIEditTaskType(deployment)
if err != nil {
return nil, err
}
}
req := &StabilityAIImageEditRequest{}
// Image sourcing
if taskType == "style-transfer" {
if len(request.Input.Images) != 2 {
return nil, fmt.Errorf("style-transfer requires exactly two images: init_image and style_image")
}
if len(request.Input.Images[0].Image) == 0 || len(request.Input.Images[1].Image) == 0 {
return nil, fmt.Errorf("style-transfer requires non-empty init_image and style_image")
}
initB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
styleB64 := base64.StdEncoding.EncodeToString(request.Input.Images[1].Image)
req.InitImage = &initB64
req.StyleImage = &styleB64
} else {
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
return nil, fmt.Errorf("at least one image is required")
}
imageB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
req.Image = &imageB64
}
// Common fields populated based on task allowlist
prompt := request.Input.Prompt
switch taskType {
case "inpaint", "recolor", "search-replace", "control-sketch", "control-structure",
"style-guide", "upscale-creative", "upscale-conservative", "outpaint", "style-transfer":
req.Prompt = &prompt
}
// Negative prompt
if request.Params != nil && request.Params.NegativePrompt != nil {
switch taskType {
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
req.NegativePrompt = request.Params.NegativePrompt
}
}
// Seed
if request.Params != nil && request.Params.Seed != nil {
switch taskType {
case "inpaint", "outpaint", "recolor", "search-replace", "erase-object", "control-sketch",
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
req.Seed = request.Params.Seed
}
}
// Mask (from Params.Mask bytes)
if request.Params != nil && len(request.Params.Mask) > 0 {
switch taskType {
case "inpaint", "erase-object":
maskB64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
req.Mask = &maskB64
}
}
// ExtraParams
if request.Params != nil {
// Typed OutputFormat takes priority over ExtraParams
if request.Params.OutputFormat != nil {
req.OutputFormat = request.Params.OutputFormat
}
if request.Params.ExtraParams != nil {
ep := make(map[string]interface{}, len(request.Params.ExtraParams))
for k, v := range request.Params.ExtraParams {
ep[k] = v
}
// output_format — all tasks (fallback if not already set by typed field)
if req.OutputFormat == nil {
if v, ok := schemas.SafeExtractStringPointer(ep["output_format"]); ok {
delete(ep, "output_format")
req.OutputFormat = v
}
}
// style_preset
switch taskType {
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
"control-structure", "style-guide", "upscale-creative":
if v, ok := schemas.SafeExtractStringPointer(ep["style_preset"]); ok {
delete(ep, "style_preset")
req.StylePreset = v
}
}
// grow_mask
switch taskType {
case "inpaint", "recolor", "search-replace", "erase-object":
if v, ok := schemas.SafeExtractIntPointer(ep["grow_mask"]); ok {
delete(ep, "grow_mask")
req.GrowMask = v
}
}
// outpaint directional fields
if taskType == "outpaint" {
if v, ok := schemas.SafeExtractIntPointer(ep["left"]); ok {
delete(ep, "left")
req.Left = v
}
if v, ok := schemas.SafeExtractIntPointer(ep["right"]); ok {
delete(ep, "right")
req.Right = v
}
if v, ok := schemas.SafeExtractIntPointer(ep["up"]); ok {
delete(ep, "up")
req.Up = v
}
if v, ok := schemas.SafeExtractIntPointer(ep["down"]); ok {
delete(ep, "down")
req.Down = v
}
}
// creativity
switch taskType {
case "upscale-creative", "upscale-conservative", "outpaint":
if v, ok := schemas.SafeExtractFloat64Pointer(ep["creativity"]); ok {
delete(ep, "creativity")
req.Creativity = v
}
}
// select_prompt (recolor)
if taskType == "recolor" {
if v, ok := schemas.SafeExtractStringPointer(ep["select_prompt"]); ok {
delete(ep, "select_prompt")
req.SelectPrompt = v
}
}
// search_prompt (search-replace)
if taskType == "search-replace" {
if v, ok := schemas.SafeExtractStringPointer(ep["search_prompt"]); ok {
delete(ep, "search_prompt")
req.SearchPrompt = v
}
}
// control_strength
switch taskType {
case "control-sketch", "control-structure":
if v, ok := schemas.SafeExtractFloat64Pointer(ep["control_strength"]); ok {
delete(ep, "control_strength")
req.ControlStrength = v
}
}
// style-guide fields
if taskType == "style-guide" {
if v, ok := schemas.SafeExtractStringPointer(ep["aspect_ratio"]); ok {
delete(ep, "aspect_ratio")
req.AspectRatio = v
}
if v, ok := schemas.SafeExtractFloat64Pointer(ep["fidelity"]); ok {
delete(ep, "fidelity")
req.Fidelity = v
}
}
// style-transfer fields
if taskType == "style-transfer" {
if v, ok := schemas.SafeExtractFloat64Pointer(ep["style_strength"]); ok {
delete(ep, "style_strength")
req.StyleStrength = v
}
if v, ok := schemas.SafeExtractFloat64Pointer(ep["composition_fidelity"]); ok {
delete(ep, "composition_fidelity")
req.CompositionFidelity = v
}
if v, ok := schemas.SafeExtractFloat64Pointer(ep["change_strength"]); ok {
delete(ep, "change_strength")
req.ChangeStrength = v
}
}
req.ExtraParams = ep
}
}
// Validate required per-task fields
if taskType == "recolor" && (req.SelectPrompt == nil || *req.SelectPrompt == "") {
return nil, fmt.Errorf("select_prompt is required for stability ai recolor task")
}
if taskType == "search-replace" && (req.SearchPrompt == nil || *req.SearchPrompt == "") {
return nil, fmt.Errorf("search_prompt is required for stability ai search-replace task")
}
return req, nil
}
// ToBifrostImageGenerationResponse converts a Bedrock image generation response to a Bifrost image generation response
func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) *schemas.BifrostImageGenerationResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostImageGenerationResponse{}
if len(response.FinishReasons) > 0 || len(response.Seeds) > 0 {
bifrostResponse.ImageGenerationResponseParameters = &schemas.ImageGenerationResponseParameters{
FinishReasons: append([]*string(nil), response.FinishReasons...),
Seeds: append([]int(nil), response.Seeds...),
}
}
for index, image := range response.Images {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.ImageData{
B64JSON: image,
Index: index,
})
}
return bifrostResponse
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,130 @@
package bedrock
import (
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// BedrockRerankRequest is the Bedrock Agent Runtime rerank request body.
type BedrockRerankRequest struct {
Queries []BedrockRerankQuery `json:"queries"`
Sources []BedrockRerankSource `json:"sources"`
RerankingConfiguration BedrockRerankingConfiguration `json:"rerankingConfiguration"`
}
// GetExtraParams implements RequestBodyWithExtraParams.
func (*BedrockRerankRequest) GetExtraParams() map[string]interface{} {
return nil
}
const (
bedrockRerankQueryTypeText = "TEXT"
bedrockRerankSourceTypeInline = "INLINE"
bedrockRerankInlineDocumentTypeText = "TEXT"
bedrockRerankConfigurationTypeBedrock = "BEDROCK_RERANKING_MODEL"
)
type BedrockRerankQuery struct {
Type string `json:"type"`
TextQuery BedrockRerankTextRef `json:"textQuery"`
}
type BedrockRerankSource struct {
Type string `json:"type"`
InlineDocumentSource BedrockRerankInlineSource `json:"inlineDocumentSource"`
}
type BedrockRerankInlineSource struct {
Type string `json:"type"`
TextDocument BedrockRerankTextValue `json:"textDocument"`
}
type BedrockRerankTextRef struct {
Text string `json:"text"`
}
type BedrockRerankTextValue struct {
Text string `json:"text"`
}
type BedrockRerankingConfiguration struct {
Type string `json:"type"`
BedrockRerankingConfiguration BedrockRerankingModelConfiguration `json:"bedrockRerankingConfiguration"`
}
type BedrockRerankingModelConfiguration struct {
ModelConfiguration BedrockRerankModelConfiguration `json:"modelConfiguration"`
NumberOfResults *int `json:"numberOfResults,omitempty"`
}
type BedrockRerankModelConfiguration struct {
ModelARN string `json:"modelArn"`
AdditionalModelRequestFields map[string]interface{} `json:"additionalModelRequestFields,omitempty"`
}
// BedrockRerankResponse is the Bedrock Agent Runtime rerank response body.
type BedrockRerankResponse struct {
Results []BedrockRerankResult `json:"results"`
NextToken *string `json:"nextToken,omitempty"`
}
type BedrockRerankResult struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevanceScore"`
Document *BedrockRerankResponseDocument `json:"document,omitempty"`
}
type BedrockRerankResponseDocument struct {
Type string `json:"type,omitempty"`
TextDocument *BedrockRerankTextValue `json:"textDocument,omitempty"`
}
func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(response.ModelSummaries)),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: providerKey,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
for _, model := range response.ModelSummaries {
for _, result := range pipeline.FilterModel(model.ModelID) {
modelEntry := schemas.Model{
ID: string(providerKey) + "/" + result.ResolvedID,
Name: schemas.Ptr(model.ModelName),
OwnedBy: schemas.Ptr(model.ProviderName),
Architecture: &schemas.Architecture{
InputModalities: model.InputModalities,
OutputModalities: model.OutputModalities,
},
}
if result.AliasValue != "" {
modelEntry.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, modelEntry)
included[strings.ToLower(result.ResolvedID)] = true
}
}
bifrostResponse.Data = append(bifrostResponse.Data,
pipeline.BackfillModels(included)...)
return bifrostResponse
}

View File

@@ -0,0 +1,55 @@
package bedrock
import (
"encoding/json"
"testing"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPayloadOrdering_BedrockConverseRequest(t *testing.T) {
req := &BedrockConverseRequest{
Messages: []BedrockMessage{
{
Role: "user",
Content: []BedrockContentBlock{
{Text: schemas.Ptr("hello")},
},
},
},
InferenceConfig: &BedrockInferenceConfig{
Temperature: schemas.Ptr(0.7),
MaxTokens: schemas.Ptr(1024),
},
ToolConfig: &BedrockToolConfig{
Tools: []BedrockTool{
{
ToolSpec: &BedrockToolSpec{
Name: "get_weather",
Description: schemas.Ptr("Get weather"),
InputSchema: BedrockToolInputSchema{
JSON: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
},
},
},
},
},
}
result, err := providerUtils.MarshalSorted(req)
require.NoError(t, err)
golden := `{"messages":[{"role":"user","content":[{"text":"hello"}]}],"inferenceConfig":{"maxTokens":1024,"temperature":0.7},"toolConfig":{"tools":[{"toolSpec":{"name":"get_weather","description":"Get weather","inputSchema":{"json":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}}}]}}`
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
// Determinism: 100 iterations must produce identical bytes
for i := 0; i < 100; i++ {
iter, err := providerUtils.MarshalSorted(req)
require.NoError(t, err)
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
}
}

View File

@@ -0,0 +1,168 @@
package bedrock
import (
"fmt"
"sort"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBedrockRerankRequest converts a Bifrost rerank request into Bedrock Agent Runtime format.
func ToBedrockRerankRequest(bifrostReq *schemas.BifrostRerankRequest, modelARN string) (*BedrockRerankRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrost rerank request is nil")
}
if strings.TrimSpace(modelARN) == "" {
return nil, fmt.Errorf("bedrock rerank model ARN is empty")
}
if len(bifrostReq.Documents) == 0 {
return nil, fmt.Errorf("documents are required for rerank request")
}
bedrockReq := &BedrockRerankRequest{
Queries: []BedrockRerankQuery{
{
Type: bedrockRerankQueryTypeText,
TextQuery: BedrockRerankTextRef{
Text: bifrostReq.Query,
},
},
},
Sources: make([]BedrockRerankSource, len(bifrostReq.Documents)),
RerankingConfiguration: BedrockRerankingConfiguration{
Type: bedrockRerankConfigurationTypeBedrock,
BedrockRerankingConfiguration: BedrockRerankingModelConfiguration{
ModelConfiguration: BedrockRerankModelConfiguration{
ModelARN: modelARN,
},
},
},
}
for i, doc := range bifrostReq.Documents {
bedrockReq.Sources[i] = BedrockRerankSource{
Type: bedrockRerankSourceTypeInline,
InlineDocumentSource: BedrockRerankInlineSource{
Type: bedrockRerankInlineDocumentTypeText,
TextDocument: BedrockRerankTextValue{
Text: doc.Text,
},
},
}
}
if bifrostReq.Params == nil {
return bedrockReq, nil
}
if bifrostReq.Params.TopN != nil {
topN := *bifrostReq.Params.TopN
if topN < 1 {
return nil, fmt.Errorf("top_n must be at least 1")
}
if topN > len(bifrostReq.Documents) {
topN = len(bifrostReq.Documents)
}
bedrockReq.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults = schemas.Ptr(topN)
}
additionalFields := make(map[string]interface{})
if bifrostReq.Params.MaxTokensPerDoc != nil {
additionalFields["max_tokens_per_doc"] = *bifrostReq.Params.MaxTokensPerDoc
}
if bifrostReq.Params.Priority != nil {
additionalFields["priority"] = *bifrostReq.Params.Priority
}
for k, v := range bifrostReq.Params.ExtraParams {
additionalFields[k] = v
}
if len(additionalFields) > 0 {
bedrockReq.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields = additionalFields
}
return bedrockReq, nil
}
// ToBifrostRerankResponse converts a Bedrock rerank response into Bifrost format.
func (response *BedrockRerankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) *schemas.BifrostRerankResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostRerankResponse{
Results: make([]schemas.RerankResult, 0, len(response.Results)),
}
for _, result := range response.Results {
rerankResult := schemas.RerankResult{
Index: result.Index,
RelevanceScore: result.RelevanceScore,
}
if result.Document != nil && result.Document.TextDocument != nil {
rerankResult.Document = &schemas.RerankDocument{
Text: result.Document.TextDocument.Text,
}
}
bifrostResponse.Results = append(bifrostResponse.Results, rerankResult)
}
sort.SliceStable(bifrostResponse.Results, func(i, j int) bool {
if bifrostResponse.Results[i].RelevanceScore == bifrostResponse.Results[j].RelevanceScore {
return bifrostResponse.Results[i].Index < bifrostResponse.Results[j].Index
}
return bifrostResponse.Results[i].RelevanceScore > bifrostResponse.Results[j].RelevanceScore
})
if returnDocuments {
for i := range bifrostResponse.Results {
resultIndex := bifrostResponse.Results[i].Index
if resultIndex >= 0 && resultIndex < len(documents) {
bifrostResponse.Results[i].Document = schemas.Ptr(documents[resultIndex])
}
}
}
return bifrostResponse
}
// ToBifrostRerankRequest converts a Bedrock Agent Runtime rerank request to Bifrost format.
func (req *BedrockRerankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
if req == nil {
return nil
}
modelARN := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.ModelARN
provider, model := schemas.ParseModelString(modelARN, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock))
bifrostReq := &schemas.BifrostRerankRequest{
Provider: provider,
Model: model,
Params: &schemas.RerankParameters{},
}
// Extract query from the first query entry
if len(req.Queries) > 0 {
bifrostReq.Query = req.Queries[0].TextQuery.Text
}
// Convert sources to documents
for _, source := range req.Sources {
bifrostReq.Documents = append(bifrostReq.Documents, schemas.RerankDocument{
Text: source.InlineDocumentSource.TextDocument.Text,
})
}
// Extract TopN from NumberOfResults
if req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults != nil {
bifrostReq.Params.TopN = req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults
}
// Pass AdditionalModelRequestFields as ExtraParams
if fields := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields; len(fields) > 0 {
bifrostReq.Params.ExtraParams = fields
}
return bifrostReq
}

View File

@@ -0,0 +1,230 @@
package bedrock
import (
"context"
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestToBedrockRerankRequest(t *testing.T) {
topN := 10
maxTokensPerDoc := 512
priority := 3
req, err := ToBedrockRerankRequest(&schemas.BifrostRerankRequest{
Model: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
Query: "capital of france",
Documents: []schemas.RerankDocument{
{Text: "Paris is the capital of France."},
{Text: "Berlin is the capital of Germany."},
},
Params: &schemas.RerankParameters{
TopN: schemas.Ptr(topN),
MaxTokensPerDoc: schemas.Ptr(maxTokensPerDoc),
Priority: schemas.Ptr(priority),
ExtraParams: map[string]interface{}{
"truncate": "END",
},
},
}, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0")
require.NoError(t, err)
require.NotNil(t, req)
require.Len(t, req.Queries, 1)
assert.Equal(t, "TEXT", req.Queries[0].Type)
assert.Equal(t, "capital of france", req.Queries[0].TextQuery.Text)
require.Len(t, req.Sources, 2)
require.NotNil(t, req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults)
assert.Equal(t, 2, *req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults, "top_n must be clamped to source count")
fields := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields
require.NotNil(t, fields)
assert.Equal(t, maxTokensPerDoc, fields["max_tokens_per_doc"])
assert.Equal(t, priority, fields["priority"])
assert.Equal(t, "END", fields["truncate"])
}
func TestBedrockRerankResponseToBifrostRerankResponse(t *testing.T) {
response := (&BedrockRerankResponse{
Results: []BedrockRerankResult{
{
Index: 2,
RelevanceScore: 0.21,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "doc-2"},
},
},
{
Index: 1,
RelevanceScore: 0.95,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "doc-1"},
},
},
{
Index: 0,
RelevanceScore: 0.95,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "doc-0"},
},
},
},
}).ToBifrostRerankResponse(nil, false)
require.NotNil(t, response)
require.Len(t, response.Results, 3)
assert.Equal(t, 0, response.Results[0].Index)
assert.Equal(t, 1, response.Results[1].Index)
assert.Equal(t, 2, response.Results[2].Index)
assert.Equal(t, "doc-0", response.Results[0].Document.Text)
assert.Equal(t, "doc-1", response.Results[1].Document.Text)
}
func TestBedrockRerankResponseToBifrostRerankResponseReturnDocuments(t *testing.T) {
requestDocs := []schemas.RerankDocument{
{Text: "request-doc-0"},
{Text: "request-doc-1"},
{Text: "request-doc-2"},
}
response := (&BedrockRerankResponse{
Results: []BedrockRerankResult{
{
Index: 2,
RelevanceScore: 0.21,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-2"},
},
},
{
Index: 1,
RelevanceScore: 0.95,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-1"},
},
},
{
Index: 0,
RelevanceScore: 0.95,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-0"},
},
},
},
}).ToBifrostRerankResponse(requestDocs, true)
require.NotNil(t, response)
require.Len(t, response.Results, 3)
require.NotNil(t, response.Results[0].Document)
require.NotNil(t, response.Results[1].Document)
require.NotNil(t, response.Results[2].Document)
assert.Equal(t, 0, response.Results[0].Index)
assert.Equal(t, 1, response.Results[1].Index)
assert.Equal(t, 2, response.Results[2].Index)
assert.Equal(t, "request-doc-0", response.Results[0].Document.Text)
assert.Equal(t, "request-doc-1", response.Results[1].Document.Text)
assert.Equal(t, "request-doc-2", response.Results[2].Document.Text)
}
func TestBedrockRerankRequestToBifrostRerankRequest(t *testing.T) {
topN := 3
bedrockReq := &BedrockRerankRequest{
Queries: []BedrockRerankQuery{
{
Type: bedrockRerankQueryTypeText,
TextQuery: BedrockRerankTextRef{Text: "capital of france"},
},
},
Sources: []BedrockRerankSource{
{
Type: bedrockRerankSourceTypeInline,
InlineDocumentSource: BedrockRerankInlineSource{
Type: bedrockRerankInlineDocumentTypeText,
TextDocument: BedrockRerankTextValue{Text: "Paris is the capital of France."},
},
},
{
Type: bedrockRerankSourceTypeInline,
InlineDocumentSource: BedrockRerankInlineSource{
Type: bedrockRerankInlineDocumentTypeText,
TextDocument: BedrockRerankTextValue{Text: "Berlin is the capital of Germany."},
},
},
},
RerankingConfiguration: BedrockRerankingConfiguration{
Type: bedrockRerankConfigurationTypeBedrock,
BedrockRerankingConfiguration: BedrockRerankingModelConfiguration{
NumberOfResults: &topN,
ModelConfiguration: BedrockRerankModelConfiguration{
ModelARN: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
AdditionalModelRequestFields: map[string]interface{}{
"truncate": "END",
},
},
},
},
}
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
result := bedrockReq.ToBifrostRerankRequest(bifrostCtx)
require.NotNil(t, result)
assert.Equal(t, schemas.Bedrock, result.Provider)
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", result.Model)
assert.Equal(t, "capital of france", result.Query)
require.Len(t, result.Documents, 2)
assert.Equal(t, "Paris is the capital of France.", result.Documents[0].Text)
assert.Equal(t, "Berlin is the capital of Germany.", result.Documents[1].Text)
require.NotNil(t, result.Params)
require.NotNil(t, result.Params.TopN)
assert.Equal(t, 3, *result.Params.TopN)
require.NotNil(t, result.Params.ExtraParams)
assert.Equal(t, "END", result.Params.ExtraParams["truncate"])
}
func TestBedrockRerankRequestToBifrostRerankRequestNil(t *testing.T) {
var req *BedrockRerankRequest
assert.Nil(t, req.ToBifrostRerankRequest(nil))
}
func TestResolveBedrockDeployment(t *testing.T) {
key := schemas.Key{
Aliases: schemas.KeyAliases{
"cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
},
}
deployment := key.Aliases.Resolve("cohere-rerank")
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", deployment)
assert.Equal(t, "cohere.rerank-v3-5:0", key.Aliases.Resolve("cohere.rerank-v3-5:0"))
assert.Equal(t, "", key.Aliases.Resolve(""))
}
func TestBedrockRerankRequiresARNModelIdentifier(t *testing.T) {
provider := &BedrockProvider{}
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
key := schemas.Key{
Aliases: schemas.KeyAliases{
"cohere-rerank": "cohere.rerank-v3-5:0",
},
}
response, bifrostErr := provider.Rerank(ctx, key, &schemas.BifrostRerankRequest{
Model: "cohere-rerank",
Query: "capital of france",
Documents: []schemas.RerankDocument{
{Text: "Paris is the capital of France."},
},
})
require.Nil(t, response)
require.NotNil(t, bifrostErr)
require.NotNil(t, bifrostErr.Error)
assert.Contains(t, bifrostErr.Error.Message, "requires an ARN")
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,130 @@
package bedrock
import (
"bytes"
"context"
"fmt"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// uploadToS3 uploads content to an S3 bucket using the provided credentials.
func uploadToS3(
ctx context.Context,
accessKey, secretKey string,
sessionToken *string,
region string,
bucket, key string,
content []byte,
) *schemas.BifrostError {
// Create AWS config with credentials
var cfg aws.Config
var err error
if accessKey != "" && secretKey != "" {
// Use provided credentials
var creds aws.CredentialsProvider
if sessionToken != nil && *sessionToken != "" {
creds = credentials.NewStaticCredentialsProvider(accessKey, secretKey, *sessionToken)
} else {
creds = credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")
}
cfg, err = config.LoadDefaultConfig(ctx,
config.WithRegion(region),
config.WithCredentialsProvider(creds),
)
} else {
// Use default credentials chain (IAM role, env vars, etc.)
cfg, err = config.LoadDefaultConfig(ctx, config.WithRegion(region))
}
if err != nil {
return providerUtils.NewBifrostOperationError("failed to load aws config for s3", err)
}
// Create S3 client
client := s3.NewFromConfig(cfg)
// Upload the content
_, err = client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
Body: bytes.NewReader(content),
ContentType: aws.String("application/jsonl"),
})
if err != nil {
return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to upload to s3: %s/%s", bucket, key), err)
}
return nil
}
// generateBatchInputS3Key generates a unique S3 key for batch input files.
func generateBatchInputS3Key(jobName string) string {
timestamp := time.Now().UnixNano()
return fmt.Sprintf("bifrost-batch-input/%s-%d.jsonl", jobName, timestamp)
}
// deriveInputS3URIFromOutput derives an input S3 URI from the output S3 URI.
// It uses the same bucket but with a different path for input files.
func deriveInputS3URIFromOutput(outputS3URI, inputKey string) string {
bucket, _ := parseS3URI(outputS3URI)
return fmt.Sprintf("s3://%s/%s", bucket, inputKey)
}
// ConvertBedrockRequestsToJSONL converts batch request items to JSONL format for Bedrock.
// Bedrock uses a specific format for batch inference requests.
func ConvertBedrockRequestsToJSONL(requests []schemas.BatchRequestItem, modelID *string) ([]byte, error) {
// Model ID is required for Bedrock batch JSONL conversion
if modelID == nil || *modelID == "" {
return nil, fmt.Errorf("modelID is required for Bedrock batch JSONL conversion")
}
// Initialize the buffer
var buf bytes.Buffer
// Iterate over the requests
for _, req := range requests {
// Build the Bedrock batch request format
bedrockReq := map[string]interface{}{
"recordId": req.CustomID,
"modelInput": map[string]interface{}{
"modelId": *modelID,
},
}
// If the request has a body, use it as the model input parameters
if req.Body != nil {
modelInput := bedrockReq["modelInput"].(map[string]interface{})
for k, v := range req.Body {
if k != "model" { // Don't override modelId
modelInput[k] = v
}
}
} else if req.Params != nil {
modelInput := bedrockReq["modelInput"].(map[string]interface{})
for k, v := range req.Params {
if k != "model" {
modelInput[k] = v
}
}
}
// Marshal the request as a JSON line
line, err := providerUtils.MarshalSorted(bedrockReq)
if err != nil {
return nil, fmt.Errorf("failed to marshal batch request item %s: %w", req.CustomID, err)
}
buf.Write(line)
buf.WriteByte('\n')
}
return buf.Bytes(), nil
}

View File

@@ -0,0 +1,433 @@
package bedrock
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/smithy-go/encoding/httpbinding"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
const (
signingAlgorithm = "AWS4-HMAC-SHA256"
amzDateKey = "X-Amz-Date"
amzSecurityToken = "X-Amz-Security-Token"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
)
// Headers to ignore during signing
var ignoredHeaders = map[string]struct{}{
"authorization": {},
"user-agent": {},
"x-amzn-trace-id": {},
"expect": {},
"transfer-encoding": {},
}
// signingKeyCache caches derived signing keys to avoid recomputation
type signingKeyCache struct {
cache map[string]cachedKey
mu sync.RWMutex
}
type cachedKey struct {
key []byte
date string // YYYYMMDD format
accessKey string
}
var keyCache = &signingKeyCache{
cache: make(map[string]cachedKey),
}
// hmacSHA256 computes HMAC-SHA256
func hmacSHA256(key, data []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(data)
return h.Sum(nil)
}
// deriveSigningKey derives the AWS signing key
func deriveSigningKey(secret, dateStamp, region, service string) []byte {
kDate := hmacSHA256([]byte("AWS4"+secret), []byte(dateStamp))
kRegion := hmacSHA256(kDate, []byte(region))
kService := hmacSHA256(kRegion, []byte(service))
kSigning := hmacSHA256(kService, []byte("aws4_request"))
return kSigning
}
// getSigningKey retrieves or computes the signing key with caching
func getSigningKey(accessKey, secretKey, dateStamp, region, service string) []byte {
cacheKey := fmt.Sprintf("%s/%s/%s/%s", accessKey, dateStamp, region, service)
keyCache.mu.RLock()
if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp {
keyCache.mu.RUnlock()
return cached.key
}
keyCache.mu.RUnlock()
keyCache.mu.Lock()
defer keyCache.mu.Unlock()
// Double-check after acquiring write lock
if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp {
return cached.key
}
key := deriveSigningKey(secretKey, dateStamp, region, service)
keyCache.cache[cacheKey] = cachedKey{
key: key,
date: dateStamp,
accessKey: accessKey,
}
return key
}
// stripExcessSpaces removes excess spaces from a string
func stripExcessSpaces(str string) string {
str = strings.TrimSpace(str)
if !strings.Contains(str, " ") {
return str
}
var result strings.Builder
result.Grow(len(str))
prevWasSpace := false
for _, ch := range str {
if ch == ' ' {
if !prevWasSpace {
result.WriteRune(ch)
}
prevWasSpace = true
} else {
result.WriteRune(ch)
prevWasSpace = false
}
}
return result.String()
}
// percentEncodeRFC3986 encodes a string per RFC 3986
// Keep unreserved characters (A-Z, a-z, 0-9, -, _, ., ~) as-is
// Percent-encode everything else as %HH using uppercase hex
func percentEncodeRFC3986(s string) string {
var result strings.Builder
result.Grow(len(s))
for i := 0; i < len(s); i++ {
b := s[i]
// RFC 3986 unreserved characters
if (b >= 'A' && b <= 'Z') ||
(b >= 'a' && b <= 'z') ||
(b >= '0' && b <= '9') ||
b == '-' || b == '_' || b == '.' || b == '~' {
result.WriteByte(b)
} else {
// Percent-encode with uppercase hex
result.WriteByte('%')
result.WriteByte(uppercaseHex(b >> 4))
result.WriteByte(uppercaseHex(b & 0x0F))
}
}
return result.String()
}
// uppercaseHex returns the uppercase hex character for a nibble (0-15)
func uppercaseHex(b byte) byte {
if b < 10 {
return '0' + b
}
return 'A' + (b - 10)
}
// percentDecode decodes percent-encoded sequences in a string without treating + as space
// This differs from url.QueryUnescape which uses form encoding (+ becomes space)
func percentDecode(s string) string {
// Quick check if there are any percent signs
if !strings.Contains(s, "%") {
return s
}
var result strings.Builder
result.Grow(len(s))
for i := 0; i < len(s); {
if s[i] == '%' && i+2 < len(s) {
// Try to decode the hex sequence
if h1 := unhex(s[i+1]); h1 >= 0 {
if h2 := unhex(s[i+2]); h2 >= 0 {
result.WriteByte(byte(h1<<4 | h2))
i += 3
continue
}
}
}
result.WriteByte(s[i])
i++
}
return result.String()
}
// unhex converts a hex character to its value, or -1 if not a hex char
func unhex(c byte) int {
switch {
case '0' <= c && c <= '9':
return int(c - '0')
case 'a' <= c && c <= 'f':
return int(c - 'a' + 10)
case 'A' <= c && c <= 'F':
return int(c - 'A' + 10)
}
return -1
}
// queryPair represents a query parameter name-value pair
type queryPair struct {
encodedName string
encodedValue string
}
// buildCanonicalQueryString builds a canonical query string per AWS SigV4 spec
// using proper RFC 3986 percent-encoding
func buildCanonicalQueryString(queryString string) string {
if queryString == "" {
return ""
}
// Split the raw query string on '&' into pairs
rawPairs := strings.Split(queryString, "&")
pairs := make([]queryPair, 0, len(rawPairs))
for _, rawPair := range rawPairs {
if rawPair == "" {
continue
}
// Split on the first '=' to get name and value
var name, value string
if idx := strings.IndexByte(rawPair, '='); idx >= 0 {
name = rawPair[:idx]
value = rawPair[idx+1:]
} else {
// No '=' means name only, empty value
name = rawPair
value = ""
}
// Decode percent-encoded sequences first to normalize (handles already-encoded values)
// then encode per RFC 3986 to ensure consistent encoding
// Note: We use percentDecode instead of url.QueryUnescape because the latter
// treats + as space (form encoding), but we need + to encode as %2B
decodedName := percentDecode(name)
decodedValue := percentDecode(value)
// Percent-encode name and value per RFC 3986
encodedName := percentEncodeRFC3986(decodedName)
encodedValue := percentEncodeRFC3986(decodedValue)
pairs = append(pairs, queryPair{
encodedName: encodedName,
encodedValue: encodedValue,
})
}
// Sort pairs lexicographically by encoded name, then by encoded value
sort.Slice(pairs, func(i, j int) bool {
if pairs[i].encodedName != pairs[j].encodedName {
return pairs[i].encodedName < pairs[j].encodedName
}
return pairs[i].encodedValue < pairs[j].encodedValue
})
// Join encoded pairs with '&'
var result strings.Builder
for i, pair := range pairs {
if i > 0 {
result.WriteByte('&')
}
result.WriteString(pair.encodedName)
result.WriteByte('=')
result.WriteString(pair.encodedValue)
}
return result.String()
}
// signAWSRequestFastHTTP signs a fasthttp request using AWS Signature Version 4
// This is a native implementation that avoids allocating http.Request
func signAWSRequestFastHTTP(
ctx context.Context,
req *fasthttp.Request,
body []byte,
accessKey, secretKey string,
sessionToken *string,
region, service string,
) *schemas.BifrostError {
// Get AWS credentials if not provided
if accessKey == "" && secretKey == "" {
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return providerUtils.NewBifrostOperationError("failed to load aws config", err)
}
creds, err := cfg.Credentials.Retrieve(ctx)
if err != nil {
return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err)
}
accessKey = creds.AccessKeyID
secretKey = creds.SecretAccessKey
if creds.SessionToken != "" {
st := creds.SessionToken
sessionToken = &st
}
}
// Get current time
now := time.Now().UTC()
amzDate := now.Format(timeFormat)
dateStamp := now.Format(shortTimeFormat)
// Parse URI
uri := req.URI()
host := string(uri.Host())
path := string(uri.Path())
if path == "" {
path = "/"
}
queryString := string(uri.QueryString())
// Escape path for canonical URI (Bedrock doesn't disable escaping)
canonicalURI := httpbinding.EscapePath(path, false)
// Calculate payload hash
hash := sha256.Sum256(body)
payloadHash := hex.EncodeToString(hash[:])
// Set required headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set(amzDateKey, amzDate)
if sessionToken != nil && *sessionToken != "" {
req.Header.Set(amzSecurityToken, *sessionToken)
}
// Build canonical headers
var headerNames []string
headerMap := make(map[string][]string)
// Always include host
headerNames = append(headerNames, "host")
headerMap["host"] = []string{host}
// Include content-length if body is present
if cl := req.Header.ContentLength(); cl >= 0 {
headerNames = append(headerNames, "content-length")
headerMap["content-length"] = []string{strconv.Itoa(cl)}
}
// Collect other headers
for key, value := range req.Header.All() {
keyStr := strings.ToLower(string(key))
// Skip ignored headers
if _, ignore := ignoredHeaders[keyStr]; ignore {
continue
}
// Skip if already handled
if keyStr == "host" || keyStr == "content-length" {
continue
}
if _, exists := headerMap[keyStr]; !exists {
headerNames = append(headerNames, keyStr)
}
headerMap[keyStr] = append(headerMap[keyStr], string(value))
}
// Sort header names
sort.Strings(headerNames)
// Build canonical headers string
var canonicalHeaders strings.Builder
for _, name := range headerNames {
canonicalHeaders.WriteString(name)
canonicalHeaders.WriteRune(':')
values := headerMap[name]
for i, v := range values {
cleanedValue := stripExcessSpaces(v)
canonicalHeaders.WriteString(cleanedValue)
if i < len(values)-1 {
canonicalHeaders.WriteRune(',')
}
}
canonicalHeaders.WriteRune('\n')
}
signedHeaders := strings.Join(headerNames, ";")
// Build canonical query string using RFC 3986 encoding
canonicalQueryString := buildCanonicalQueryString(queryString)
// Build canonical request
canonicalRequest := strings.Join([]string{
string(req.Header.Method()),
canonicalURI,
canonicalQueryString,
canonicalHeaders.String(),
signedHeaders,
payloadHash,
}, "\n")
// Build credential scope
credentialScope := strings.Join([]string{
dateStamp,
region,
service,
"aws4_request",
}, "/")
// Build string to sign
canonicalRequestHash := sha256.Sum256([]byte(canonicalRequest))
stringToSign := strings.Join([]string{
signingAlgorithm,
amzDate,
credentialScope,
hex.EncodeToString(canonicalRequestHash[:]),
}, "\n")
// Calculate signature
signingKey := getSigningKey(accessKey, secretKey, dateStamp, region, service)
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
// Build authorization header
authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
signingAlgorithm,
accessKey,
credentialScope,
signedHeaders,
signature,
)
req.Header.Set("Authorization", authHeader)
return nil
}

View File

@@ -0,0 +1,229 @@
package bedrock
import (
"strings"
"github.com/maximhq/bifrost/core/providers/anthropic"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBedrockTextCompletionRequest converts a Bifrost text completion request to Bedrock format
func ToBedrockTextCompletionRequest(bifrostReq *schemas.BifrostTextCompletionRequest) *BedrockTextCompletionRequest {
if bifrostReq == nil || (bifrostReq.Input.PromptStr == nil && len(bifrostReq.Input.PromptArray) == 0) {
return nil
}
// Extract the raw prompt from bifrostReq
prompt := ""
if bifrostReq.Input != nil {
if bifrostReq.Input.PromptStr != nil {
prompt = *bifrostReq.Input.PromptStr
} else if len(bifrostReq.Input.PromptArray) > 0 && bifrostReq.Input.PromptArray != nil {
prompt = strings.Join(bifrostReq.Input.PromptArray, "\n\n")
}
}
bedrockReq := &BedrockTextCompletionRequest{
Prompt: prompt,
}
// Apply parameters
if bifrostReq.Params != nil {
bedrockReq.Temperature = bifrostReq.Params.Temperature
bedrockReq.TopP = bifrostReq.Params.TopP
if bifrostReq.Params.ExtraParams != nil {
bedrockReq.ExtraParams = bifrostReq.Params.ExtraParams
if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok {
delete(bedrockReq.ExtraParams, "top_k")
bedrockReq.TopK = topK
}
}
}
// Apply model-specific formatting and field naming
if strings.Contains(bifrostReq.Model, "anthropic.") || strings.Contains(bifrostReq.Model, "claude") {
// For Claude models, wrap the prompt in Anthropic format and use Anthropic field names
anthropicReq := anthropic.ToAnthropicTextCompletionRequest(bifrostReq)
bedrockReq.Prompt = anthropicReq.Prompt
bedrockReq.MaxTokensToSample = &anthropicReq.MaxTokensToSample
bedrockReq.StopSequences = anthropicReq.StopSequences
} else {
// For other models, use standard field names with raw prompt
if bifrostReq.Params != nil {
bedrockReq.MaxTokens = bifrostReq.Params.MaxTokens
bedrockReq.Stop = bifrostReq.Params.Stop
}
}
return bedrockReq
}
// ToBifrostTextCompletionRequest converts a Bedrock text completion request to Bifrost format
func (request *BedrockTextCompletionRequest) ToBifrostTextCompletionRequest(ctx *schemas.BifrostContext) *schemas.BifrostTextCompletionRequest {
if request == nil {
return nil
}
prompt := request.Prompt
// Fallback for Claude 3 Messages API
if prompt == "" && len(request.Messages) > 0 {
var parts []string
for _, msg := range request.Messages {
for _, content := range msg.Content {
if content.Text != nil {
parts = append(parts, *content.Text)
}
}
}
prompt = strings.Join(parts, "\n\n")
}
provider, model := schemas.ParseModelString(request.ModelID, utils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock))
bifrostReq := &schemas.BifrostTextCompletionRequest{
Provider: provider,
Model: model,
Input: &schemas.TextCompletionInput{
PromptStr: &prompt,
},
Params: &schemas.TextCompletionParameters{
Temperature: request.Temperature,
TopP: request.TopP,
},
}
if request.MaxTokens != nil {
bifrostReq.Params.MaxTokens = request.MaxTokens
} else if request.MaxTokensToSample != nil {
bifrostReq.Params.MaxTokens = request.MaxTokensToSample
}
if len(request.Stop) > 0 {
bifrostReq.Params.Stop = request.Stop
} else if len(request.StopSequences) > 0 {
bifrostReq.Params.Stop = request.StopSequences
}
return bifrostReq
}
// ToBifrostTextCompletionResponse converts a Bedrock Anthropic text response to Bifrost format
func (response *BedrockAnthropicTextResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
if response == nil {
return nil
}
return &schemas.BifrostTextCompletionResponse{
Object: "text_completion",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
Text: &response.Completion,
},
FinishReason: &response.StopReason,
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
},
}
}
// ToBifrostTextCompletionResponse converts a Bedrock Mistral text response to Bifrost format
func (response *BedrockMistralTextResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
if response == nil {
return nil
}
var choices []schemas.BifrostResponseChoice
for i, output := range response.Outputs {
choices = append(choices, schemas.BifrostResponseChoice{
Index: i,
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
Text: &output.Text,
},
FinishReason: &output.StopReason,
})
}
return &schemas.BifrostTextCompletionResponse{
Object: "text_completion",
Choices: choices,
ExtraFields: schemas.BifrostResponseExtraFields{
},
}
}
// ToBedrockTextCompletionResponse converts a BifrostTextCompletionResponse back to Bedrock text completion format
// Returns either *BedrockAnthropicTextResponse or *BedrockMistralTextResponse based on the model
func ToBedrockTextCompletionResponse(bifrostResp *schemas.BifrostTextCompletionResponse) interface{} {
if bifrostResp == nil {
return nil
}
// Determine response format based on resolved model identity.
// Use ResolvedModelUsed (actual provider ID) for accurate family detection,
// falling back to bifrostResp.Model, then OriginalModelRequested as a last resort.
model := bifrostResp.Model
if bifrostResp.ExtraFields.ResolvedModelUsed != "" {
model = bifrostResp.ExtraFields.ResolvedModelUsed
} else if model == "" && bifrostResp.ExtraFields.OriginalModelRequested != "" {
model = bifrostResp.ExtraFields.OriginalModelRequested
}
if strings.Contains(model, "anthropic.") || strings.Contains(model, "claude") {
// Convert to Anthropic format
bedrockResp := &BedrockAnthropicTextResponse{}
// Convert choices to completion text
if len(bifrostResp.Choices) > 0 {
choice := bifrostResp.Choices[0] // Anthropic text API typically returns one choice
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
bedrockResp.Completion = *choice.TextCompletionResponseChoice.Text
}
if choice.FinishReason != nil {
bedrockResp.StopReason = *choice.FinishReason
}
}
return bedrockResp
} else if strings.Contains(model, "mistral.") {
// Convert to Mistral format
bedrockResp := &BedrockMistralTextResponse{}
// Convert choices to outputs
for _, choice := range bifrostResp.Choices {
var output struct {
Text string `json:"text"`
StopReason string `json:"stop_reason"`
}
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
output.Text = *choice.TextCompletionResponseChoice.Text
}
if choice.FinishReason != nil {
output.StopReason = *choice.FinishReason
}
bedrockResp.Outputs = append(bedrockResp.Outputs, output)
}
return bedrockResp
}
// Default to Anthropic format if model type cannot be determined
bedrockResp := &BedrockAnthropicTextResponse{}
if len(bifrostResp.Choices) > 0 {
choice := bifrostResp.Choices[0]
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
bedrockResp.Completion = *choice.TextCompletionResponseChoice.Text
}
if choice.FinishReason != nil {
bedrockResp.StopReason = *choice.FinishReason
}
}
return bedrockResp
}

View File

@@ -0,0 +1,714 @@
package bedrock
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// redirectTransport is an http.RoundTripper that rewrites every request's
// host/scheme to a fixed target URL, used to redirect provider requests to a
// local httptest.Server without modifying provider code.
type redirectTransport struct {
target *url.URL
transport http.RoundTripper
}
func (r *redirectTransport) RoundTrip(req *http.Request) (*http.Response, error) {
cloned := req.Clone(req.Context())
cloned.URL.Scheme = r.target.Scheme
cloned.URL.Host = r.target.Host
cloned.Host = r.target.Host
return r.transport.RoundTrip(cloned)
}
// noopLogger is a no-op schemas.Logger for use in tests.
type noopLogger struct{}
func (noopLogger) Debug(string, ...any) {}
func (noopLogger) Info(string, ...any) {}
func (noopLogger) Warn(string, ...any) {}
func (noopLogger) Error(string, ...any) {}
func (noopLogger) Fatal(string, ...any) {}
func (noopLogger) SetLevel(schemas.LogLevel) {}
func (noopLogger) SetOutputType(schemas.LoggerOutputType) {}
func (noopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
// newTestProviderWithServer returns a BedrockProvider whose HTTP client is
// redirected to the given httptest.Server.
func newTestProviderWithServer(t *testing.T, ts *httptest.Server) *BedrockProvider {
t.Helper()
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 5,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, noopLogger{})
require.NoError(t, err)
targetURL, err := url.Parse(ts.URL)
require.NoError(t, err)
redirect := &redirectTransport{
target: targetURL,
transport: ts.Client().Transport,
}
provider.client = &http.Client{
Transport: redirect,
Timeout: 5 * time.Second,
}
// Streaming paths use streamingClient (no Timeout); redirect it to the
// test server too, otherwise Bedrock streaming tests would hit the real
// AWS endpoint.
provider.streamingClient = &http.Client{Transport: redirect}
return provider
}
// testBedrockKey returns a minimal Key with a bearer value so makeStreamingRequest
// skips IAM signing and proceeds to the HTTP call.
func testBedrockKey() schemas.Key {
region := schemas.NewEnvVar("us-east-1")
return schemas.Key{
Value: *schemas.NewEnvVar("test-api-key"),
BedrockKeyConfig: &schemas.BedrockKeyConfig{
Region: region,
},
}
}
// testBedrockCtx returns a BifrostContext suitable for unit tests.
func testBedrockCtx() *schemas.BifrostContext {
return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
}
// noopPostHookRunner is a PostHookRunner that passes through results unchanged.
func noopPostHookRunner(_ *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return result, err
}
// testChatRequest returns a minimal BifrostChatRequest for streaming tests.
func testChatRequest() *schemas.BifrostChatRequest {
content := "hello"
return &schemas.BifrostChatRequest{
Model: "anthropic.claude-sonnet-4-5",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: &content},
},
},
}
}
// TestMakeStreamingRequest_StaleConnection_IsRetryable verifies that when the
// HTTP server closes the connection before sending a response (simulating a
// stale HTTP/2 connection), makeStreamingRequest returns a BifrostError with
// IsBifrostError:false so the retry gate in executeRequestWithRetries retries.
func TestMakeStreamingRequest_StaleConnection_IsRetryable(t *testing.T) {
// Server that immediately closes the connection without sending anything.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hj, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "hijack not supported", http.StatusInternalServerError)
return
}
conn, _, _ := hj.Hijack()
conn.Close() // close without writing any response
}))
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
ctx := testBedrockCtx()
key := testBedrockKey()
_, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream")
require.NotNil(t, bifrostErr, "expected error when server closes connection")
assert.False(t, bifrostErr.IsBifrostError,
"stale-connection error must be IsBifrostError:false so the retry gate can retry it")
require.NotNil(t, bifrostErr.Error)
// Either ErrProviderNetworkError (net.OpError) or ErrProviderDoRequest (EOF/connection-reset)
// are both retryable — the key invariant is IsBifrostError:false.
assert.Contains(t, []string{schemas.ErrProviderNetworkError, schemas.ErrProviderDoRequest}, bifrostErr.Error.Message,
"stale-connection error must use a retryable error message")
}
// TestChatCompletionStream_StaleConnection_ChunkIsRetryable verifies that when
// the server returns HTTP 200 but closes the body immediately (simulating a
// stale connection mid-stream before any EventStream data arrives), the first
// chunk received from the stream channel carries a BifrostError with
// IsBifrostError:false so that CheckFirstStreamChunkForError + the retry gate
// can transparently retry the request.
func TestChatCompletionStream_StaleConnection_ChunkIsRetryable(t *testing.T) {
// Server: returns 200 with the correct content-type but closes body immediately.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
hj, ok := w.(http.Hijacker)
if !ok {
return
}
conn, _, _ := hj.Hijack()
conn.Close() // close without any EventStream bytes
}))
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
ctx := testBedrockCtx()
key := testBedrockKey()
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
if bifrostErr != nil {
// Error surfaced synchronously (e.g. connection refused before HTTP 200).
assert.False(t, bifrostErr.IsBifrostError,
"pre-stream network error must be IsBifrostError:false")
return
}
// Error surfaced as the first stream chunk.
require.NotNil(t, streamChan)
chunk, ok := <-streamChan
require.True(t, ok, "channel must not be empty")
require.NotNil(t, chunk)
require.NotNil(t, chunk.BifrostError, "expected an error chunk from the stream")
assert.False(t, chunk.BifrostError.IsBifrostError,
"stream transport error must be IsBifrostError:false so the retry gate can retry it")
require.NotNil(t, chunk.BifrostError.Error)
assert.Equal(t, schemas.ErrProviderNetworkError, chunk.BifrostError.Error.Message,
"stream transport error must use ErrProviderNetworkError message")
// Drain any remaining chunks.
for range streamChan {
}
}
// TestChatCompletionStream_NetOpError_ChunkIsRetryable verifies the specific
// "use of closed network connection" *net.OpError scenario from issue #2424:
// a successful HTTP connection that is then closed server-side produces a
// *net.OpError during EventStream decoding, which must arrive as a retryable
// IsBifrostError:false chunk.
func TestChatCompletionStream_NetOpError_ChunkIsRetryable(t *testing.T) {
// Server: returns 200 + correct headers, writes a truncated EventStream
// prelude (not a valid frame), then forcibly resets the TCP connection —
// producing a *net.OpError on the client's read side.
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// Write a partial EventStream frame header (3 bytes, not a valid frame).
_, _ = w.Write([]byte{0x00, 0x00, 0x00})
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
hj, ok := w.(http.Hijacker)
if !ok {
return
}
conn, _, _ := hj.Hijack()
// RST instead of FIN — guarantees a *net.OpError on the client read.
if tc, ok := conn.(*net.TCPConn); ok {
_ = tc.SetLinger(0)
}
conn.Close()
}))
ts.Start()
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
ctx := testBedrockCtx()
key := testBedrockKey()
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
if bifrostErr != nil {
assert.False(t, bifrostErr.IsBifrostError,
"pre-stream network error must be IsBifrostError:false")
return
}
require.NotNil(t, streamChan)
// Collect chunks until we find an error chunk (may not be the very first
// if the OS buffers the partial write, but it must appear before close).
var errChunk *schemas.BifrostStreamChunk
for chunk := range streamChan {
if chunk != nil && chunk.BifrostError != nil {
errChunk = chunk
break
}
}
// Drain remaining.
for range streamChan {
}
require.NotNil(t, errChunk, "expected an error chunk from the stream")
assert.False(t, errChunk.BifrostError.IsBifrostError,
"net.OpError during EventStream decoding must be IsBifrostError:false so the retry gate can retry it")
require.NotNil(t, errChunk.BifrostError.Error)
assert.Equal(t, schemas.ErrProviderNetworkError, errChunk.BifrostError.Error.Message,
"net.OpError during EventStream decoding must use ErrProviderNetworkError message")
}
// writeEventStreamException encodes a well-formed AWS EventStream exception
// frame with the given exception type and message into w.
// The frame format is: prelude (total_len + headers_len + CRC) + headers + payload + message_CRC.
// We use the AWS SDK's eventstream.Encoder so the binary framing is correct.
func writeEventStreamException(t *testing.T, w io.Writer, excType, msg string) {
t.Helper()
enc := eventstream.NewEncoder()
payload, err := json.Marshal(map[string]string{"message": msg})
require.NoError(t, err, "failed to marshal exception payload")
headers := eventstream.Headers{
{Name: ":message-type", Value: eventstream.StringValue("exception")},
{Name: ":exception-type", Value: eventstream.StringValue(excType)},
{Name: ":content-type", Value: eventstream.StringValue("application/json")},
}
err = enc.Encode(w, eventstream.Message{Headers: headers, Payload: payload})
require.NoError(t, err, "failed to encode EventStream exception frame")
}
// TestChatCompletionStream_RetryableException_ChunkIsRetryable verifies that
// when AWS Bedrock sends a retryable exception (serviceUnavailableException,
// throttlingException, etc.) through the EventStream, the resulting error chunk
// has IsBifrostError:false and the correct HTTP StatusCode so that the retry
// gate in executeRequestWithRetries can retry the request.
func TestChatCompletionStream_RetryableException_ChunkIsRetryable(t *testing.T) {
tests := []struct {
excType string
expectedStatus int
}{
{"serviceUnavailableException", 503},
{"throttlingException", 429},
{"modelNotReadyException", 503},
{"internalServerException", 500},
}
for _, tc := range tests {
tc := tc
t.Run(tc.excType, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
writeEventStreamException(t, w, tc.excType, "service is unavailable, please retry")
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}))
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
ctx := testBedrockCtx()
key := testBedrockKey()
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk")
require.NotNil(t, streamChan)
var errChunk *schemas.BifrostStreamChunk
for chunk := range streamChan {
if chunk != nil && chunk.BifrostError != nil {
errChunk = chunk
break
}
}
for range streamChan {
}
require.NotNil(t, errChunk, "expected error chunk for %s", tc.excType)
assert.False(t, errChunk.BifrostError.IsBifrostError,
"%s must be IsBifrostError:false so retry gate can retry it", tc.excType)
require.NotNil(t, errChunk.BifrostError.StatusCode,
"%s must carry a StatusCode for the retryableStatusCodes gate", tc.excType)
assert.Equal(t, tc.expectedStatus, *errChunk.BifrostError.StatusCode,
"%s must map to HTTP %d", tc.excType, tc.expectedStatus)
})
}
}
// TestChatCompletionStream_NonRetryableException_IsTerminal verifies that
// non-retryable exception types (e.g. validationException, accessDeniedException)
// continue to use ProcessAndSendError (IsBifrostError:true) and are NOT retried.
func TestChatCompletionStream_NonRetryableException_IsTerminal(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
writeEventStreamException(t, w, "validationException", "input validation failed")
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}))
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
ctx := testBedrockCtx()
key := testBedrockKey()
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk")
require.NotNil(t, streamChan)
var errChunk *schemas.BifrostStreamChunk
for chunk := range streamChan {
if chunk != nil && chunk.BifrostError != nil {
errChunk = chunk
break
}
}
for range streamChan {
}
require.NotNil(t, errChunk, "expected error chunk for validationException")
assert.True(t, errChunk.BifrostError.IsBifrostError,
"non-retryable validationException must remain IsBifrostError:true")
}
// testTextCompletionRequest returns a minimal BifrostTextCompletionRequest for streaming tests.
func testTextCompletionRequest() *schemas.BifrostTextCompletionRequest {
prompt := "hello"
return &schemas.BifrostTextCompletionRequest{
Model: "anthropic.claude-sonnet-4-5",
Input: &schemas.TextCompletionInput{PromptStr: &prompt},
}
}
// testResponsesRequest returns a minimal BifrostResponsesRequest for streaming tests.
func testResponsesRequest() *schemas.BifrostResponsesRequest {
msgType := schemas.ResponsesMessageType("message")
roleUser := schemas.ResponsesMessageRoleType("user")
content := "hello"
return &schemas.BifrostResponsesRequest{
Model: "anthropic.claude-sonnet-4-5",
Input: []schemas.ResponsesMessage{
{
Type: &msgType,
Role: &roleUser,
Content: &schemas.ResponsesMessageContent{ContentStr: &content},
},
},
}
}
// assertRetryableExceptionChunk is the shared assertion helper for all three
// streaming-method retryable-exception tests.
func assertRetryableExceptionChunk(t *testing.T, streamChan chan *schemas.BifrostStreamChunk, bifrostErr *schemas.BifrostError, excType string, expectedStatus int) {
t.Helper()
require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk, not a pre-stream error")
require.NotNil(t, streamChan)
var errChunk *schemas.BifrostStreamChunk
for chunk := range streamChan {
if chunk != nil && chunk.BifrostError != nil {
errChunk = chunk
break
}
}
for range streamChan {
}
require.NotNil(t, errChunk, "expected error chunk for %s", excType)
assert.False(t, errChunk.BifrostError.IsBifrostError,
"%s must be IsBifrostError:false so retry gate can retry it", excType)
require.NotNil(t, errChunk.BifrostError.StatusCode,
"%s must carry a StatusCode for the retryableStatusCodes gate", excType)
assert.Equal(t, expectedStatus, *errChunk.BifrostError.StatusCode,
"%s must map to HTTP %d", excType, expectedStatus)
}
// TestTextCompletionStream_RetryableException_ChunkIsRetryable mirrors the
// ChatCompletionStream test for the TextCompletionStream path, which has
// slightly different payload-parsing logic (extra BedrockError JSON unmarshal).
func TestTextCompletionStream_RetryableException_ChunkIsRetryable(t *testing.T) {
tests := []struct {
excType string
expectedStatus int
}{
{"serviceUnavailableException", 503},
{"throttlingException", 429},
{"modelNotReadyException", 503},
{"internalServerException", 500},
}
for _, tc := range tests {
tc := tc
t.Run(tc.excType, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
writeEventStreamException(t, w, tc.excType, "service is unavailable, please retry")
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}))
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
streamChan, bifrostErr := provider.TextCompletionStream(testBedrockCtx(), noopPostHookRunner, nil, testBedrockKey(), testTextCompletionRequest())
assertRetryableExceptionChunk(t, streamChan, bifrostErr, tc.excType, tc.expectedStatus)
})
}
}
// TestResponsesStream_RetryableException_ChunkIsRetryable mirrors the
// ChatCompletionStream test for the ResponsesStream path.
func TestResponsesStream_RetryableException_ChunkIsRetryable(t *testing.T) {
tests := []struct {
excType string
expectedStatus int
}{
{"serviceUnavailableException", 503},
{"throttlingException", 429},
{"modelNotReadyException", 503},
{"internalServerException", 500},
}
for _, tc := range tests {
tc := tc
t.Run(tc.excType, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
writeEventStreamException(t, w, tc.excType, "service is unavailable, please retry")
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}))
defer ts.Close()
provider := newTestProviderWithServer(t, ts)
streamChan, bifrostErr := provider.ResponsesStream(testBedrockCtx(), noopPostHookRunner, nil, testBedrockKey(), testResponsesRequest())
assertRetryableExceptionChunk(t, streamChan, bifrostErr, tc.excType, tc.expectedStatus)
})
}
}
func generateTestCACert(t *testing.T) string {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "testca"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
IsCA: true,
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
return string(certPEM)
}
func TestBedrockTransportHTTP2Config(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
MaxConnsPerHost: 5000,
EnforceHTTP2: true,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
require.NotNil(t, provider)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok, "transport should be *http.Transport")
assert.Equal(t, 5000, transport.MaxConnsPerHost)
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConns)
assert.True(t, transport.ForceAttemptHTTP2)
}
func TestBedrockTransportCustomMaxConns(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
MaxConnsPerHost: 50,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
assert.Equal(t, 50, transport.MaxConnsPerHost)
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConns)
}
func TestBedrockTransportDefaultMaxConns(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
// MaxConnsPerHost left as 0 — should default to 5000
},
}
config.CheckAndSetDefaults()
assert.Equal(t, schemas.DefaultMaxConnsPerHost, config.NetworkConfig.MaxConnsPerHost)
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
assert.Equal(t, schemas.DefaultMaxConnsPerHost, transport.MaxConnsPerHost)
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConns)
}
func TestBedrockTransportTLSInsecureSkipVerify(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
InsecureSkipVerify: true,
EnforceHTTP2: true,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
require.NotNil(t, transport.TLSClientConfig)
assert.True(t, transport.TLSClientConfig.InsecureSkipVerify)
assert.Equal(t, uint16(tls.VersionTLS12), transport.TLSClientConfig.MinVersion)
// ForceAttemptHTTP2 should still be true even with custom TLS config
assert.True(t, transport.ForceAttemptHTTP2)
}
func TestBedrockTransportTLSCACert(t *testing.T) {
testCACert := generateTestCACert(t)
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
CACertPEM: testCACert,
EnforceHTTP2: true,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
require.NotNil(t, transport.TLSClientConfig)
assert.NotNil(t, transport.TLSClientConfig.RootCAs)
assert.Equal(t, uint16(tls.VersionTLS12), transport.TLSClientConfig.MinVersion)
assert.True(t, transport.ForceAttemptHTTP2)
}
func TestBedrockTransportDefaultTLS(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
// No TLS settings — should use system defaults
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
// No custom TLS config should be set
assert.Nil(t, transport.TLSClientConfig)
// EnforceHTTP2 not set — ForceAttemptHTTP2 should be false
assert.False(t, transport.ForceAttemptHTTP2)
}
func TestBedrockTransportEnforceHTTP2(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
EnforceHTTP2: true,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
assert.True(t, transport.ForceAttemptHTTP2)
// TLSNextProto should NOT be set when HTTP/2 is enforced, allowing ALPN negotiation
assert.Nil(t, transport.TLSNextProto)
}
func TestBedrockTransportEnforceHTTP2Disabled(t *testing.T) {
config := &schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
DefaultRequestTimeoutInSeconds: 30,
EnforceHTTP2: false,
},
}
config.CheckAndSetDefaults()
provider, err := NewBedrockProvider(config, nil)
require.NoError(t, err)
transport, ok := provider.client.Transport.(*http.Transport)
require.True(t, ok)
assert.False(t, transport.ForceAttemptHTTP2)
// TLSNextProto must be set to empty map to truly disable HTTP/2 ALPN negotiation
assert.NotNil(t, transport.TLSNextProto)
assert.Empty(t, transport.TLSNextProto)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,405 @@
// Package cerebras implements the Cerebras LLM provider.
package cerebras
import (
"context"
"strings"
"time"
"github.com/maximhq/bifrost/core/providers/openai"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// CerebrasProvider implements the Provider interface for Cerebras's API.
type CerebrasProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
}
// NewCerebrasProvider creates a new Cerebras provider instance.
// It initializes the HTTP client with the provided configuration and sets up response pools.
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
func NewCerebrasProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*CerebrasProvider, error) {
config.CheckAndSetDefaults()
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
client := &fasthttp.Client{
ReadTimeout: requestTimeout,
WriteTimeout: requestTimeout,
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
MaxIdleConnDuration: 30 * time.Second,
MaxConnWaitTimeout: requestTimeout,
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
ConnPoolStrategy: fasthttp.FIFO,
}
// Configure proxy and retry policy
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
client = providerUtils.ConfigureDialer(client)
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
streamingClient := providerUtils.BuildStreamingClient(client)
// Set default BaseURL if not provided
if config.NetworkConfig.BaseURL == "" {
config.NetworkConfig.BaseURL = "https://api.cerebras.ai"
}
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
return &CerebrasProvider{
logger: logger,
client: client,
streamingClient: streamingClient,
networkConfig: config.NetworkConfig,
sendBackRawRequest: config.SendBackRawRequest,
sendBackRawResponse: config.SendBackRawResponse,
}, nil
}
// GetProviderKey returns the provider identifier for Cerebras.
func (provider *CerebrasProvider) GetProviderKey() schemas.ModelProvider {
return schemas.Cerebras
}
// ListModels performs a list models request to Cerebras's API.
func (provider *CerebrasProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
return openai.HandleOpenAIListModelsRequest(
ctx,
provider.client,
request,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"),
keys,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
)
}
// TextCompletion performs a text completion request to Cerebras's API.
// It formats the request, sends it to Cerebras, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *CerebrasProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
return openai.HandleOpenAITextCompletionRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
request,
key,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
nil,
nil,
provider.logger,
)
}
// TextCompletionStream performs a streaming text completion request to Cerebras's API.
// It formats the request, sends it to Cerebras, and processes the response.
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
func (provider *CerebrasProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
var authHeader map[string]string
if key.Value.GetValue() != "" {
authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()}
}
// Use shared OpenAI-compatible streaming logic
return openai.HandleOpenAITextCompletionStreaming(
ctx,
provider.streamingClient,
provider.networkConfig.BaseURL+"/v1/completions",
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
nil,
postHookRunner,
nil,
nil,
provider.logger,
postHookSpanFinalizer,
)
}
// ChatCompletion performs a chat completion request to the Cerebras API.
func (provider *CerebrasProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
return openai.HandleOpenAIChatCompletionRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
request,
key,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
nil,
nil,
provider.logger,
)
}
// ChatCompletionStream performs a streaming chat completion request to the Cerebras API.
// It supports real-time streaming of responses using Server-Sent Events (SSE).
// Uses Cerebras's OpenAI-compatible streaming format.
// Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails.
func (provider *CerebrasProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
var authHeader map[string]string
if key.Value.GetValue() != "" {
authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()}
}
// Use shared OpenAI-compatible streaming logic
return openai.HandleOpenAIChatCompletionStreaming(
ctx,
provider.streamingClient,
provider.networkConfig.BaseURL+"/v1/chat/completions",
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
schemas.Cerebras,
postHookRunner,
nil,
nil,
nil,
nil,
nil,
provider.logger,
postHookSpanFinalizer,
)
}
func (provider *CerebrasProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
if err != nil {
return nil, err
}
response := chatResponse.ToBifrostResponsesResponse()
return response, nil
}
// ResponsesStream performs a streaming responses request to the Cerebras API.
func (provider *CerebrasProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
return provider.ChatCompletionStream(
ctx,
postHookRunner,
postHookSpanFinalizer,
key,
request.ToChatRequest(),
)
}
// Embedding is not supported by the Cerebras provider.
func (provider *CerebrasProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey())
}
// Speech is not supported by the Cerebras provider.
func (provider *CerebrasProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
}
// Rerank is not supported by the Cerebras provider.
func (provider *CerebrasProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
}
// OCR is not supported by the Cerebras provider.
func (provider *CerebrasProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
}
// SpeechStream is not supported by the Cerebras provider.
func (provider *CerebrasProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
}
// Transcription is not supported by the Cerebras provider.
func (provider *CerebrasProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey())
}
// TranscriptionStream is not supported by the Cerebras provider.
func (provider *CerebrasProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
}
// ImageGeneration is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
}
// ImageGenerationStream is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
}
// ImageEdit is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
}
// ImageEditStream is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
}
// ImageVariation is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
}
// VideoGeneration is not supported by the Cerebras provider.
func (provider *CerebrasProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey())
}
// VideoRetrieve is not supported by the Cerebras provider.
func (provider *CerebrasProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey())
}
// VideoDownload is not supported by the Cerebras provider.
func (provider *CerebrasProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey())
}
// VideoDelete is not supported by Cerebras provider.
func (provider *CerebrasProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
}
// VideoList is not supported by Cerebras provider.
func (provider *CerebrasProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
}
// VideoRemix is not supported by Cerebras provider.
func (provider *CerebrasProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
}
// FileUpload is not supported by Cerebras provider.
func (provider *CerebrasProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
}
// FileList is not supported by Cerebras provider.
func (provider *CerebrasProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
}
// FileRetrieve is not supported by Cerebras provider.
func (provider *CerebrasProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
}
// FileDelete is not supported by Cerebras provider.
func (provider *CerebrasProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
}
// FileContent is not supported by Cerebras provider.
func (provider *CerebrasProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
}
// BatchCreate is not supported by Cerebras provider.
func (provider *CerebrasProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
}
// BatchList is not supported by Cerebras provider.
func (provider *CerebrasProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
}
// BatchRetrieve is not supported by Cerebras provider.
func (provider *CerebrasProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
}
// BatchCancel is not supported by Cerebras provider.
func (provider *CerebrasProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
}
// BatchDelete is not supported by Cerebras provider.
func (provider *CerebrasProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
}
// BatchResults is not supported by Cerebras provider.
func (provider *CerebrasProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
}
// CountTokens is not supported by the Cerebras provider.
func (provider *CerebrasProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
}
// ContainerCreate is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
}
// ContainerList is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
}
// ContainerRetrieve is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
}
// ContainerDelete is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
}
// ContainerFileCreate is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
}
// ContainerFileList is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
}
// ContainerFileRetrieve is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
}
// ContainerFileContent is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
}
// ContainerFileDelete is not supported by the Cerebras provider.
func (provider *CerebrasProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
}
// Passthrough is not supported by the Cerebras provider.
func (provider *CerebrasProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
}
func (provider *CerebrasProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
}

View File

@@ -0,0 +1,60 @@
package cerebras_test
import (
"os"
"strings"
"testing"
"github.com/maximhq/bifrost/core/internal/llmtests"
"github.com/maximhq/bifrost/core/schemas"
)
func TestCerebras(t *testing.T) {
t.Parallel()
if strings.TrimSpace(os.Getenv("CEREBRAS_API_KEY")) == "" {
t.Skip("Skipping Cerebras tests because CEREBRAS_API_KEY is not set")
}
client, ctx, cancel, err := llmtests.SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
testConfig := llmtests.ComprehensiveTestConfig{
Provider: schemas.Cerebras,
ChatModel: "llama3.1-8b",
Fallbacks: []schemas.Fallback{
{Provider: schemas.Cerebras, Model: "llama3.1-8b"},
{Provider: schemas.Cerebras, Model: "gpt-oss-120b"},
},
TextModel: "llama3.1-8b",
EmbeddingModel: "", // Cerebras doesn't support embedding
ReasoningModel: "gpt-oss-120b",
Scenarios: llmtests.TestScenarios{
TextCompletion: true,
TextCompletionStream: true,
SimpleChat: true,
CompletionStream: true,
MultiTurnConversation: true,
ToolCalls: true,
ToolCallsStreaming: true,
MultipleToolCalls: false, // llama3.1-8b doesn't reliably produce parallel tool calls
End2EndToolCalling: true,
AutomaticFunctionCall: true,
ImageURL: false,
ImageBase64: false,
MultipleImages: false,
CompleteEnd2End: true,
Embedding: false,
ListModels: true,
Reasoning: true,
},
}
t.Run("CerebrasTests", func(t *testing.T) {
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
})
}

View File

@@ -0,0 +1,722 @@
package cohere
import (
"fmt"
"time"
"github.com/maximhq/bifrost/core/providers/anthropic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToCohereChatCompletionRequest converts a Bifrost request to Cohere v2 format
func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*CohereChatRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, fmt.Errorf("bifrost request is nil")
}
messages := bifrostReq.Input
cohereReq := &CohereChatRequest{
Model: bifrostReq.Model,
}
// Convert messages to Cohere v2 format
var cohereMessages []CohereMessage
for _, msg := range messages {
cohereMsg := CohereMessage{
Role: string(msg.Role),
}
// Convert content
if msg.Content != nil && msg.Content.ContentStr != nil {
cohereMsg.Content = NewStringContent(*msg.Content.ContentStr)
} else if msg.Content != nil && msg.Content.ContentBlocks != nil {
var contentBlocks []CohereContentBlock
for _, block := range msg.Content.ContentBlocks {
if block.Text != nil {
contentBlocks = append(contentBlocks, CohereContentBlock{
Type: CohereContentBlockTypeText,
Text: block.Text,
})
} else if block.ImageURLStruct != nil {
contentBlocks = append(contentBlocks, CohereContentBlock{
Type: CohereContentBlockTypeImage,
ImageURL: &CohereImageURL{
URL: block.ImageURLStruct.URL,
},
})
}
}
if len(contentBlocks) > 0 {
cohereMsg.Content = NewBlocksContent(contentBlocks)
}
}
// Convert tool calls for assistant messages
if msg.ChatAssistantMessage != nil && msg.ChatAssistantMessage.ToolCalls != nil {
var toolCalls []CohereToolCall
for _, toolCall := range msg.ChatAssistantMessage.ToolCalls {
// Safely extract function name and arguments
var functionName *string
var functionArguments string
if toolCall.Function.Name != nil {
functionName = toolCall.Function.Name
} else {
// Use empty string if Name is nil
functionName = schemas.Ptr("")
}
// Arguments is a string, not a pointer, so it's safe to access directly
// Default to "{}" if empty to ensure the field is always present.
if toolCall.Function.Arguments == "" {
functionArguments = "{}"
} else {
functionArguments = toolCall.Function.Arguments
}
cohereToolCall := CohereToolCall{
ID: toolCall.ID,
Type: "function",
Function: &CohereFunction{
Name: functionName,
Arguments: functionArguments,
},
}
toolCalls = append(toolCalls, cohereToolCall)
}
cohereMsg.ToolCalls = toolCalls
}
// Convert tool messages
if msg.ChatToolMessage != nil && msg.ChatToolMessage.ToolCallID != nil {
cohereMsg.ToolCallID = msg.ChatToolMessage.ToolCallID
}
cohereMessages = append(cohereMessages, cohereMsg)
}
cohereReq.Messages = cohereMessages
// Convert parameters
if bifrostReq.Params != nil {
cohereReq.MaxTokens = bifrostReq.Params.MaxCompletionTokens
cohereReq.Temperature = bifrostReq.Params.Temperature
cohereReq.P = bifrostReq.Params.TopP
cohereReq.StopSequences = bifrostReq.Params.Stop
cohereReq.FrequencyPenalty = bifrostReq.Params.FrequencyPenalty
cohereReq.PresencePenalty = bifrostReq.Params.PresencePenalty
// Convert reasoning
if bifrostReq.Params.Reasoning != nil {
if bifrostReq.Params.Reasoning.MaxTokens != nil {
thinking := &CohereThinking{
Type: ThinkingTypeEnabled,
}
if *bifrostReq.Params.Reasoning.MaxTokens == -1 {
// cohere does not support dynamic reasoning budget like gemini
// setting it to minimum reasoning budget
thinking.TokenBudget = schemas.Ptr(anthropic.MinimumReasoningMaxTokens)
} else {
thinking.TokenBudget = bifrostReq.Params.Reasoning.MaxTokens
}
cohereReq.Thinking = thinking
} else if bifrostReq.Params.Reasoning.Effort != nil {
if *bifrostReq.Params.Reasoning.Effort != "none" {
maxCompletionTokens := providerUtils.GetMaxOutputTokensOrDefault(bifrostReq.Model, DefaultCompletionMaxTokens)
if bifrostReq.Params.MaxCompletionTokens != nil {
maxCompletionTokens = *bifrostReq.Params.MaxCompletionTokens
}
budgetTokens, err := providerUtils.GetBudgetTokensFromReasoningEffort(*bifrostReq.Params.Reasoning.Effort, MinimumReasoningMaxTokens, maxCompletionTokens)
if err != nil {
return nil, err
}
cohereReq.Thinking = &CohereThinking{
Type: ThinkingTypeEnabled,
TokenBudget: schemas.Ptr(budgetTokens), // Max tokens for reasoning
}
} else {
cohereReq.Thinking = &CohereThinking{
Type: ThinkingTypeDisabled,
}
}
}
}
// Convert response format
if bifrostReq.Params.ResponseFormat != nil {
cohereReq.ResponseFormat = convertResponseFormatToCohere(bifrostReq.Params.ResponseFormat)
}
// Convert extra params
if bifrostReq.Params.ExtraParams != nil {
// Handle thinking parameter
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
if thinkingParam, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "thinking"); ok {
if thinkingMap, ok := thinkingParam.(map[string]interface{}); ok {
thinking := &CohereThinking{}
if typeStr, ok := schemas.SafeExtractString(thinkingMap["type"]); ok {
delete(thinkingMap, "type")
thinking.Type = CohereThinkingType(typeStr)
}
if tokenBudget, ok := schemas.SafeExtractIntPointer(thinkingMap["token_budget"]); ok {
delete(thinkingMap, "token_budget")
thinking.TokenBudget = tokenBudget
}
cohereReq.Thinking = thinking
cohereReq.ExtraParams["thinking"] = thinkingMap
}
}
// Handle other Cohere-specific extra params
if safetyMode, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["safety_mode"]); ok {
delete(cohereReq.ExtraParams, "safety_mode")
cohereReq.SafetyMode = safetyMode
}
if logProbs, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["log_probs"]); ok {
delete(cohereReq.ExtraParams, "log_probs")
cohereReq.LogProbs = logProbs
}
if strictToolChoice, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["strict_tool_choice"]); ok {
delete(cohereReq.ExtraParams, "strict_tool_choice")
cohereReq.StrictToolChoice = strictToolChoice
}
}
// Convert tools to Cohere-specific format (without "strict" field)
if bifrostReq.Params.Tools != nil {
cohereTools := make([]CohereChatRequestTool, len(bifrostReq.Params.Tools))
for i, tool := range bifrostReq.Params.Tools {
cohereTools[i] = CohereChatRequestTool{
Type: string(tool.Type),
}
if tool.Function != nil {
cohereTools[i].Function = CohereChatRequestFunction{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: tool.Function.Parameters, // Convert to map
// Note: No "strict" field - Cohere doesn't support it
}
}
}
cohereReq.Tools = cohereTools
}
// Convert tool choice
if bifrostReq.Params.ToolChoice != nil {
toolChoice := bifrostReq.Params.ToolChoice
if toolChoice.ChatToolChoiceStr != nil {
switch schemas.ChatToolChoiceType(*toolChoice.ChatToolChoiceStr) {
case schemas.ChatToolChoiceTypeNone:
toolChoice := ToolChoiceNone
cohereReq.ToolChoice = &toolChoice
default:
toolChoice := ToolChoiceRequired
cohereReq.ToolChoice = &toolChoice
}
} else if toolChoice.ChatToolChoiceStruct != nil {
switch toolChoice.ChatToolChoiceStruct.Type {
case schemas.ChatToolChoiceTypeFunction:
toolChoice := ToolChoiceRequired
cohereReq.ToolChoice = &toolChoice
default:
toolChoice := ToolChoiceAuto
cohereReq.ToolChoice = &toolChoice
}
}
}
}
return cohereReq, nil
}
// ToBifrostChatRequest converts a Cohere v2 chat request to Bifrost format
func (req *CohereChatRequest) ToBifrostChatRequest(ctx *schemas.BifrostContext) *schemas.BifrostChatRequest {
if req == nil {
return nil
}
provider, model := schemas.ParseModelString(req.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
bifrostReq := &schemas.BifrostChatRequest{
Provider: provider,
Model: model,
Params: &schemas.ChatParameters{},
}
// Convert messages
if req.Messages != nil {
bifrostMessages := make([]schemas.ChatMessage, len(req.Messages))
for i, message := range req.Messages {
bifrostMessages[i] = *message.ToBifrostChatMessage()
}
bifrostReq.Input = bifrostMessages
}
// Convert parameters
if req.MaxTokens != nil {
bifrostReq.Params.MaxCompletionTokens = req.MaxTokens
}
if req.Temperature != nil {
bifrostReq.Params.Temperature = req.Temperature
}
if req.P != nil {
bifrostReq.Params.TopP = req.P
}
if req.StopSequences != nil {
bifrostReq.Params.Stop = req.StopSequences
}
if req.FrequencyPenalty != nil {
bifrostReq.Params.FrequencyPenalty = req.FrequencyPenalty
}
if req.PresencePenalty != nil {
bifrostReq.Params.PresencePenalty = req.PresencePenalty
}
// Convert reasoning
if req.Thinking != nil {
if req.Thinking.Type == ThinkingTypeDisabled {
bifrostReq.Params.Reasoning = &schemas.ChatReasoning{
Effort: schemas.Ptr("none"),
}
} else {
bifrostReq.Params.Reasoning = &schemas.ChatReasoning{
Effort: schemas.Ptr("auto"),
}
if req.Thinking.TokenBudget != nil {
bifrostReq.Params.Reasoning.MaxTokens = req.Thinking.TokenBudget
}
}
}
if req.ResponseFormat != nil {
bifrostReq.Params.ResponseFormat = convertCohereResponseFormatToBifrost(req.ResponseFormat)
}
// Convert tools
if req.Tools != nil {
bifrostTools := make([]schemas.ChatTool, len(req.Tools))
for i, tool := range req.Tools {
bifrostTools[i] = schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: convertInterfaceToToolFunctionParameters(tool.Function.Parameters),
},
}
}
bifrostReq.Params.Tools = bifrostTools
}
// Convert tool choice
if req.ToolChoice != nil {
switch *req.ToolChoice {
case ToolChoiceNone:
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeNone)),
}
case ToolChoiceRequired:
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeRequired)),
}
case ToolChoiceAuto:
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeAny)),
}
}
}
// Convert extra params
extraParams := make(map[string]interface{})
if req.SafetyMode != nil {
extraParams["safety_mode"] = *req.SafetyMode
}
if req.LogProbs != nil {
extraParams["log_probs"] = *req.LogProbs
}
if req.StrictToolChoice != nil {
extraParams["strict_tool_choice"] = *req.StrictToolChoice
}
if req.Thinking != nil {
thinkingMap := map[string]interface{}{
"type": string(req.Thinking.Type),
}
if req.Thinking.TokenBudget != nil {
thinkingMap["token_budget"] = *req.Thinking.TokenBudget
}
extraParams["thinking"] = thinkingMap
}
if len(extraParams) > 0 {
bifrostReq.Params.ExtraParams = extraParams
}
return bifrostReq
}
// ToBifrostChatResponse converts a Cohere v2 response to Bifrost format
func (response *CohereChatResponse) ToBifrostChatResponse(model string) *schemas.BifrostChatResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostChatResponse{
ID: response.ID,
Model: model,
Object: "chat.completion",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{},
},
},
Created: int(time.Now().Unix()),
ExtraFields: schemas.BifrostResponseExtraFields{
},
}
// Convert messages
if response.Message != nil {
bifrostMessage := response.Message.ToBifrostChatMessage()
bifrostResponse.Choices[0].ChatNonStreamResponseChoice.Message = bifrostMessage
}
// Convert finish reason
if response.FinishReason != nil {
finishReason := ConvertCohereFinishReasonToBifrost(*response.FinishReason)
bifrostResponse.Choices[0].FinishReason = schemas.Ptr(finishReason)
}
// Convert usage information
if response.Usage != nil {
usage := &schemas.BifrostLLMUsage{}
if response.Usage.Tokens != nil {
if response.Usage.Tokens.InputTokens != nil {
usage.PromptTokens = *response.Usage.Tokens.InputTokens
}
if response.Usage.Tokens.OutputTokens != nil {
usage.CompletionTokens = *response.Usage.Tokens.OutputTokens
}
if response.Usage.CachedTokens != nil {
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
CachedReadTokens: *response.Usage.CachedTokens,
}
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
bifrostResponse.Usage = usage
}
return bifrostResponse
}
func (chunk *CohereStreamEvent) ToBifrostChatCompletionStream() (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) {
switch chunk.Type {
case StreamEventMessageStart:
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.Role != nil {
// Create streaming response for this delta
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Role: chunk.Delta.Message.Role,
},
},
},
},
}
return streamResponse, nil, false
}
case StreamEventContentDelta:
if chunk.Delta != nil &&
chunk.Delta.Message != nil &&
chunk.Delta.Message.Content != nil &&
chunk.Delta.Message.Content.CohereStreamContentObject != nil {
if chunk.Delta.Message.Content.CohereStreamContentObject.Text != nil {
// Try to cast content to CohereStreamContent
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Content: chunk.Delta.Message.Content.CohereStreamContentObject.Text,
},
},
},
},
}
return streamResponse, nil, false
} else if chunk.Delta.Message.Content.CohereStreamContentObject.Thinking != nil {
thinkingText := *chunk.Delta.Message.Content.CohereStreamContentObject.Thinking
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Reasoning: schemas.Ptr(thinkingText),
ReasoningDetails: []schemas.ChatReasoningDetails{
{
Index: 0,
Type: schemas.BifrostReasoningDetailsTypeText,
Text: schemas.Ptr(thinkingText),
},
},
},
},
},
},
}
return streamResponse, nil, false
}
}
case StreamEventToolPlanDelta:
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolPlan != nil {
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Reasoning: chunk.Delta.Message.ToolPlan,
},
},
},
},
}
return streamResponse, nil, false
}
case StreamEventContentStart:
// Content start event - just continue, actual content comes in content-delta
return nil, nil, false
case StreamEventToolCallStart, StreamEventToolCallDelta:
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolCalls != nil && chunk.Delta.Message.ToolCalls.CohereToolCallObject != nil {
// Handle single tool call object (tool-call-start/delta events)
cohereToolCall := chunk.Delta.Message.ToolCalls.CohereToolCallObject
toolCall := schemas.ChatAssistantMessageToolCall{}
if chunk.Index != nil {
toolCall.Index = uint16(*chunk.Index)
}
if cohereToolCall.ID != nil {
toolCall.ID = cohereToolCall.ID
}
if cohereToolCall.Function != nil {
if cohereToolCall.Function.Name != nil {
toolCall.Function.Name = cohereToolCall.Function.Name
}
toolCall.Function.Arguments = cohereToolCall.Function.Arguments
}
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
},
},
},
},
}
return streamResponse, nil, false
}
case StreamEventToolCallEnd:
return nil, nil, false
case StreamEventContentEnd:
return nil, nil, false
case StreamEventMessageEnd:
if chunk.Delta != nil {
var finishReason string
usage := &schemas.BifrostLLMUsage{}
// Set finish reason
if chunk.Delta.FinishReason != nil {
finishReason = ConvertCohereFinishReasonToBifrost(*chunk.Delta.FinishReason)
}
// Set usage information
if chunk.Delta.Usage != nil {
if chunk.Delta.Usage.Tokens != nil {
if chunk.Delta.Usage.Tokens.InputTokens != nil {
usage.PromptTokens = *chunk.Delta.Usage.Tokens.InputTokens
}
if chunk.Delta.Usage.Tokens.OutputTokens != nil {
usage.CompletionTokens = *chunk.Delta.Usage.Tokens.OutputTokens
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
}
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
FinishReason: &finishReason,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{},
},
},
},
Usage: usage,
}
return streamResponse, nil, true
}
return nil, nil, false
}
return nil, nil, false
}
func (cm *CohereMessage) ToBifrostChatMessage() *schemas.ChatMessage {
if cm == nil {
return nil
}
var content *string
var contentBlocks []schemas.ChatContentBlock
var toolCalls []schemas.ChatAssistantMessageToolCall
var reasoningDetails []schemas.ChatReasoningDetails
var reasoningText string
// Convert message content
if cm.Content != nil {
if cm.Content.IsString() ||
(cm.Content.IsBlocks() &&
len(cm.Content.GetBlocks()) == 1 &&
cm.Content.GetBlocks()[0].Type == CohereContentBlockTypeText) {
if cm.Content.IsString() {
content = cm.Content.GetString()
} else {
content = cm.Content.GetBlocks()[0].Text
}
} else if cm.Content.IsBlocks() {
for _, block := range cm.Content.GetBlocks() {
if block.Type == CohereContentBlockTypeText && block.Text != nil {
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
Type: schemas.ChatContentBlockTypeText,
Text: block.Text,
})
} else if block.Type == CohereContentBlockTypeImage && block.ImageURL != nil {
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
Type: schemas.ChatContentBlockTypeImage,
ImageURLStruct: &schemas.ChatInputImage{
URL: block.ImageURL.URL,
},
})
} else if block.Type == CohereContentBlockTypeThinking && block.Thinking != nil {
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
Index: len(reasoningDetails),
Type: schemas.BifrostReasoningDetailsTypeText,
Text: block.Thinking,
})
if len(reasoningText) > 0 {
reasoningText += "\n"
}
reasoningText += *block.Thinking
}
}
}
}
if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText {
content = contentBlocks[0].Text
contentBlocks = nil
}
// Create the message content
messageContent := &schemas.ChatMessageContent{
ContentStr: content,
ContentBlocks: contentBlocks,
}
// Convert tool calls
if cm.ToolCalls != nil {
for _, toolCall := range cm.ToolCalls {
// Check if Function is nil to avoid nil pointer dereference
if toolCall.Function == nil {
// Skip this tool call if Function is nil
continue
}
// Safely extract function name and arguments
var functionName *string
var functionArguments string
if toolCall.Function.Name != nil {
functionName = toolCall.Function.Name
} else {
// Use empty string if Name is nil
functionName = schemas.Ptr("")
}
// Arguments is a string, not a pointer, so it's safe to access directly
functionArguments = toolCall.Function.Arguments
bifrostToolCall := schemas.ChatAssistantMessageToolCall{
Index: uint16(len(toolCalls)),
ID: toolCall.ID,
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: functionName,
Arguments: functionArguments,
},
}
toolCalls = append(toolCalls, bifrostToolCall)
}
}
// Create assistant message if we have tool calls
var assistantMessage *schemas.ChatAssistantMessage
if len(toolCalls) > 0 {
assistantMessage = &schemas.ChatAssistantMessage{
ToolCalls: toolCalls,
}
}
if len(reasoningDetails) > 0 {
if assistantMessage == nil {
assistantMessage = &schemas.ChatAssistantMessage{}
}
assistantMessage.ReasoningDetails = reasoningDetails
assistantMessage.Reasoning = schemas.Ptr(reasoningText)
}
bifrostMessage := &schemas.ChatMessage{
Role: schemas.ChatMessageRole(cm.Role),
Content: messageContent,
ChatAssistantMessage: assistantMessage,
}
if cm.Role == "tool" {
bifrostMessage.ChatToolMessage = &schemas.ChatToolMessage{
ToolCallID: cm.ToolCallID,
}
}
return bifrostMessage
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,62 @@
package cohere_test
import (
"os"
"strings"
"testing"
"github.com/maximhq/bifrost/core/internal/llmtests"
"github.com/maximhq/bifrost/core/schemas"
)
func TestCohere(t *testing.T) {
t.Parallel()
if strings.TrimSpace(os.Getenv("COHERE_API_KEY")) == "" {
t.Skip("Skipping Cohere tests because COHERE_API_KEY is not set")
}
client, ctx, cancel, err := llmtests.SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
testConfig := llmtests.ComprehensiveTestConfig{
Provider: schemas.Cohere,
ChatModel: "command-a-03-2025",
VisionModel: "command-a-vision-07-2025", // Cohere's latest vision model
TextModel: "", // Cohere focuses on chat
EmbeddingModel: "embed-v4.0",
RerankModel: "rerank-v3.5",
ReasoningModel: "command-a-reasoning-08-2025",
Scenarios: llmtests.TestScenarios{
TextCompletion: false, // Not typical for Cohere
SimpleChat: true,
CompletionStream: true,
MultiTurnConversation: true,
ToolCalls: true,
ToolCallsStreaming: true,
MultipleToolCalls: true,
MultipleToolCallsStreaming: true,
End2EndToolCalling: true,
AutomaticFunctionCall: true, // May not support automatic
ImageURL: false, // Supported by c4ai-aya-vision-8b model
ImageBase64: true, // Supported by c4ai-aya-vision-8b model
MultipleImages: false, // Supported by c4ai-aya-vision-8b model
FileBase64: false, // Not supported
FileURL: false, // Not supported
CompleteEnd2End: false,
Embedding: true,
Rerank: true,
Reasoning: true,
ListModels: true,
CountTokens: true,
},
}
t.Run("CohereTests", func(t *testing.T) {
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
})
}

View File

@@ -0,0 +1,125 @@
package cohere
import (
"fmt"
"strings"
"unicode/utf8"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostResponsesRequest converts a Cohere count tokens request to Bifrost format.
func (req *CohereCountTokensRequest) ToBifrostResponsesRequest(ctx *schemas.BifrostContext) *schemas.BifrostResponsesRequest {
if req == nil {
return nil
}
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
userRole := schemas.ResponsesInputMessageRoleUser
return &schemas.BifrostResponsesRequest{
Provider: provider,
Model: model,
Input: []schemas.ResponsesMessage{
{
Role: &userRole,
Content: &schemas.ResponsesMessageContent{
ContentStr: &req.Text,
},
},
},
}
}
// ToCohereCountTokensRequest converts a Bifrost count tokens request to Cohere's tokenize payload.
func ToCohereCountTokensRequest(bifrostReq *schemas.BifrostResponsesRequest) (*CohereCountTokensRequest, error) {
if bifrostReq == nil {
return nil, nil
}
if bifrostReq.Input == nil {
return nil, fmt.Errorf("count tokens input is not provided")
}
text := buildCohereCountTokensText(bifrostReq.Input)
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return nil, fmt.Errorf("count tokens text is empty after conversion")
}
runeCount := utf8.RuneCountInString(trimmed)
if runeCount < cohereTokenizeMinTextLength || runeCount > cohereTokenizeMaxTextLength {
return nil, fmt.Errorf("count tokens text length must be between %d and %d characters", cohereTokenizeMinTextLength, cohereTokenizeMaxTextLength)
}
cohereReq := &CohereCountTokensRequest{
Model: bifrostReq.Model,
Text: trimmed,
}
if bifrostReq.Params != nil {
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
}
return cohereReq, nil
}
// ToBifrostCountTokensResponse converts a Cohere tokenize response to Bifrost format.
func (resp *CohereCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
if resp == nil {
return nil
}
inputTokens := len(resp.Tokens)
if inputTokens == 0 && len(resp.TokenStrings) > 0 {
inputTokens = len(resp.TokenStrings)
}
totalTokens := inputTokens
return &schemas.BifrostCountTokensResponse{
Model: model,
InputTokens: inputTokens,
TotalTokens: &totalTokens,
TokenStrings: resp.TokenStrings,
Tokens: resp.Tokens,
Object: "response.input_tokens",
}
}
// buildCohereCountTokensText flattens Responses messages into a plain text payload for tokenization.
func buildCohereCountTokensText(messages []schemas.ResponsesMessage) string {
var parts []string
for _, msg := range messages {
var contentParts []string
if msg.Content != nil {
if msg.Content.ContentStr != nil {
contentParts = append(contentParts, *msg.Content.ContentStr)
}
for _, block := range msg.Content.ContentBlocks {
if block.Text != nil {
contentParts = append(contentParts, *block.Text)
}
if block.ResponsesOutputMessageContentRefusal != nil && block.ResponsesOutputMessageContentRefusal.Refusal != "" {
contentParts = append(contentParts, block.ResponsesOutputMessageContentRefusal.Refusal)
}
}
}
if msg.ResponsesReasoning != nil {
for _, summary := range msg.ResponsesReasoning.Summary {
if summary.Text != "" {
contentParts = append(contentParts, summary.Text)
}
}
}
if len(contentParts) == 0 {
continue
}
parts = append(parts, strings.Join(contentParts, "\n"))
}
return strings.TrimSpace(strings.Join(parts, "\n"))
}

View File

@@ -0,0 +1,190 @@
package cohere
import (
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToCohereEmbeddingRequest converts a Bifrost embedding request to Cohere format
func ToCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *CohereEmbeddingRequest {
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
return nil
}
embeddingInput := bifrostReq.Input
cohereReq := &CohereEmbeddingRequest{
Model: bifrostReq.Model,
}
texts := []string{}
if embeddingInput.Text != nil {
texts = append(texts, *embeddingInput.Text)
} else {
texts = embeddingInput.Texts
}
// Convert texts from Bifrost format
if len(texts) > 0 {
cohereReq.Texts = texts
}
// Set default input type if not specified in extra params
cohereReq.InputType = "search_document" // Default value
if bifrostReq.Params != nil {
cohereReq.OutputDimension = bifrostReq.Params.Dimensions
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
if bifrostReq.Params.ExtraParams != nil {
if maxTokens, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["max_tokens"]); ok {
delete(cohereReq.ExtraParams, "max_tokens")
cohereReq.MaxTokens = maxTokens
}
}
}
// Handle extra params
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
// Input type
if inputType, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["input_type"]); ok {
delete(cohereReq.ExtraParams, "input_type")
cohereReq.InputType = inputType
}
// Embedding types
if embeddingTypes, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["embedding_types"]); ok {
if len(embeddingTypes) > 0 {
delete(cohereReq.ExtraParams, "embedding_types")
cohereReq.EmbeddingTypes = embeddingTypes
}
}
// Truncate
if truncate, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["truncate"]); ok {
delete(cohereReq.ExtraParams, "truncate")
cohereReq.Truncate = truncate
}
}
return cohereReq
}
// ToBifrostEmbeddingRequest converts a Cohere embedding request to Bifrost format
func (req *CohereEmbeddingRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest {
if req == nil {
return nil
}
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
bifrostReq := &schemas.BifrostEmbeddingRequest{
Provider: provider,
Model: model,
Input: &schemas.EmbeddingInput{},
Params: &schemas.EmbeddingParameters{},
}
// Convert texts
if len(req.Texts) > 0 {
if len(req.Texts) == 1 {
bifrostReq.Input.Text = &req.Texts[0]
} else {
bifrostReq.Input.Texts = req.Texts
}
}
// Convert parameters
if req.OutputDimension != nil {
bifrostReq.Params.Dimensions = req.OutputDimension
}
// Convert extra params
extraParams := make(map[string]interface{})
if req.InputType != "" {
extraParams["input_type"] = req.InputType
}
if req.EmbeddingTypes != nil {
extraParams["embedding_types"] = req.EmbeddingTypes
}
if req.Truncate != nil {
extraParams["truncate"] = *req.Truncate
}
if req.MaxTokens != nil {
extraParams["max_tokens"] = *req.MaxTokens
}
if len(extraParams) > 0 {
bifrostReq.Params.ExtraParams = extraParams
}
return bifrostReq
}
// ToBifrostEmbeddingResponse converts a Cohere embedding response to Bifrost format
func (response *CohereEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostEmbeddingResponse{
Object: "list",
}
// Convert embeddings data
if response.Embeddings != nil {
var bifrostEmbeddings []schemas.EmbeddingData
// Handle different embedding types - prioritize float embeddings
if response.Embeddings.Float != nil {
for i, embedding := range response.Embeddings.Float {
bifrostEmbedding := schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{
EmbeddingArray: embedding,
},
}
bifrostEmbeddings = append(bifrostEmbeddings, bifrostEmbedding)
}
} else if response.Embeddings.Base64 != nil {
// Handle base64 embeddings as strings
for i, embedding := range response.Embeddings.Base64 {
bifrostEmbedding := schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{
EmbeddingStr: &embedding,
},
}
bifrostEmbeddings = append(bifrostEmbeddings, bifrostEmbedding)
}
}
// Note: Int8, Uint8, Binary, Ubinary types would need special handling
// depending on how Bifrost wants to represent them
bifrostResponse.Data = bifrostEmbeddings
}
// Convert usage information
if response.Meta != nil {
if response.Meta.Tokens != nil {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{}
if response.Meta.Tokens.InputTokens != nil {
bifrostResponse.Usage.PromptTokens = int(*response.Meta.Tokens.InputTokens)
}
if response.Meta.Tokens.OutputTokens != nil {
bifrostResponse.Usage.CompletionTokens = int(*response.Meta.Tokens.OutputTokens)
}
bifrostResponse.Usage.TotalTokens = bifrostResponse.Usage.PromptTokens + bifrostResponse.Usage.CompletionTokens
} else if response.Meta.BilledUnits != nil {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{}
if response.Meta.BilledUnits.InputTokens != nil {
bifrostResponse.Usage.PromptTokens = int(*response.Meta.BilledUnits.InputTokens)
}
if response.Meta.BilledUnits.OutputTokens != nil {
bifrostResponse.Usage.CompletionTokens = int(*response.Meta.BilledUnits.OutputTokens)
}
bifrostResponse.Usage.TotalTokens = bifrostResponse.Usage.PromptTokens + bifrostResponse.Usage.CompletionTokens
}
}
return bifrostResponse
}

View File

@@ -0,0 +1,90 @@
package cohere
import (
"context"
"testing"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestToCohereEmbeddingRequest(t *testing.T) {
t.Run("returns nil for missing input", func(t *testing.T) {
assert.Nil(t, ToCohereEmbeddingRequest(nil))
assert.Nil(t, ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{}))
assert.Nil(t, ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{
Input: &schemas.EmbeddingInput{},
}))
})
t.Run("single text keeps model in direct cohere body", func(t *testing.T) {
text := "hello"
truncate := "END"
dimensions := 1024
maxTokens := 256
bifrostReq := &schemas.BifrostEmbeddingRequest{
Model: "embed-v4.0",
Input: &schemas.EmbeddingInput{Text: &text},
Params: &schemas.EmbeddingParameters{
Dimensions: &dimensions,
ExtraParams: map[string]interface{}{
"input_type": "classification",
"embedding_types": []string{"float", "int8"},
"truncate": truncate,
"max_tokens": maxTokens,
"priority": "high",
},
},
}
req := ToCohereEmbeddingRequest(bifrostReq)
require.NotNil(t, req)
assert.Equal(t, "embed-v4.0", req.Model)
assert.Equal(t, "classification", req.InputType)
assert.Equal(t, []string{"hello"}, req.Texts)
assert.Equal(t, []string{"float", "int8"}, req.EmbeddingTypes)
assert.Equal(t, &dimensions, req.OutputDimension)
assert.Equal(t, &maxTokens, req.MaxTokens)
require.NotNil(t, req.Truncate)
assert.Equal(t, truncate, *req.Truncate)
assert.Equal(t, map[string]interface{}{"priority": "high"}, req.ExtraParams)
})
t.Run("multiple texts use default input type", func(t *testing.T) {
req := ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{
Model: "embed-english-v3.0",
Input: &schemas.EmbeddingInput{Texts: []string{"hello", "world"}},
})
require.NotNil(t, req)
assert.Equal(t, "embed-english-v3.0", req.Model)
assert.Equal(t, "search_document", req.InputType)
assert.Equal(t, []string{"hello", "world"}, req.Texts)
assert.Nil(t, req.ExtraParams)
})
}
func TestToCohereEmbeddingRequestBodyIncludesModelForDirectCohere(t *testing.T) {
text := "hello"
bifrostReq := &schemas.BifrostEmbeddingRequest{
Model: "embed-v4.0",
Input: &schemas.EmbeddingInput{Text: &text},
}
wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
context.Background(),
bifrostReq,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToCohereEmbeddingRequest(bifrostReq), nil
},
schemas.Cohere,
)
require.Nil(t, bifrostErr)
assert.JSONEq(t, `{
"model": "embed-v4.0",
"input_type": "search_document",
"texts": ["hello"]
}`, string(wireBody))
}

View File

@@ -0,0 +1,21 @@
package cohere
import (
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
func parseCohereError(resp *fasthttp.Response) *schemas.BifrostError {
var errorResp CohereError
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
bifrostErr.Type = &errorResp.Type
if bifrostErr.Error == nil {
bifrostErr.Error = &schemas.ErrorField{}
}
bifrostErr.Error.Message = errorResp.Message
if errorResp.Code != nil {
bifrostErr.Error.Code = errorResp.Code
}
return bifrostErr
}

View File

@@ -0,0 +1,92 @@
package cohere
import (
"encoding/json"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// CohereRerankRequest represents a Cohere rerank API request.
type CohereRerankRequest struct {
Model string `json:"model"`
Query string `json:"query"`
Documents []string `json:"documents"`
TopN *int `json:"top_n,omitempty"`
MaxTokensPerDoc *int `json:"max_tokens_per_doc,omitempty"`
Priority *int `json:"priority,omitempty"`
ExtraParams map[string]interface{} `json:"-"`
}
// GetExtraParams returns extra parameters for the rerank request.
func (r *CohereRerankRequest) GetExtraParams() map[string]interface{} {
return r.ExtraParams
}
// CohereRerankResult represents a single result from Cohere rerank.
type CohereRerankResult struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
Document json.RawMessage `json:"document,omitempty"`
}
// CohereRerankResponse represents a Cohere rerank API response.
type CohereRerankResponse struct {
ID string `json:"id"`
Results []CohereRerankResult `json:"results"`
Meta *CohereRerankMeta `json:"meta,omitempty"`
}
// CohereRerankMeta represents metadata in Cohere rerank response.
type CohereRerankMeta struct {
APIVersion *CohereEmbeddingAPIVersion `json:"api_version,omitempty"`
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"`
Tokens *CohereTokenUsage `json:"tokens,omitempty"`
}
func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(response.Models)),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: providerKey,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
for _, model := range response.Models {
// Cohere uses model.Name as the model identifier
for _, result := range pipeline.FilterModel(model.Name) {
entry := schemas.Model{
ID: string(providerKey) + "/" + result.ResolvedID,
Name: schemas.Ptr(model.Name),
ContextLength: schemas.Ptr(int(model.ContextLength)),
SupportedMethods: model.Endpoints,
}
if result.AliasValue != "" {
entry.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, entry)
included[strings.ToLower(result.ResolvedID)] = true
}
}
bifrostResponse.Data = append(bifrostResponse.Data,
pipeline.BackfillModels(included)...)
return bifrostResponse
}

View File

@@ -0,0 +1,209 @@
package cohere
import (
"sort"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"gopkg.in/yaml.v3"
)
// ToCohereRerankRequest converts a Bifrost rerank request to Cohere format
func ToCohereRerankRequest(bifrostReq *schemas.BifrostRerankRequest) *CohereRerankRequest {
if bifrostReq == nil {
return nil
}
cohereReq := &CohereRerankRequest{
Model: bifrostReq.Model,
Query: bifrostReq.Query,
}
// Cohere v2 expects documents as a list of strings.
documents := make([]string, len(bifrostReq.Documents))
for i, doc := range bifrostReq.Documents {
documents[i] = formatCohereRerankDocument(doc)
}
cohereReq.Documents = documents
if bifrostReq.Params != nil {
cohereReq.TopN = bifrostReq.Params.TopN
cohereReq.MaxTokensPerDoc = bifrostReq.Params.MaxTokensPerDoc
cohereReq.Priority = bifrostReq.Params.Priority
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
}
return cohereReq
}
// ToBifrostRerankRequest converts a Cohere rerank request to Bifrost format
func (req *CohereRerankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
if req == nil {
return nil
}
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
bifrostReq := &schemas.BifrostRerankRequest{
Provider: provider,
Model: model,
Query: req.Query,
Params: &schemas.RerankParameters{},
}
// Convert documents
for _, doc := range req.Documents {
bifrostReq.Documents = append(bifrostReq.Documents, schemas.RerankDocument{
Text: doc,
})
}
if req.TopN != nil {
bifrostReq.Params.TopN = req.TopN
}
if req.MaxTokensPerDoc != nil {
bifrostReq.Params.MaxTokensPerDoc = req.MaxTokensPerDoc
}
if req.Priority != nil {
bifrostReq.Params.Priority = req.Priority
}
if req.ExtraParams != nil {
bifrostReq.Params.ExtraParams = req.ExtraParams
}
return bifrostReq
}
// ToBifrostRerankResponse converts a Cohere rerank response to Bifrost format.
func (response *CohereRerankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) *schemas.BifrostRerankResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostRerankResponse{
ID: response.ID,
}
// Convert results
for _, result := range response.Results {
rerankResult := schemas.RerankResult{
Index: result.Index,
RelevanceScore: result.RelevanceScore,
}
// Convert document if present
if len(result.Document) > 0 {
var docMap map[string]interface{}
if err := sonic.Unmarshal(result.Document, &docMap); err == nil {
doc := &schemas.RerankDocument{}
populated := false
if text, ok := docMap["text"].(string); ok {
doc.Text = text
populated = true
}
if id, ok := docMap["id"].(string); ok {
doc.ID = &id
populated = true
}
// Collect metadata: unwrap "metadata"/"meta" keys to avoid nesting
meta := make(map[string]interface{})
if rawMeta, ok := docMap["metadata"].(map[string]interface{}); ok {
for k, v := range rawMeta {
meta[k] = v
}
} else if rawMeta, ok := docMap["meta"].(map[string]interface{}); ok {
for k, v := range rawMeta {
meta[k] = v
}
}
for k, v := range docMap {
if k != "text" && k != "id" && k != "metadata" && k != "meta" {
meta[k] = v
}
}
if len(meta) > 0 {
doc.Meta = meta
populated = true
}
if populated {
rerankResult.Document = doc
}
}
}
bifrostResponse.Results = append(bifrostResponse.Results, rerankResult)
}
sort.SliceStable(bifrostResponse.Results, func(i, j int) bool {
if bifrostResponse.Results[i].RelevanceScore == bifrostResponse.Results[j].RelevanceScore {
return bifrostResponse.Results[i].Index < bifrostResponse.Results[j].Index
}
return bifrostResponse.Results[i].RelevanceScore > bifrostResponse.Results[j].RelevanceScore
})
if returnDocuments {
for i := range bifrostResponse.Results {
resultIndex := bifrostResponse.Results[i].Index
if resultIndex >= 0 && resultIndex < len(documents) {
bifrostResponse.Results[i].Document = schemas.Ptr(documents[resultIndex])
}
}
}
// Convert usage information
if response.Meta != nil {
promptTokens := 0
completionTokens := 0
hasTokenUsage := false
if response.Meta.Tokens != nil {
if response.Meta.Tokens.InputTokens != nil {
promptTokens = int(*response.Meta.Tokens.InputTokens)
hasTokenUsage = true
}
if response.Meta.Tokens.OutputTokens != nil {
completionTokens = int(*response.Meta.Tokens.OutputTokens)
hasTokenUsage = true
}
} else if response.Meta.BilledUnits != nil {
if response.Meta.BilledUnits.InputTokens != nil {
promptTokens = int(*response.Meta.BilledUnits.InputTokens)
hasTokenUsage = true
}
if response.Meta.BilledUnits.OutputTokens != nil {
completionTokens = int(*response.Meta.BilledUnits.OutputTokens)
hasTokenUsage = true
}
}
if hasTokenUsage {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
}
return bifrostResponse
}
func formatCohereRerankDocument(doc schemas.RerankDocument) string {
if doc.ID == nil && len(doc.Meta) == 0 {
return doc.Text
}
// Keep metadata/id available by encoding a structured string document.
documentPayload := map[string]interface{}{
"text": doc.Text,
}
if doc.ID != nil {
documentPayload["id"] = *doc.ID
}
if len(doc.Meta) > 0 {
documentPayload["metadata"] = doc.Meta
}
encoded, err := yaml.Marshal(documentPayload)
if err != nil {
return doc.Text
}
return string(encoded)
}

View File

@@ -0,0 +1,72 @@
package cohere
import (
"encoding/json"
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCohereRerankResponseToBifrostRerankResponse(t *testing.T) {
response := (&CohereRerankResponse{
ID: "rerank-response-id",
Results: []CohereRerankResult{
{
Index: 1,
RelevanceScore: 0.62,
Document: json.RawMessage(`{"text":"provider-doc-1","id":"doc-1","topic":"geography"}`),
},
{
Index: 0,
RelevanceScore: 0.91,
Document: json.RawMessage(`{"text":"provider-doc-0"}`),
},
},
}).ToBifrostRerankResponse(nil, false)
require.NotNil(t, response)
assert.Equal(t, "rerank-response-id", response.ID)
require.Len(t, response.Results, 2)
assert.Equal(t, 0, response.Results[0].Index)
assert.Equal(t, 1, response.Results[1].Index)
require.NotNil(t, response.Results[0].Document)
require.NotNil(t, response.Results[1].Document)
assert.Equal(t, "provider-doc-0", response.Results[0].Document.Text)
assert.Equal(t, "provider-doc-1", response.Results[1].Document.Text)
require.NotNil(t, response.Results[1].Document.ID)
assert.Equal(t, "doc-1", *response.Results[1].Document.ID)
assert.Equal(t, "geography", response.Results[1].Document.Meta["topic"])
}
func TestCohereRerankResponseToBifrostRerankResponseReturnDocuments(t *testing.T) {
requestDocs := []schemas.RerankDocument{
{Text: "request-doc-0"},
{Text: "request-doc-1"},
}
response := (&CohereRerankResponse{
Results: []CohereRerankResult{
{
Index: 1,
RelevanceScore: 0.62,
Document: json.RawMessage(`{"text":"provider-doc-1"}`),
},
{
Index: 0,
RelevanceScore: 0.91,
Document: json.RawMessage(`{"text":"provider-doc-0"}`),
},
},
}).ToBifrostRerankResponse(requestDocs, true)
require.NotNil(t, response)
require.Len(t, response.Results, 2)
require.NotNil(t, response.Results[0].Document)
require.NotNil(t, response.Results[1].Document)
assert.Equal(t, 0, response.Results[0].Index)
assert.Equal(t, 1, response.Results[1].Index)
assert.Equal(t, "request-doc-0", response.Results[0].Document.Text)
assert.Equal(t, "request-doc-1", response.Results[1].Document.Text)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,616 @@
package cohere
import (
"encoding/json"
"fmt"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
const (
MinimumReasoningMaxTokens = 1
DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default
)
// Limits for tokenize input api call https://docs.cohere.com/reference/tokenize#request
const (
cohereTokenizeMinTextLength = 1
cohereTokenizeMaxTextLength = 65536
)
// ==================== REQUEST TYPES ====================
// CohereChatRequest represents a Cohere chat completion request
type CohereChatRequest struct {
Model string `json:"model"` // Required: Model to use for chat completion
Messages []CohereMessage `json:"messages"` // Required: Array of message objects
Tools []CohereChatRequestTool `json:"tools,omitempty"` // Optional: Tools available for the model
ToolChoice *CohereToolChoice `json:"tool_choice,omitempty"` // Optional: Tool choice configuration
Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature
P *float64 `json:"p,omitempty"` // Optional: Top-p sampling
K *int `json:"k,omitempty"` // Optional: Top-k sampling
MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate
StopSequences []string `json:"stop_sequences,omitempty"` // Optional: Stop sequences
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty
Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming
SafetyMode *string `json:"safety_mode,omitempty"` // Optional: Safety mode
LogProbs *bool `json:"log_probs,omitempty"` // Optional: Log probabilities
StrictToolChoice *bool `json:"strict_tool_choice,omitempty"` // Optional: Strict tool choice
Thinking *CohereThinking `json:"thinking,omitempty"` // Optional: Reasoning configuration
ResponseFormat *CohereResponseFormat `json:"response_format,omitempty"` // Optional: Format for the response
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
}
// IsStreamingRequested implements the StreamingRequest interface
func (r *CohereChatRequest) IsStreamingRequested() bool {
return r.Stream != nil && *r.Stream
}
func (r *CohereChatRequest) GetExtraParams() map[string]interface{} {
return r.ExtraParams
}
type CohereChatRequestTool struct {
Type string `json:"type"` // always "function"
Function CohereChatRequestFunction `json:"function"`
}
type CohereChatRequestFunction struct {
Name string `json:"name"` // Function name
Parameters interface{} `json:"parameters,omitempty"` // Function parameters (JSON string)
Description *string `json:"description,omitempty"` // Optional: Function description
}
// CohereMessage represents a message in Cohere format
type CohereMessage struct {
Role string `json:"role"` // Required: Message role (system, user, assistant, tool)
Content *CohereMessageContent `json:"content,omitempty"` // Optional: Message content (string or array of content blocks)
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"` // Optional: Tool calls (for assistant messages)
ToolCallID *string `json:"tool_call_id,omitempty"` // Optional: Tool call ID (for tool messages)
ToolPlan *string `json:"tool_plan,omitempty"` // Optional: Chain-of-thought style reflection (assistant only)
}
// CohereMessageContent represents flexible content that can be string or content blocks
type CohereMessageContent struct {
// Use custom marshaling to handle string or []CohereContentBlock
StringContent *string `json:"-"`
BlocksContent []CohereContentBlock `json:"-"`
}
// MarshalJSON implements custom JSON marshaling for CohereMessageContent
func (c *CohereMessageContent) MarshalJSON() ([]byte, error) {
if c.StringContent != nil {
return providerUtils.MarshalSorted(*c.StringContent)
}
if c.BlocksContent != nil {
return providerUtils.MarshalSorted(c.BlocksContent)
}
return []byte("null"), nil
}
// UnmarshalJSON implements custom JSON unmarshaling for CohereMessageContent
func (c *CohereMessageContent) UnmarshalJSON(data []byte) error {
// Try to unmarshal as string first
var str string
if err := sonic.Unmarshal(data, &str); err == nil {
c.StringContent = &str
return nil
}
// Try to unmarshal as content blocks array
var blocks []CohereContentBlock
if err := sonic.Unmarshal(data, &blocks); err == nil {
c.BlocksContent = blocks
return nil
}
return fmt.Errorf("content must be either string or array of content blocks")
}
// Helper methods for CohereMessageContent
// NewStringContent creates a CohereMessageContent with string content
func NewStringContent(content string) *CohereMessageContent {
return &CohereMessageContent{
StringContent: &content,
}
}
// NewBlocksContent creates a CohereMessageContent with content blocks
func NewBlocksContent(blocks []CohereContentBlock) *CohereMessageContent {
return &CohereMessageContent{
BlocksContent: blocks,
}
}
// IsString returns true if content is a string
func (c *CohereMessageContent) IsString() bool {
return c.StringContent != nil
}
// IsBlocks returns true if content is content blocks
func (c *CohereMessageContent) IsBlocks() bool {
return c.BlocksContent != nil
}
// GetString returns the string content (nil if not string)
func (c *CohereMessageContent) GetString() *string {
return c.StringContent
}
// GetBlocks returns the content blocks (nil if not blocks)
func (c *CohereMessageContent) GetBlocks() []CohereContentBlock {
return c.BlocksContent
}
type CohereContentBlockType string
const (
CohereContentBlockTypeText CohereContentBlockType = "text"
CohereContentBlockTypeImage CohereContentBlockType = "image_url"
CohereContentBlockTypeThinking CohereContentBlockType = "thinking"
CohereContentBlockTypeDocument CohereContentBlockType = "document"
)
// CohereContentBlock represents a content block in Cohere format
// This is a union type that can be text, image_url, thinking, or document
type CohereContentBlock struct {
Type CohereContentBlockType `json:"type"` // Required: Content block type
// Text content block
Text *string `json:"text,omitempty"`
// Image URL content block
ImageURL *CohereImageURL `json:"image_url,omitempty"`
// Thinking content block (assistant only)
Thinking *string `json:"thinking,omitempty"`
// Document content block (tool messages)
Document *CohereDocument `json:"document,omitempty"`
}
// CohereImageURL represents an image URL content block
type CohereImageURL struct {
URL string `json:"url"` // Required: Image URL
}
// CohereDocument represents a document content block
type CohereDocument struct {
Data schemas.OrderedMap `json:"data"` // Required: Document data as key-value pairs
ID *string `json:"id,omitempty"` // Optional: Document ID for citations
}
// CohereThinking represents reasoning configuration
type CohereThinking struct {
Type CohereThinkingType `json:"type"` // Required: Reasoning type (enabled, disabled)
TokenBudget *int `json:"token_budget,omitempty"` // Optional: Maximum thinking tokens (>=1)
}
// CohereThinkingType represents the type of reasoning
type CohereThinkingType string
const (
ThinkingTypeEnabled CohereThinkingType = "enabled"
ThinkingTypeDisabled CohereThinkingType = "disabled"
)
// CohereResponseFormat represents the response format configuration for Cohere chat requests
type CohereResponseFormat struct {
Type CohereResponseFormatType `json:"type"` // Required: Response format type
JSONSchema *interface{} `json:"schema,omitempty"` // Optional: JSON schema for structured output (not used when type is "text")
}
// CohereResponseFormatType represents the type of response format
type CohereResponseFormatType string
const (
ResponseFormatTypeText CohereResponseFormatType = "text"
ResponseFormatTypeJSONObject CohereResponseFormatType = "json_object"
)
// CohereToolChoice represents tool choice configuration
type CohereToolChoice string
const (
ToolChoiceRequired CohereToolChoice = "REQUIRED"
ToolChoiceNone CohereToolChoice = "NONE"
ToolChoiceAuto CohereToolChoice = "AUTO"
)
// CohereToolCall represents a tool call in Cohere format
type CohereToolCall struct {
ID *string `json:"id,omitempty"` // Optional: Tool call ID
Type string `json:"type"` // Required: Tool call type (must be "function")
Function *CohereFunction `json:"function"` // Required: Function call details
}
// CohereFunction represents a function call
type CohereFunction struct {
Name *string `json:"name,omitempty"` // Optional: Function name
Arguments string `json:"arguments,omitempty"` // Optional: Function arguments (JSON string)
}
// CohereParameterDefinition represents a parameter definition for a Cohere tool.
// It defines the type, description, and whether the parameter is required.
type CohereParameterDefinition struct {
Type string `json:"type"` // Type of the parameter
Description *string `json:"description,omitempty"` // Optional description of the parameter
Required bool `json:"required"` // Whether the parameter is required
}
// CohereTool represents a tool definition for the Cohere API.
// It includes the tool's name, description, and parameter definitions.
type CohereTool struct {
Name string `json:"name"` // Name of the tool
Description string `json:"description"` // Description of the tool
ParameterDefinitions map[string]CohereParameterDefinition `json:"parameter_definitions"` // Definitions of the tool's parameters
}
// CohereCountTokensRequest represents a Cohere tokenize request
type CohereCountTokensRequest struct {
Model string `json:"model"` // Required: Model whose tokenizer should be used
Text string `json:"text"` // Required: Text to tokenize (1-65536 chars)
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
}
func (r *CohereCountTokensRequest) GetExtraParams() map[string]interface{} {
return r.ExtraParams
}
// CohereEmbeddingRequest represents a Cohere embedding request
type CohereEmbeddingRequest struct {
Model string `json:"model"` // Required: ID of embedding model
InputType string `json:"input_type"` // Required: Type of input for v3+ models
Texts []string `json:"texts,omitempty"` // Optional: Array of strings to embed (max 96)
Images []string `json:"images,omitempty"` // Optional: Array of image data URIs (max 1)
Inputs []CohereEmbeddingInput `json:"inputs,omitempty"` // Optional: Array of mixed text/image inputs (max 96)
MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Max tokens to embed per input
OutputDimension *int `json:"output_dimension,omitempty"` // Optional: Embedding dimensions (256, 512, 1024, 1536)
EmbeddingTypes []string `json:"embedding_types,omitempty"` // Optional: Types of embeddings to return
Truncate *string `json:"truncate,omitempty"` // Optional: How to handle long inputs
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
}
func (r *CohereEmbeddingRequest) GetExtraParams() map[string]interface{} {
return r.ExtraParams
}
// CohereEmbeddingInput represents a mixed text/image input
type CohereEmbeddingInput struct {
Content []CohereContentBlock `json:"content"` // Required: Array of content blocks (reuses chat content blocks)
}
// CohereEmbeddingResponse represents a Cohere embedding response
type CohereEmbeddingResponse struct {
ID string `json:"id"` // Response ID
Embeddings *CohereEmbeddingData `json:"embeddings,omitempty"` // Embedding data object
ResponseType *string `json:"response_type,omitempty"` // Response type (embeddings_floats, embeddings_by_type)
Texts []string `json:"texts,omitempty"` // Original text entries
Images []CohereEmbeddingImageInfo `json:"images,omitempty"` // Original image entries
Meta *CohereEmbeddingMeta `json:"meta,omitempty"` // Response metadata
}
// CohereEmbeddingData represents the embeddings object with different types
type CohereEmbeddingData struct {
Float [][]float64 `json:"float,omitempty"` // Float embeddings
Int8 [][]int8 `json:"int8,omitempty"` // Int8 embeddings
Uint8 [][]uint8 `json:"uint8,omitempty"` // Uint8 embeddings
Binary [][]int8 `json:"binary,omitempty"` // Binary embeddings
Ubinary [][]uint8 `json:"ubinary,omitempty"` // Unsigned binary embeddings
Base64 []string `json:"base64,omitempty"` // Base64 embeddings
}
// CohereEmbeddingImageInfo represents image information in the response
type CohereEmbeddingImageInfo struct {
Width int64 `json:"width"` // Width in pixels
Height int64 `json:"height"` // Height in pixels
Format string `json:"format"` // Image format
BitDepth int64 `json:"bit_depth"` // Bit depth
}
// CohereEmbeddingMeta represents metadata in embedding response
type CohereEmbeddingMeta struct {
APIVersion *CohereEmbeddingAPIVersion `json:"api_version,omitempty"` // API version info
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"` // Billing information
Tokens *CohereTokenUsage `json:"tokens,omitempty"` // Token usage
Warnings []string `json:"warnings,omitempty"` // Any warnings
}
// CohereEmbeddingAPIVersion represents API version information
type CohereEmbeddingAPIVersion struct {
Version *string `json:"version,omitempty"` // API version
IsDeprecated *bool `json:"is_deprecated,omitempty"` // Deprecation status
IsExperimental *bool `json:"is_experimental,omitempty"` // Experimental status
}
// ==================== RESPONSE TYPES ====================
// CohereCountTokensResponse represents the response from the tokenize endpoint
type CohereCountTokensResponse struct {
Tokens []int `json:"tokens"`
TokenStrings []string `json:"token_strings,omitempty"`
Meta *CohereTokenizeMeta `json:"meta,omitempty"`
}
// CohereTokenizeMeta captures metadata returned by the tokenize endpoint
type CohereTokenizeMeta struct {
APIVersion *CohereTokenizeAPIVersion `json:"api_version,omitempty"`
}
// CohereTokenizeAPIVersion describes API version metadata
type CohereTokenizeAPIVersion struct {
Version *string `json:"version,omitempty"`
}
// CohereChatResponse represents a Cohere chat completion response
type CohereChatResponse struct {
ID string `json:"id"` // Unique identifier for the generated reply
FinishReason *CohereFinishReason `json:"finish_reason,omitempty"` // Reason for completion
Message *CohereMessage `json:"message,omitempty"` // Generated message from assistant
Usage *CohereUsage `json:"usage,omitempty"` // Token usage information
LogProbs []CohereLogProb `json:"logprobs,omitempty"` // Log probabilities (if requested)
}
// CohereFinishReason represents the reason a chat request has finished
type CohereFinishReason string
const (
FinishReasonComplete CohereFinishReason = "COMPLETE" // Model finished sending complete message
FinishReasonStopSequence CohereFinishReason = "STOP_SEQUENCE" // Stop sequence was reached
FinishReasonMaxTokens CohereFinishReason = "MAX_TOKENS" // Max tokens exceeded
FinishReasonToolCall CohereFinishReason = "TOOL_CALL" // Model generated tool call
FinishReasonError CohereFinishReason = "ERROR" // Generation failed due to internal error
FinishReasonTimeout CohereFinishReason = "TIMEOUT" // Timeout
)
// CohereUsage represents token usage information
type CohereUsage struct {
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"` // Billed usage information
Tokens *CohereTokenUsage `json:"tokens,omitempty"` // Token usage details
CachedTokens *int `json:"cached_tokens,omitempty"` // Cached tokens
}
// CohereBilledUnits represents billed usage information
type CohereBilledUnits struct {
InputTokens *int `json:"input_tokens,omitempty"` // Number of billed input tokens
OutputTokens *int `json:"output_tokens,omitempty"` // Number of billed output tokens
SearchUnits *int `json:"search_units,omitempty"` // Number of billed search units
Classifications *int `json:"classifications,omitempty"` // Number of billed classification units
}
// CohereTokenUsage represents detailed token usage
type CohereTokenUsage struct {
InputTokens *int `json:"input_tokens"` // Number of input tokens used
OutputTokens *int `json:"output_tokens"` // Number of output tokens produced
}
// CohereLogProb represents log probability information
type CohereLogProb struct {
TokenIDs []int `json:"token_ids"` // Token IDs of each token in text chunk
Text *string `json:"text,omitempty"` // Text chunk for log probabilities
LogProbs []float64 `json:"logprobs,omitempty"` // Log probability of each token
}
type CohereCitationType string
const (
CitationTypeTextContent CohereCitationType = "TEXT_CONTENT"
CitationTypeThinkingContent CohereCitationType = "THINKING_CONTENT"
CitationTypePlan CohereCitationType = "PLAN"
)
type CohereSourceType string
const (
SourceTypeTool CohereSourceType = "tool"
SourceTypeDocument CohereSourceType = "document"
)
// CohereCitation represents a citation in the response
type CohereCitation struct {
Start int `json:"start"` // Start position of cited text
End int `json:"end"` // End position of cited text
Text string `json:"text"` // Cited text
Sources []CohereSource `json:"sources,omitempty"` // Citation sources
ContentIndex int `json:"content_index"` // Content index of the citation
Type CohereCitationType `json:"type"` // Type of citation
}
// CohereSource represents a citation source
type CohereSource struct {
Type CohereSourceType `json:"type"` // Source type ("tool" or "document")
ID *string `json:"id,omitempty"` // Source ID (nullable)
ToolOutput *json.RawMessage `json:"tool_output,omitempty"` // Tool output (for tool sources, json.RawMessage preserves key ordering)
Document *json.RawMessage `json:"document,omitempty"` // Document data (for document sources, json.RawMessage preserves key ordering)
}
// ==================== STREAMING TYPES ====================
// CohereStreamEventType represents the type of streaming event
type CohereStreamEventType string
const (
StreamEventMessageStart CohereStreamEventType = "message-start"
StreamEventContentStart CohereStreamEventType = "content-start"
StreamEventContentDelta CohereStreamEventType = "content-delta"
StreamEventContentEnd CohereStreamEventType = "content-end"
StreamEventToolPlanDelta CohereStreamEventType = "tool-plan-delta"
StreamEventToolCallStart CohereStreamEventType = "tool-call-start"
StreamEventToolCallDelta CohereStreamEventType = "tool-call-delta"
StreamEventToolCallEnd CohereStreamEventType = "tool-call-end"
StreamEventCitationStart CohereStreamEventType = "citation-start"
StreamEventCitationEnd CohereStreamEventType = "citation-end"
StreamEventMessageEnd CohereStreamEventType = "message-end"
StreamEventDebug CohereStreamEventType = "debug"
)
// CohereStreamEvent represents a unified streaming event from Cohere API
type CohereStreamEvent struct {
Type CohereStreamEventType `json:"type"`
ID *string `json:"id,omitempty"` // For message-start
Index *int `json:"index,omitempty"` // For indexed events
Delta *CohereStreamDelta `json:"delta,omitempty"`
}
// CohereStreamDelta represents the delta content in streaming events
type CohereStreamDelta struct {
Message *CohereStreamMessage `json:"message,omitempty"`
FinishReason *CohereFinishReason `json:"finish_reason,omitempty"`
Usage *CohereUsage `json:"usage,omitempty"`
}
type CohereStreamToolCallStruct struct {
CohereToolCallObject *CohereToolCall
CohereToolCallArray []CohereToolCall
}
// JSON marshaling for CohereStreamToolCall
func (c *CohereStreamToolCallStruct) MarshalJSON() ([]byte, error) {
if c.CohereToolCallObject != nil {
return providerUtils.MarshalSorted(c.CohereToolCallObject)
}
if c.CohereToolCallArray != nil {
return providerUtils.MarshalSorted(c.CohereToolCallArray)
}
return providerUtils.MarshalSorted(nil)
}
func (c *CohereStreamToolCallStruct) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
return nil
}
// Try to unmarshal as array first
var toolCallArray []CohereToolCall
if err := sonic.Unmarshal(data, &toolCallArray); err == nil {
c.CohereToolCallArray = toolCallArray
return nil
}
// Try to unmarshal as single object
var toolCallObject CohereToolCall
if err := sonic.Unmarshal(data, &toolCallObject); err == nil {
c.CohereToolCallObject = &toolCallObject
return nil
}
return fmt.Errorf("tool_calls field is neither array nor object")
}
type CohereStreamContentStruct struct {
CohereStreamContentObject *CohereStreamContent
CohereStreamContentArray []CohereStreamContent
}
func (c *CohereStreamContentStruct) MarshalJSON() ([]byte, error) {
if c.CohereStreamContentObject != nil {
return providerUtils.MarshalSorted(c.CohereStreamContentObject)
}
if c.CohereStreamContentArray != nil {
return providerUtils.MarshalSorted(c.CohereStreamContentArray)
}
return providerUtils.MarshalSorted(nil)
}
func (c *CohereStreamContentStruct) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
return nil
}
// Try to unmarshal as array first
var contentArray []CohereStreamContent
if err := sonic.Unmarshal(data, &contentArray); err == nil {
c.CohereStreamContentArray = contentArray
return nil
}
// Try to unmarshal as single object
var contentObject CohereStreamContent
if err := sonic.Unmarshal(data, &contentObject); err == nil {
c.CohereStreamContentObject = &contentObject
return nil
}
return fmt.Errorf("content field is neither array nor object")
}
type CohereStreamCitationStruct struct {
CohereStreamCitationObject *CohereCitation
CohereStreamCitationArray []CohereCitation
}
func (c *CohereStreamCitationStruct) MarshalJSON() ([]byte, error) {
if c.CohereStreamCitationObject != nil {
return providerUtils.MarshalSorted(c.CohereStreamCitationObject)
}
if c.CohereStreamCitationArray != nil {
return providerUtils.MarshalSorted(c.CohereStreamCitationArray)
}
return providerUtils.MarshalSorted(nil)
}
func (c *CohereStreamCitationStruct) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
return nil
}
// Try to unmarshal as array first
var citationArray []CohereCitation
if err := sonic.Unmarshal(data, &citationArray); err == nil {
c.CohereStreamCitationArray = citationArray
return nil
}
// Try to unmarshal as single object
var citationObject CohereCitation
if err := sonic.Unmarshal(data, &citationObject); err == nil {
c.CohereStreamCitationObject = &citationObject
return nil
}
return fmt.Errorf("citations field is neither array nor object")
}
// CohereStreamMessage represents the message part of streaming deltas
type CohereStreamMessage struct {
Role *string `json:"role,omitempty"` // For message-start
Content *CohereStreamContentStruct `json:"content,omitempty"` // For content events (object)
ToolPlan *string `json:"tool_plan,omitempty"` // For tool-plan-delta
ToolCalls *CohereStreamToolCallStruct `json:"tool_calls,omitempty"` // For tool-call events (flexible)
Citations *CohereStreamCitationStruct `json:"citations,omitempty"` // For citation events
}
// CohereStreamContent represents content in streaming events
type CohereStreamContent struct {
Type CohereContentBlockType `json:"type,omitempty"` // For content-start
Text *string `json:"text,omitempty"` // For content deltas
Thinking *string `json:"thinking,omitempty"` // For thinking deltas
}
// ==================== ERROR TYPES ====================
// CohereError represents an error response from the Cohere API
type CohereError struct {
Type string `json:"type"` // Error type
Message string `json:"message"` // Error message
Code *string `json:"code,omitempty"` // Optional error code
}
// ==================== MODEL TYPES ====================
type CohereModel struct {
Name string `json:"name"`
IsDeprecated bool `json:"is_deprecated"`
Endpoints []string `json:"endpoints"`
Finetuned bool `json:"finetuned"`
ContextLength int `json:"context_length"`
TokenizerURL string `json:"tokenizer_url"`
DefaultEndpoints []string `json:"default_endpoints"`
Features []string `json:"features"`
}
type CohereListModelsResponse struct {
Models []CohereModel `json:"models"`
NextPageToken string `json:"next_page_token"`
}

View File

@@ -0,0 +1,293 @@
package cohere
import (
"encoding/json"
"github.com/maximhq/bifrost/core/schemas"
"github.com/tidwall/sjson"
)
var (
// Maps provider-specific finish reasons to Bifrost format
cohereFinishReasonToBifrost = map[CohereFinishReason]string{
FinishReasonComplete: "stop",
FinishReasonStopSequence: "stop",
FinishReasonMaxTokens: "length",
FinishReasonToolCall: "tool_calls",
}
)
// ConvertCohereFinishReasonToBifrost converts provider finish reasons to Bifrost format
func ConvertCohereFinishReasonToBifrost(providerReason CohereFinishReason) string {
if bifrostReason, ok := cohereFinishReasonToBifrost[providerReason]; ok {
return bifrostReason
}
return string(providerReason)
}
// convertInterfaceToToolFunctionParameters converts an interface{} to ToolFunctionParameters
// This handles the conversion from Cohere's flexible parameter format to Bifrost's structured format
func convertInterfaceToToolFunctionParameters(params interface{}) *schemas.ToolFunctionParameters {
if params == nil {
return nil
}
// Try to convert from map[string]interface{}
paramsMap, ok := params.(map[string]interface{})
if !ok {
return nil
}
result := &schemas.ToolFunctionParameters{}
// Extract type
if typeVal, ok := paramsMap["type"].(string); ok {
result.Type = typeVal
}
// Extract description
if descVal, ok := paramsMap["description"].(string); ok {
result.Description = &descVal
}
// Extract required
if requiredVal, ok := paramsMap["required"].([]interface{}); ok {
required := make([]string, 0, len(requiredVal))
for _, v := range requiredVal {
if s, ok := v.(string); ok {
required = append(required, s)
}
}
result.Required = required
}
// Extract properties
if orderedProps, ok := schemas.SafeExtractOrderedMap(paramsMap["properties"]); ok {
result.Properties = orderedProps
}
// Extract enum
if enumVal, ok := paramsMap["enum"].([]interface{}); ok {
enum := make([]string, 0, len(enumVal))
for _, v := range enumVal {
if s, ok := v.(string); ok {
enum = append(enum, s)
}
}
result.Enum = enum
}
// Extract additionalProperties
if addPropsVal, ok := paramsMap["additionalProperties"].(bool); ok {
result.AdditionalProperties = &schemas.AdditionalPropertiesStruct{
AdditionalPropertiesBool: &addPropsVal,
}
}
if addPropsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["additionalProperties"]); ok {
result.AdditionalProperties = &schemas.AdditionalPropertiesStruct{
AdditionalPropertiesMap: addPropsVal,
}
}
// Extract $defs (JSON Schema draft 2019-09+)
if defsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["$defs"]); ok {
result.Defs = defsVal
}
// Extract definitions (legacy JSON Schema draft-07)
if defsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["definitions"]); ok {
result.Definitions = defsVal
}
// Extract $ref
if refVal, ok := paramsMap["$ref"].(string); ok {
result.Ref = &refVal
}
// Extract items (array element schema)
if itemsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["items"]); ok {
result.Items = itemsVal
}
// Extract minItems
if minItemsVal, ok := extractInt64(paramsMap["minItems"]); ok {
result.MinItems = &minItemsVal
}
// Extract maxItems
if maxItemsVal, ok := extractInt64(paramsMap["maxItems"]); ok {
result.MaxItems = &maxItemsVal
}
// Extract anyOf
if anyOfVal, ok := paramsMap["anyOf"].([]interface{}); ok {
anyOf := make([]schemas.OrderedMap, 0, len(anyOfVal))
for _, v := range anyOfVal {
if m, ok := schemas.SafeExtractOrderedMap(v); ok {
anyOf = append(anyOf, *m)
}
}
result.AnyOf = anyOf
}
// Extract oneOf
if oneOfVal, ok := paramsMap["oneOf"].([]interface{}); ok {
oneOf := make([]schemas.OrderedMap, 0, len(oneOfVal))
for _, v := range oneOfVal {
if m, ok := schemas.SafeExtractOrderedMap(v); ok {
oneOf = append(oneOf, *m)
}
}
result.OneOf = oneOf
}
// Extract allOf
if allOfVal, ok := paramsMap["allOf"].([]interface{}); ok {
allOf := make([]schemas.OrderedMap, 0, len(allOfVal))
for _, v := range allOfVal {
if m, ok := schemas.SafeExtractOrderedMap(v); ok {
allOf = append(allOf, *m)
}
}
result.AllOf = allOf
}
// Extract format
if formatVal, ok := paramsMap["format"].(string); ok {
result.Format = &formatVal
}
// Extract pattern
if patternVal, ok := paramsMap["pattern"].(string); ok {
result.Pattern = &patternVal
}
// Extract minLength
if minLengthVal, ok := extractInt64(paramsMap["minLength"]); ok {
result.MinLength = &minLengthVal
}
// Extract maxLength
if maxLengthVal, ok := extractInt64(paramsMap["maxLength"]); ok {
result.MaxLength = &maxLengthVal
}
// Extract minimum
if minVal, ok := extractFloat64(paramsMap["minimum"]); ok {
result.Minimum = &minVal
}
// Extract maximum
if maxVal, ok := extractFloat64(paramsMap["maximum"]); ok {
result.Maximum = &maxVal
}
// Extract title
if titleVal, ok := paramsMap["title"].(string); ok {
result.Title = &titleVal
}
// Extract default
if defaultVal, exists := paramsMap["default"]; exists {
result.Default = defaultVal
}
// Extract nullable
if nullableVal, ok := paramsMap["nullable"].(bool); ok {
result.Nullable = &nullableVal
}
return result
}
// extractInt64 extracts an int64 from various numeric types
func extractInt64(v interface{}) (int64, bool) {
switch val := v.(type) {
case int:
return int64(val), true
case int64:
return val, true
case float64:
return int64(val), true
case float32:
return int64(val), true
default:
return 0, false
}
}
// extractFloat64 extracts a float64 from various numeric types
func extractFloat64(v interface{}) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
// ConvertResponseFormatToCohere converts OpenAI-style response_format (interface{}) to Cohere's typed format
// Input can be a map with structure: { type: "json_schema", json_schema: { schema: {...} } }
// Output: CohereResponseFormat with flat structure: { type: "json_object", json_schema: {...} }
func convertResponseFormatToCohere(responseFormat *interface{}) *CohereResponseFormat {
if responseFormat == nil {
return nil
}
// Try to extract as map
formatMap, ok := (*responseFormat).(map[string]interface{})
if !ok {
return nil
}
cohereFormat := &CohereResponseFormat{}
// Extract type
typeStr, _ := formatMap["type"].(string)
switch typeStr {
case "text":
cohereFormat.Type = ResponseFormatTypeText
case "json_object", "json_schema":
cohereFormat.Type = ResponseFormatTypeJSONObject
// Extract the nested schema
// OpenAI format: { type: "json_schema", json_schema: { name: "X", strict: true, schema: {...} } }
if jsonSchemaWrapper, ok := formatMap["json_schema"].(map[string]interface{}); ok {
if schema, ok := jsonSchemaWrapper["schema"].(map[string]interface{}); ok {
var schemaInterface interface{} = schema
cohereFormat.JSONSchema = &schemaInterface
}
}
default:
return nil
}
return cohereFormat
}
// convertCohereResponseFormatToBifrost converts Cohere's typed response_format back to interface{}
func convertCohereResponseFormatToBifrost(cohereFormat *CohereResponseFormat) *interface{} {
if cohereFormat == nil {
return nil
}
// Build JSON bytes with deterministic key order using sjson
data := []byte(`{}`)
if cohereFormat.JSONSchema != nil {
data, _ = sjson.SetBytes(data, "type", "json_schema")
schemaBytes, _ := schemas.MarshalSorted(cohereFormat.JSONSchema)
data, _ = sjson.SetRawBytes(data, "json_schema", schemaBytes)
} else {
data, _ = sjson.SetBytes(data, "type", string(cohereFormat.Type))
}
var resultInterface interface{} = json.RawMessage(data)
return &resultInterface
}

View File

@@ -0,0 +1,933 @@
package elevenlabs
import (
"bytes"
"context"
"errors"
"io"
"mime/multipart"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
type ElevenlabsProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
customProviderConfig *schemas.CustomProviderConfig // Custom provider config
}
// NewElevenlabsProvider creates a new Elevenlabs provider instance.
// It initializes the HTTP client with the provided configuration.
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
func NewElevenlabsProvider(config *schemas.ProviderConfig, logger schemas.Logger) *ElevenlabsProvider {
config.CheckAndSetDefaults()
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
client := &fasthttp.Client{
ReadTimeout: requestTimeout,
WriteTimeout: requestTimeout,
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
MaxIdleConnDuration: 30 * time.Second,
MaxConnWaitTimeout: requestTimeout,
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
ConnPoolStrategy: fasthttp.FIFO,
}
// Configure proxy and retry policy
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
client = providerUtils.ConfigureDialer(client)
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
streamingClient := providerUtils.BuildStreamingClient(client)
// Set default BaseURL if not provided
if config.NetworkConfig.BaseURL == "" {
config.NetworkConfig.BaseURL = "https://api.elevenlabs.io"
}
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
return &ElevenlabsProvider{
logger: logger,
client: client,
streamingClient: streamingClient,
networkConfig: config.NetworkConfig,
customProviderConfig: config.CustomProviderConfig,
sendBackRawRequest: config.SendBackRawRequest,
sendBackRawResponse: config.SendBackRawResponse,
}
}
// GetProviderKey returns the provider identifier for Elevenlabs.
func (provider *ElevenlabsProvider) GetProviderKey() schemas.ModelProvider {
return providerUtils.GetProviderName(schemas.Elevenlabs, provider.customProviderConfig)
}
// listModelsByKey performs a list models request for a single key.
// Returns the response and latency, or an error if the request fails.
func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Set any extra headers from network config
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
// Build URL using centralized URL construction
req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, "/v1/models"))
req.Header.SetMethod(http.MethodGet)
req.Header.SetContentType("application/json")
if key.Value.GetValue() != "" {
req.Header.Set("xi-api-key", key.Value.GetValue())
}
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
// Extract and set provider response headers so they're available on error paths
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
if resp.StatusCode() != fasthttp.StatusOK {
return nil, parseElevenlabsError(resp)
}
var elevenlabsResponse ElevenlabsListModelsResponse
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(resp.Body(), &elevenlabsResponse, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
response := elevenlabsResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered)
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
// ListModels performs a list models request to Elevenlabs' API.
// Requests are made concurrently for improved performance.
func (provider *ElevenlabsProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.ListModelsRequest); err != nil {
return nil, err
}
return providerUtils.HandleMultipleListModelsRequests(
ctx,
keys,
request,
provider.listModelsByKey,
)
}
// TextCompletion is not supported by the Elevenlabs provider
func (provider *ElevenlabsProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey())
}
// TextCompletionStream is not supported by the Elevenlabs provider
func (provider *ElevenlabsProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey())
}
// ChatCompletion is not supported by the Elevenlabs provider
func (provider *ElevenlabsProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ChatCompletionRequest, provider.GetProviderKey())
}
// ChatCompletionStream is not supported by the Elevenlabs provider
func (provider *ElevenlabsProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ChatCompletionStreamRequest, provider.GetProviderKey())
}
// Responses is not supported by the Elevenlabs provider
func (provider *ElevenlabsProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ResponsesRequest, provider.GetProviderKey())
}
// ResponsesStream is not supported by the Elevenlabs provider
func (provider *ElevenlabsProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ResponsesStreamRequest, provider.GetProviderKey())
}
// Embedding is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, input *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey())
}
// Speech performs a text to speech request
func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.SpeechRequest); err != nil {
return nil, err
}
// Create request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Set any extra headers from network config
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
withTimestampsRequest := request.Params != nil && request.Params.WithTimestamps != nil && *request.Params.WithTimestamps
var endpoint string
if request.Params != nil && request.Params.VoiceConfig != nil && request.Params.VoiceConfig.Voice != nil {
voice := *request.Params.VoiceConfig.Voice
// Determine if timestamps are requested
if withTimestampsRequest {
endpoint = "/v1/text-to-speech/" + voice + "/with-timestamps"
} else {
endpoint = "/v1/text-to-speech/" + voice
}
} else {
return nil, providerUtils.NewBifrostOperationError("voice parameter is required", nil)
}
requestURL := provider.buildBaseSpeechRequestURL(ctx, endpoint, schemas.SpeechRequest, request)
req.SetRequestURI(requestURL)
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
if key.Value.GetValue() != "" {
req.Header.Set("xi-api-key", key.Value.GetValue())
}
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToElevenlabsSpeechRequest(request), nil
})
if bifrostErr != nil {
return nil, bifrostErr
}
if !providerUtils.ApplyLargePayloadRequestBody(ctx, req) {
req.SetBody(jsonData)
}
// Make request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Extract and set provider response headers so they're available on error paths
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Get the response body
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Create response based on whether timestamps were requested
bifrostResponse := &schemas.BifrostSpeechResponse{
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
},
}
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData)
}
if withTimestampsRequest {
var timestampResponse ElevenlabsSpeechWithTimestampsResponse
if err := sonic.Unmarshal(body, &timestampResponse); err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to parse with-timestamps response", err)
}
bifrostResponse.AudioBase64 = &timestampResponse.AudioBase64
if timestampResponse.Alignment != nil {
bifrostResponse.Alignment = &schemas.SpeechAlignment{
CharStartTimesMs: timestampResponse.Alignment.CharStartTimesMs,
CharEndTimesMs: timestampResponse.Alignment.CharEndTimesMs,
Characters: timestampResponse.Alignment.Characters,
}
}
if timestampResponse.NormalizedAlignment != nil {
bifrostResponse.NormalizedAlignment = &schemas.SpeechAlignment{
CharStartTimesMs: timestampResponse.NormalizedAlignment.CharStartTimesMs,
CharEndTimesMs: timestampResponse.NormalizedAlignment.CharEndTimesMs,
Characters: timestampResponse.NormalizedAlignment.Characters,
}
}
return bifrostResponse, nil
}
bifrostResponse.Audio = body
return bifrostResponse, nil
}
// Rerank is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
}
// OCR is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
}
// SpeechStream performs a text to speech stream request
func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.SpeechStreamRequest); err != nil {
return nil, err
}
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToElevenlabsSpeechRequest(request), nil
})
if bifrostErr != nil {
return nil, bifrostErr
}
// Create HTTP request for streaming
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
resp.StreamBody = true
defer fasthttp.ReleaseRequest(req)
// Set any extra headers from network config
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
if request.Params == nil || request.Params.VoiceConfig == nil || request.Params.VoiceConfig.Voice == nil {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("voice parameter is required", nil), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
req.SetRequestURI(provider.buildBaseSpeechRequestURL(ctx, "/v1/text-to-speech/"+*request.Params.VoiceConfig.Voice+"/stream", schemas.SpeechStreamRequest, request))
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
if key.Value.GetValue() != "" {
req.Header.Set("xi-api-key", key.Value.GetValue())
}
if !providerUtils.ApplyLargePayloadRequestBody(ctx, req) {
req.SetBody(jsonBody)
}
// Make request
startTime := time.Now()
err := provider.streamingClient.Do(req, resp)
if err != nil {
defer providerUtils.ReleaseStreamingResponse(resp)
if errors.Is(err, context.Canceled) {
return nil, providerUtils.EnrichError(ctx, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Type: schemas.Ptr(schemas.RequestCancelled),
Message: schemas.ErrRequestCancelled,
Error: err,
},
}, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Extract provider response headers before status check so error responses also forward them
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
// Check for HTTP errors
if resp.StatusCode() != fasthttp.StatusOK {
defer providerUtils.ReleaseStreamingResponse(resp)
return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Create response channel
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
go func() {
defer func() {
if ctx.Err() == context.Canceled {
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
} else if ctx.Err() == context.DeadlineExceeded {
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
}
close(responseChan)
}()
defer providerUtils.ReleaseStreamingResponse(resp)
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
defer releaseGzip()
// Wrap reader with idle timeout to detect stalled streams.
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
defer stopIdleTimeout()
// Setup cancellation handler to close the raw network stream on ctx cancellation,
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
defer stopCancellation()
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
// read binary audio chunks from the stream
// 4KB buffer for reading chunks
buffer := make([]byte, 4096)
bodyStream := reader
chunkIndex := -1
lastChunkTime := time.Now()
for {
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
}
n, err := bodyStream.Read(buffer)
if err != nil {
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
return
}
if err == io.EOF {
break
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
provider.logger.Warn("Error reading stream: %v", err)
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger, postHookSpanFinalizer)
return
}
if n > 0 {
chunkIndex++
audioChunk := make([]byte, n)
copy(audioChunk, buffer[:n])
response := &schemas.BifrostSpeechStreamResponse{
Type: schemas.SpeechStreamResponseTypeDelta,
Audio: audioChunk,
ExtraFields: schemas.BifrostResponseExtraFields{
ChunkIndex: chunkIndex,
Latency: time.Since(lastChunkTime).Milliseconds(),
},
}
lastChunkTime = time.Now()
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = audioChunk
}
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil, nil), responseChan, postHookSpanFinalizer)
}
}
// Send final response after natural loop termination (similar to Gemini pattern)
finalResponse := &schemas.BifrostSpeechStreamResponse{
Type: schemas.SpeechStreamResponseTypeDone,
Audio: []byte{},
ExtraFields: schemas.BifrostResponseExtraFields{
ChunkIndex: chunkIndex + 1,
Latency: time.Since(startTime).Milliseconds(),
},
}
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequest(&finalResponse.ExtraFields, jsonBody)
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, finalResponse, nil, nil), responseChan, postHookSpanFinalizer)
}()
return responseChan, nil
}
// Transcription performs a transcription request
func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.TranscriptionRequest); err != nil {
return nil, err
}
reqBody := ToElevenlabsTranscriptionRequest(request)
if reqBody == nil {
return nil, providerUtils.NewBifrostOperationError("transcription request is not provided", nil)
}
hasFile := len(reqBody.File) > 0
hasURL := reqBody.CloudStorageURL != nil && strings.TrimSpace(*reqBody.CloudStorageURL) != ""
if hasFile && hasURL {
return nil, providerUtils.NewBifrostOperationError("provide either a file or cloud_storage_url, not both", nil)
}
if !hasFile && !hasURL {
return nil, providerUtils.NewBifrostOperationError("either a transcription file or cloud_storage_url must be provided", nil)
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if bifrostErr := writeTranscriptionMultipart(writer, reqBody); bifrostErr != nil {
return nil, bifrostErr
}
contentType := writer.FormDataContentType()
if err := writer.Close(); err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to finalize multipart transcription request", err)
}
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
requestPath, isCompleteURL := providerUtils.GetRequestPath(ctx, "/v1/speech-to-text", provider.customProviderConfig, schemas.TranscriptionRequest)
if isCompleteURL {
req.SetRequestURI(requestPath)
} else {
req.SetRequestURI(provider.networkConfig.BaseURL + requestPath)
}
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType(contentType)
if key.Value.GetValue() != "" {
req.Header.Set("xi-api-key", key.Value.GetValue())
}
req.SetBody(body.Bytes())
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
// Extract and set provider response headers so they're available on error paths
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
if resp.StatusCode() != fasthttp.StatusOK {
return nil, parseElevenlabsError(resp)
}
responseBody, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
// Check for empty response
trimmed := strings.TrimSpace(string(responseBody))
if len(trimmed) == 0 {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: schemas.ErrProviderResponseEmpty,
},
}
}
chunks, err := parseTranscriptionResponse(responseBody)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(err.Error(), nil)
}
if len(chunks) == 0 {
return nil, providerUtils.NewBifrostOperationError("no chunks found in transcription response", nil)
}
response := ToBifrostTranscriptionResponse(chunks)
response.ExtraFields = schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
var rawResponse interface{}
if err := sonic.Unmarshal(responseBody, &rawResponse); err != nil {
rawResponse = string(responseBody)
}
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTranscriptionRequest) *schemas.BifrostError {
if err := writer.WriteField("model_id", reqBody.ModelID); err != nil {
return providerUtils.NewBifrostOperationError("failed to write model_id field", err)
}
if len(reqBody.File) > 0 {
filename := reqBody.Filename
if filename == "" {
filename = providerUtils.AudioFilenameFromBytes(reqBody.File)
}
fileWriter, err := writer.CreateFormFile("file", filename)
if err != nil {
return providerUtils.NewBifrostOperationError("failed to create file field", err)
}
if _, err := fileWriter.Write(reqBody.File); err != nil {
return providerUtils.NewBifrostOperationError("failed to write file data", err)
}
}
if reqBody.CloudStorageURL != nil && strings.TrimSpace(*reqBody.CloudStorageURL) != "" {
if err := writer.WriteField("cloud_storage_url", *reqBody.CloudStorageURL); err != nil {
return providerUtils.NewBifrostOperationError("failed to write cloud_storage_url field", err)
}
}
if reqBody.LanguageCode != nil && strings.TrimSpace(*reqBody.LanguageCode) != "" {
if err := writer.WriteField("language_code", *reqBody.LanguageCode); err != nil {
return providerUtils.NewBifrostOperationError("failed to write language_code field", err)
}
}
if reqBody.TagAudioEvents != nil {
if err := writer.WriteField("tag_audio_events", strconv.FormatBool(*reqBody.TagAudioEvents)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write tag_audio_events field", err)
}
}
if reqBody.NumSpeakers != nil && *reqBody.NumSpeakers > 0 {
if err := writer.WriteField("num_speakers", strconv.Itoa(*reqBody.NumSpeakers)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write num_speakers field", err)
}
}
if reqBody.TimestampsGranularity != nil && *reqBody.TimestampsGranularity != "" {
if err := writer.WriteField("timestamps_granularity", string(*reqBody.TimestampsGranularity)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write timestamps_granularity field", err)
}
}
if reqBody.Diarize != nil {
if err := writer.WriteField("diarize", strconv.FormatBool(*reqBody.Diarize)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write diarize field", err)
}
}
if reqBody.DiarizationThreshold != nil {
if err := writer.WriteField("diarization_threshold", strconv.FormatFloat(*reqBody.DiarizationThreshold, 'f', -1, 64)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write diarization_threshold field", err)
}
}
if len(reqBody.AdditionalFormats) > 0 {
payload, err := providerUtils.MarshalSorted(reqBody.AdditionalFormats)
if err != nil {
return providerUtils.NewBifrostOperationError("failed to marshal additional_formats", err)
}
if err := writer.WriteField("additional_formats", string(payload)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write additional_formats field", err)
}
}
if reqBody.FileFormat != nil && *reqBody.FileFormat != "" {
if err := writer.WriteField("file_format", string(*reqBody.FileFormat)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write file_format field", err)
}
}
if reqBody.Webhook != nil {
if err := writer.WriteField("webhook", strconv.FormatBool(*reqBody.Webhook)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write webhook field", err)
}
}
if reqBody.WebhookID != nil && strings.TrimSpace(*reqBody.WebhookID) != "" {
if err := writer.WriteField("webhook_id", *reqBody.WebhookID); err != nil {
return providerUtils.NewBifrostOperationError("failed to write webhook_id field", err)
}
}
if reqBody.Temperature != nil {
if err := writer.WriteField("temperature", strconv.FormatFloat(*reqBody.Temperature, 'f', -1, 64)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write temperature field", err)
}
}
if reqBody.Seed != nil {
if err := writer.WriteField("seed", strconv.Itoa(*reqBody.Seed)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write seed field", err)
}
}
if reqBody.UseMultiChannel != nil {
if err := writer.WriteField("use_multi_channel", strconv.FormatBool(*reqBody.UseMultiChannel)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write use_multi_channel field", err)
}
}
if reqBody.WebhookMetadata != nil {
switch v := reqBody.WebhookMetadata.(type) {
case string:
if strings.TrimSpace(v) != "" {
if err := writer.WriteField("webhook_metadata", v); err != nil {
return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err)
}
}
default:
payload, err := providerUtils.MarshalSorted(v)
if err != nil {
return providerUtils.NewBifrostOperationError("failed to marshal webhook_metadata", err)
}
if err := writer.WriteField("webhook_metadata", string(payload)); err != nil {
return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err)
}
}
}
return nil
}
// TranscriptionStream is not supported by the Elevenlabs provider
func (provider *ElevenlabsProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
}
// ImageGeneration is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
}
// ImageGenerationStream is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
}
// ImageEdit is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
}
// ImageEditStream is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
}
// ImageVariation is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
}
// VideoGeneration is not supported by the ElevenLabs provider.
func (provider *ElevenlabsProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey())
}
// VideoRetrieve is not supported by the ElevenLabs provider.
func (provider *ElevenlabsProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey())
}
// VideoDownload is not supported by the ElevenLabs provider.
func (provider *ElevenlabsProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey())
}
// VideoDelete is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
}
// VideoList is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
}
// VideoRemix is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
}
// buildSpeechRequestURL constructs the full request URL using the provider's configuration for speech.
func (provider *ElevenlabsProvider) buildBaseSpeechRequestURL(ctx *schemas.BifrostContext, defaultPath string, requestType schemas.RequestType, request *schemas.BifrostSpeechRequest) string {
baseURL := provider.networkConfig.BaseURL
requestPath, isCompleteURL := providerUtils.GetRequestPath(ctx, defaultPath, provider.customProviderConfig, requestType)
var finalURL string
if isCompleteURL {
finalURL = requestPath
} else {
u, parseErr := url.Parse(baseURL)
if parseErr != nil {
finalURL = baseURL + requestPath
} else {
u.Path = path.Join(u.Path, requestPath)
finalURL = u.String()
}
}
// Parse the final URL to add query parameters
u, parseErr := url.Parse(finalURL)
if parseErr != nil {
return finalURL
}
q := u.Query()
if request.Params != nil {
if request.Params.EnableLogging != nil {
q.Set("enable_logging", strconv.FormatBool(*request.Params.EnableLogging))
}
convertedFormat := ConvertBifrostSpeechFormatToElevenlabs(request.Params.ResponseFormat)
if convertedFormat != "" {
q.Set("output_format", convertedFormat)
}
if request.Params.OptimizeStreamingLatency != nil {
q.Set("optimize_streaming_latency", strconv.FormatBool(*request.Params.OptimizeStreamingLatency))
}
}
u.RawQuery = q.Encode()
return u.String()
}
// BatchCreate is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
}
// BatchList is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
}
// BatchRetrieve is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
}
// BatchCancel is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
}
// BatchDelete is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
}
// BatchResults is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
}
// FileUpload is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
}
// FileList is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
}
// FileRetrieve is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
}
// FileDelete is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
}
// FileContent is not supported by Elevenlabs provider.
func (provider *ElevenlabsProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
}
// CountTokens is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
}
// ContainerCreate is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
}
// ContainerList is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
}
// ContainerRetrieve is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
}
// ContainerDelete is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
}
// ContainerFileCreate is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
}
// ContainerFileList is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
}
// ContainerFileRetrieve is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
}
// ContainerFileContent is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
}
// ContainerFileDelete is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
}
// Passthrough is not supported by the Elevenlabs provider.
func (provider *ElevenlabsProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
}
func (provider *ElevenlabsProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
}

View File

@@ -0,0 +1,62 @@
package elevenlabs_test
import (
"os"
"strings"
"testing"
"github.com/maximhq/bifrost/core/internal/llmtests"
"github.com/maximhq/bifrost/core/schemas"
)
func TestElevenlabs(t *testing.T) {
t.Parallel()
if strings.TrimSpace(os.Getenv("ELEVENLABS_API_KEY")) == "" {
t.Skip("Skipping Elevenlabs tests because ELEVENLABS_API_KEY is not set")
}
client, ctx, cancel, err := llmtests.SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
realtimeAgentID := strings.TrimSpace(os.Getenv("ELEVENLABS_AGENT_ID"))
hasRealtimeAgent := false
testConfig := llmtests.ComprehensiveTestConfig{
Provider: schemas.Elevenlabs,
SpeechSynthesisModel: "eleven_turbo_v2_5",
TranscriptionModel: "scribe_v1",
RealtimeModel: realtimeAgentID,
Scenarios: llmtests.TestScenarios{
TextCompletion: false,
TextCompletionStream: false,
SimpleChat: false,
CompletionStream: false,
MultiTurnConversation: false,
ToolCalls: false,
MultipleToolCalls: false,
End2EndToolCalling: false,
AutomaticFunctionCall: false,
ImageURL: false,
ImageBase64: false,
MultipleImages: false,
CompleteEnd2End: false,
SpeechSynthesis: true,
SpeechSynthesisStream: true,
Transcription: true,
TranscriptionStream: false,
Embedding: false,
Reasoning: false,
ListModels: false,
Realtime: hasRealtimeAgent,
},
}
t.Run("ElevenlabsTests", func(t *testing.T) {
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
})
}

View File

@@ -0,0 +1,90 @@
package elevenlabs
import (
"strings"
"github.com/valyala/fasthttp"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
)
func parseElevenlabsError(resp *fasthttp.Response) *schemas.BifrostError {
var errorResp ElevenlabsError
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
if errorResp.Detail != nil {
var message string
// Handle validation errors (array format)
if len(errorResp.Detail.ValidationErrors) > 0 {
var messages []string
var locations []string
var errorTypes []string
for _, validationErr := range errorResp.Detail.ValidationErrors {
// Get message from either Message or Msg field
msg := validationErr.Message
if msg == "" {
msg = validationErr.Msg
}
if msg != "" {
messages = append(messages, msg)
}
// Collect location if available
if len(validationErr.Loc) > 0 {
locations = append(locations, strings.Join(validationErr.Loc, "."))
}
// Collect error type if available
if validationErr.Type != "" {
errorTypes = append(errorTypes, validationErr.Type)
}
}
// Build combined message
if len(messages) > 0 {
message = strings.Join(messages, "; ")
}
if len(locations) > 0 {
locationStr := strings.Join(locations, ", ")
message = message + " [" + locationStr + "]"
}
errorType := ""
if len(errorTypes) > 0 {
errorType = strings.Join(errorTypes, ", ")
}
if message != "" {
result := &schemas.BifrostError{
IsBifrostError: false,
StatusCode: schemas.Ptr(resp.StatusCode()),
Error: &schemas.ErrorField{
Type: schemas.Ptr(errorType),
Message: message,
},
}
return result
}
}
// Handle non-validation errors (single object format)
if errorResp.Detail.Message != nil {
message = *errorResp.Detail.Message
}
errorType := ""
if errorResp.Detail.Status != nil {
errorType = *errorResp.Detail.Status
}
if message != "" {
if bifrostErr.Error == nil {
bifrostErr.Error = &schemas.ErrorField{}
}
bifrostErr.Error.Type = schemas.Ptr(errorType)
bifrostErr.Error.Message = message
}
}
return bifrostErr
}

View File

@@ -0,0 +1,51 @@
package elevenlabs
import (
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(*response)),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: providerKey,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
for _, model := range *response {
for _, result := range pipeline.FilterModel(model.ModelID) {
entry := schemas.Model{
ID: string(providerKey) + "/" + result.ResolvedID,
Name: schemas.Ptr(model.Name),
}
if result.AliasValue != "" {
entry.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, entry)
included[strings.ToLower(result.ResolvedID)] = true
}
}
bifrostResponse.Data = append(bifrostResponse.Data,
pipeline.BackfillModels(included)...)
return bifrostResponse
}

View File

@@ -0,0 +1,257 @@
package elevenlabs
import (
"encoding/json"
"fmt"
"strings"
"github.com/maximhq/bifrost/core/schemas"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
)
// SupportsRealtimeAPI returns true since ElevenLabs supports Conversational AI via WebSocket.
func (provider *ElevenlabsProvider) SupportsRealtimeAPI() bool {
return true
}
// RealtimeWebSocketURL returns the WSS URL for the ElevenLabs Conversational AI endpoint.
// The model parameter is used as the agent_id query parameter.
// Format: wss://api.elevenlabs.io/v1/convai/conversation?agent_id=<model>
func (provider *ElevenlabsProvider) RealtimeWebSocketURL(key schemas.Key, model string) string {
base := provider.networkConfig.BaseURL
base = strings.Replace(base, "https://", "wss://", 1)
base = strings.Replace(base, "http://", "ws://", 1)
return base + "/v1/convai/conversation?agent_id=" + model
}
// RealtimeHeaders returns the headers required for the ElevenLabs Conversational AI WebSocket.
func (provider *ElevenlabsProvider) RealtimeHeaders(key schemas.Key) map[string]string {
headers := map[string]string{
"xi-api-key": key.Value.GetValue(),
}
for k, v := range provider.networkConfig.ExtraHeaders {
if strings.EqualFold(k, "xi-api-key") {
continue
}
headers[k] = v
}
return headers
}
// SupportsRealtimeWebRTC returns false — ElevenLabs WebRTC SDP exchange is not yet implemented.
func (provider *ElevenlabsProvider) SupportsRealtimeWebRTC() bool {
return false
}
// ExchangeRealtimeWebRTCSDP is not yet implemented for ElevenLabs.
func (provider *ElevenlabsProvider) ExchangeRealtimeWebRTCSDP(_ *schemas.BifrostContext, _ schemas.Key, _ string, _ string, _ json.RawMessage) (string, *schemas.BifrostError) {
return "", &schemas.BifrostError{
IsBifrostError: true,
StatusCode: schemas.Ptr(400),
Error: &schemas.ErrorField{Type: schemas.Ptr("invalid_request_error"), Message: "WebRTC SDP exchange is not yet implemented for ElevenLabs"},
}
}
func (provider *ElevenlabsProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool {
return false
}
func (provider *ElevenlabsProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType {
return schemas.RTEventResponseDone
}
func (provider *ElevenlabsProvider) RealtimeWebRTCDataChannelLabel() string {
return ""
}
func (provider *ElevenlabsProvider) RealtimeWebSocketSubprotocol() string {
return ""
}
func (provider *ElevenlabsProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool {
return true
}
func (provider *ElevenlabsProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool {
return eventType == schemas.RTEventResponseDone
}
// ElevenLabs Conversational AI WebSocket event types
const (
elConversationInitMetadata = "conversation_initiation_metadata"
elPing = "ping"
elAudio = "audio"
elUserTranscript = "user_transcript"
elAgentResponse = "agent_response"
elAgentResponseCorrection = "agent_response_correction"
elInterruption = "interruption"
elClientToolCall = "client_tool_call"
elUserAudioChunk = "user_audio_chunk"
elPong = "pong"
elClientToolResult = "client_tool_result"
elContextualUpdate = "contextual_update"
)
// elevenlabsEvent represents a raw ElevenLabs Conversational AI WebSocket event.
type elevenlabsEvent struct {
Type string `json:"type"`
// Server events
ConversationInitMetadata json.RawMessage `json:"conversation_initiation_metadata_event,omitempty"`
Audio json.RawMessage `json:"audio_event,omitempty"`
UserTranscript json.RawMessage `json:"user_transcription_event,omitempty"`
AgentResponse json.RawMessage `json:"agent_response_event,omitempty"`
AgentResponseCorrection json.RawMessage `json:"agent_response_correction_event,omitempty"`
ClientToolCall json.RawMessage `json:"client_tool_call,omitempty"`
PingEvent json.RawMessage `json:"ping_event,omitempty"`
// Client events
UserAudioChunk json.RawMessage `json:"user_audio_chunk,omitempty"`
}
// elevenlabsAudioEvent is the audio event structure from ElevenLabs.
type elevenlabsAudioEvent struct {
Audio string `json:"audio_base_64,omitempty"`
Alignment json.RawMessage `json:"alignment,omitempty"`
}
// elevenlabsTranscriptEvent is the user/agent transcript event from ElevenLabs.
type elevenlabsTranscriptEvent struct {
UserTranscript string `json:"user_transcript,omitempty"`
AgentResponse string `json:"agent_response,omitempty"`
AgentResponseID string `json:"agent_response_id,omitempty"`
}
// elevenlabsCorrectionEvent is the agent response correction event from ElevenLabs.
type elevenlabsCorrectionEvent struct {
OriginalAgentResponse string `json:"original_agent_response,omitempty"`
CorrectedAgentResponse string `json:"corrected_agent_response,omitempty"`
}
// ToBifrostRealtimeEvent converts an ElevenLabs Conversational AI event to the unified Bifrost format.
func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.RawMessage) (*schemas.BifrostRealtimeEvent, error) {
var raw elevenlabsEvent
if err := json.Unmarshal(providerEvent, &raw); err != nil {
return nil, fmt.Errorf("failed to unmarshal ElevenLabs realtime event: %w", err)
}
event := &schemas.BifrostRealtimeEvent{
RawData: providerEvent,
}
switch raw.Type {
case elConversationInitMetadata:
event.Type = schemas.RTEventSessionCreated
event.Session = &schemas.RealtimeSession{}
case elPing:
event.Type = schemas.RealtimeEventType("ping")
case elAudio:
event.Type = schemas.RTEventResponseAudioDelta
if raw.Audio != nil {
var audioEvt elevenlabsAudioEvent
if err := json.Unmarshal(raw.Audio, &audioEvt); err == nil {
event.Delta = &schemas.RealtimeDelta{
Audio: audioEvt.Audio,
}
}
}
case elUserTranscript:
event.Type = schemas.RTEventInputAudioTransCompleted
if raw.UserTranscript != nil {
var transcript elevenlabsTranscriptEvent
if err := json.Unmarshal(raw.UserTranscript, &transcript); err == nil {
event.Delta = &schemas.RealtimeDelta{
Transcript: transcript.UserTranscript,
}
}
}
case elAgentResponse:
event.Type = schemas.RTEventResponseDone
if raw.AgentResponse != nil {
var agentResp elevenlabsTranscriptEvent
if err := json.Unmarshal(raw.AgentResponse, &agentResp); err == nil {
event.Delta = &schemas.RealtimeDelta{
Text: agentResp.AgentResponse,
}
}
}
case elAgentResponseCorrection:
event.Type = schemas.RTEventResponseTextDelta
if raw.AgentResponseCorrection != nil {
var correction elevenlabsCorrectionEvent
if err := json.Unmarshal(raw.AgentResponseCorrection, &correction); err == nil {
event.Delta = &schemas.RealtimeDelta{
Text: correction.CorrectedAgentResponse,
}
}
}
case elInterruption:
event.Type = schemas.RTEventResponseCancel
case elClientToolCall:
event.Type = schemas.RealtimeEventType("client_tool_call")
if raw.ClientToolCall != nil {
var toolCall struct {
ToolName string `json:"tool_name"`
Parameters json.RawMessage `json:"parameters"`
ToolCallID string `json:"tool_call_id"`
}
if err := json.Unmarshal(raw.ClientToolCall, &toolCall); err == nil {
args := string(toolCall.Parameters)
if len(toolCall.Parameters) > 0 {
var parsed interface{}
if err := json.Unmarshal(toolCall.Parameters, &parsed); err == nil {
if sorted, err := providerUtils.MarshalSorted(parsed); err == nil {
args = string(sorted)
}
}
}
event.Item = &schemas.RealtimeItem{
Type: "function_call",
Name: toolCall.ToolName,
CallID: toolCall.ToolCallID,
Arguments: args,
}
}
}
default:
event.Type = schemas.RealtimeEventType(raw.Type)
}
return event, nil
}
// ToProviderRealtimeEvent converts a unified Bifrost Realtime event to ElevenLabs' native JSON.
func (provider *ElevenlabsProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) {
switch bifrostEvent.Type {
case schemas.RTEventInputAudioAppend:
if bifrostEvent.Delta == nil {
return nil, fmt.Errorf("delta must be set for input_audio_buffer.append events")
}
out := map[string]interface{}{
"type": elUserAudioChunk,
"user_audio_chunk": bifrostEvent.Delta.Audio,
}
return schemas.MarshalSorted(out)
case schemas.RealtimeEventType("pong"):
return schemas.MarshalSorted(map[string]interface{}{
"type": "pong",
})
default:
out := map[string]interface{}{
"type": string(bifrostEvent.Type),
}
return schemas.MarshalSorted(out)
}
}

View File

@@ -0,0 +1,102 @@
package elevenlabs
import (
"github.com/maximhq/bifrost/core/schemas"
)
func ToElevenlabsSpeechRequest(bifrostReq *schemas.BifrostSpeechRequest) *ElevenlabsSpeechRequest {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil
}
elevenlabsReq := &ElevenlabsSpeechRequest{
ModelID: bifrostReq.Model,
Text: bifrostReq.Input.Input,
}
if bifrostReq.Params != nil {
elevenlabsReq.ExtraParams = bifrostReq.Params.ExtraParams
voiceSettings := ElevenlabsVoiceSettings{}
hasVoiceSettings := false
if bifrostReq.Params.Speed != nil {
voiceSettings.Speed = *bifrostReq.Params.Speed
hasVoiceSettings = true
}
if bifrostReq.Params.ExtraParams != nil {
if stability, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["stability"]); ok {
delete(elevenlabsReq.ExtraParams, "stability")
voiceSettings.Stability = *stability
hasVoiceSettings = true
}
if useSpeakerBoost, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["use_speaker_boost"]); ok {
delete(elevenlabsReq.ExtraParams, "use_speaker_boost")
voiceSettings.UseSpeakerBoost = *useSpeakerBoost
hasVoiceSettings = true
}
if similarityBoost, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["similarity_boost"]); ok {
delete(elevenlabsReq.ExtraParams, "similarity_boost")
voiceSettings.SimilarityBoost = *similarityBoost
hasVoiceSettings = true
}
if style, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["style"]); ok {
delete(elevenlabsReq.ExtraParams, "style")
voiceSettings.Style = *style
hasVoiceSettings = true
}
if seed, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["seed"]); ok {
delete(elevenlabsReq.ExtraParams, "seed")
elevenlabsReq.Seed = seed
}
if previousText, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["previous_text"]); ok {
delete(elevenlabsReq.ExtraParams, "previous_text")
elevenlabsReq.PreviousText = previousText
}
if nextText, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["next_text"]); ok {
delete(elevenlabsReq.ExtraParams, "next_text")
elevenlabsReq.NextText = nextText
}
if previousRequestIDs, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["previous_request_ids"]); ok {
delete(elevenlabsReq.ExtraParams, "previous_request_ids")
elevenlabsReq.PreviousRequestIDs = previousRequestIDs
}
if nextRequestIDs, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["next_request_ids"]); ok {
delete(elevenlabsReq.ExtraParams, "next_request_ids")
elevenlabsReq.NextRequestIDs = nextRequestIDs
}
if applyTextNormalization, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["apply_text_normalization"]); ok {
delete(elevenlabsReq.ExtraParams, "apply_text_normalization")
elevenlabsReq.ApplyTextNormalization = applyTextNormalization
}
if applyLanguageTextNormalization, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["apply_language_text_normalization"]); ok {
delete(elevenlabsReq.ExtraParams, "apply_language_text_normalization")
elevenlabsReq.ApplyLanguageTextNormalization = applyLanguageTextNormalization
}
if usePVCAsIVC, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["use_pvc_as_ivc"]); ok {
delete(elevenlabsReq.ExtraParams, "use_pvc_as_ivc")
elevenlabsReq.UsePVCAsIVC = usePVCAsIVC
}
}
if hasVoiceSettings {
elevenlabsReq.VoiceSettings = &voiceSettings
}
if bifrostReq.Params.LanguageCode != nil {
elevenlabsReq.LanguageCode = bifrostReq.Params.LanguageCode
}
if len(bifrostReq.Params.PronunciationDictionaryLocators) > 0 {
elevenlabsReq.PronunciationDictionaryLocators = make([]ElevenlabsPronunciationDictionaryLocator, len(bifrostReq.Params.PronunciationDictionaryLocators))
for i, locator := range bifrostReq.Params.PronunciationDictionaryLocators {
elevenlabsReq.PronunciationDictionaryLocators[i] = ElevenlabsPronunciationDictionaryLocator{
PronunciationDictionaryID: locator.PronunciationDictionaryID,
VersionID: locator.VersionID,
}
}
}
}
return elevenlabsReq
}

View File

@@ -0,0 +1,269 @@
package elevenlabs
import (
"errors"
"strings"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
)
func ToElevenlabsTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionRequest) *ElevenlabsTranscriptionRequest {
if bifrostReq == nil {
return nil
}
req := &ElevenlabsTranscriptionRequest{
ModelID: bifrostReq.Model,
}
if bifrostReq.Input != nil && len(bifrostReq.Input.File) > 0 {
req.File = bifrostReq.Input.File
req.Filename = bifrostReq.Input.Filename
}
if bifrostReq.Params == nil {
return req
}
params := bifrostReq.Params
if params.Language != nil {
req.LanguageCode = params.Language
}
if params.ExtraParams != nil {
if tagAudioEvents, ok := schemas.SafeExtractBoolPointer(params.ExtraParams["tag_audio_events"]); ok {
delete(params.ExtraParams, "tag_audio_events")
req.TagAudioEvents = tagAudioEvents
}
if numSpeakers, ok := schemas.SafeExtractIntPointer(params.ExtraParams["num_speakers"]); ok {
delete(params.ExtraParams, "num_speakers")
req.NumSpeakers = numSpeakers
}
if timestampsGranularity, ok := schemas.SafeExtractStringPointer(params.ExtraParams["timestamps_granularity"]); ok {
granularity := ElevenlabsTimestampsGranularity(*timestampsGranularity)
delete(params.ExtraParams, "timestamps_granularity")
req.TimestampsGranularity = &granularity
}
if diarize, ok := schemas.SafeExtractBoolPointer(params.ExtraParams["diarize"]); ok {
delete(params.ExtraParams, "diarize")
req.Diarize = diarize
}
if diarizationThreshold, ok := schemas.SafeExtractFloat64Pointer(params.ExtraParams["diarization_threshold"]); ok {
delete(params.ExtraParams, "diarization_threshold")
req.DiarizationThreshold = diarizationThreshold
}
if fileFormat, ok := schemas.SafeExtractStringPointer(params.ExtraParams["file_format"]); ok {
fileFormat := ElevenlabsFileFormat(*fileFormat)
delete(params.ExtraParams, "file_format")
req.FileFormat = &fileFormat
}
if cloudStorageURL, ok := schemas.SafeExtractStringPointer(params.ExtraParams["cloud_storage_url"]); ok {
delete(params.ExtraParams, "cloud_storage_url")
req.CloudStorageURL = cloudStorageURL
}
if webhook, ok := schemas.SafeExtractBoolPointer(params.ExtraParams["webhook"]); ok {
delete(params.ExtraParams, "webhook")
req.Webhook = webhook
}
if webhookID, ok := schemas.SafeExtractStringPointer(params.ExtraParams["webhook_id"]); ok {
delete(params.ExtraParams, "webhook_id")
req.WebhookID = webhookID
}
if temperature, ok := schemas.SafeExtractFloat64Pointer(params.ExtraParams["temperature"]); ok {
delete(params.ExtraParams, "temperature")
req.Temperature = temperature
}
if seed, ok := schemas.SafeExtractIntPointer(params.ExtraParams["seed"]); ok {
delete(params.ExtraParams, "seed")
req.Seed = seed
}
if useMultiChannel, ok := schemas.SafeExtractBoolPointer(params.ExtraParams["use_multi_channel"]); ok {
delete(params.ExtraParams, "use_multi_channel")
req.UseMultiChannel = useMultiChannel
}
req.ExtraParams = bifrostReq.Params.ExtraParams
}
if len(params.AdditionalFormats) > 0 {
additionalFormats := make([]ElevenlabsAdditionalFormat, 0, len(params.AdditionalFormats))
for _, format := range params.AdditionalFormats {
if converted, ok := convertAdditionalFormat(format); ok {
additionalFormats = append(additionalFormats, converted)
}
}
if len(additionalFormats) > 0 {
req.AdditionalFormats = additionalFormats
}
}
if params.WebhookMetadata != nil {
if metadataMap, ok := params.WebhookMetadata.(map[string]interface{}); ok {
if len(metadataMap) > 0 {
req.WebhookMetadata = metadataMap
}
} else {
req.WebhookMetadata = params.WebhookMetadata
}
}
return req
}
func ToBifrostTranscriptionResponse(chunks []ElevenlabsSpeechToTextChunkResponse) *schemas.BifrostTranscriptionResponse {
if len(chunks) == 0 {
return nil
}
textParts := make([]string, 0, len(chunks))
allWords := make([]schemas.TranscriptionWord, 0)
allLogProbs := make([]schemas.TranscriptionLogProb, 0)
var language *string
var overallDuration *float64
for _, chunk := range chunks {
textParts = append(textParts, chunk.Text)
words, logProbs, chunkDuration := convertWords(chunk.Words)
allWords = append(allWords, words...)
allLogProbs = append(allLogProbs, logProbs...)
if language == nil && chunk.LanguageCode != "" {
lc := chunk.LanguageCode
language = &lc
}
if chunkDuration != nil {
if overallDuration == nil || *chunkDuration > *overallDuration {
val := *chunkDuration
overallDuration = &val
}
}
}
text := strings.Join(textParts, "\n")
response := &schemas.BifrostTranscriptionResponse{
Text: text,
Words: allWords,
LogProbs: allLogProbs,
}
if language != nil {
response.Language = language
}
if overallDuration != nil {
response.Duration = overallDuration
}
return response
}
func convertAdditionalFormat(format schemas.TranscriptionAdditionalFormat) (ElevenlabsAdditionalFormat, bool) {
if format.Format == "" {
return ElevenlabsAdditionalFormat{}, false
}
converted := ElevenlabsAdditionalFormat{
Format: ElevenlabsExportOptions(format.Format),
}
if format.IncludeSpeakers != nil {
converted.IncludeSpeakers = format.IncludeSpeakers
}
if format.IncludeTimestamps != nil {
converted.IncludeTimestamps = format.IncludeTimestamps
}
if format.SegmentOnSilenceLongerThanS != nil {
converted.SegmentOnSilenceLongerThanS = format.SegmentOnSilenceLongerThanS
}
if format.MaxSegmentDurationS != nil {
converted.MaxSegmentDurationS = format.MaxSegmentDurationS
}
if format.MaxSegmentChars != nil {
converted.MaxSegmentChars = format.MaxSegmentChars
}
if format.MaxCharactersPerLine != nil {
converted.MaxCharactersPerLine = format.MaxCharactersPerLine
}
return converted, true
}
func convertWords(words []ElevenlabsSpeechToTextWord) ([]schemas.TranscriptionWord, []schemas.TranscriptionLogProb, *float64) {
if len(words) == 0 {
return nil, nil, nil
}
convertedWords := make([]schemas.TranscriptionWord, 0, len(words))
logProbs := make([]schemas.TranscriptionLogProb, 0, len(words))
var maxEnd float64
var hasEnd bool
for _, word := range words {
trimmed := strings.TrimSpace(word.Text)
if word.Type == "spacing" && trimmed == "" {
continue
}
transcriptionWord := schemas.TranscriptionWord{
Word: word.Text,
}
if word.Start != nil {
transcriptionWord.Start = *word.Start
}
if word.End != nil {
transcriptionWord.End = *word.End
if !hasEnd || *word.End > maxEnd {
maxEnd = *word.End
hasEnd = true
}
}
convertedWords = append(convertedWords, transcriptionWord)
logProbs = append(logProbs, schemas.TranscriptionLogProb{
Token: word.Text,
LogProb: word.LogProb,
})
}
if !hasEnd {
return convertedWords, logProbs, nil
}
duration := maxEnd
return convertedWords, logProbs, &duration
}
func parseTranscriptionResponse(body []byte) ([]ElevenlabsSpeechToTextChunkResponse, error) {
var multichannel ElevenlabsMultichannelSpeechToTextResponse
if err := sonic.Unmarshal(body, &multichannel); err == nil && len(multichannel.Transcripts) > 0 {
return multichannel.Transcripts, nil
}
var single ElevenlabsSpeechToTextChunkResponse
if err := sonic.Unmarshal(body, &single); err == nil {
if single.LanguageCode != "" || single.Text != "" || len(single.Words) > 0 {
return []ElevenlabsSpeechToTextChunkResponse{single}, nil
}
}
var webhook ElevenlabsSpeechToTextWebhookResponse
if err := sonic.Unmarshal(body, &webhook); err == nil && strings.TrimSpace(webhook.Message) != "" {
return nil, errors.New(webhook.Message)
}
return nil, errors.New("unexpected Elevenlabs transcription response format")
}

View File

@@ -0,0 +1,289 @@
package elevenlabs
import (
"strings"
"github.com/bytedance/sonic"
)
// SPEECH TYPES
type ElevenlabsSpeechRequest struct {
Text string `json:"text"`
ModelID string `json:"model_id"` // defaults to "eleven_multilingual_v2"
LanguageCode *string `json:"language_code,omitempty"`
VoiceSettings *ElevenlabsVoiceSettings `json:"voice_settings,omitempty"`
PronunciationDictionaryLocators []ElevenlabsPronunciationDictionaryLocator `json:"pronunciation_dictionary_locators"`
Seed *int `json:"seed,omitempty"`
PreviousText *string `json:"previous_text,omitempty"`
NextText *string `json:"next_text,omitempty"`
PreviousRequestIDs []string `json:"previous_request_ids"`
NextRequestIDs []string `json:"next_request_ids"`
ApplyTextNormalization *string `json:"apply_text_normalization,omitempty"`
ApplyLanguageTextNormalization *bool `json:"apply_language_text_normalization,omitempty"`
UsePVCAsIVC *bool `json:"use_pvc_as_ivc,omitempty"` // deprecated
ExtraParams map[string]interface{} `json:"-"`
}
// GetExtraParams implements the providerUtils.RequestBodyWithExtraParams interface.
func (r *ElevenlabsSpeechRequest) GetExtraParams() map[string]interface{} {
return r.ExtraParams
}
// ElevenlabsSpeechWithTimestampsResponse represents the response from the with-timestamps endpoint
type ElevenlabsSpeechWithTimestampsResponse struct {
AudioBase64 string `json:"audio_base64"`
Alignment *ElevenlabsAlignment `json:"alignment,omitempty"`
NormalizedAlignment *ElevenlabsAlignment `json:"normalized_alignment,omitempty"`
}
// ElevenlabsAlignment represents character-level timing information
type ElevenlabsAlignment struct {
CharStartTimesMs []float64 `json:"char_start_times_ms"`
CharEndTimesMs []float64 `json:"char_end_times_ms"`
Characters []string `json:"characters"`
}
type ElevenlabsVoiceSettings struct {
Stability float64 `json:"stability"` // 0-1, default 0.5
UseSpeakerBoost bool `json:"use_speaker_boost"` // default true
SimilarityBoost float64 `json:"similarity_boost"` // 0-1, default 0.75
Style float64 `json:"style"` // default 0
Speed float64 `json:"speed"` // default 1
}
type ElevenlabsPronunciationDictionaryLocator struct {
PronunciationDictionaryID string `json:"pronunciation_dictionary_id"`
VersionID *string `json:"version_id,omitempty"`
}
// TRANSCRIPTION TYPES
type ElevenlabsTranscriptionRequest struct {
ModelID string `json:"model_id"`
File []byte `json:"-"`
Filename string `json:"-"` // Original filename, used to preserve file format extension
LanguageCode *string `json:"language_code,omitempty"`
TagAudioEvents *bool `json:"tag_audio_events,omitempty"`
NumSpeakers *int `json:"num_speakers,omitempty"`
TimestampsGranularity *ElevenlabsTimestampsGranularity `json:"timestamps_granularity,omitempty"`
Diarize *bool `json:"diarize,omitempty"`
DiarizationThreshold *float64 `json:"diarization_threshold,omitempty"`
AdditionalFormats []ElevenlabsAdditionalFormat `json:"additional_formats,omitempty"`
FileFormat *ElevenlabsFileFormat `json:"file_format,omitempty"`
CloudStorageURL *string `json:"cloud_storage_url,omitempty"`
Webhook *bool `json:"webhook,omitempty"`
WebhookID *string `json:"webhook_id,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed *int `json:"seed,omitempty"`
UseMultiChannel *bool `json:"use_multi_channel,omitempty"`
WebhookMetadata interface{} `json:"webhook_metadata,omitempty"`
ExtraParams map[string]interface{} `json:"-"`
}
// GetExtraParams implements the RequestBodyWithExtraParams interface
func (req *ElevenlabsTranscriptionRequest) GetExtraParams() map[string]interface{} {
return req.ExtraParams
}
type ElevenlabsTimestampsGranularity string
const (
ElevenlabsTimestampsGranularityNone ElevenlabsTimestampsGranularity = "none"
ElevenlabsTimestampsGranularityWord ElevenlabsTimestampsGranularity = "word"
ElevenlabsTimestampsGranularityCharacter ElevenlabsTimestampsGranularity = "character"
)
type ElevenlabsFileFormat string
const (
ElevenlabsFileFormatPcmS16le16 ElevenlabsFileFormat = "pcm_s16le_16"
ElevenlabsFileFormatOther ElevenlabsFileFormat = "other"
)
type ElevenlabsAdditionalFormat struct {
Format ElevenlabsExportOptions `json:"format"`
IncludeSpeakers *bool `json:"include_speakers,omitempty"`
IncludeTimestamps *bool `json:"include_timestamps,omitempty"`
SegmentOnSilenceLongerThanS *float64 `json:"segment_on_silence_longer_than_s,omitempty"`
MaxSegmentDurationS *float64 `json:"max_segment_duration_s,omitempty"`
MaxSegmentChars *int `json:"max_segment_chars,omitempty"`
MaxCharactersPerLine *int `json:"max_characters_per_line,omitempty"`
}
type ElevenlabsExportOptions string
const (
ElevenlabsExportOptionsSegmentedJson ElevenlabsExportOptions = "segmented_json"
ElevenlabsExportOptionsDocx ElevenlabsExportOptions = "docx"
ElevenlabsExportOptionsPdf ElevenlabsExportOptions = "pdf"
ElevenlabsExportOptionsTxt ElevenlabsExportOptions = "txt"
ElevenlabsExportOptionsHtml ElevenlabsExportOptions = "html"
ElevenlabsExportOptionsSrt ElevenlabsExportOptions = "srt"
)
type ElevenlabsSpeechToTextChunkResponse struct {
LanguageCode string `json:"language_code"`
LanguageProbability *float64 `json:"language_probability,omitempty"`
Text string `json:"text"`
Words []ElevenlabsSpeechToTextWord `json:"words"`
ChannelIndex *int `json:"channel_index,omitempty"`
AdditionalFormats []*ElevenlabsAdditionalFormatResponse `json:"additional_formats,omitempty"`
TranscriptionID *string `json:"transcription_id,omitempty"`
}
type ElevenlabsSpeechToTextWord struct {
Text string `json:"text"`
Start *float64 `json:"start,omitempty"`
End *float64 `json:"end,omitempty"`
Type string `json:"type"`
SpeakerID *string `json:"speaker_id,omitempty"`
LogProb float64 `json:"logprob"`
Characters []ElevenlabsSpeechToTextCharacter `json:"characters,omitempty"`
}
type ElevenlabsSpeechToTextCharacter struct {
Text string `json:"text"`
Start *float64 `json:"start,omitempty"`
End *float64 `json:"end,omitempty"`
}
type ElevenlabsAdditionalFormatResponse struct {
RequestedFormat string `json:"requested_format"`
FileExtension string `json:"file_extension"`
ContentType string `json:"content_type"`
IsBase64Encoded bool `json:"is_base64_encoded"`
Content string `json:"content"`
}
type ElevenlabsMultichannelSpeechToTextResponse struct {
Transcripts []ElevenlabsSpeechToTextChunkResponse `json:"transcripts"`
TranscriptionID *string `json:"transcription_id,omitempty"`
}
type ElevenlabsSpeechToTextWebhookResponse struct {
Message string `json:"message"`
RequestID string `json:"request_id"`
TranscriptionID *string `json:"transcription_id,omitempty"`
}
// ERROR TYPES
type ElevenlabsError struct {
Detail *ElevenlabsErrorDetail `json:"detail,omitempty"`
}
// ElevenlabsErrorDetail handles both single object (non-validation errors) and
// array of objects (validation errors) formats from ElevenLabs API.
type ElevenlabsErrorDetail struct {
// Non-validation error fields (when detail is a single object)
Status *string `json:"status,omitempty"`
Message *string `json:"message,omitempty"`
// Validation error fields (when detail is an array)
ValidationErrors []ElevenlabsValidationError `json:"-"`
}
// ElevenlabsValidationError represents a single validation error entry
type ElevenlabsValidationError struct {
Loc []string `json:"loc"`
Msg string `json:"msg"`
Message string `json:"message"` // Some APIs use "message" instead of "msg"
Type string `json:"type"`
}
// UnmarshalJSON implements custom JSON unmarshaling to handle both
// single object and array formats from ElevenLabs API.
func (d *ElevenlabsErrorDetail) UnmarshalJSON(data []byte) error {
// First, try to unmarshal as an array (validation errors)
// Check if it's an array by looking at the first non-whitespace character
trimmed := strings.TrimSpace(string(data))
if len(trimmed) > 0 && trimmed[0] == '[' {
var validationErrors []ElevenlabsValidationError
if err := sonic.Unmarshal(data, &validationErrors); err != nil {
return err
}
d.ValidationErrors = validationErrors
// Extract message from first validation error if available
if len(validationErrors) > 0 {
if validationErrors[0].Message != "" {
d.Message = &validationErrors[0].Message
} else if validationErrors[0].Msg != "" {
d.Message = &validationErrors[0].Msg
}
}
return nil
}
// If not an array, try to unmarshal as a single object (non-validation error)
var obj struct {
Type *string `json:"type,omitempty"`
Loc []string `json:"loc,omitempty"`
Message *string `json:"message,omitempty"`
Status *string `json:"status,omitempty"`
Msg *string `json:"msg,omitempty"` // Some APIs use "msg" instead of "message"
}
if err := sonic.Unmarshal(data, &obj); err != nil {
return err
}
// Populate non-validation error fields
d.Status = obj.Status
if obj.Message != nil {
d.Message = obj.Message
} else if obj.Msg != nil {
d.Message = obj.Msg
}
// If this object has validation-like fields (Loc, Type), treat it as a single validation error
if len(obj.Loc) > 0 || obj.Type != nil {
validationErr := ElevenlabsValidationError{
Loc: obj.Loc,
Type: func() string {
if obj.Type != nil {
return *obj.Type
}
return ""
}(),
}
if obj.Message != nil {
validationErr.Message = *obj.Message
} else if obj.Msg != nil {
validationErr.Msg = *obj.Msg
validationErr.Message = *obj.Msg
}
d.ValidationErrors = []ElevenlabsValidationError{validationErr}
}
return nil
}
// MODEL TYPES
type ElevenlabsModel struct {
ModelID string `json:"model_id"`
Name string `json:"name"`
Description string `json:"description"`
ServesProVoices bool `json:"serves_pro_voices"`
TokenCostFactor float64 `json:"token_cost_factor"`
CanBeFinetuned bool `json:"can_be_finetuned"`
CanDoTextToSpeech bool `json:"can_do_text_to_speech"`
CanDoVoiceConversion bool `json:"can_do_voice_conversion"`
CanUseStyle bool `json:"can_use_style"`
CanUseSpeakerBoost bool `json:"can_use_speaker_boost"`
Languages []ElevenlabsLanguage `json:"languages"`
RequiresAlphaAccess bool `json:"requires_alpha_access"`
MaxCharactersRequestFreeUser int `json:"max_characters_request_free_user"`
MaxCharactersRequestSubscribedUser int `json:"max_characters_request_subscribed_user"`
MaxTextLengthPerRequest int `json:"maximum_text_length_per_request"`
ModelRates ElevenlabsModelRate `json:"model_rates"`
ConcurrencyGroup string `json:"concurrency_group"`
}
type ElevenlabsLanguage struct {
LanguageID string `json:"language_id"`
Name string `json:"name"`
}
type ElevenlabsModelRate struct {
CharacterCostMultiplier float64 `json:"character_cost_multiplier"`
}
type ElevenlabsListModelsResponse []ElevenlabsModel

View File

@@ -0,0 +1,35 @@
package elevenlabs
var (
// Maps provider-specific finish reasons to Bifrost format
bifrostToElevenlabsSpeechFormat = map[string]string{
"": "mp3_44100_128",
"mp3": "mp3_44100_128",
"opus": "opus_48000_128",
"wav": "pcm_44100",
"pcm": "pcm_44100",
}
// Maps Bifrost finish reasons to provider-specific format
elevenlabsSpeechFormatToBifrost = map[string]string{
"mp3_44100_128": "mp3",
"opus_48000_128": "opus",
"pcm_44100": "wav",
}
)
// ConvertBifrostSpeechFormatToElevenlabs converts Bifrost speech format to Elevenlabs format
func ConvertBifrostSpeechFormatToElevenlabs(format string) string {
if elevenlabsFormat, ok := bifrostToElevenlabsSpeechFormat[format]; ok {
return elevenlabsFormat
}
return format
}
// ConvertElevenlabsSpeechFormatToBifrost converts Elevenlabs speech format to Bifrost format
func ConvertElevenlabsSpeechFormatToBifrost(format string) string {
if bifrostFormat, ok := elevenlabsSpeechFormatToBifrost[format]; ok {
return bifrostFormat
}
return format
}

View File

@@ -0,0 +1,435 @@
// Package fireworks implements the Fireworks AI provider and its utility functions.
package fireworks
import (
"context"
"strings"
"time"
"github.com/maximhq/bifrost/core/providers/openai"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// FireworksProvider implements the Provider interface for Fireworks AI's API.
type FireworksProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
}
// NewFireworksProvider creates a new Fireworks AI provider instance.
// It initializes the HTTP client with the provided configuration and sets up response pools.
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
func NewFireworksProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*FireworksProvider, error) {
config.CheckAndSetDefaults()
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
client := &fasthttp.Client{
ReadTimeout: requestTimeout,
WriteTimeout: requestTimeout,
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
MaxIdleConnDuration: 30 * time.Second,
MaxConnWaitTimeout: requestTimeout,
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
ConnPoolStrategy: fasthttp.FIFO,
}
// Configure proxy and retry policy
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
client = providerUtils.ConfigureDialer(client)
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
streamingClient := providerUtils.BuildStreamingClient(client)
// Set default BaseURL if not provided
if config.NetworkConfig.BaseURL == "" {
config.NetworkConfig.BaseURL = "https://api.fireworks.ai/inference"
}
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
return &FireworksProvider{
logger: logger,
client: client,
streamingClient: streamingClient,
networkConfig: config.NetworkConfig,
sendBackRawRequest: config.SendBackRawRequest,
sendBackRawResponse: config.SendBackRawResponse,
}, nil
}
// GetProviderKey returns the provider identifier for Fireworks AI.
func (provider *FireworksProvider) GetProviderKey() schemas.ModelProvider {
return schemas.Fireworks
}
// ListModels performs a list models request to Fireworks AI's API.
func (provider *FireworksProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
return openai.HandleOpenAIListModelsRequest(
ctx,
provider.client,
request,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"),
keys,
provider.networkConfig.ExtraHeaders,
schemas.Fireworks,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
)
}
// TextCompletion performs a text completion request to the Fireworks AI API.
func (provider *FireworksProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
return openai.HandleOpenAITextCompletionRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
request,
key,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
nil,
nil,
provider.logger,
)
}
// TextCompletionStream performs a streaming text completion request to the Fireworks AI API.
func (provider *FireworksProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
var authHeader map[string]string
if v := key.Value.GetValue(); v != "" {
authHeader = map[string]string{"Authorization": "Bearer " + v}
}
return openai.HandleOpenAITextCompletionStreaming(
ctx,
provider.streamingClient,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
nil,
postHookRunner,
nil,
nil,
provider.logger,
postHookSpanFinalizer,
)
}
// ChatCompletion performs a chat completion request to the Fireworks AI API.
func (provider *FireworksProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
return openai.HandleOpenAIChatCompletionRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
request,
key,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
nil,
nil,
provider.logger,
)
}
// ChatCompletionStream performs a streaming chat completion request to the Fireworks AI API.
// It supports real-time streaming of responses using Server-Sent Events (SSE).
// Uses Fireworks AI's OpenAI-compatible streaming format.
// Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails.
func (provider *FireworksProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
var authHeader map[string]string
if v := key.Value.GetValue(); v != "" {
authHeader = map[string]string{"Authorization": "Bearer " + v}
}
// Use shared OpenAI-compatible streaming logic
return openai.HandleOpenAIChatCompletionStreaming(
ctx,
provider.streamingClient,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
schemas.Fireworks,
postHookRunner,
nil,
nil,
nil,
nil,
nil,
provider.logger,
postHookSpanFinalizer,
)
}
// Responses performs a responses request to the Fireworks AI API.
func (provider *FireworksProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
return openai.HandleOpenAIResponsesRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/responses"),
request,
key,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
nil,
nil,
provider.logger,
)
}
// ResponsesStream performs a streaming responses request to the Fireworks AI API.
func (provider *FireworksProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
var authHeader map[string]string
if v := key.Value.GetValue(); v != "" {
authHeader = map[string]string{"Authorization": "Bearer " + v}
}
return openai.HandleOpenAIResponsesStreaming(
ctx,
provider.streamingClient,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/responses"),
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
nil,
nil,
nil,
nil,
provider.logger,
postHookSpanFinalizer,
)
}
// Embedding performs an embedding request to the Fireworks AI provider.
func (provider *FireworksProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
return openai.HandleOpenAIEmbeddingRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"),
request,
key,
provider.networkConfig.ExtraHeaders,
provider.GetProviderKey(),
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
nil,
provider.logger,
)
}
// Speech is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
}
// Rerank is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
}
// OCR is not supported by the Fireworks provider.
func (provider *FireworksProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
}
// SpeechStream is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
}
// Transcription is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey())
}
// TranscriptionStream is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
}
// ImageGeneration is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
}
// ImageGenerationStream is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
}
// ImageEdit is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
}
// ImageEditStream is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
}
// ImageVariation is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
}
// VideoGeneration is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey())
}
// VideoRetrieve is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey())
}
// VideoDownload is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey())
}
// VideoDelete is not supported by Fireworks AI provider.
func (provider *FireworksProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
}
// VideoList is not supported by Fireworks AI provider.
func (provider *FireworksProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
}
// VideoRemix is not supported by Fireworks AI provider.
func (provider *FireworksProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
}
// BatchCreate is not supported by Fireworks AI provider.
func (provider *FireworksProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
}
// BatchList is not supported by Fireworks AI provider.
func (provider *FireworksProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
}
// BatchRetrieve is not supported by Fireworks AI provider.
func (provider *FireworksProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
}
// BatchCancel is not supported by Fireworks AI provider.
func (provider *FireworksProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
}
// BatchDelete is not supported by Fireworks AI provider.
func (provider *FireworksProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
}
// BatchResults is not supported by Fireworks AI provider.
func (provider *FireworksProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
}
// FileUpload is not supported by Fireworks AI provider.
func (provider *FireworksProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
}
// FileList is not supported by Fireworks AI provider.
func (provider *FireworksProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
}
// FileRetrieve is not supported by Fireworks AI provider.
func (provider *FireworksProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
}
// FileDelete is not supported by Fireworks AI provider.
func (provider *FireworksProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
}
// FileContent is not supported by Fireworks AI provider.
func (provider *FireworksProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
}
// CountTokens is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
}
// ContainerCreate is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
}
// ContainerList is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
}
// ContainerRetrieve is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
}
// ContainerDelete is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
}
// ContainerFileCreate is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
}
// ContainerFileList is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
}
// ContainerFileRetrieve is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
}
// ContainerFileContent is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
}
// ContainerFileDelete is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
}
// Passthrough is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
}
// PassthroughStream is not supported by the Fireworks AI provider.
func (provider *FireworksProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
}

View File

@@ -0,0 +1,443 @@
package fireworks_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/internal/llmtests"
fireworksprovider "github.com/maximhq/bifrost/core/providers/fireworks"
"github.com/maximhq/bifrost/core/schemas"
)
func TestFireworks(t *testing.T) {
t.Parallel()
if strings.TrimSpace(os.Getenv("FIREWORKS_API_KEY")) == "" {
t.Skip("Skipping Fireworks tests because FIREWORKS_API_KEY is not set")
}
client, ctx, cancel, err := llmtests.SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
chatModel, textModel, embeddingModel := resolveFireworksModels(t, client, ctx)
testConfig := llmtests.ComprehensiveTestConfig{
Provider: schemas.Fireworks,
ChatModel: chatModel,
Fallbacks: []schemas.Fallback{},
TextModel: textModel,
TextCompletionFallbacks: []schemas.Fallback{},
EmbeddingModel: embeddingModel,
ReasoningModel: "",
TranscriptionModel: "",
SpeechSynthesisModel: "",
Scenarios: llmtests.TestScenarios{
TextCompletion: textModel != "",
TextCompletionStream: textModel != "",
SimpleChat: true,
CompletionStream: true,
MultiTurnConversation: true,
ToolCalls: true,
ToolCallsStreaming: true,
MultipleToolCalls: false,
End2EndToolCalling: false,
AutomaticFunctionCall: false,
ImageURL: false,
ImageBase64: false,
MultipleImages: false,
FileBase64: false,
FileURL: false,
CompleteEnd2End: true,
Embedding: embeddingModel != "",
ListModels: true,
Reasoning: false,
Transcription: false,
SpeechSynthesis: false,
PromptCaching: false,
},
}
t.Run("FireworksTests", func(t *testing.T) {
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
})
}
// resolveFireworksModels discovers live Fireworks models for chat, completions, and embeddings.
func resolveFireworksModels(t *testing.T, client *bifrost.Bifrost, ctx context.Context) (string, string, string) {
t.Helper()
requestedChatModel := normalizeFireworksModelID(os.Getenv("FIREWORKS_CHAT_MODEL"))
requestedTextModel := normalizeFireworksModelID(os.Getenv("FIREWORKS_TEXT_MODEL"))
requestedEmbeddingModel := normalizeFireworksModelID(os.Getenv("FIREWORKS_EMBEDDING_MODEL"))
chatModel := requestedChatModel
textModel := requestedTextModel
embeddingModel := requestedEmbeddingModel
if requestedChatModel != "" {
t.Logf("Using FIREWORKS_CHAT_MODEL=%q override", requestedChatModel)
}
if requestedTextModel != "" {
t.Logf("Using FIREWORKS_TEXT_MODEL=%q override", requestedTextModel)
}
if requestedEmbeddingModel != "" {
t.Logf("Using FIREWORKS_EMBEDDING_MODEL=%q override", requestedEmbeddingModel)
}
if chatModel == "" || textModel == "" || embeddingModel == "" {
pageToken := ""
for page := 0; page < 5; page++ {
req := &schemas.BifrostListModelsRequest{
Provider: schemas.Fireworks,
PageSize: 200,
PageToken: pageToken,
}
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
resp, bifrostErr := client.ListModelsRequest(bfCtx, req)
if bifrostErr != nil {
if chatModel == "" {
t.Fatalf("Failed to list Fireworks models for test discovery: %v", llmtests.GetErrorMessage(bifrostErr))
}
t.Logf("Fireworks model discovery failed: %v", llmtests.GetErrorMessage(bifrostErr))
break
}
if chatModel == "" {
chatModel = pickFireworksChatModel(resp.Data)
}
if textModel == "" {
// Fireworks text completions currently reuse the chat-capable model pool;
// a later probe verifies that the selected model accepts /v1/completions.
textModel = pickFireworksChatModel(resp.Data)
}
if embeddingModel == "" {
embeddingModel = pickFireworksEmbeddingModel(resp.Data)
}
if chatModel != "" && textModel != "" && embeddingModel != "" {
break
}
if resp.NextPageToken == "" {
break
}
pageToken = resp.NextPageToken
}
}
if chatModel == "" {
t.Fatal("Unable to discover a Fireworks chat model from /v1/models; set FIREWORKS_CHAT_MODEL to override")
}
if textModel != "" && !fireworksModelSupportsTextCompletions(t, client, ctx, textModel) {
t.Logf("Skipping Fireworks text completion scenarios because model %q did not accept /v1/completions", textModel)
textModel = ""
}
if embeddingModel != "" && !fireworksModelSupportsEmbeddings(t, client, ctx, embeddingModel) {
t.Logf("Skipping Fireworks embedding scenario because model %q did not accept /v1/embeddings", embeddingModel)
embeddingModel = ""
}
if textModel == "" {
t.Log("No Fireworks completions-capable model discovered from /v1/models; text completion scenarios will be skipped unless FIREWORKS_TEXT_MODEL is set")
}
if embeddingModel == "" {
t.Log("No Fireworks embedding model discovered from /v1/models; embedding scenario will be skipped unless FIREWORKS_EMBEDDING_MODEL is set")
}
return chatModel, textModel, embeddingModel
}
// fireworksModelSupportsTextCompletions validates that the selected model actually accepts Fireworks /v1/completions.
func fireworksModelSupportsTextCompletions(t *testing.T, client *bifrost.Bifrost, ctx context.Context, model string) bool {
t.Helper()
prompt := "Say ok"
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
resp, bifrostErr := client.TextCompletionRequest(bfCtx, &schemas.BifrostTextCompletionRequest{
Provider: schemas.Fireworks,
Model: model,
Input: &schemas.TextCompletionInput{
PromptStr: &prompt,
},
Params: &schemas.TextCompletionParameters{
MaxTokens: schemas.Ptr(8),
},
})
if bifrostErr != nil {
t.Logf("Fireworks /v1/completions probe failed for %q: %v", model, llmtests.GetErrorMessage(bifrostErr))
return false
}
return resp != nil && len(resp.Choices) > 0
}
// fireworksModelSupportsEmbeddings validates that the selected model actually accepts Fireworks /v1/embeddings.
func fireworksModelSupportsEmbeddings(t *testing.T, client *bifrost.Bifrost, ctx context.Context, model string) bool {
t.Helper()
text := "embedding probe"
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
resp, bifrostErr := client.EmbeddingRequest(bfCtx, &schemas.BifrostEmbeddingRequest{
Provider: schemas.Fireworks,
Model: model,
Input: &schemas.EmbeddingInput{
Text: &text,
},
})
if bifrostErr != nil {
t.Logf("Fireworks /v1/embeddings probe failed for %q: %v", model, llmtests.GetErrorMessage(bifrostErr))
return false
}
return resp != nil && len(resp.Data) > 0
}
// pickFireworksChatModel selects a text-capable Fireworks model from ListModels output.
func pickFireworksChatModel(models []schemas.Model) string {
for _, model := range models {
normalized := normalizeFireworksModelID(model.ID)
if isFireworksTextCapable(normalized) {
return normalized
}
}
return ""
}
// pickFireworksEmbeddingModel selects an embedding-capable Fireworks model from ListModels output.
func pickFireworksEmbeddingModel(models []schemas.Model) string {
for _, model := range models {
normalized := normalizeFireworksModelID(model.ID)
lower := strings.ToLower(normalized)
if strings.Contains(lower, "embedding") || strings.Contains(lower, "embed") {
return normalized
}
}
return ""
}
// normalizeFireworksModelID strips any provider prefix so tests can pass raw Fireworks model IDs to Bifrost requests.
func normalizeFireworksModelID(modelID string) string {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return ""
}
_, normalized := schemas.ParseModelString(modelID, schemas.Fireworks)
return normalized
}
// isFireworksTextCapable applies a conservative name-based heuristic for text/chat-capable Fireworks models.
func isFireworksTextCapable(modelID string) bool {
lower := strings.ToLower(modelID)
excludedHints := []string{
"flux",
"whisper",
"audio",
"speech",
"transcrib",
"embedding",
"embed",
"rerank",
}
for _, hint := range excludedHints {
if strings.Contains(lower, hint) {
return false
}
}
preferredHints := []string{
"instruct",
"chat",
"gpt-oss",
"deepseek",
"qwen",
"llama",
"glm",
"mixtral",
"mistral",
"cogito",
"gemma",
}
for _, hint := range preferredHints {
if strings.Contains(lower, hint) {
return true
}
}
return false
}
// TestFireworksProviderUsesNativeEndpoints verifies that the Fireworks provider targets native completions, responses, and embeddings endpoints.
func TestFireworksProviderUsesNativeEndpoints(t *testing.T) {
tests := []struct {
name string
expectedPath string
run func(t *testing.T, provider *fireworksprovider.FireworksProvider, ctx *schemas.BifrostContext, key schemas.Key)
}{
{
name: "TextCompletion",
expectedPath: "/v1/completions",
run: func(t *testing.T, provider *fireworksprovider.FireworksProvider, ctx *schemas.BifrostContext, key schemas.Key) {
prompt := "A is for apple and B is for"
resp, err := provider.TextCompletion(ctx, key, &schemas.BifrostTextCompletionRequest{
Provider: schemas.Fireworks,
Model: "accounts/fireworks/models/deepseek-v3p2",
Input: &schemas.TextCompletionInput{
PromptStr: &prompt,
},
})
if err != nil {
t.Fatalf("TextCompletion returned error: %v", llmtests.GetErrorMessage(err))
}
if resp == nil || len(resp.Choices) == 0 || resp.Choices[0].Text == nil || *resp.Choices[0].Text == "" {
t.Fatalf("unexpected text completion response: %#v", resp)
}
},
},
{
name: "Responses",
expectedPath: "/v1/responses",
run: func(t *testing.T, provider *fireworksprovider.FireworksProvider, ctx *schemas.BifrostContext, key schemas.Key) {
resp, err := provider.Responses(ctx, key, &schemas.BifrostResponsesRequest{
Provider: schemas.Fireworks,
Model: "accounts/fireworks/models/deepseek-v3p2",
Input: []schemas.ResponsesMessage{
llmtests.CreateBasicResponsesMessage("hello"),
},
Params: &schemas.ResponsesParameters{
PreviousResponseID: schemas.Ptr("resp_previous"),
MaxToolCalls: schemas.Ptr(2),
Store: schemas.Ptr(true),
},
})
if err != nil {
t.Fatalf("Responses returned error: %v", llmtests.GetErrorMessage(err))
}
if resp == nil || resp.PreviousResponseID == nil || *resp.PreviousResponseID != "resp_previous" {
t.Fatalf("unexpected responses response: %#v", resp)
}
},
},
{
name: "Embedding",
expectedPath: "/v1/embeddings",
run: func(t *testing.T, provider *fireworksprovider.FireworksProvider, ctx *schemas.BifrostContext, key schemas.Key) {
resp, err := provider.Embedding(ctx, key, &schemas.BifrostEmbeddingRequest{
Provider: schemas.Fireworks,
Model: "accounts/fireworks/models/nomic-embed-text-v1.5",
Input: &schemas.EmbeddingInput{
Text: schemas.Ptr("embedding test"),
},
})
if err != nil {
t.Fatalf("Embedding returned error: %v", llmtests.GetErrorMessage(err))
}
if resp == nil || len(resp.Data) != 1 || len(resp.Data[0].Embedding.EmbeddingArray) != 3 {
t.Fatalf("unexpected embedding response: %#v", resp)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var requestedPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestedPath = r.URL.Path
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/v1/completions":
_, _ = fmt.Fprint(w, `{"id":"cmpl_1","object":"text_completion","created":1,"model":"accounts/fireworks/models/deepseek-v3p2","choices":[{"text":" banana","index":0,"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"completion_tokens":1,"total_tokens":5}}`)
case "/v1/responses":
_, _ = fmt.Fprint(w, `{"id":"resp_1","object":"response","created_at":1,"status":"completed","model":"accounts/fireworks/models/deepseek-v3p2","output":[{"id":"msg_1","type":"message","status":"completed","role":"assistant","content":[{"type":"output_text","text":"hello","annotations":[],"logprobs":[]}]}],"previous_response_id":"resp_previous","max_tool_calls":2,"store":true,"usage":{"input_tokens":1,"input_tokens_details":{"cached_tokens":0,"cached_read_tokens":0,"cached_write_tokens":0},"output_tokens":1,"total_tokens":2}}`)
case "/v1/embeddings":
_, _ = fmt.Fprint(w, `{"object":"list","model":"accounts/fireworks/models/nomic-embed-text-v1.5","data":[{"object":"embedding","index":0,"embedding":[0.1,0.2,0.3]}],"usage":{"prompt_tokens":2,"total_tokens":2}}`)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
provider := newTestFireworksProvider(t, server.URL)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
key := schemas.Key{Value: *schemas.NewEnvVar("test-key")}
tt.run(t, provider, ctx, key)
if requestedPath != tt.expectedPath {
t.Fatalf("expected request path %q, got %q", tt.expectedPath, requestedPath)
}
})
}
}
// TestFireworksResponsesStreamUsesNativeResponsesEndpoint verifies that Fireworks responses streaming targets the native responses endpoint.
func TestFireworksResponsesStreamUsesNativeResponsesEndpoint(t *testing.T) {
var requestedPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestedPath = r.URL.Path
if r.URL.Path != "/v1/responses" {
http.NotFound(w, r)
return
}
w.Header().Set("Content-Type", "text/event-stream")
_, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"sequence_number\":0,\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":1,\"status\":\"completed\",\"model\":\"accounts/fireworks/models/deepseek-v3p2\",\"output\":[{\"id\":\"msg_1\",\"type\":\"message\",\"status\":\"completed\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello\",\"annotations\":[],\"logprobs\":[]}]}],\"usage\":{\"input_tokens\":1,\"input_tokens_details\":{\"cached_tokens\":0,\"cached_read_tokens\":0,\"cached_write_tokens\":0},\"output_tokens\":1,\"total_tokens\":2}}}\n\n")
}))
defer server.Close()
provider := newTestFireworksProvider(t, server.URL)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
key := schemas.Key{Value: *schemas.NewEnvVar("test-key")}
postHookRunner := func(_ *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return result, err
}
stream, err := provider.ResponsesStream(ctx, postHookRunner, nil, key, &schemas.BifrostResponsesRequest{
Provider: schemas.Fireworks,
Model: "accounts/fireworks/models/deepseek-v3p2",
Input: []schemas.ResponsesMessage{
llmtests.CreateBasicResponsesMessage("hello"),
},
})
if err != nil {
t.Fatalf("ResponsesStream returned error: %v", llmtests.GetErrorMessage(err))
}
sawCompleted := false
for chunk := range stream {
if chunk != nil && chunk.BifrostResponsesStreamResponse != nil &&
chunk.BifrostResponsesStreamResponse.Type == schemas.ResponsesStreamResponseTypeCompleted {
sawCompleted = true
}
}
if requestedPath != "/v1/responses" {
t.Fatalf("expected responses stream to hit /v1/responses, got %q", requestedPath)
}
if !sawCompleted {
t.Fatal("expected a completed responses stream chunk")
}
}
// newTestFireworksProvider creates a Fireworks provider configured to hit a local test server.
func newTestFireworksProvider(t *testing.T, baseURL string) *fireworksprovider.FireworksProvider {
t.Helper()
provider, err := fireworksprovider.NewFireworksProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{
BaseURL: baseURL,
DefaultRequestTimeoutInSeconds: 30,
},
}, bifrost.NewNoOpLogger())
if err != nil {
t.Fatalf("failed to create Fireworks provider: %v", err)
}
return provider
}

View File

@@ -0,0 +1,370 @@
package gemini
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// ToBifrostBatchStatus converts Gemini batch job state to Bifrost status.
func ToBifrostBatchStatus(geminiState string) schemas.BatchStatus {
switch geminiState {
case GeminiBatchStatePending, GeminiBatchStateRunning:
return schemas.BatchStatusInProgress
case GeminiBatchStateSucceeded:
return schemas.BatchStatusCompleted
case GeminiBatchStateFailed:
return schemas.BatchStatusFailed
case GeminiBatchStateCancelling:
return schemas.BatchStatusCancelling
case GeminiBatchStateCancelled:
return schemas.BatchStatusCancelled
case GeminiBatchStateExpired:
return schemas.BatchStatusExpired
default:
return schemas.BatchStatus(geminiState)
}
}
// ToGeminiBatchStatus converts Bifrost batch status to Gemini batch job state.
func ToGeminiBatchStatus(status schemas.BatchStatus) string {
switch status {
case schemas.BatchStatusValidating, schemas.BatchStatusInProgress:
return GeminiBatchStateRunning
case schemas.BatchStatusFinalizing:
return GeminiBatchStateRunning
case schemas.BatchStatusCompleted, schemas.BatchStatusEnded:
return GeminiBatchStateSucceeded
case schemas.BatchStatusFailed:
return GeminiBatchStateFailed
case schemas.BatchStatusCancelling:
return GeminiBatchStateCancelling
case schemas.BatchStatusCancelled:
return GeminiBatchStateCancelled
case schemas.BatchStatusExpired:
return GeminiBatchStateExpired
default:
return GeminiBatchStateUnspecified
}
}
// ToGeminiBatchJobResponse converts Bifrost batch create response to Gemini batch job response format.
func ToGeminiBatchJobResponse(resp *schemas.BifrostBatchCreateResponse) *GeminiBatchJobResponse {
if resp == nil {
return nil
}
succeededCount := resp.RequestCounts.Succeeded
if succeededCount == 0 {
succeededCount = resp.RequestCounts.Completed
}
geminiResp := &GeminiBatchJobResponse{
Name: resp.ID,
Metadata: &GeminiBatchMetadata{
Name: resp.ID,
Type: "type.googleapis.com/google.ai.generativelanguage.v1beta.BatchPredictionJob",
CreateTime: formatGeminiTimestamp(resp.CreatedAt),
UpdateTime: formatGeminiTimestamp(resp.CreatedAt),
State: ToGeminiBatchStatus(resp.Status),
BatchStats: &GeminiBatchStats{
RequestCount: resp.RequestCounts.Total,
PendingRequestCount: max(0, resp.RequestCounts.Total-succeededCount-resp.RequestCounts.Failed),
SuccessfulRequestCount: succeededCount,
},
},
}
if resp.OperationName != nil && *resp.OperationName != "" {
geminiResp.Metadata.Name = *resp.OperationName
geminiResp.Name = *resp.OperationName
}
if resp.InputFileID != "" {
geminiResp.Metadata.InputConfig = &GeminiBatchMetadataInputConfig{
FileName: resp.InputFileID,
}
}
if resp.OutputFileID != nil && *resp.OutputFileID != "" {
geminiResp.Dest = &GeminiBatchDest{
FileName: *resp.OutputFileID,
}
geminiResp.Metadata.Output = &GeminiBatchMetadataOutputConfig{
ResponsesFile: *resp.OutputFileID,
}
}
if resp.Status == schemas.BatchStatusCompleted ||
resp.Status == schemas.BatchStatusEnded ||
resp.Status == schemas.BatchStatusFailed ||
resp.Status == schemas.BatchStatusExpired ||
resp.Status == schemas.BatchStatusCancelled {
geminiResp.Done = true
}
return geminiResp
}
// ToGeminiBatchRetrieveResponse converts a Bifrost batch retrieve response to Gemini batch job response format.
func ToGeminiBatchRetrieveResponse(resp *schemas.BifrostBatchRetrieveResponse) *GeminiBatchJobResponse {
if resp == nil {
return nil
}
succeededCount := resp.RequestCounts.Succeeded
if succeededCount == 0 {
succeededCount = resp.RequestCounts.Completed
}
pendingCount := resp.RequestCounts.Pending
if pendingCount == 0 && resp.RequestCounts.Total > 0 {
processedCount := resp.RequestCounts.Completed
if processedCount == 0 {
processedCount = succeededCount
}
pendingCount = resp.RequestCounts.Total - processedCount - resp.RequestCounts.Failed
if pendingCount < 0 {
pendingCount = 0
}
}
geminiResp := &GeminiBatchJobResponse{
Name: resp.ID,
Metadata: &GeminiBatchMetadata{
Name: resp.ID,
Type: "type.googleapis.com/google.ai.generativelanguage.v1beta.BatchPredictionJob",
CreateTime: formatGeminiTimestamp(resp.CreatedAt),
UpdateTime: formatGeminiTimestamp(resp.CreatedAt),
State: ToGeminiBatchStatus(resp.Status),
BatchStats: &GeminiBatchStats{
RequestCount: resp.RequestCounts.Total,
PendingRequestCount: pendingCount,
SuccessfulRequestCount: succeededCount,
},
},
}
if resp.OperationName != nil && *resp.OperationName != "" {
geminiResp.Metadata.Name = *resp.OperationName
geminiResp.Name = *resp.OperationName
}
if resp.Done != nil {
geminiResp.Done = *resp.Done
} else {
geminiResp.Done = resp.Status == schemas.BatchStatusCompleted ||
resp.Status == schemas.BatchStatusEnded ||
resp.Status == schemas.BatchStatusFailed ||
resp.Status == schemas.BatchStatusExpired ||
resp.Status == schemas.BatchStatusCancelled
}
if resp.InputFileID != "" {
geminiResp.Metadata.InputConfig = &GeminiBatchMetadataInputConfig{
FileName: resp.InputFileID,
}
}
if resp.OutputFileID != nil && *resp.OutputFileID != "" {
geminiResp.Dest = &GeminiBatchDest{
FileName: *resp.OutputFileID,
}
geminiResp.Metadata.Output = &GeminiBatchMetadataOutputConfig{
ResponsesFile: *resp.OutputFileID,
}
}
// Set end time from the most relevant terminal timestamp
var endTime int64
if resp.CompletedAt != nil {
endTime = *resp.CompletedAt
} else if resp.FailedAt != nil {
endTime = *resp.FailedAt
} else if resp.ExpiredAt != nil {
endTime = *resp.ExpiredAt
} else if resp.CancelledAt != nil {
endTime = *resp.CancelledAt
}
if endTime > 0 {
geminiResp.Metadata.EndTime = formatGeminiTimestamp(endTime)
}
return geminiResp
}
// ToGeminiBatchListResponse converts a Bifrost batch list response to Gemini format.
func ToGeminiBatchListResponse(resp *schemas.BifrostBatchListResponse) *GeminiBatchListResponse {
if resp == nil {
return nil
}
operations := make([]GeminiBatchJobResponse, 0, len(resp.Data))
for i := range resp.Data {
if geminiResp := ToGeminiBatchRetrieveResponse(&resp.Data[i]); geminiResp != nil {
operations = append(operations, *geminiResp)
}
}
geminiListResp := &GeminiBatchListResponse{
Operations: operations,
}
if resp.NextCursor != nil {
geminiListResp.NextPageToken = *resp.NextCursor
}
return geminiListResp
}
// parseGeminiTimestamp converts Gemini RFC3339 timestamp to Unix timestamp.
func parseGeminiTimestamp(timestamp string) int64 {
if timestamp == "" {
return 0
}
t, err := time.Parse(time.RFC3339, timestamp)
if err != nil {
return 0
}
return t.Unix()
}
// extractBatchIDFromName extracts the batch ID from the full resource name.
// e.g., "batches/abc123" -> "abc123"
func extractBatchIDFromName(name string) string {
if name == "" {
return ""
}
parts := strings.Split(name, "/")
return parts[len(parts)-1]
}
// downloadBatchResultsFile downloads and parses a batch results file from Gemini.
// Returns the parsed result items from the JSONL file and any parse errors encountered.
func (provider *GeminiProvider) downloadBatchResultsFile(ctx context.Context, key schemas.Key, fileName string) ([]schemas.BatchResultItem, []schemas.BatchError, *schemas.BifrostError) {
// Create request to download the file
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Build download URL - use the download endpoint with alt=media
// The base URL is like https://generativelanguage.googleapis.com/v1beta
// We need to change it to https://generativelanguage.googleapis.com/download/v1beta
baseURL := strings.Replace(provider.networkConfig.BaseURL, "/v1beta", "/download/v1beta", 1)
// Ensure fileName has proper format
fileID := fileName
if !strings.HasPrefix(fileID, "files/") {
fileID = "files/" + fileID
}
url := fmt.Sprintf("%s/%s:download?alt=media", baseURL, fileID)
provider.logger.Debug("gemini batch results file download url: " + url)
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(url)
req.Header.SetMethod(http.MethodGet)
if key.Value.GetValue() != "" {
req.Header.Set("x-goog-api-key", key.Value.GetValue())
}
// Make request
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, nil, bifrostErr
}
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
return nil, nil, parseGeminiError(resp)
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
// Parse JSONL content - each line is a separate JSON object
// Use streaming parser to avoid string conversion and collect parse errors
results := make([]schemas.BatchResultItem, 0)
parseResult := providerUtils.ParseJSONL(body, func(line []byte) error {
var resultLine GeminiBatchFileResultLine
if err := sonic.Unmarshal(line, &resultLine); err != nil {
provider.logger.Warn("gemini batch results file parse error: " + err.Error())
return err
}
customID := resultLine.Key
if customID == "" {
customID = fmt.Sprintf("request-%d", len(results))
}
resultItem := schemas.BatchResultItem{
CustomID: customID,
}
if resultLine.Error != nil {
resultItem.Error = &schemas.BatchResultError{
Code: fmt.Sprintf("%d", resultLine.Error.Code),
Message: resultLine.Error.Message,
}
} else if resultLine.Response != nil {
// Convert the response to a map for the Body field
respBody := make(map[string]interface{})
if len(resultLine.Response.Candidates) > 0 {
candidate := resultLine.Response.Candidates[0]
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
var textParts []string
for _, part := range candidate.Content.Parts {
if part.Text != "" {
textParts = append(textParts, part.Text)
}
}
if len(textParts) > 0 {
respBody["text"] = strings.Join(textParts, "")
}
}
respBody["finish_reason"] = string(candidate.FinishReason)
}
if resultLine.Response.UsageMetadata != nil {
respBody["usage"] = map[string]interface{}{
"prompt_tokens": resultLine.Response.UsageMetadata.PromptTokenCount,
"completion_tokens": resultLine.Response.UsageMetadata.CandidatesTokenCount,
"total_tokens": resultLine.Response.UsageMetadata.TotalTokenCount,
}
}
resultItem.Response = &schemas.BatchResultResponse{
StatusCode: 200,
Body: respBody,
}
}
results = append(results, resultItem)
return nil
})
return results, parseResult.Errors, nil
}
// extractGeminiUsageMetadata extracts usage metadata (as ints) from Gemini response
func extractGeminiUsageMetadata(geminiResponse *GenerateContentResponse) (int, int, int) {
var inputTokens, outputTokens, totalTokens int
if geminiResponse.UsageMetadata != nil {
usageMetadata := geminiResponse.UsageMetadata
inputTokens = int(usageMetadata.PromptTokenCount)
outputTokens = int(usageMetadata.CandidatesTokenCount)
totalTokens = int(usageMetadata.TotalTokenCount)
}
return inputTokens, outputTokens, totalTokens
}

View File

@@ -0,0 +1,558 @@
package gemini
import (
"encoding/base64"
"fmt"
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
// ToGeminiChatCompletionRequest converts a BifrostChatRequest to Gemini's generation request format for chat completion
func ToGeminiChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*GeminiGenerationRequest, error) {
if bifrostReq == nil {
return nil, nil
}
// Create the base Gemini generation request
geminiReq := &GeminiGenerationRequest{
Model: bifrostReq.Model,
}
// Convert parameters to generation config
if bifrostReq.Params != nil {
geminiReq.ExtraParams = bifrostReq.Params.ExtraParams
var err error
geminiReq.GenerationConfig, err = convertParamsToGenerationConfig(bifrostReq.Params, []string{}, bifrostReq.Model)
if err != nil {
return nil, err
}
// Handle tool-related parameters
if len(bifrostReq.Params.Tools) > 0 {
geminiReq.Tools = convertBifrostToolsToGemini(bifrostReq.Params.Tools)
// Convert tool choice to tool config
if bifrostReq.Params.ToolChoice != nil {
geminiReq.ToolConfig = convertToolChoiceToToolConfig(bifrostReq.Params.ToolChoice)
}
}
// Handle extra parameters
if bifrostReq.Params.ExtraParams != nil {
// Safety settings
if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok {
delete(geminiReq.ExtraParams, "safety_settings")
if settings, ok := SafeExtractSafetySettings(safetySettings); ok {
geminiReq.SafetySettings = settings
}
}
// Cached content
if cachedContent, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["cached_content"]); ok {
delete(geminiReq.ExtraParams, "cached_content")
geminiReq.CachedContent = cachedContent
}
// Labels
if labels, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "labels"); ok {
delete(geminiReq.ExtraParams, "labels")
if labelMap, ok := schemas.SafeExtractStringMap(labels); ok {
geminiReq.Labels = labelMap
}
}
}
}
// Convert chat completion messages to Gemini format
contents, systemInstruction := convertBifrostMessagesToGemini(bifrostReq.Input)
if systemInstruction != nil {
geminiReq.SystemInstruction = systemInstruction
}
geminiReq.Contents = contents
return geminiReq, nil
}
// ToBifrostChatResponse converts a GenerateContentResponse to a BifrostChatResponse
func (response *GenerateContentResponse) ToBifrostChatResponse() *schemas.BifrostChatResponse {
bifrostResp := &schemas.BifrostChatResponse{
ID: response.ResponseID,
Model: response.ModelVersion,
Object: "chat.completion",
}
// Set creation timestamp if available
if !response.CreateTime.IsZero() {
bifrostResp.Created = int(response.CreateTime.Unix())
}
// Handle empty candidates (filtered/malformed responses)
if len(response.Candidates) == 0 {
finishReason := ConvertGeminiFinishReasonToBifrost(FinishReasonMalformedFunctionCall)
return createErrorResponse(response, finishReason, false)
}
candidate := response.Candidates[0]
// Check for filtered finish reasons that indicate errors
if isErrorFinishReason(candidate.FinishReason) {
finishReason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason)
return createErrorResponse(response, finishReason, false)
}
// Collect all content and tool calls into a single message
var toolCalls []schemas.ChatAssistantMessageToolCall
var contentBlocks []schemas.ChatContentBlock
var reasoningDetails []schemas.ChatReasoningDetails
var contentStr *string
// Process candidate content to extract text, tool calls, and reasoning
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
for _, part := range candidate.Content.Parts {
// Handle thought/reasoning text separately - add to reasoning details
if part.Text != "" && part.Thought {
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
Index: len(reasoningDetails),
Type: schemas.BifrostReasoningDetailsTypeText,
Text: &part.Text,
})
continue
}
// Handle regular text
if part.Text != "" {
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
Type: schemas.ChatContentBlockTypeText,
Text: &part.Text,
})
// Add thought signature to reasoning details if present with text
if len(part.ThoughtSignature) > 0 {
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
Index: len(reasoningDetails),
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
Signature: &thoughtSig,
})
}
}
if part.FunctionCall != nil {
function := schemas.ChatAssistantMessageToolCallFunction{
Name: &part.FunctionCall.Name,
}
if len(part.FunctionCall.Args) > 0 {
function.Arguments = string(part.FunctionCall.Args)
}
callID := part.FunctionCall.Name
if part.FunctionCall.ID != "" {
callID = part.FunctionCall.ID
}
// Embed thought signature into CallID if present (matches responses.go pattern)
if len(part.ThoughtSignature) > 0 && !strings.Contains(callID, thoughtSignatureSeparator) {
encoded := base64.RawURLEncoding.EncodeToString(part.ThoughtSignature)
callID = fmt.Sprintf("%s%s%s", callID, thoughtSignatureSeparator, encoded)
}
toolCall := schemas.ChatAssistantMessageToolCall{
Index: uint16(len(toolCalls)),
Type: schemas.Ptr(string(schemas.ChatToolChoiceTypeFunction)),
ID: &callID,
Function: function,
}
toolCalls = append(toolCalls, toolCall)
// Also add to reasoning details for backward compatibility
if len(part.ThoughtSignature) > 0 {
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
// Extract base ID without signature for reasoning detail lookup
baseCallID := callID
if strings.Contains(callID, thoughtSignatureSeparator) {
parts := strings.SplitN(callID, thoughtSignatureSeparator, 2)
if len(parts) == 2 {
baseCallID = parts[0]
}
}
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
Index: len(reasoningDetails),
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
Signature: &thoughtSig,
ID: schemas.Ptr(fmt.Sprintf("tool_call_%s", baseCallID)),
})
}
}
if part.FunctionResponse != nil {
// Extract the output from the response
output := extractFunctionResponseOutput(part.FunctionResponse)
// Add as text content block
if output != "" {
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
Type: schemas.ChatContentBlockTypeText,
Text: &output,
})
}
}
// Handle code execution results
if part.CodeExecutionResult != nil {
output := part.CodeExecutionResult.Output
if part.CodeExecutionResult.Outcome != OutcomeOK {
output = "Error: " + output
}
if output != "" {
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
Type: schemas.ChatContentBlockTypeText,
Text: &output,
})
}
}
// Handle executable code
if part.ExecutableCode != nil {
codeContent := "```" + part.ExecutableCode.Language + "\n" + part.ExecutableCode.Code + "\n```"
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
Type: schemas.ChatContentBlockTypeText,
Text: &codeContent,
})
}
// Handle standalone thought signature (not associated with function call or text)
if len(part.ThoughtSignature) > 0 && part.FunctionCall == nil && part.Text == "" {
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
Index: len(reasoningDetails),
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
Signature: &thoughtSig,
})
}
}
// Build the choice with message
message := &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
}
if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText {
contentStr = contentBlocks[0].Text
contentBlocks = nil
}
message.Content = &schemas.ChatMessageContent{
ContentStr: contentStr,
ContentBlocks: contentBlocks,
}
if len(toolCalls) > 0 || len(reasoningDetails) > 0 {
message.ChatAssistantMessage = &schemas.ChatAssistantMessage{
ToolCalls: toolCalls,
ReasoningDetails: reasoningDetails,
}
}
// Convert finish reason to Bifrost format.
// Gemini uses "STOP" for both normal text completions and tool call responses —
// it has no dedicated finish reason for tool calls. Override to "tool_calls" when
// tool calls are present so downstream consumers see a uniform signal.
finishReason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason)
if len(toolCalls) > 0 && finishReason == "stop" {
finishReason = "tool_calls"
}
bifrostResp.Choices = append(bifrostResp.Choices, schemas.BifrostResponseChoice{
Index: 0,
FinishReason: &finishReason,
LogProbs: ConvertGeminiLogprobsResultToBifrost(candidate.LogprobsResult),
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: message,
},
})
}
// Set usage information
bifrostResp.Usage = ConvertGeminiUsageMetadataToChatUsage(response.UsageMetadata)
return bifrostResp
}
// GeminiStreamState tracks tool-call index across streaming chunks.
type GeminiStreamState struct {
nextToolCallIndex int
hadToolCalls bool // true if any tool calls were seen in this stream
}
// NewGeminiStreamState returns initialised stream state for one streaming response.
func NewGeminiStreamState() *GeminiStreamState {
return &GeminiStreamState{}
}
// ToBifrostChatCompletionStream converts a Gemini streaming response to a Bifrost Chat Completion Stream response
// Returns the response, error (if any), and a boolean indicating if this is the last chunk
func (response *GenerateContentResponse) ToBifrostChatCompletionStream(state *GeminiStreamState) (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) {
if response == nil {
return nil, nil, false
}
if state == nil {
state = NewGeminiStreamState()
}
// Handle empty candidates (filtered/malformed responses)
if len(response.Candidates) == 0 {
finishReason := ConvertGeminiFinishReasonToBifrost(FinishReasonMalformedFunctionCall)
return createErrorResponse(response, finishReason, true), nil, true
}
candidate := response.Candidates[0]
// Check for filtered finish reasons that indicate errors
if isErrorFinishReason(candidate.FinishReason) {
finishReason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason)
return createErrorResponse(response, finishReason, true), nil, true
}
// Determine if this is the last chunk based on finish reason and usage metadata
isLastChunk := candidate.FinishReason != "" && response.UsageMetadata != nil
// Create the streaming response
streamResponse := &schemas.BifrostChatResponse{
ID: response.ResponseID,
Model: response.ModelVersion,
Object: "chat.completion.chunk",
}
// Set creation timestamp if available
if !response.CreateTime.IsZero() {
streamResponse.Created = int(response.CreateTime.Unix())
}
// Build delta content
delta := &schemas.ChatStreamResponseChoiceDelta{}
// Process content parts
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
// Set role from the first chunk (Gemini uses "model" for assistant)
if candidate.Content.Role != "" {
role := candidate.Content.Role
if role == string(RoleModel) {
role = string(schemas.ChatMessageRoleAssistant)
}
delta.Role = &role
}
var textContent string
var toolCalls []schemas.ChatAssistantMessageToolCall
var reasoningDetails []schemas.ChatReasoningDetails
for _, part := range candidate.Content.Parts {
switch {
case part.Text != "" && part.Thought:
// Thought/reasoning content - add to reasoning details
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
Index: len(reasoningDetails),
Type: schemas.BifrostReasoningDetailsTypeText,
Text: &part.Text,
})
case part.Text != "":
// Regular text content
textContent += part.Text
case part.FunctionCall != nil:
// Function call
jsonArgs := ""
if len(part.FunctionCall.Args) > 0 {
jsonArgs = string(part.FunctionCall.Args)
}
// Use ID if available, otherwise use function name
callID := part.FunctionCall.Name
if part.FunctionCall.ID != "" {
callID = part.FunctionCall.ID
}
// Embed thought signature into CallID if present
if len(part.ThoughtSignature) > 0 && !strings.Contains(callID, thoughtSignatureSeparator) {
encoded := base64.RawURLEncoding.EncodeToString(part.ThoughtSignature)
callID = fmt.Sprintf("%s%s%s", callID, thoughtSignatureSeparator, encoded)
}
toolCallIdx := state.nextToolCallIndex
state.nextToolCallIndex++
toolCall := schemas.ChatAssistantMessageToolCall{
Index: uint16(toolCallIdx),
Type: schemas.Ptr(string(schemas.ChatToolTypeFunction)),
ID: &callID,
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: &part.FunctionCall.Name,
Arguments: jsonArgs,
},
}
toolCalls = append(toolCalls, toolCall)
// Also add thought signature to reasoning details if present
if len(part.ThoughtSignature) > 0 {
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
// Extract base ID without signature for reasoning detail lookup
baseCallID := callID
if strings.Contains(callID, thoughtSignatureSeparator) {
parts := strings.SplitN(callID, thoughtSignatureSeparator, 2)
if len(parts) == 2 {
baseCallID = parts[0]
}
}
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
Index: len(reasoningDetails),
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
Signature: &thoughtSig,
ID: schemas.Ptr(fmt.Sprintf("tool_call_%s", baseCallID)),
})
}
case part.FunctionResponse != nil:
// Extract the output from the response and add to text content
output := extractFunctionResponseOutput(part.FunctionResponse)
if output != "" {
textContent += output
}
case part.CodeExecutionResult != nil:
output := part.CodeExecutionResult.Output
if part.CodeExecutionResult.Outcome != OutcomeOK {
output = "Error: " + output
}
if output != "" {
textContent += output
}
case part.ExecutableCode != nil:
codeContent := "```" + part.ExecutableCode.Language + "\n" + part.ExecutableCode.Code + "\n```"
textContent += codeContent
}
// Handle thought signature separately (not part of the switch since it can co-exist with other types)
if len(part.ThoughtSignature) > 0 && part.FunctionCall == nil {
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
Index: len(reasoningDetails),
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
Signature: &thoughtSig,
})
}
}
// Set text content if present
if textContent != "" {
delta.Content = &textContent
}
// Set reasoning details if present
if len(reasoningDetails) > 0 {
delta.ReasoningDetails = reasoningDetails
}
// Set tool calls if present
if len(toolCalls) > 0 {
delta.ToolCalls = toolCalls
state.hadToolCalls = true
}
}
// Check if delta has any content - if not and it's not the last chunk, skip it
hasDeltaContent := delta.Role != nil || delta.Content != nil || len(delta.ToolCalls) > 0 || len(delta.ReasoningDetails) > 0
if !hasDeltaContent && !isLastChunk {
return nil, nil, false
}
// Build the choice
var finishReason *string
if isLastChunk && candidate.FinishReason != "" {
reason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason)
// Gemini uses "STOP" for both text completions and tool call responses.
// Override to "tool_calls" when tool calls were seen in this stream for uniformity.
if (len(delta.ToolCalls) > 0 || state.hadToolCalls) && reason == "stop" {
reason = "tool_calls"
}
finishReason = &reason
}
choice := schemas.BifrostResponseChoice{
Index: int(candidate.Index),
FinishReason: finishReason,
LogProbs: ConvertGeminiLogprobsResultToBifrost(candidate.LogprobsResult),
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: delta,
},
}
streamResponse.Choices = []schemas.BifrostResponseChoice{choice}
// Add usage information if this is the last chunk
if isLastChunk && response.UsageMetadata != nil {
streamResponse.Usage = ConvertGeminiUsageMetadataToChatUsage(response.UsageMetadata)
}
return streamResponse, nil, isLastChunk
}
// isErrorFinishReason checks if a finish reason indicates a filtered or error response
func isErrorFinishReason(reason FinishReason) bool {
return reason == FinishReasonSafety ||
reason == FinishReasonRecitation ||
reason == FinishReasonMalformedFunctionCall ||
reason == FinishReasonBlocklist ||
reason == FinishReasonProhibitedContent ||
reason == FinishReasonSPII ||
reason == FinishReasonImageSafety ||
reason == FinishReasonUnexpectedToolCall ||
reason == FinishReasonMissingThoughtSignature ||
reason == FinishReasonMalformedResponse ||
reason == FinishReasonImageProhibitedContent ||
reason == FinishReasonImageRecitation ||
reason == FinishReasonTooManyToolCalls ||
reason == FinishReasonNoImage
}
// createErrorResponse creates a complete BifrostChatResponse for error cases
func createErrorResponse(response *GenerateContentResponse, finishReason string, isStream bool) *schemas.BifrostChatResponse {
var choice schemas.BifrostResponseChoice
if isStream {
choice = schemas.BifrostResponseChoice{
Index: 0,
FinishReason: &finishReason,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{},
},
}
} else {
choice = schemas.BifrostResponseChoice{
Index: 0,
FinishReason: &finishReason,
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
Message: &schemas.ChatMessage{
Role: schemas.ChatMessageRoleAssistant,
Content: &schemas.ChatMessageContent{},
},
},
}
}
objectType := "chat.completion"
if isStream {
objectType = "chat.completion.chunk"
}
errorResp := &schemas.BifrostChatResponse{
ID: response.ResponseID,
Model: response.ModelVersion,
Object: objectType,
Choices: []schemas.BifrostResponseChoice{choice},
Usage: ConvertGeminiUsageMetadataToChatUsage(response.UsageMetadata),
}
if !response.CreateTime.IsZero() {
errorResp.Created = int(response.CreateTime.Unix())
}
return errorResp
}

View File

@@ -0,0 +1,80 @@
package gemini
import (
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostCountTokensResponse converts a Gemini count tokens response to Bifrost format.
func (resp *GeminiCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
if resp == nil {
return nil
}
// Sum prompt tokens and map modality-specific counts
inputTokens := 0
inputDetails := &schemas.ResponsesResponseInputTokens{}
for _, m := range resp.PromptTokensDetails {
if m == nil {
continue
}
inputTokens += int(m.TokenCount)
mod := strings.ToLower(string(m.Modality))
// handle audio modality
if strings.Contains(mod, "audio") {
inputDetails.AudioTokens += int(m.TokenCount)
}
}
// Set cached tokens from top-level field if present
if resp.CachedContentTokenCount != 0 {
inputDetails.CachedReadTokens = int(resp.CachedContentTokenCount)
} else if resp.CacheTokensDetails != nil {
// If cache tokens details present, sum them
cachedSum := 0
for _, m := range resp.CacheTokensDetails {
if m == nil {
continue
}
cachedSum += int(m.TokenCount)
if strings.Contains(strings.ToLower(string(m.Modality)), "audio") {
// also populate audio tokens from cache into AudioTokens (additive)
inputDetails.AudioTokens += int(m.TokenCount)
}
}
inputDetails.CachedReadTokens = cachedSum
}
total := int(resp.TotalTokens)
return &schemas.BifrostCountTokensResponse{
Model: model,
Object: "response.input_tokens",
InputTokens: inputTokens,
InputTokensDetails: inputDetails,
TotalTokens: &total,
ExtraFields: schemas.BifrostResponseExtraFields{},
}
}
// ToGeminiCountTokensResponse converts a Bifrost count tokens response to Gemini format.
func ToGeminiCountTokensResponse(bifrostResp *schemas.BifrostCountTokensResponse) *GeminiCountTokensResponse {
if bifrostResp == nil {
return nil
}
response := &GeminiCountTokensResponse{
TotalTokens: int32(bifrostResp.InputTokens),
}
// Map cached content token count if available
if bifrostResp.InputTokensDetails != nil && bifrostResp.InputTokensDetails.CachedReadTokens > 0 {
response.CachedContentTokenCount = int32(bifrostResp.InputTokensDetails.CachedReadTokens)
} else {
response.CachedContentTokenCount = 0
}
return response
}

View File

@@ -0,0 +1,247 @@
package gemini
import (
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToGeminiEmbeddingRequest converts a BifrostRequest with embedding input to Gemini's batch embedding request format
// GeminiGenerationRequest contains requests array for batch embed content endpoint
func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *GeminiBatchEmbeddingRequest {
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
return nil
}
embeddingInput := bifrostReq.Input
// Collect all texts to embed
var texts []string
if embeddingInput.Text != nil {
texts = append(texts, *embeddingInput.Text)
}
if len(embeddingInput.Texts) > 0 {
texts = append(texts, embeddingInput.Texts...)
}
if len(texts) == 0 {
return nil
}
// Create batch embedding request with one request per text
batchRequest := &GeminiBatchEmbeddingRequest{
Requests: make([]GeminiEmbeddingRequest, len(texts)),
}
if bifrostReq.Params != nil {
batchRequest.ExtraParams = bifrostReq.Params.ExtraParams
}
// Create individual embedding requests for each text
for i, text := range texts {
embeddingReq := GeminiEmbeddingRequest{
Model: "models/" + bifrostReq.Model,
Content: &Content{
Parts: []*Part{
{
Text: text,
},
},
},
}
// Add parameters if available
if bifrostReq.Params != nil {
if bifrostReq.Params.Dimensions != nil {
embeddingReq.OutputDimensionality = bifrostReq.Params.Dimensions
}
// Handle extra parameters
if bifrostReq.Params.ExtraParams != nil {
if taskType, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["taskType"]); ok {
delete(batchRequest.ExtraParams, "taskType")
embeddingReq.TaskType = taskType
}
if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok {
delete(batchRequest.ExtraParams, "title")
embeddingReq.Title = title
}
}
}
batchRequest.Requests[i] = embeddingReq
}
return batchRequest
}
// ToGeminiEmbeddingResponse converts a BifrostResponse with embedding data to Gemini's embedding response format
func ToGeminiEmbeddingResponse(bifrostResp *schemas.BifrostEmbeddingResponse) *GeminiEmbeddingResponse {
if bifrostResp == nil || len(bifrostResp.Data) == 0 {
return nil
}
geminiResp := &GeminiEmbeddingResponse{
Embeddings: make([]GeminiEmbedding, len(bifrostResp.Data)),
}
// Convert each embedding from Bifrost format to Gemini format
for i, embedding := range bifrostResp.Data {
var values []float64
// Extract embedding values from BifrostEmbeddingResponse
if embedding.Embedding.EmbeddingArray != nil {
values = append([]float64(nil), embedding.Embedding.EmbeddingArray...)
} else if len(embedding.Embedding.Embedding2DArray) > 0 {
// If it's a 2D array, take the first array
values = append([]float64(nil), embedding.Embedding.Embedding2DArray[0]...)
}
geminiEmbedding := GeminiEmbedding{
Values: values,
}
// Add statistics if available (token count from usage metadata)
if bifrostResp.Usage != nil {
geminiEmbedding.Statistics = &ContentEmbeddingStatistics{
TokenCount: int32(bifrostResp.Usage.PromptTokens),
}
}
geminiResp.Embeddings[i] = geminiEmbedding
}
// Set metadata if available (for Vertex API compatibility)
if bifrostResp.Usage != nil {
geminiResp.Metadata = &EmbedContentMetadata{
BillableCharacterCount: int32(bifrostResp.Usage.PromptTokens),
}
}
return geminiResp
}
// ToBifrostEmbeddingResponse converts a Gemini embedding response to BifrostEmbeddingResponse format
func ToBifrostEmbeddingResponse(geminiResp *GeminiEmbeddingResponse, model string) *schemas.BifrostEmbeddingResponse {
if geminiResp == nil || len(geminiResp.Embeddings) == 0 {
return nil
}
bifrostResp := &schemas.BifrostEmbeddingResponse{
Data: make([]schemas.EmbeddingData, len(geminiResp.Embeddings)),
Model: model,
Object: "list",
}
// Convert each embedding from Gemini format to Bifrost format
for i, geminiEmbedding := range geminiResp.Embeddings {
embeddingData := schemas.EmbeddingData{
Index: i,
Object: "embedding",
Embedding: schemas.EmbeddingStruct{
EmbeddingArray: geminiEmbedding.Values,
},
}
bifrostResp.Data[i] = embeddingData
}
// Convert usage metadata if available
if geminiResp.Metadata != nil || (len(geminiResp.Embeddings) > 0 && geminiResp.Embeddings[0].Statistics != nil) {
bifrostResp.Usage = &schemas.BifrostLLMUsage{}
// Use statistics from the first embedding if available
if geminiResp.Embeddings[0].Statistics != nil {
bifrostResp.Usage.PromptTokens = int(geminiResp.Embeddings[0].Statistics.TokenCount)
} else if geminiResp.Metadata != nil {
// Fall back to metadata if statistics are not available
bifrostResp.Usage.PromptTokens = int(geminiResp.Metadata.BillableCharacterCount)
}
// Set total tokens same as prompt tokens for embeddings
bifrostResp.Usage.TotalTokens = bifrostResp.Usage.PromptTokens
}
return bifrostResp
}
// ToBifrostEmbeddingRequest converts a GeminiGenerationRequest to BifrostEmbeddingRequest format
func (request *GeminiGenerationRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest {
if request == nil {
return nil
}
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
// Create the embedding request
bifrostReq := &schemas.BifrostEmbeddingRequest{
Provider: provider,
Model: model,
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
}
// SDK batch embedding request contains multiple embedding requests with same parameters but different text fields.
if len(request.Requests) > 0 {
var texts []string
for _, req := range request.Requests {
if req.Content != nil && len(req.Content.Parts) > 0 {
for _, part := range req.Content.Parts {
if part != nil && part.Text != "" {
texts = append(texts, part.Text)
}
}
}
}
if len(texts) > 0 {
bifrostReq.Input = &schemas.EmbeddingInput{}
if len(texts) == 1 {
bifrostReq.Input.Text = &texts[0]
} else {
bifrostReq.Input.Texts = texts
}
}
embeddingRequest := request.Requests[0]
// Convert parameters
if embeddingRequest.OutputDimensionality != nil || embeddingRequest.TaskType != nil || embeddingRequest.Title != nil {
bifrostReq.Params = &schemas.EmbeddingParameters{}
if embeddingRequest.OutputDimensionality != nil {
bifrostReq.Params.Dimensions = embeddingRequest.OutputDimensionality
}
// Handle extra parameters
if embeddingRequest.TaskType != nil || embeddingRequest.Title != nil {
bifrostReq.Params.ExtraParams = make(map[string]interface{})
if embeddingRequest.TaskType != nil {
bifrostReq.Params.ExtraParams["taskType"] = embeddingRequest.TaskType
}
if embeddingRequest.Title != nil {
bifrostReq.Params.ExtraParams["title"] = embeddingRequest.Title
}
}
}
}
// Generation-style requests (e.g., non-Imagen :predict) carry text in contents[].parts[].
// If no SDK requests[] were provided, derive embedding input from contents.
if bifrostReq.Input == nil {
var texts []string
for _, content := range request.Contents {
for _, part := range content.Parts {
if part != nil && part.Text != "" {
texts = append(texts, part.Text)
}
}
}
if len(texts) > 0 {
bifrostReq.Input = &schemas.EmbeddingInput{}
if len(texts) == 1 {
bifrostReq.Input.Text = &texts[0]
} else {
bifrostReq.Input.Texts = texts
}
}
}
return bifrostReq
}

View File

@@ -0,0 +1,79 @@
package gemini
import (
"strconv"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// ToGeminiError derives a GeminiGenerationError from a BifrostError
func ToGeminiError(bifrostErr *schemas.BifrostError) *GeminiGenerationError {
if bifrostErr == nil {
return nil
}
code := 500
status := ""
if bifrostErr.Error != nil && bifrostErr.Error.Type != nil {
status = *bifrostErr.Error.Type
}
message := ""
if bifrostErr.Error != nil && bifrostErr.Error.Message != "" {
message = bifrostErr.Error.Message
}
if bifrostErr.StatusCode != nil {
code = *bifrostErr.StatusCode
}
return &GeminiGenerationError{
Error: &GeminiGenerationErrorStruct{
Code: code,
Message: message,
Status: status,
},
}
}
// parseGeminiError parses Gemini error responses
func parseGeminiError(resp *fasthttp.Response) *schemas.BifrostError {
// Try to parse as []GeminiGenerationError
var errorResps []GeminiGenerationError
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResps)
if len(errorResps) > 0 {
var message string
var firstError *GeminiGenerationErrorStruct
for _, errorResp := range errorResps {
if errorResp.Error != nil {
if firstError == nil {
firstError = errorResp.Error
}
message = message + errorResp.Error.Message + "\n"
}
}
// Trim trailing newline
message = strings.TrimSuffix(message, "\n")
if bifrostErr.Error == nil {
bifrostErr.Error = &schemas.ErrorField{}
}
// Set Code from first error if available
if firstError != nil {
bifrostErr.Error.Code = schemas.Ptr(strconv.Itoa(firstError.Code))
}
// Set Message to trimmed concatenated message
bifrostErr.Error.Message = message
return bifrostErr
}
// Try to parse as GeminiGenerationError
var errorResp GeminiGenerationError
bifrostErr = providerUtils.HandleProviderAPIError(resp, &errorResp)
if errorResp.Error != nil {
if bifrostErr.Error == nil {
bifrostErr.Error = &schemas.ErrorField{}
}
bifrostErr.Error.Code = schemas.Ptr(strconv.Itoa(errorResp.Error.Code))
bifrostErr.Error.Message = errorResp.Error.Message
}
return bifrostErr
}

View File

@@ -0,0 +1,145 @@
package gemini
import (
"fmt"
"strings"
"time"
"github.com/maximhq/bifrost/core/schemas"
)
// Gemini Files API types
// The Gemini Files API allows uploading files for use with multimodal models.
// GeminiFileResponse represents a file object from Gemini's API.
type GeminiFileResponse struct {
Name string `json:"name"` // Resource name (e.g., "files/abc123")
DisplayName string `json:"displayName"` // User-provided display name
MimeType string `json:"mimeType"` // MIME type of the file
SizeBytes string `json:"sizeBytes"` // Size in bytes (as string)
CreateTime string `json:"createTime"` // RFC3339 timestamp
UpdateTime string `json:"updateTime"` // RFC3339 timestamp
ExpirationTime string `json:"expirationTime,omitempty"` // RFC3339 timestamp when file will be deleted
SHA256Hash string `json:"sha256Hash"` // Base64 encoded SHA256 hash
URI string `json:"uri"` // URI for accessing the file
State string `json:"state"` // "PROCESSING", "ACTIVE", "FAILED"
VideoMetadata *GeminiFileVideoMetadata `json:"videoMetadata,omitempty"`
}
// GeminiFileVideoMetadata contains video-specific metadata.
type GeminiFileVideoMetadata struct {
VideoDuration string `json:"videoDuration"` // Duration in seconds
}
// GeminiFileListResponse represents the response from listing files.
type GeminiFileListResponse struct {
Files []GeminiFileResponse `json:"files"`
NextPageToken string `json:"nextPageToken,omitempty"`
}
// ToBifrostFileStatus converts Gemini file state to Bifrost status.
func ToBifrostFileStatus(state string) schemas.FileStatus {
switch state {
case "PROCESSING":
return schemas.FileStatusProcessing
case "ACTIVE":
return schemas.FileStatusProcessed
case "FAILED":
return schemas.FileStatusError
default:
return schemas.FileStatus(strings.ToLower(state))
}
}
// ToGeminiFileListResponse converts a Bifrost file list response to Gemini format.
func ToGeminiFileListResponse(resp *schemas.BifrostFileListResponse) *GeminiFileListResponse {
files := make([]GeminiFileResponse, len(resp.Data))
for i, f := range resp.Data {
updateAt := f.UpdatedAt
if updateAt == 0 {
updateAt = f.CreatedAt
}
files[i] = GeminiFileResponse{
Name: f.ID,
DisplayName: f.Filename,
SizeBytes: fmt.Sprintf("%d", f.Bytes),
CreateTime: formatGeminiTimestamp(f.CreatedAt),
UpdateTime: formatGeminiTimestamp(updateAt),
State: toGeminiFileState(f.Status),
ExpirationTime: formatGeminiTimestamp(safeDerefInt64(f.ExpiresAt)),
}
}
result := &GeminiFileListResponse{Files: files}
if resp.After != nil && *resp.After != "" {
result.NextPageToken = *resp.After
}
return result
}
// ToGeminiFileRetrieveResponse converts a Bifrost file retrieve response to Gemini format.
func ToGeminiFileRetrieveResponse(resp *schemas.BifrostFileRetrieveResponse) *GeminiFileResponse {
updateAt := resp.UpdatedAt
if updateAt == 0 {
updateAt = resp.CreatedAt
}
return &GeminiFileResponse{
Name: resp.ID,
DisplayName: resp.Filename,
SizeBytes: fmt.Sprintf("%d", resp.Bytes),
CreateTime: formatGeminiTimestamp(resp.CreatedAt),
UpdateTime: formatGeminiTimestamp(updateAt),
State: toGeminiFileState(resp.Status),
URI: resp.StorageURI,
ExpirationTime: formatGeminiTimestamp(safeDerefInt64(resp.ExpiresAt)),
}
}
// toGeminiFileState converts Bifrost file status to Gemini state.
func toGeminiFileState(status schemas.FileStatus) string {
switch status {
case schemas.FileStatusProcessing:
return "PROCESSING"
case schemas.FileStatusProcessed:
return "ACTIVE"
case schemas.FileStatusError:
return "FAILED"
default:
return strings.ToUpper(string(status))
}
}
// formatGeminiTimestamp converts Unix timestamp to Gemini RFC3339 format.
func formatGeminiTimestamp(unixTime int64) string {
if unixTime == 0 {
return ""
}
return time.Unix(unixTime, 0).UTC().Format(time.RFC3339)
}
// safeDerefInt64 safely dereferences an int64 pointer.
func safeDerefInt64(ptr *int64) int64 {
if ptr == nil {
return 0
}
return *ptr
}
// ToGeminiFileUploadResponse converts a Bifrost file upload response to Gemini format.
func ToGeminiFileUploadResponse(resp *schemas.BifrostFileUploadResponse) map[string]interface{} {
file := map[string]interface{}{
"name": resp.ID,
"displayName": resp.Filename,
"mimeType": "application/octet-stream",
"sizeBytes": fmt.Sprintf("%d", resp.Bytes),
"createTime": formatGeminiTimestamp(resp.CreatedAt),
"updateTime": formatGeminiTimestamp(resp.CreatedAt),
"state": toGeminiFileState(resp.Status),
"uri": resp.StorageURI,
}
if exp := formatGeminiTimestamp(safeDerefInt64(resp.ExpiresAt)); exp != "" {
file["expirationTime"] = exp
}
return map[string]interface{}{
"file": file,
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,62 @@
package gemini
import (
"bufio"
"bytes"
"compress/gzip"
"io"
"testing"
)
func TestReadNextSSEDataLine_SkipInlineDataOnGzipReader(t *testing.T) {
var compressed bytes.Buffer
gz := gzip.NewWriter(&compressed)
payload := "data: {\"candidates\":[{\"content\":{\"parts\":[{\"inlineData\":{\"data\":\"abc\"}}]}}]}\n" +
"data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]}}]}\n"
if _, err := gz.Write([]byte(payload)); err != nil {
t.Fatalf("failed to write gzip payload: %v", err)
}
if err := gz.Close(); err != nil {
t.Fatalf("failed to close gzip writer: %v", err)
}
reader, err := gzip.NewReader(bytes.NewReader(compressed.Bytes()))
if err != nil {
t.Fatalf("failed to create gzip reader: %v", err)
}
defer reader.Close()
line, err := readNextSSEDataLine(bufio.NewReaderSize(reader, 64*1024), true)
if err != nil {
t.Fatalf("expected next non-inline line, got error: %v", err)
}
want := []byte(`{"candidates":[{"content":{"parts":[{"text":"ok"}]}}]}`)
if !bytes.Equal(line, want) {
t.Fatalf("expected %q, got %q", string(want), string(line))
}
}
func TestReadNextSSEDataLine_SkipInlineDataContinuedLine(t *testing.T) {
longInline := bytes.Repeat([]byte("x"), 70*1024)
var stream bytes.Buffer
stream.WriteString("data: {\"candidates\":[{\"content\":{\"parts\":[{\"inlineData\":{\"data\":\"")
stream.Write(longInline)
stream.WriteString("\"}}]}}]}\n")
stream.WriteString("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]}}]}\n")
line, err := readNextSSEDataLine(bufio.NewReaderSize(bytes.NewReader(stream.Bytes()), 64*1024), true)
if err != nil {
t.Fatalf("expected next non-inline line, got error: %v", err)
}
want := []byte(`{"candidates":[{"content":{"parts":[{"text":"ok"}]}}]}`)
if !bytes.Equal(line, want) {
t.Fatalf("expected %q, got %q", string(want), string(line))
}
_, err = readNextSSEDataLine(bufio.NewReaderSize(bytes.NewReader(nil), 64*1024), true)
if err != io.EOF {
t.Fatalf("expected EOF on empty reader, got %v", err)
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,61 @@
package gemini
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type testNoopLogger struct{}
func (testNoopLogger) Debug(string, ...any) {}
func (testNoopLogger) Info(string, ...any) {}
func (testNoopLogger) Warn(string, ...any) {}
func (testNoopLogger) Error(string, ...any) {}
func (testNoopLogger) Fatal(string, ...any) {}
func (testNoopLogger) SetLevel(schemas.LogLevel) {}
func (testNoopLogger) SetOutputType(schemas.LoggerOutputType) {}
func (testNoopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
func TestListModelsByKey_ParsesSingleModelPayload(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
return
}
if r.URL.Path != "/models/gemini-2.5-pro" {
http.Error(w, "unexpected path", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"name":"models/gemini-2.5-pro","displayName":"Gemini 2.5 Pro","description":"test","inputTokenLimit":1048576,"outputTokenLimit":8192,"supportedGenerationMethods":["generateContent"]}`))
}))
defer ts.Close()
provider := NewGeminiProvider(&schemas.ProviderConfig{
NetworkConfig: schemas.NetworkConfig{BaseURL: ts.URL},
}, testNoopLogger{})
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
ctx.SetValue(schemas.BifrostContextKeyURLPath, "/models/gemini-2.5-pro")
key := schemas.Key{Value: *schemas.NewEnvVar("dummy-key")}
// Unfiltered=true bypasses the allowed/alias/blacklist filter pipeline so
// this test can focus on the single-model-payload parsing code path in
// listModelsByKey (gemini.go:215-220).
resp, err := provider.listModelsByKey(ctx, key, &schemas.BifrostListModelsRequest{Provider: schemas.Gemini, Unfiltered: true})
require.Nil(t, err)
require.NotNil(t, resp)
require.Len(t, resp.Data, 1)
assert.Equal(t, "gemini/gemini-2.5-pro", resp.Data[0].ID)
require.NotNil(t, resp.Data[0].Name)
assert.Equal(t, "Gemini 2.5 Pro", *resp.Data[0].Name)
}

View File

@@ -0,0 +1,105 @@
package gemini
import (
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
func toGeminiModelResourceName(modelID string) string {
if strings.HasPrefix(modelID, "models/") {
return modelID
}
if idx := strings.Index(modelID, "/"); idx >= 0 && idx+1 < len(modelID) {
return "models/" + modelID[idx+1:]
}
return "models/" + modelID
}
func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(response.Models)),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: providerKey,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
for _, model := range response.Models {
contextLength := model.InputTokenLimit + model.OutputTokenLimit
// Gemini returns model names with a "models/" prefix — strip it before filtering
// so that allowedModels entries like "gemini-1.5-pro" match correctly.
modelName := strings.TrimPrefix(model.Name, "models/")
for _, result := range pipeline.FilterModel(modelName) {
entry := schemas.Model{
ID: string(providerKey) + "/" + result.ResolvedID,
Name: schemas.Ptr(model.DisplayName),
Description: schemas.Ptr(model.Description),
ContextLength: schemas.Ptr(int(contextLength)),
MaxInputTokens: schemas.Ptr(model.InputTokenLimit),
MaxOutputTokens: schemas.Ptr(model.OutputTokenLimit),
SupportedMethods: model.SupportedGenerationMethods,
}
if result.AliasValue != "" {
entry.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, entry)
included[strings.ToLower(result.ResolvedID)] = true
}
}
bifrostResponse.Data = append(bifrostResponse.Data,
pipeline.BackfillModels(included)...)
return bifrostResponse
}
func ToGeminiListModelsResponse(resp *schemas.BifrostListModelsResponse) *GeminiListModelsResponse {
if resp == nil {
return nil
}
geminiResponse := &GeminiListModelsResponse{
Models: make([]GeminiModel, 0, len(resp.Data)),
NextPageToken: resp.NextPageToken,
}
for _, model := range resp.Data {
geminiModel := GeminiModel{
Name: toGeminiModelResourceName(model.ID),
SupportedGenerationMethods: model.SupportedMethods,
}
if model.Name != nil {
geminiModel.DisplayName = *model.Name
}
if model.Description != nil {
geminiModel.Description = *model.Description
}
if model.MaxInputTokens != nil {
geminiModel.InputTokenLimit = *model.MaxInputTokens
}
if model.MaxOutputTokens != nil {
geminiModel.OutputTokenLimit = *model.MaxOutputTokens
}
geminiResponse.Models = append(geminiResponse.Models, geminiModel)
}
return geminiResponse
}

View File

@@ -0,0 +1,41 @@
package gemini
import (
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
)
func TestToGeminiModelResourceName(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{name: "already native", input: "models/gemini-2.5-pro", want: "models/gemini-2.5-pro"},
{name: "provider prefixed", input: "gemini/gemini-2.5-pro", want: "models/gemini-2.5-pro"},
{name: "bare model", input: "gemini-2.5-pro", want: "models/gemini-2.5-pro"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, toGeminiModelResourceName(tc.input))
})
}
}
func TestToGeminiListModelsResponse_UsesNativeModelResourceName(t *testing.T) {
resp := &schemas.BifrostListModelsResponse{
Data: []schemas.Model{
{ID: "gemini/gemini-2.5-pro"},
{ID: "models/gemini-2.5-flash"},
},
}
converted := ToGeminiListModelsResponse(resp)
if assert.Len(t, converted.Models, 2) {
assert.Equal(t, "models/gemini-2.5-pro", converted.Models[0].Name)
assert.Equal(t, "models/gemini-2.5-flash", converted.Models[1].Name)
}
}

View File

@@ -0,0 +1,56 @@
package gemini
import (
"testing"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPayloadOrdering_GeminiGenerationRequest(t *testing.T) {
req := &GeminiGenerationRequest{
Model: "gemini-2.0-flash",
Contents: []Content{
{
Parts: []*Part{{Text: "hello"}},
Role: "user",
},
},
GenerationConfig: GenerationConfig{
Temperature: schemas.Ptr(float64(0.7)),
},
Tools: []Tool{
{
FunctionDeclarations: []*FunctionDeclaration{
{
Name: "get_weather",
Description: "Get weather",
Parameters: &Schema{
Type: "OBJECT",
Properties: map[string]*Schema{
"location": {Type: "STRING"},
},
Required: []string{"location"},
},
},
},
},
},
}
result, err := providerUtils.MarshalSorted(req)
require.NoError(t, err)
golden := `{"model":"gemini-2.0-flash","contents":[{"parts":[{"text":"hello"}],"role":"user"}],"generationConfig":{"temperature":0.7},"tools":[{"functionDeclarations":[{"description":"Get weather","name":"get_weather","parameters":{"properties":{"location":{"type":"STRING"}},"required":["location"],"type":"OBJECT"}}]}]}`
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
// Determinism: 100 iterations must produce identical bytes
for i := 0; i < 100; i++ {
iter, err := providerUtils.MarshalSorted(req)
require.NoError(t, err)
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,200 @@
package gemini
import (
"context"
"fmt"
"strings"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostSpeechRequest converts a GeminiGenerationRequest to a BifrostSpeechRequest
func (request *GeminiGenerationRequest) ToBifrostSpeechRequest(ctx *schemas.BifrostContext) *schemas.BifrostSpeechRequest {
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
bifrostReq := &schemas.BifrostSpeechRequest{
Provider: provider,
Model: model,
}
// Extract text input from contents
var textInput string
for _, content := range request.Contents {
for _, part := range content.Parts {
if part.Text != "" {
textInput += part.Text
}
}
}
bifrostReq.Input = &schemas.SpeechInput{
Input: textInput,
}
// Convert generation config to parameters
if request.GenerationConfig.SpeechConfig != nil || len(request.GenerationConfig.ResponseModalities) > 0 {
bifrostReq.Params = &schemas.SpeechParameters{}
// Extract voice config from speech config
if request.GenerationConfig.SpeechConfig != nil {
// Handle single-speaker voice config
if request.GenerationConfig.SpeechConfig.VoiceConfig != nil {
bifrostReq.Params.VoiceConfig = &schemas.SpeechVoiceInput{}
if request.GenerationConfig.SpeechConfig.VoiceConfig.PrebuiltVoiceConfig != nil {
voiceName := request.GenerationConfig.SpeechConfig.VoiceConfig.PrebuiltVoiceConfig.VoiceName
bifrostReq.Params.VoiceConfig.Voice = &voiceName
}
} else if request.GenerationConfig.SpeechConfig.MultiSpeakerVoiceConfig != nil {
// Handle multi-speaker voice config
// Convert to Bifrost's MultiVoiceConfig format
if len(request.GenerationConfig.SpeechConfig.MultiSpeakerVoiceConfig.SpeakerVoiceConfigs) > 0 {
bifrostReq.Params.VoiceConfig = &schemas.SpeechVoiceInput{}
multiVoiceConfig := make([]schemas.VoiceConfig, 0, len(request.GenerationConfig.SpeechConfig.MultiSpeakerVoiceConfig.SpeakerVoiceConfigs))
for _, speakerConfig := range request.GenerationConfig.SpeechConfig.MultiSpeakerVoiceConfig.SpeakerVoiceConfigs {
if speakerConfig.VoiceConfig != nil && speakerConfig.VoiceConfig.PrebuiltVoiceConfig != nil {
multiVoiceConfig = append(multiVoiceConfig, schemas.VoiceConfig{
Speaker: speakerConfig.Speaker,
Voice: speakerConfig.VoiceConfig.PrebuiltVoiceConfig.VoiceName,
})
}
}
bifrostReq.Params.VoiceConfig.MultiVoiceConfig = multiVoiceConfig
}
}
}
// Store response modalities in extra params if needed
if len(request.GenerationConfig.ResponseModalities) > 0 {
if bifrostReq.Params.ExtraParams == nil {
bifrostReq.Params.ExtraParams = make(map[string]interface{})
}
modalities := make([]string, len(request.GenerationConfig.ResponseModalities))
for i, mod := range request.GenerationConfig.ResponseModalities {
modalities[i] = string(mod)
}
bifrostReq.Params.ExtraParams["response_modalities"] = modalities
}
}
return bifrostReq
}
// ToGeminiSpeechRequest converts a BifrostSpeechRequest to a GeminiGenerationRequest
func ToGeminiSpeechRequest(bifrostReq *schemas.BifrostSpeechRequest) (*GeminiGenerationRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrostReq is nil")
}
// Here we confirm if the response_format is wav or empty string
// If its anything else, we will return an error
if bifrostReq.Params != nil && bifrostReq.Params.ResponseFormat != "" && bifrostReq.Params.ResponseFormat != "wav" {
return nil, fmt.Errorf("gemini does not support response_format: %s. Only wav or empty string is supported which defaults to wav", bifrostReq.Params.ResponseFormat)
}
// Create the base Gemini generation request
geminiReq := &GeminiGenerationRequest{
Model: bifrostReq.Model,
}
// Convert parameters to generation config
geminiReq.GenerationConfig.ResponseModalities = []Modality{ModalityAudio}
// Convert speech input to Gemini format
if bifrostReq.Input != nil && bifrostReq.Input.Input != "" {
geminiReq.Contents = []Content{
{
Parts: []*Part{
{
Text: bifrostReq.Input.Input,
},
},
},
}
// Add speech config to generation config if voice config is provided
if bifrostReq.Params != nil && bifrostReq.Params.VoiceConfig != nil {
// Handle both single voice and multi-voice configurations
if bifrostReq.Params.VoiceConfig.Voice != nil || len(bifrostReq.Params.VoiceConfig.MultiVoiceConfig) > 0 {
addSpeechConfigToGenerationConfig(&geminiReq.GenerationConfig, bifrostReq.Params.VoiceConfig)
}
geminiReq.ExtraParams = bifrostReq.Params.ExtraParams
}
}
return geminiReq, nil
}
// ToBifrostSpeechResponse converts a GenerateContentResponse to a BifrostSpeechResponse
func (response *GenerateContentResponse) ToBifrostSpeechResponse(ctx context.Context) (*schemas.BifrostSpeechResponse, error) {
bifrostResp := &schemas.BifrostSpeechResponse{}
// Process candidates to extract audio content
if len(response.Candidates) > 0 {
candidate := response.Candidates[0]
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
var audioData []byte
// Extract audio data from all parts
for _, part := range candidate.Content.Parts {
if part.InlineData != nil && len(part.InlineData.Data) > 0 {
// Check if this is audio data
if strings.HasPrefix(part.InlineData.MIMEType, "audio/") {
decodedData, err := decodeBase64StringToBytes(part.InlineData.Data)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 audio data: %v", err)
}
audioData = append(audioData, decodedData...)
}
}
}
if len(audioData) > 0 {
responseFormat := ctx.Value(BifrostContextKeyResponseFormat).(string)
// Gemini returns PCM audio (s16le, 24000 Hz, mono)
// Convert to WAV for standard playable output format
if responseFormat == "wav" {
wavData, err := utils.ConvertPCMToWAV(audioData, utils.DefaultGeminiPCMConfig())
if err != nil {
return nil, fmt.Errorf("failed to convert PCM to WAV: %v", err)
}
bifrostResp.Audio = wavData
} else {
bifrostResp.Audio = audioData
}
}
// Set usage information
if response.UsageMetadata != nil {
bifrostResp.Usage = convertGeminiUsageMetadataToSpeechUsage(response.UsageMetadata)
}
}
}
return bifrostResp, nil
}
// ToGeminiSpeechResponse converts a BifrostSpeechResponse to Gemini's GenerateContentResponse
func ToGeminiSpeechResponse(bifrostResp *schemas.BifrostSpeechResponse) *GenerateContentResponse {
if bifrostResp == nil {
return nil
}
genaiResp := &GenerateContentResponse{}
candidate := &Candidate{
Content: &Content{
Parts: []*Part{
{
InlineData: &Blob{
Data: encodeBytesToBase64String(bifrostResp.Audio),
MIMEType: utils.DetectAudioMimeType(bifrostResp.Audio),
},
},
},
Role: string(RoleModel),
},
}
// Set usage metadata if present
if bifrostResp.Usage != nil {
genaiResp.UsageMetadata = convertBifrostSpeechUsageToGeminiUsageMetadata(bifrostResp.Usage)
}
genaiResp.Candidates = []*Candidate{candidate}
return genaiResp
}

View File

@@ -0,0 +1,233 @@
package gemini
import (
"fmt"
"strings"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostTranscriptionRequest converts a GeminiGenerationRequest to a BifrostTranscriptionRequest
func (request *GeminiGenerationRequest) ToBifrostTranscriptionRequest(ctx *schemas.BifrostContext) (*schemas.BifrostTranscriptionRequest, error) {
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
bifrostReq := &schemas.BifrostTranscriptionRequest{
Provider: provider,
Model: model,
}
// Extract audio data and prompt from contents
var promptText string
var audioData []byte
var audioMimeType string
for _, content := range request.Contents {
for _, part := range content.Parts {
// Extract text prompt
if part.Text != "" {
if promptText != "" {
promptText += " "
}
promptText += part.Text
}
// Extract audio data from inline data
if part.InlineData != nil && strings.HasPrefix(strings.ToLower(part.InlineData.MIMEType), "audio/") {
decodedData, err := decodeBase64StringToBytes(part.InlineData.Data)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 audio data: %v", err)
}
audioData = append(audioData, decodedData...)
if audioMimeType == "" {
audioMimeType = part.InlineData.MIMEType
}
}
// Extract audio data from file data (would need to be fetched separately in real scenario)
// For now, we just note the file URI in extra params
if part.FileData != nil && strings.HasPrefix(strings.ToLower(part.FileData.MIMEType), "audio/") {
if bifrostReq.Params == nil {
bifrostReq.Params = &schemas.TranscriptionParameters{}
}
if bifrostReq.Params.ExtraParams == nil {
bifrostReq.Params.ExtraParams = make(map[string]interface{})
}
bifrostReq.Params.ExtraParams["file_uri"] = part.FileData.FileURI
if audioMimeType == "" {
audioMimeType = part.FileData.MIMEType
}
}
}
}
// Set the audio input
bifrostReq.Input = &schemas.TranscriptionInput{
File: audioData,
}
// Set parameters
if bifrostReq.Params == nil {
bifrostReq.Params = &schemas.TranscriptionParameters{}
}
// Set prompt if provided
if promptText != "" {
bifrostReq.Params.Prompt = &promptText
}
// Handle safety settings from request
if len(request.SafetySettings) > 0 {
if bifrostReq.Params.ExtraParams == nil {
bifrostReq.Params.ExtraParams = make(map[string]interface{})
}
bifrostReq.Params.ExtraParams["safety_settings"] = request.SafetySettings
}
// Handle cached content
if request.CachedContent != "" {
if bifrostReq.Params.ExtraParams == nil {
bifrostReq.Params.ExtraParams = make(map[string]interface{})
}
bifrostReq.Params.ExtraParams["cached_content"] = request.CachedContent
}
// Handle labels
if len(request.Labels) > 0 {
if bifrostReq.Params.ExtraParams == nil {
bifrostReq.Params.ExtraParams = make(map[string]interface{})
}
bifrostReq.Params.ExtraParams["labels"] = request.Labels
}
return bifrostReq, nil
}
func ToGeminiTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionRequest) *GeminiGenerationRequest {
if bifrostReq == nil {
return nil
}
// Create the base Gemini generation request
geminiReq := &GeminiGenerationRequest{
Model: bifrostReq.Model,
}
// Convert parameters to generation config
if bifrostReq.Params != nil {
geminiReq.ExtraParams = bifrostReq.Params.ExtraParams
// Handle extra parameters
if bifrostReq.Params.ExtraParams != nil {
// Safety settings
if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok {
delete(geminiReq.ExtraParams, "safety_settings")
if settings, ok := SafeExtractSafetySettings(safetySettings); ok {
geminiReq.SafetySettings = settings
}
}
// Cached content
if cachedContent, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["cached_content"]); ok {
delete(geminiReq.ExtraParams, "cached_content")
geminiReq.CachedContent = cachedContent
}
// Labels
if labels, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "labels"); ok {
if labelMap, ok := schemas.SafeExtractStringMap(labels); ok {
delete(geminiReq.ExtraParams, "labels")
geminiReq.Labels = labelMap
}
}
}
}
// Determine the prompt text
var prompt string
if bifrostReq.Params != nil && bifrostReq.Params.Prompt != nil {
prompt = *bifrostReq.Params.Prompt
} else {
prompt = "Generate a transcript of the speech."
}
// Create parts for the transcription request
parts := []*Part{
{
Text: prompt,
},
}
// Add audio file if present
if len(bifrostReq.Input.File) > 0 {
parts = append(parts, &Part{
InlineData: &Blob{
MIMEType: utils.DetectAudioMimeType(bifrostReq.Input.File),
Data: encodeBytesToBase64String(bifrostReq.Input.File),
},
})
}
geminiReq.Contents = []Content{
{
Parts: parts,
},
}
return geminiReq
}
// ToBifrostTranscriptionResponse converts a GenerateContentResponse to a BifrostTranscriptionResponse
func (response *GenerateContentResponse) ToBifrostTranscriptionResponse() *schemas.BifrostTranscriptionResponse {
bifrostResp := &schemas.BifrostTranscriptionResponse{}
// Process candidates to extract text content
if len(response.Candidates) > 0 {
candidate := response.Candidates[0]
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
var textContent string
// Extract text content from all parts
for _, part := range candidate.Content.Parts {
if part.Text != "" {
textContent += part.Text
}
}
if textContent != "" {
bifrostResp.Text = textContent
bifrostResp.Task = schemas.Ptr("transcribe")
// Set usage information with modality details
bifrostResp.Usage = convertGeminiUsageMetadataToTranscriptionUsage(response.UsageMetadata)
}
}
}
return bifrostResp
}
// ToGeminiTranscriptionResponse converts a BifrostTranscriptionResponse to Gemini's GenerateContentResponse
func ToGeminiTranscriptionResponse(bifrostResp *schemas.BifrostTranscriptionResponse) *GenerateContentResponse {
if bifrostResp == nil {
return nil
}
genaiResp := &GenerateContentResponse{}
candidate := &Candidate{
Content: &Content{
Parts: []*Part{
{
Text: bifrostResp.Text,
},
},
Role: string(RoleModel),
},
}
// Set usage metadata from transcription usage with modality details
genaiResp.UsageMetadata = convertBifrostTranscriptionUsageToGeminiUsageMetadata(bifrostResp.Usage)
genaiResp.Candidates = []*Candidate{candidate}
return genaiResp
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,595 @@
package gemini
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"strconv"
"strings"
"time"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
const defaultVideoContentType = "video/mp4"
// sizeToAspectRatio converts OpenAI-style size strings to Gemini aspect ratios.
// Gemini supports 16:9 and 9:16. Returns default value if no mapping exists.
func sizeToAspectRatio(size string) string {
switch size {
case "1280x720", "1792x1024":
return "16:9"
case "720x1280", "1024x1792":
return "9:16"
default:
return "16:9"
}
}
func addVideoURLOutput(uri, contentType string) *schemas.VideoOutput {
if uri == "" {
return nil
}
if strings.TrimSpace(contentType) == "" {
contentType = defaultVideoContentType
}
return &schemas.VideoOutput{
Type: schemas.VideoOutputTypeURL,
URL: schemas.Ptr(uri),
ContentType: contentType,
}
}
func addVideoBase64Output(base64Value, contentType string) *schemas.VideoOutput {
if base64Value == "" {
return nil
}
if strings.TrimSpace(contentType) == "" {
contentType = defaultVideoContentType
}
return &schemas.VideoOutput{
Type: schemas.VideoOutputTypeBase64,
Base64Data: schemas.Ptr(base64Value),
ContentType: contentType,
}
}
func parseVideoDataURL(data string) (mimeType string, base64Payload string, ok bool) {
if !strings.HasPrefix(data, "data:") {
return "", "", false
}
parts := strings.SplitN(data, ",", 2)
if len(parts) != 2 {
return "", "", false
}
header := parts[0]
payload := parts[1]
if payload == "" {
return "", "", false
}
header = strings.TrimPrefix(header, "data:")
if before, _, found := strings.Cut(header, ";"); found {
return before, payload, true
}
return header, payload, true
}
// ToGeminiVideoGenerationRequest converts a Bifrost video generation request to Gemini REST API format
// This creates the request body for POST /models/{model}:predictLongRunning
func ToGeminiVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRequest) (*GeminiVideoGenerationRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, fmt.Errorf("bifrost request or input is nil")
}
// Create the instance with prompt
instance := &GeminiVideoGenerationInstance{
Prompt: bifrostReq.Input.Prompt,
}
// Handle input reference (image for image-to-video)
if bifrostReq.Input.InputReference != nil && *bifrostReq.Input.InputReference != "" {
// extract mime type and base64 string from input reference
sanitizedURL, err := schemas.SanitizeImageURL(*bifrostReq.Input.InputReference)
if err != nil {
return nil, fmt.Errorf("invalid input reference: %w", err)
}
urlInfo := schemas.ExtractURLTypeInfo(sanitizedURL)
image := &VideoImageData{}
if urlInfo.DataURLWithoutPrefix != nil {
image.BytesBase64Encoded = urlInfo.DataURLWithoutPrefix
}
image.MimeType = schemas.Ptr("image/png")
if urlInfo.MediaType != nil {
image.MimeType = urlInfo.MediaType
}
instance.Image = image
}
if bifrostReq.Params != nil && bifrostReq.Params.VideoURI != nil {
instance.Video = &VideoGenerationVideoInput{
URI: bifrostReq.Params.VideoURI,
}
}
req := &GeminiVideoGenerationRequest{
Instances: []GeminiVideoGenerationInstance{*instance},
}
// Map parameters if provided
if bifrostReq.Params != nil {
params := &VideoGenerationParameters{}
// Extract all video generation parameters from ExtraParams
if bifrostReq.Params.NegativePrompt != nil {
params.NegativePrompt = bifrostReq.Params.NegativePrompt
}
if bifrostReq.Params.Seconds != nil {
seconds, err := strconv.Atoi(*bifrostReq.Params.Seconds)
if err != nil {
return nil, fmt.Errorf("invalid seconds value: %w", err)
}
params.DurationSeconds = &seconds
}
if bifrostReq.Params.Seed != nil {
params.Seed = bifrostReq.Params.Seed
}
if bifrostReq.Params.Audio != nil {
params.GenerateAudio = bifrostReq.Params.Audio
}
if bifrostReq.Params.ExtraParams != nil {
req.ExtraParams = bifrostReq.Params.ExtraParams
if aspectRatio, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["aspectRatio"]); ok {
params.AspectRatio = aspectRatio
}
if resolution, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["resolution"]); ok {
params.Resolution = resolution
}
if sampleCount, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["sampleCount"]); ok {
params.SampleCount = sampleCount
}
if personGeneration, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["personGeneration"]); ok {
params.PersonGeneration = personGeneration
}
if numberOfVideos, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["numberOfVideos"]); ok {
params.NumberOfVideos = numberOfVideos
}
if storageURI, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["storageURI"]); ok {
params.StorageURI = storageURI
}
if compressionQuality, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["compressionQuality"]); ok {
params.CompressionQuality = compressionQuality
}
if enhancePrompt, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enhancePrompt"]); ok {
params.EnhancePrompt = enhancePrompt
}
if resizeMode, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["resizeMode"]); ok {
params.ResizeMode = resizeMode
}
if referenceImages, ok := bifrostReq.Params.ExtraParams["referenceImages"]; ok {
if referenceImages, ok := referenceImages.([]VideoReferenceImage); ok && referenceImages != nil {
params.ReferenceImages = referenceImages
} else if data, err := providerUtils.MarshalSorted(referenceImages); err == nil {
var referenceImages []VideoReferenceImage
if sonic.Unmarshal(data, &referenceImages) == nil {
params.ReferenceImages = referenceImages
}
}
}
if lastFrame, ok := bifrostReq.Params.ExtraParams["lastFrame"]; ok {
if lastFrame, ok := lastFrame.(*VideoImageData); ok {
params.LastFrame = lastFrame
} else if data, err := providerUtils.MarshalSorted(lastFrame); err == nil {
var lastFrame VideoImageData
if sonic.Unmarshal(data, &lastFrame) == nil {
params.LastFrame = &lastFrame
}
}
}
}
// Convert size to aspect ratio if size is provided and aspect ratio is not already set
if params.AspectRatio == nil && bifrostReq.Params.Size != "" {
aspectRatio := sizeToAspectRatio(bifrostReq.Params.Size)
if aspectRatio != "" {
params.AspectRatio = &aspectRatio
}
}
req.Parameters = params
}
return req, nil
}
// ToBifrostVideoGenerationResponse converts Gemini operation response to Bifrost format
func ToBifrostVideoGenerationResponse(operation *GenerateVideosOperation, model string) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
if operation == nil {
return nil, providerUtils.NewBifrostOperationError("operation is nil", nil)
}
response := &schemas.BifrostVideoGenerationResponse{
ID: operation.Name,
Object: "video",
CreatedAt: time.Now().Unix(),
}
if model != "" {
response.Model = model
}
// Set status based on operation state
if !operation.Done {
response.Status = schemas.VideoStatusInProgress
if operation.Metadata != nil {
if p := providerUtils.GetJSONField([]byte(operation.Metadata), "progress"); p.Exists() {
progress := p.Float()
response.Progress = &progress
}
}
} else if operation.Error != nil {
response.Status = schemas.VideoStatusFailed
code := providerUtils.GetJSONField(operation.Error, "code").String()
message := providerUtils.GetJSONField(operation.Error, "message").String()
if code == "" {
code = "video_generation_failed"
}
if message == "" {
message = string(operation.Error)
}
response.Error = &schemas.VideoCreateError{
Code: code,
Message: message,
}
} else if operation.Response != nil {
// Check new response format with content filtering support
if genVideoResp := operation.Response.GenerateVideoResponse; genVideoResp != nil {
// Check for content filtering
if genVideoResp.RAIMediaFilteredCount > 0 {
response.Status = schemas.VideoStatusFailed
response.ContentFilter = &schemas.ContentFilterInfo{
FilteredCount: int(genVideoResp.RAIMediaFilteredCount),
Reasons: genVideoResp.RAIMediaFilteredReasons,
}
errorMsg := "Content filtered by safety policies"
if len(genVideoResp.RAIMediaFilteredReasons) > 0 {
errorMsg = genVideoResp.RAIMediaFilteredReasons[0]
}
response.Error = &schemas.VideoCreateError{
Code: "content_filtered",
Message: errorMsg,
}
} else {
response.Status = schemas.VideoStatusCompleted
// Collect all generated videos from multiple possible locations.
var videos []schemas.VideoOutput
// Priority 1: GeneratedSamples
if len(genVideoResp.GeneratedSamples) > 0 {
for _, sample := range genVideoResp.GeneratedSamples {
if sample == nil || sample.Video == nil {
continue
}
if sample.Video.URI != "" {
videoOutput := addVideoURLOutput(sample.Video.URI, sample.Video.MIMEType)
if videoOutput != nil {
videos = append(videos, *videoOutput)
}
}
if len(sample.Video.VideoBytes) > 0 {
videoOutput := addVideoBase64Output(
base64.StdEncoding.EncodeToString(sample.Video.VideoBytes),
sample.Video.MIMEType,
)
if videoOutput != nil {
videos = append(videos, *videoOutput)
}
}
}
}
if len(videos) > 0 {
response.Videos = videos
}
}
} else if len(operation.Response.GeneratedVideos) > 0 {
// Backward compatibility for older response shapes
response.Status = schemas.VideoStatusCompleted
var videos []schemas.VideoOutput
for _, genVideo := range operation.Response.GeneratedVideos {
if genVideo == nil || genVideo.Video == nil {
continue
}
if genVideo.Video.URI != "" {
videoOutput := addVideoURLOutput(genVideo.Video.URI, genVideo.Video.MIMEType)
if videoOutput != nil {
videos = append(videos, *videoOutput)
}
}
if len(genVideo.Video.VideoBytes) > 0 {
videoOutput := addVideoBase64Output(
base64.StdEncoding.EncodeToString(genVideo.Video.VideoBytes),
genVideo.Video.MIMEType,
)
if videoOutput != nil {
videos = append(videos, *videoOutput)
}
}
}
if len(videos) > 0 {
response.Videos = videos
}
} else if len(operation.Response.Videos) > 0 {
response.Status = schemas.VideoStatusCompleted
var videos []schemas.VideoOutput
for _, video := range operation.Response.Videos {
if video.GCSURI != nil && *video.GCSURI != "" {
mimeType := defaultVideoContentType
if video.MIMEType != nil && *video.MIMEType != "" {
mimeType = *video.MIMEType
}
videoOutput := addVideoURLOutput(*video.GCSURI, mimeType)
if videoOutput != nil {
videos = append(videos, *videoOutput)
}
} else if video.BytesBase64Encoded != nil && *video.BytesBase64Encoded != "" {
mimeType := defaultVideoContentType
if video.MIMEType != nil && *video.MIMEType != "" {
mimeType = *video.MIMEType
}
videoOutput := addVideoBase64Output(*video.BytesBase64Encoded, mimeType)
if videoOutput != nil {
videos = append(videos, *videoOutput)
}
}
}
if len(videos) > 0 {
response.Videos = videos
}
} else {
response.Status = schemas.VideoStatusCompleted
}
} else {
response.Status = schemas.VideoStatusCompleted
}
// Try to extract timestamps from metadata
if operation.Metadata != nil {
if ct := providerUtils.GetJSONField([]byte(operation.Metadata), "createTime"); ct.Exists() {
if t, err := time.Parse(time.RFC3339, ct.String()); err == nil {
response.CreatedAt = t.Unix()
}
}
if ut := providerUtils.GetJSONField([]byte(operation.Metadata), "updateTime"); ut.Exists() {
if t, err := time.Parse(time.RFC3339, ut.String()); err == nil && operation.Done {
response.CompletedAt = schemas.Ptr(t.Unix())
}
}
}
return response, nil
}
func (request *GeminiVideoGenerationRequest) ToBifrostVideoGenerationRequest(ctx *schemas.BifrostContext) (*schemas.BifrostVideoGenerationRequest, error) {
if request == nil || len(request.Instances) == 0 {
return nil, fmt.Errorf("request is nil or has no instances")
}
// Use the first instance for the main input
instance := request.Instances[0]
provider, model := schemas.ParseModelString(request.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
bifrostReq := &schemas.BifrostVideoGenerationRequest{
Provider: provider,
Model: model,
Input: &schemas.VideoGenerationInput{
Prompt: instance.Prompt,
},
}
// Handle image input for image-to-video
if instance.Image != nil && instance.Image.BytesBase64Encoded != nil && *instance.Image.BytesBase64Encoded != "" {
// attach mime type and base64 string to input reference
mimeType := "image/png"
if instance.Image.MimeType != nil && *instance.Image.MimeType != "" {
mimeType = *instance.Image.MimeType
}
bifrostReq.Input.InputReference = schemas.Ptr(fmt.Sprintf("data:%s;base64,%s", mimeType, *instance.Image.BytesBase64Encoded))
}
// Helper to ensure params are initialized
ensureParams := func() {
if bifrostReq.Params == nil {
bifrostReq.Params = &schemas.VideoGenerationParameters{
ExtraParams: make(map[string]any),
}
}
}
// Handle reference images
if len(instance.ReferenceImages) > 0 {
ensureParams()
bifrostReq.Params.ExtraParams["referenceImages"] = instance.ReferenceImages
}
// Handle video URI
if instance.Video != nil && instance.Video.URI != nil {
ensureParams()
bifrostReq.Params.VideoURI = instance.Video.URI
}
// Handle last frame
if instance.LastFrame != nil {
ensureParams()
bifrostReq.Params.ExtraParams["lastFrame"] = instance.LastFrame
}
// Map parameters if provided
if request.Parameters != nil {
ensureParams()
params := bifrostReq.Params
if request.Parameters.NegativePrompt != nil {
params.NegativePrompt = request.Parameters.NegativePrompt
}
if request.Parameters.DurationSeconds != nil {
seconds := strconv.Itoa(*request.Parameters.DurationSeconds)
params.Seconds = &seconds
}
if request.Parameters.Seed != nil {
params.Seed = request.Parameters.Seed
}
if request.Parameters.GenerateAudio != nil {
params.Audio = request.Parameters.GenerateAudio
}
if request.Parameters.AspectRatio != nil {
params.ExtraParams["aspectRatio"] = *request.Parameters.AspectRatio
}
if request.Parameters.Resolution != nil {
params.ExtraParams["resolution"] = *request.Parameters.Resolution
}
if request.Parameters.SampleCount != nil {
params.ExtraParams["sampleCount"] = *request.Parameters.SampleCount
}
if request.Parameters.PersonGeneration != nil {
params.ExtraParams["personGeneration"] = *request.Parameters.PersonGeneration
}
if request.Parameters.NumberOfVideos != nil {
params.ExtraParams["numberOfVideos"] = *request.Parameters.NumberOfVideos
}
if request.Parameters.StorageURI != nil {
params.ExtraParams["storageURI"] = *request.Parameters.StorageURI
}
if request.Parameters.CompressionQuality != nil {
params.ExtraParams["compressionQuality"] = *request.Parameters.CompressionQuality
}
if request.Parameters.EnhancePrompt != nil {
params.ExtraParams["enhancePrompt"] = *request.Parameters.EnhancePrompt
}
if request.Parameters.ResizeMode != nil {
params.ExtraParams["resizeMode"] = *request.Parameters.ResizeMode
}
}
return bifrostReq, nil
}
func ToGeminiVideoGenerationResponse(response *schemas.BifrostVideoGenerationResponse) *GenerateVideosOperation {
if response == nil {
return nil
}
decodedID := response.ID
if decoded, err := url.PathUnescape(decodedID); err == nil {
decodedID = decoded
}
// if id is in gemini or vertex format, set name in format models/model/operations/operation_id:provider
// else make the id in gemini format
if !(strings.HasPrefix(decodedID, "models/") && strings.Contains(decodedID, response.Model) && strings.Contains(decodedID, "operations/")) {
// url encode model
encodedModel := url.PathEscape(response.Model)
decodedID = "models/" + encodedModel + "/operations/" + decodedID
}
operation := &GenerateVideosOperation{
Name: decodedID,
}
switch response.Status {
case schemas.VideoStatusCompleted:
operation.Done = true
if len(response.Videos) > 0 {
generatedSamples := make([]*GeneratedVideo, 0, len(response.Videos))
for _, output := range response.Videos {
var video *Video
switch output.Type {
case schemas.VideoOutputTypeURL:
if output.URL == nil || *output.URL == "" {
continue
}
video = &Video{
URI: *output.URL,
}
if output.ContentType != "" {
video.MIMEType = output.ContentType
}
case schemas.VideoOutputTypeBase64:
if output.Base64Data == nil || *output.Base64Data == "" {
continue
}
base64Payload := *output.Base64Data
mimeType := output.ContentType
if parsedMimeType, payload, ok := parseVideoDataURL(*output.Base64Data); ok {
base64Payload = payload
if mimeType == "" {
mimeType = parsedMimeType
}
}
decoded, err := base64.StdEncoding.DecodeString(base64Payload)
if err != nil {
continue
}
if mimeType == "" {
mimeType = defaultVideoContentType
}
video = &Video{
VideoBytes: decoded,
MIMEType: mimeType,
}
default:
continue
}
if video == nil {
continue
}
generatedSamples = append(generatedSamples, &GeneratedVideo{
Video: video,
})
}
if len(generatedSamples) > 0 {
operation.Response = &GenerateVideosOperationResponse{
GenerateVideoResponse: &GenerateVideoResponse{
GeneratedSamples: generatedSamples,
},
}
}
}
case schemas.VideoStatusFailed:
operation.Done = true
// Check if this is a content filtering case
if response.ContentFilter != nil && response.ContentFilter.FilteredCount > 0 {
operation.Response = &GenerateVideosOperationResponse{
GenerateVideoResponse: &GenerateVideoResponse{
RAIMediaFilteredCount: int32(response.ContentFilter.FilteredCount),
RAIMediaFilteredReasons: response.ContentFilter.Reasons,
},
}
} else if response.Error != nil {
errBytes, _ := providerUtils.MarshalSorted(map[string]any{
"message": response.Error.Message,
"code": response.Error.Code,
})
operation.Error = json.RawMessage(errBytes)
}
default:
operation.Done = false
}
return operation
}

402
core/providers/groq/groq.go Normal file
View File

@@ -0,0 +1,402 @@
// Package groq implements the Groq provider and its utility functions.
package groq
import (
"context"
"strings"
"time"
"github.com/maximhq/bifrost/core/providers/openai"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// GroqProvider implements the Provider interface for Groq's API.
type GroqProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
}
// NewGroqProvider creates a new Groq provider instance.
// It initializes the HTTP client with the provided configuration and sets up response pools.
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
func NewGroqProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*GroqProvider, error) {
config.CheckAndSetDefaults()
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
client := &fasthttp.Client{
ReadTimeout: requestTimeout,
WriteTimeout: requestTimeout,
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
MaxIdleConnDuration: 30 * time.Second,
MaxConnWaitTimeout: requestTimeout,
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
ConnPoolStrategy: fasthttp.FIFO,
}
// // Pre-warm response pools
// for range config.ConcurrencyAndBufferSize.Concurrency {
// groqResponsePool.Put(&schemas.BifrostResponse{})
// }
// Configure proxy and retry policy
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
client = providerUtils.ConfigureDialer(client)
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
streamingClient := providerUtils.BuildStreamingClient(client)
// Set default BaseURL if not provided
if config.NetworkConfig.BaseURL == "" {
config.NetworkConfig.BaseURL = "https://api.groq.com/openai"
}
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
return &GroqProvider{
logger: logger,
client: client,
streamingClient: streamingClient,
networkConfig: config.NetworkConfig,
sendBackRawRequest: config.SendBackRawRequest,
sendBackRawResponse: config.SendBackRawResponse,
}, nil
}
// GetProviderKey returns the provider identifier for Groq.
func (provider *GroqProvider) GetProviderKey() schemas.ModelProvider {
return schemas.Groq
}
// ListModels performs a list models request to Groq's API.
func (provider *GroqProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
return openai.HandleOpenAIListModelsRequest(
ctx,
provider.client,
request,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"),
keys,
provider.networkConfig.ExtraHeaders,
schemas.Groq,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
)
}
// TextCompletion is not supported by the Groq provider.
func (provider *GroqProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError("text completion", "groq")
}
// TextCompletionStream performs a streaming text completion request to Groq's API.
// It formats the request, sends it to Groq, and processes the response.
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
func (provider *GroqProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError("text completion", "groq")
}
// ChatCompletion performs a chat completion request to the Groq API.
func (provider *GroqProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
return openai.HandleOpenAIChatCompletionRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
request,
key,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
nil,
nil,
provider.logger,
)
}
// ChatCompletionStream performs a streaming chat completion request to the Groq API.
// It supports real-time streaming of responses using Server-Sent Events (SSE).
// Uses Groq's OpenAI-compatible streaming format.
// Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails.
func (provider *GroqProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
var authHeader map[string]string
if v := key.Value.GetValue(); v != "" {
authHeader = map[string]string{"Authorization": "Bearer " + v}
}
// Use shared OpenAI-compatible streaming logic
return openai.HandleOpenAIChatCompletionStreaming(
ctx,
provider.streamingClient,
provider.networkConfig.BaseURL+"/v1/chat/completions",
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
schemas.Groq,
postHookRunner,
nil,
nil,
nil,
nil,
nil,
provider.logger,
postHookSpanFinalizer,
)
}
// Responses performs a responses request to the Groq API.
func (provider *GroqProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
if err != nil {
return nil, err
}
response := chatResponse.ToBifrostResponsesResponse()
return response, nil
}
// ResponsesStream performs a streaming responses request to the Groq API.
func (provider *GroqProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
return provider.ChatCompletionStream(
ctx,
postHookRunner,
postHookSpanFinalizer,
key,
request.ToChatRequest(),
)
}
// Embedding is not supported by the Groq provider.
func (provider *GroqProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey())
}
// Speech handles non-streaming speech synthesis requests.
// It formats the request body, makes the API call, and returns the response.
// Returns the response and any error that occurred.
func (provider *GroqProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
return openai.HandleOpenAISpeechRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/audio/speech"),
request,
key,
provider.networkConfig.ExtraHeaders,
schemas.Groq,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
nil,
provider.logger,
)
}
// Rerank is not supported by the Groq provider.
func (provider *GroqProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
}
// OCR is not supported by the Groq provider.
func (provider *GroqProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
}
// SpeechStream is not supported by the Groq provider.
func (provider *GroqProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
}
// Transcription handles non-streaming transcription requests.
// It creates a multipart form, adds fields, makes the API call, and returns the response.
// Returns the response and any error that occurred.
func (provider *GroqProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
return openai.HandleOpenAITranscriptionRequest(
ctx,
provider.client,
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/audio/transcriptions"),
request,
key,
provider.networkConfig.ExtraHeaders,
schemas.Groq,
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
nil,
provider.logger,
)
}
// TranscriptionStream is not supported by the Groq provider.
func (provider *GroqProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
}
// ImageGeneration is not supported by the Groq provider.
func (provider *GroqProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
}
// ImageGenerationStream is not supported by the Groq provider.
func (provider *GroqProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
}
// ImageEdit is not supported by the Groq provider.
func (provider *GroqProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
}
// ImageEditStream is not supported by the Groq provider.
func (provider *GroqProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
}
// ImageVariation is not supported by the Groq provider.
func (provider *GroqProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
}
// VideoGeneration is not supported by the Groq provider.
func (provider *GroqProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey())
}
// VideoRetrieve is not supported by the Groq provider.
func (provider *GroqProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey())
}
// VideoDownload is not supported by the Groq provider.
func (provider *GroqProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey())
}
// VideoDelete is not supported by Groq provider.
func (provider *GroqProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
}
// VideoList is not supported by Groq provider.
func (provider *GroqProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
}
// VideoRemix is not supported by Groq provider.
func (provider *GroqProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
}
// BatchCreate is not supported by Groq provider.
func (provider *GroqProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
}
// BatchList is not supported by Groq provider.
func (provider *GroqProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
}
// BatchRetrieve is not supported by Groq provider.
func (provider *GroqProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
}
// BatchCancel is not supported by Groq provider.
func (provider *GroqProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
}
// BatchDelete is not supported by Groq provider.
func (provider *GroqProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
}
// BatchResults is not supported by Groq provider.
func (provider *GroqProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
}
// FileUpload is not supported by Groq provider.
func (provider *GroqProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
}
// FileList is not supported by Groq provider.
func (provider *GroqProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
}
// FileRetrieve is not supported by Groq provider.
func (provider *GroqProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
}
// FileDelete is not supported by Groq provider.
func (provider *GroqProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
}
// FileContent is not supported by Groq provider.
func (provider *GroqProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
}
// CountTokens is not supported by the Groq provider.
func (provider *GroqProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
}
// ContainerCreate is not supported by the Groq provider.
func (provider *GroqProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
}
// ContainerList is not supported by the Groq provider.
func (provider *GroqProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
}
// ContainerRetrieve is not supported by the Groq provider.
func (provider *GroqProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
}
// ContainerDelete is not supported by the Groq provider.
func (provider *GroqProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
}
// ContainerFileCreate is not supported by the Groq provider.
func (provider *GroqProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
}
// ContainerFileList is not supported by the Groq provider.
func (provider *GroqProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
}
// ContainerFileRetrieve is not supported by the Groq provider.
func (provider *GroqProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
}
// ContainerFileContent is not supported by the Groq provider.
func (provider *GroqProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
}
// ContainerFileDelete is not supported by the Groq provider.
func (provider *GroqProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
}
// Passthrough is not supported by the Groq provider.
func (provider *GroqProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
}
func (provider *GroqProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
}

View File

@@ -0,0 +1,68 @@
package groq_test
import (
"os"
"strings"
"testing"
"github.com/maximhq/bifrost/core/internal/llmtests"
"github.com/maximhq/bifrost/core/schemas"
)
func TestGroq(t *testing.T) {
t.Parallel()
if strings.TrimSpace(os.Getenv("GROQ_API_KEY")) == "" {
t.Skip("Skipping Groq tests because GROQ_API_KEY is not set")
}
client, ctx, cancel, err := llmtests.SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
testConfig := llmtests.ComprehensiveTestConfig{
Provider: schemas.Groq,
ChatModel: "llama-3.3-70b-versatile",
Fallbacks: []schemas.Fallback{
{Provider: schemas.Groq, Model: "openai/gpt-oss-120b"},
},
TextModel: "llama-3.3-70b-versatile",
TextCompletionFallbacks: []schemas.Fallback{
{Provider: schemas.Groq, Model: "openai/gpt-oss-20b"},
},
EmbeddingModel: "", // Groq doesn't support embedding
ReasoningModel: "openai/gpt-oss-120b",
TranscriptionModel: "whisper-large-v3",
SpeechSynthesisModel: "canopylabs/orpheus-v1-english",
Scenarios: llmtests.TestScenarios{
TextCompletion: false,
TextCompletionStream: false,
SimpleChat: true,
CompletionStream: true,
MultiTurnConversation: true,
ToolCalls: true,
ToolCallsStreaming: true,
MultipleToolCalls: true,
MultipleToolCallsStreaming: true,
End2EndToolCalling: true,
AutomaticFunctionCall: true,
ImageURL: false,
ImageBase64: false,
MultipleImages: false,
FileBase64: false, // Not supported
FileURL: false, // Not supported
CompleteEnd2End: true,
Embedding: false,
ListModels: true,
Reasoning: true,
Transcription: true,
SpeechSynthesis: true,
},
}
t.Run("GroqTests", func(t *testing.T) {
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
})
}

View File

@@ -0,0 +1,144 @@
// Package huggingface provides a HuggingFace chat provider.
package huggingface
import (
"fmt"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
)
// sanitizeMessagesForHuggingFace removes unsupported ChatAssistantMessage fields
// from chat messages. HuggingFace's OpenAI-compatible API doesn't support fields
// like reasoning_details, reasoning, annotations, audio, and refusal.
// Only ToolCalls is preserved from ChatAssistantMessage.
func sanitizeMessagesForHuggingFace(messages []schemas.ChatMessage) []schemas.ChatMessage {
sanitized := make([]schemas.ChatMessage, len(messages))
for i, msg := range messages {
sanitized[i] = schemas.ChatMessage{
Name: msg.Name,
Role: msg.Role,
Content: msg.Content,
ChatToolMessage: msg.ChatToolMessage,
}
// Only preserve ToolCalls from ChatAssistantMessage
if msg.ChatAssistantMessage != nil && len(msg.ChatAssistantMessage.ToolCalls) > 0 {
sanitized[i].ChatAssistantMessage = &schemas.ChatAssistantMessage{
ToolCalls: msg.ChatAssistantMessage.ToolCalls,
}
}
}
return sanitized
}
func ToHuggingFaceChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*HuggingFaceChatRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, nil
}
// Create the HuggingFace request
// Sanitize messages to remove unsupported fields like reasoning_details
hfReq := &HuggingFaceChatRequest{
Messages: sanitizeMessagesForHuggingFace(bifrostReq.Input),
Model: bifrostReq.Model,
}
// Map parameters if present
if bifrostReq.Params != nil {
params := bifrostReq.Params
if params.FrequencyPenalty != nil {
hfReq.FrequencyPenalty = params.FrequencyPenalty
}
if params.LogProbs != nil {
hfReq.Logprobs = params.LogProbs
}
if params.MaxCompletionTokens != nil {
hfReq.MaxTokens = params.MaxCompletionTokens
}
if params.PresencePenalty != nil {
hfReq.PresencePenalty = params.PresencePenalty
}
if params.Seed != nil {
hfReq.Seed = params.Seed
}
if len(params.Stop) > 0 {
hfReq.Stop = params.Stop
}
if params.Temperature != nil {
hfReq.Temperature = params.Temperature
}
if params.TopLogProbs != nil {
hfReq.TopLogprobs = params.TopLogProbs
}
if params.TopP != nil {
hfReq.TopP = params.TopP
}
// Handle response format (direct type assertion to avoid marshal→unmarshal round-trip)
if params.ResponseFormat != nil {
var hfRF *HuggingFaceResponseFormat
if rfMap, ok := (*params.ResponseFormat).(map[string]interface{}); ok {
hfRF = &HuggingFaceResponseFormat{}
if t, ok := rfMap["type"].(string); ok {
hfRF.Type = t
}
if jsVal, ok := rfMap["json_schema"]; ok {
jsBytes, err := providerUtils.MarshalSorted(jsVal)
if err != nil {
return nil, fmt.Errorf("failed to marshal json_schema: %w", err)
}
var hfSchema HuggingFaceJSONSchema
if err := sonic.Unmarshal(jsBytes, &hfSchema); err != nil {
return nil, fmt.Errorf("failed to unmarshal json_schema: %w", err)
}
hfRF.JSONSchema = &hfSchema
}
} else if converted, err := schemas.ConvertViaJSON[HuggingFaceResponseFormat](*params.ResponseFormat); err == nil {
hfRF = &converted
}
hfReq.ResponseFormat = hfRF
}
// Handle stream options
if params.StreamOptions != nil {
hfReq.StreamOptions = &schemas.ChatStreamOptions{
IncludeUsage: params.StreamOptions.IncludeUsage,
}
}
hfReq.Tools = params.Tools
// Handle tool choice
if params.ToolChoice != nil {
hfToolChoice := &HuggingFaceToolChoice{}
if params.ToolChoice.ChatToolChoiceStr != nil {
switch *params.ToolChoice.ChatToolChoiceStr {
case "auto":
auto := EnumStringTypeAuto
hfToolChoice.EnumValue = &auto
case "none":
none := EnumStringTypeNone
hfToolChoice.EnumValue = &none
case "required":
required := EnumStringTypeRequired
hfToolChoice.EnumValue = &required
}
} else if params.ToolChoice.ChatToolChoiceStruct != nil {
if params.ToolChoice.ChatToolChoiceStruct.Type == schemas.ChatToolChoiceTypeFunction && params.ToolChoice.ChatToolChoiceStruct.Function != nil {
hfToolChoice.Function = &schemas.ChatToolChoiceFunction{
Name: params.ToolChoice.ChatToolChoiceStruct.Function.Name,
}
}
}
if hfToolChoice.EnumValue != nil || hfToolChoice.Function != nil {
hfReq.ToolChoice = hfToolChoice
}
}
hfReq.ExtraParams = bifrostReq.Params.ExtraParams
}
return hfReq, nil
}

Some files were not shown because too many files have changed in this diff Show More