first commit
This commit is contained in:
2740
core/providers/anthropic/anthropic.go
Normal file
2740
core/providers/anthropic/anthropic.go
Normal file
File diff suppressed because it is too large
Load Diff
84
core/providers/anthropic/anthropic_test.go
Normal file
84
core/providers/anthropic/anthropic_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package anthropic_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestAnthropic(t *testing.T) {
|
||||
t.Parallel()
|
||||
if strings.TrimSpace(os.Getenv("ANTHROPIC_API_KEY")) == "" {
|
||||
t.Skip("Skipping Anthropic tests because ANTHROPIC_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.Anthropic,
|
||||
ChatModel: "claude-sonnet-4-5",
|
||||
Fallbacks: []schemas.Fallback{
|
||||
{Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"},
|
||||
{Provider: schemas.Anthropic, Model: "claude-sonnet-4-20250514"},
|
||||
},
|
||||
VisionModel: "claude-sonnet-4-5", // Same model supports vision
|
||||
ReasoningModel: "claude-opus-4-5",
|
||||
PromptCachingModel: "claude-sonnet-4-20250514",
|
||||
PassthroughModel: "claude-sonnet-4-5",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: false, // Not supported
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: true,
|
||||
MultipleToolCallsStreaming: true,
|
||||
End2EndToolCalling: true,
|
||||
AutomaticFunctionCall: true,
|
||||
WebSearchTool: true,
|
||||
ImageURL: true,
|
||||
ImageBase64: true,
|
||||
MultipleImages: true,
|
||||
FileBase64: true,
|
||||
FileURL: true,
|
||||
CompleteEnd2End: true,
|
||||
Embedding: false,
|
||||
Reasoning: true,
|
||||
PromptCaching: true,
|
||||
ListModels: true,
|
||||
BatchCreate: true,
|
||||
BatchList: true,
|
||||
BatchRetrieve: true,
|
||||
BatchCancel: true,
|
||||
BatchResults: true,
|
||||
FileUpload: true,
|
||||
FileList: true,
|
||||
FileRetrieve: true,
|
||||
FileDelete: true,
|
||||
FileContent: false,
|
||||
FileBatchInput: false, // Anthropic batch API only supports inline requests, not file-based input
|
||||
CountTokens: true,
|
||||
StructuredOutputs: true, // Structured outputs with nullable enum support
|
||||
PassthroughAPI: true,
|
||||
Compaction: true,
|
||||
InterleavedThinking: true,
|
||||
FastMode: false, // Enable when test API key has Opus 4.6 access
|
||||
EagerInputStreaming: true, // fine-grained-tool-streaming-2025-05-14 (GA on Anthropic)
|
||||
ServerToolsViaOpenAIEndpoint: true, // web_search / web_fetch / code_execution via /v1/chat/completions
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("AnthropicTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
359
core/providers/anthropic/batch.go
Normal file
359
core/providers/anthropic/batch.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// Anthropic Batch API Types
|
||||
|
||||
// AnthropicBatchRequestItem represents a single request in a batch.
|
||||
type AnthropicBatchRequestItem struct {
|
||||
CustomID string `json:"custom_id"`
|
||||
Params map[string]any `json:"params"`
|
||||
}
|
||||
|
||||
// AnthropicBatchCreateRequest represents the request body for creating a batch.
|
||||
type AnthropicBatchCreateRequest struct {
|
||||
Requests []AnthropicBatchRequestItem `json:"requests"`
|
||||
}
|
||||
|
||||
// AnthropicBatchCancelRequest represents the request body for canceling a batch.
|
||||
type AnthropicBatchCancelRequest struct {
|
||||
BatchID string `json:"batch_id"`
|
||||
}
|
||||
|
||||
// AnthropicBatchRetrieveRequest represents the request body for retrieving a batch.
|
||||
type AnthropicBatchRetrieveRequest struct {
|
||||
BatchID string `json:"batch_id"`
|
||||
}
|
||||
|
||||
// AnthropicBatchListRequest represents the request body for listing batches.
|
||||
type AnthropicBatchListRequest struct {
|
||||
PageToken *string `json:"page_token"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
// AnthropicBatchResultsRequest represents the request body for retrieving batch results.
|
||||
type AnthropicBatchResultsRequest struct {
|
||||
BatchID string `json:"batch_id"`
|
||||
}
|
||||
|
||||
// AnthropicBatchResponse represents an Anthropic batch response.
|
||||
type AnthropicBatchResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
ProcessingStatus string `json:"processing_status"`
|
||||
RequestCounts *AnthropicBatchRequestCounts `json:"request_counts,omitempty"`
|
||||
EndedAt *string `json:"ended_at,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
ArchivedAt *string `json:"archived_at,omitempty"`
|
||||
CancelInitiatedAt *string `json:"cancel_initiated_at,omitempty"`
|
||||
ResultsURL *string `json:"results_url,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicBatchRequestCounts represents the request counts for a batch.
|
||||
type AnthropicBatchRequestCounts struct {
|
||||
Processing int `json:"processing"`
|
||||
Succeeded int `json:"succeeded"`
|
||||
Errored int `json:"errored"`
|
||||
Canceled int `json:"canceled"`
|
||||
Expired int `json:"expired"`
|
||||
}
|
||||
|
||||
// AnthropicBatchListResponse represents the response from listing batches.
|
||||
type AnthropicBatchListResponse struct {
|
||||
Data []AnthropicBatchResponse `json:"data"`
|
||||
HasMore bool `json:"has_more"`
|
||||
FirstID *string `json:"first_id,omitempty"`
|
||||
LastID *string `json:"last_id,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicBatchResultItem represents a single result from a batch.
|
||||
type AnthropicBatchResultItem struct {
|
||||
CustomID string `json:"custom_id"`
|
||||
Result AnthropicBatchResultData `json:"result"`
|
||||
}
|
||||
|
||||
// AnthropicBatchResultData represents the result data.
|
||||
type AnthropicBatchResultData struct {
|
||||
Type string `json:"type"` // "succeeded", "errored", "expired", "canceled"
|
||||
Message map[string]interface{} `json:"message,omitempty"`
|
||||
Error *AnthropicBatchError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicBatchError represents an error in batch results.
|
||||
type AnthropicBatchError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ToBifrostBatchStatus converts Anthropic processing_status to Bifrost status.
|
||||
func ToBifrostBatchStatus(status string) schemas.BatchStatus {
|
||||
switch status {
|
||||
case "in_progress":
|
||||
return schemas.BatchStatusInProgress
|
||||
case "canceling":
|
||||
return schemas.BatchStatusCancelling
|
||||
case "ended":
|
||||
return schemas.BatchStatusEnded
|
||||
default:
|
||||
return schemas.BatchStatus(status)
|
||||
}
|
||||
}
|
||||
|
||||
// parseAnthropicTimestamp converts Anthropic ISO timestamp to Unix timestamp.
|
||||
func parseAnthropicTimestamp(timestamp string) int64 {
|
||||
if timestamp == "" {
|
||||
return 0
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339Nano, timestamp)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
// ToBifrostObjectType converts Anthropic type to Bifrost object type.
|
||||
func ToBifrostObjectType(anthropicType string) string {
|
||||
switch anthropicType {
|
||||
case "message_batch":
|
||||
return "batch"
|
||||
default:
|
||||
return anthropicType
|
||||
}
|
||||
}
|
||||
|
||||
// ToBifrostBatchCreateResponse converts Anthropic batch response to Bifrost batch create response.
|
||||
func (r *AnthropicBatchResponse) ToBifrostBatchCreateResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchCreateResponse {
|
||||
expiresAt := parseAnthropicTimestamp(r.ExpiresAt)
|
||||
resp := &schemas.BifrostBatchCreateResponse{
|
||||
ID: r.ID,
|
||||
Object: ToBifrostObjectType(r.Type),
|
||||
Status: ToBifrostBatchStatus(r.ProcessingStatus),
|
||||
ProcessingStatus: &r.ProcessingStatus,
|
||||
ResultsURL: r.ResultsURL,
|
||||
CreatedAt: parseAnthropicTimestamp(r.CreatedAt),
|
||||
ExpiresAt: &expiresAt,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Latency: latency.Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
if r.RequestCounts != nil {
|
||||
resp.RequestCounts = schemas.BatchRequestCounts{
|
||||
Total: r.RequestCounts.Processing + r.RequestCounts.Succeeded + r.RequestCounts.Errored + r.RequestCounts.Canceled + r.RequestCounts.Expired,
|
||||
Completed: r.RequestCounts.Succeeded,
|
||||
Failed: r.RequestCounts.Errored,
|
||||
Succeeded: r.RequestCounts.Succeeded,
|
||||
Expired: r.RequestCounts.Expired,
|
||||
Canceled: r.RequestCounts.Canceled,
|
||||
Pending: r.RequestCounts.Processing,
|
||||
}
|
||||
}
|
||||
|
||||
if sendBackRawRequest {
|
||||
resp.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
|
||||
if sendBackRawResponse {
|
||||
resp.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// ToBifrostBatchRetrieveResponse converts Anthropic batch response to Bifrost batch retrieve response.
|
||||
func (r *AnthropicBatchResponse) ToBifrostBatchRetrieveResponse(latency time.Duration, sendBackRawRequest bool, sendBackRawResponse bool, rawRequest interface{}, rawResponse interface{}) *schemas.BifrostBatchRetrieveResponse {
|
||||
resp := &schemas.BifrostBatchRetrieveResponse{
|
||||
ID: r.ID,
|
||||
Object: ToBifrostObjectType(r.Type),
|
||||
Status: ToBifrostBatchStatus(r.ProcessingStatus),
|
||||
ProcessingStatus: &r.ProcessingStatus,
|
||||
ResultsURL: r.ResultsURL,
|
||||
CreatedAt: parseAnthropicTimestamp(r.CreatedAt),
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Latency: latency.Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
if sendBackRawRequest {
|
||||
resp.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
|
||||
expiresAt := parseAnthropicTimestamp(r.ExpiresAt)
|
||||
if expiresAt > 0 {
|
||||
resp.ExpiresAt = &expiresAt
|
||||
}
|
||||
|
||||
if r.EndedAt != nil {
|
||||
endedAt := parseAnthropicTimestamp(*r.EndedAt)
|
||||
resp.CompletedAt = &endedAt
|
||||
}
|
||||
|
||||
if r.ArchivedAt != nil {
|
||||
archivedAt := parseAnthropicTimestamp(*r.ArchivedAt)
|
||||
resp.ArchivedAt = &archivedAt
|
||||
}
|
||||
|
||||
if r.CancelInitiatedAt != nil {
|
||||
cancellingAt := parseAnthropicTimestamp(*r.CancelInitiatedAt)
|
||||
resp.CancellingAt = &cancellingAt
|
||||
}
|
||||
|
||||
if r.RequestCounts != nil {
|
||||
resp.RequestCounts = schemas.BatchRequestCounts{
|
||||
Total: r.RequestCounts.Processing + r.RequestCounts.Succeeded + r.RequestCounts.Errored + r.RequestCounts.Canceled + r.RequestCounts.Expired,
|
||||
Completed: r.RequestCounts.Succeeded,
|
||||
Failed: r.RequestCounts.Errored,
|
||||
Succeeded: r.RequestCounts.Succeeded,
|
||||
Expired: r.RequestCounts.Expired,
|
||||
Canceled: r.RequestCounts.Canceled,
|
||||
Pending: r.RequestCounts.Processing,
|
||||
}
|
||||
}
|
||||
|
||||
if sendBackRawResponse {
|
||||
resp.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// ToAnthropicBatchCreateResponse converts a Bifrost batch create response to Anthropic format.
|
||||
func ToAnthropicBatchCreateResponse(resp *schemas.BifrostBatchCreateResponse) *AnthropicBatchResponse {
|
||||
result := &AnthropicBatchResponse{
|
||||
ID: resp.ID,
|
||||
Type: "message_batch",
|
||||
ProcessingStatus: toAnthropicProcessingStatus(resp.Status),
|
||||
CreatedAt: formatAnthropicTimestamp(resp.CreatedAt),
|
||||
ResultsURL: resp.ResultsURL,
|
||||
}
|
||||
if resp.ExpiresAt != nil {
|
||||
result.ExpiresAt = formatAnthropicTimestamp(*resp.ExpiresAt)
|
||||
} else {
|
||||
// This is a fallback for worst case scenario where expires_at is not available
|
||||
// Which is never expected to happen, but just in case.
|
||||
result.ExpiresAt = formatAnthropicTimestamp(time.Now().Add(24 * time.Hour).Unix())
|
||||
}
|
||||
if resp.RequestCounts.Total > 0 {
|
||||
result.RequestCounts = &AnthropicBatchRequestCounts{
|
||||
Processing: resp.RequestCounts.Pending,
|
||||
Succeeded: resp.RequestCounts.Succeeded,
|
||||
Errored: resp.RequestCounts.Failed,
|
||||
Canceled: resp.RequestCounts.Canceled,
|
||||
Expired: resp.RequestCounts.Expired,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ToAnthropicBatchListResponse converts a Bifrost batch list response to Anthropic format.
|
||||
func ToAnthropicBatchListResponse(resp *schemas.BifrostBatchListResponse) *AnthropicBatchListResponse {
|
||||
result := &AnthropicBatchListResponse{
|
||||
Data: make([]AnthropicBatchResponse, len(resp.Data)),
|
||||
HasMore: resp.HasMore,
|
||||
FirstID: resp.FirstID,
|
||||
LastID: resp.LastID,
|
||||
}
|
||||
|
||||
for i, batch := range resp.Data {
|
||||
result.Data[i] = *ToAnthropicBatchRetrieveResponse(&batch)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ToAnthropicBatchRetrieveResponse converts a Bifrost batch retrieve response to Anthropic format.
|
||||
func ToAnthropicBatchRetrieveResponse(resp *schemas.BifrostBatchRetrieveResponse) *AnthropicBatchResponse {
|
||||
result := &AnthropicBatchResponse{
|
||||
ID: resp.ID,
|
||||
Type: "message_batch",
|
||||
ProcessingStatus: toAnthropicProcessingStatus(resp.Status),
|
||||
CreatedAt: formatAnthropicTimestamp(resp.CreatedAt),
|
||||
ResultsURL: resp.ResultsURL,
|
||||
}
|
||||
|
||||
if resp.ExpiresAt != nil {
|
||||
result.ExpiresAt = formatAnthropicTimestamp(*resp.ExpiresAt)
|
||||
}
|
||||
|
||||
if resp.CompletedAt != nil {
|
||||
endedAt := formatAnthropicTimestamp(*resp.CompletedAt)
|
||||
result.EndedAt = &endedAt
|
||||
}
|
||||
|
||||
if resp.ArchivedAt != nil {
|
||||
archivedAt := formatAnthropicTimestamp(*resp.ArchivedAt)
|
||||
result.ArchivedAt = &archivedAt
|
||||
}
|
||||
|
||||
if resp.CancellingAt != nil {
|
||||
cancelInitiatedAt := formatAnthropicTimestamp(*resp.CancellingAt)
|
||||
result.CancelInitiatedAt = &cancelInitiatedAt
|
||||
}
|
||||
|
||||
if resp.RequestCounts.Total > 0 {
|
||||
result.RequestCounts = &AnthropicBatchRequestCounts{
|
||||
Processing: resp.RequestCounts.Pending,
|
||||
Succeeded: resp.RequestCounts.Succeeded,
|
||||
Errored: resp.RequestCounts.Failed,
|
||||
Canceled: resp.RequestCounts.Canceled,
|
||||
Expired: resp.RequestCounts.Expired,
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ToAnthropicBatchCancelResponse converts a Bifrost batch cancel response to Anthropic format.
|
||||
func ToAnthropicBatchCancelResponse(resp *schemas.BifrostBatchCancelResponse) *AnthropicBatchResponse {
|
||||
result := &AnthropicBatchResponse{
|
||||
ID: resp.ID,
|
||||
Type: "message_batch",
|
||||
ProcessingStatus: toAnthropicProcessingStatus(resp.Status),
|
||||
}
|
||||
|
||||
if resp.CancellingAt != nil {
|
||||
cancelInitiatedAt := formatAnthropicTimestamp(*resp.CancellingAt)
|
||||
result.CancelInitiatedAt = &cancelInitiatedAt
|
||||
}
|
||||
|
||||
if resp.RequestCounts.Total > 0 {
|
||||
result.RequestCounts = &AnthropicBatchRequestCounts{
|
||||
Processing: resp.RequestCounts.Pending,
|
||||
Succeeded: resp.RequestCounts.Succeeded,
|
||||
Canceled: resp.RequestCounts.Canceled,
|
||||
Expired: resp.RequestCounts.Expired,
|
||||
Errored: resp.RequestCounts.Failed,
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// toAnthropicProcessingStatus converts Bifrost batch status to Anthropic processing_status.
|
||||
func toAnthropicProcessingStatus(status schemas.BatchStatus) string {
|
||||
switch status {
|
||||
case schemas.BatchStatusInProgress:
|
||||
fallthrough
|
||||
case schemas.BatchStatusValidating:
|
||||
return "in_progress"
|
||||
case schemas.BatchStatusCancelling:
|
||||
return "canceling"
|
||||
case schemas.BatchStatusEnded, schemas.BatchStatusCompleted, schemas.BatchStatusCancelled:
|
||||
return "ended"
|
||||
default:
|
||||
return string(status)
|
||||
}
|
||||
}
|
||||
|
||||
// formatAnthropicTimestamp converts Unix timestamp to Anthropic ISO timestamp format.
|
||||
func formatAnthropicTimestamp(unixTime int64) string {
|
||||
if unixTime == 0 {
|
||||
return ""
|
||||
}
|
||||
return time.Unix(unixTime, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
1388
core/providers/anthropic/chat.go
Normal file
1388
core/providers/anthropic/chat.go
Normal file
File diff suppressed because it is too large
Load Diff
366
core/providers/anthropic/chat_server_tools_test.go
Normal file
366
core/providers/anthropic/chat_server_tools_test.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TestChatTool_ServerToolRoundTrip verifies that every Anthropic server-tool
|
||||
// variant survives Marshal/Unmarshal through the neutral ChatTool schema.
|
||||
// This locks in the fix for the user-reported bug where a raw JSON tool like
|
||||
// {"type":"web_search_20260209","name":"web_search","max_uses":5} was being
|
||||
// dropped at the neutral-schema layer because ChatTool had no slots for the
|
||||
// server-tool metadata.
|
||||
func TestChatTool_ServerToolRoundTrip(t *testing.T) {
|
||||
five := 5
|
||||
ptrTrue := true
|
||||
w, h := 1280, 800
|
||||
maxChars := 16000
|
||||
maxContent := 32000
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
raw string
|
||||
}{
|
||||
{
|
||||
name: "web_search_20260209",
|
||||
raw: `{"type":"web_search_20260209","name":"web_search","max_uses":5,"allowed_callers":["direct"]}`,
|
||||
},
|
||||
{
|
||||
name: "web_search_with_domains",
|
||||
raw: `{"type":"web_search_20250305","name":"web_search","allowed_domains":["example.com","docs.example.com"]}`,
|
||||
},
|
||||
{
|
||||
name: "web_search_with_user_location",
|
||||
raw: `{"type":"web_search_20250305","name":"web_search","user_location":{"type":"approximate","city":"San Francisco","country":"US","timezone":"America/Los_Angeles"}}`,
|
||||
},
|
||||
{
|
||||
name: "web_fetch_20260309",
|
||||
raw: `{"type":"web_fetch_20260309","name":"web_fetch","max_uses":5,"max_content_tokens":32000,"citations":{"enabled":true},"use_cache":true}`,
|
||||
},
|
||||
{
|
||||
name: "computer_20251124",
|
||||
raw: `{"type":"computer_20251124","name":"computer","display_width_px":1280,"display_height_px":800,"display_number":1,"enable_zoom":true}`,
|
||||
},
|
||||
{
|
||||
name: "text_editor_20250728",
|
||||
raw: `{"type":"text_editor_20250728","name":"str_replace_based_edit_tool","max_characters":16000}`,
|
||||
},
|
||||
{
|
||||
name: "bash_20250124",
|
||||
raw: `{"type":"bash_20250124","name":"bash"}`,
|
||||
},
|
||||
{
|
||||
name: "memory_20250818",
|
||||
raw: `{"type":"memory_20250818","name":"memory"}`,
|
||||
},
|
||||
{
|
||||
name: "code_execution_20250825",
|
||||
raw: `{"type":"code_execution_20250825","name":"code_execution"}`,
|
||||
},
|
||||
{
|
||||
name: "tool_search_tool_bm25",
|
||||
raw: `{"type":"tool_search_tool_bm25","name":"tool_search_tool_bm25"}`,
|
||||
},
|
||||
{
|
||||
name: "mcp_toolset",
|
||||
raw: `{"type":"mcp_toolset","name":"my_mcp","mcp_server_name":"notion","configs":{"search":{"enabled":true}}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Variant-specific field assertions. Invoked twice — once after
|
||||
// initial decode, once after round-trip — so that a regression in
|
||||
// MarshalSorted that silently drops any variant-specific field
|
||||
// fails this test instead of sneaking through.
|
||||
assertVariantFields := func(label string, tl schemas.ChatTool) {
|
||||
t.Helper()
|
||||
switch tc.name {
|
||||
case "web_search_20260209":
|
||||
if tl.MaxUses == nil || *tl.MaxUses != five {
|
||||
t.Errorf("%s: MaxUses not preserved, got %v", label, tl.MaxUses)
|
||||
}
|
||||
if len(tl.AllowedCallers) != 1 || tl.AllowedCallers[0] != "direct" {
|
||||
t.Errorf("%s: AllowedCallers not preserved, got %v", label, tl.AllowedCallers)
|
||||
}
|
||||
case "web_fetch_20260309":
|
||||
if tl.MaxContentTokens == nil || *tl.MaxContentTokens != maxContent {
|
||||
t.Errorf("%s: MaxContentTokens not preserved, got %v", label, tl.MaxContentTokens)
|
||||
}
|
||||
if tl.Citations == nil || tl.Citations.Enabled == nil || !*tl.Citations.Enabled {
|
||||
t.Errorf("%s: Citations not preserved, got %v", label, tl.Citations)
|
||||
}
|
||||
if tl.UseCache == nil || !*tl.UseCache {
|
||||
t.Errorf("%s: UseCache not preserved", label)
|
||||
}
|
||||
_ = ptrTrue
|
||||
case "computer_20251124":
|
||||
if tl.DisplayWidthPx == nil || *tl.DisplayWidthPx != w {
|
||||
t.Errorf("%s: DisplayWidthPx not preserved, got %v", label, tl.DisplayWidthPx)
|
||||
}
|
||||
if tl.DisplayHeightPx == nil || *tl.DisplayHeightPx != h {
|
||||
t.Errorf("%s: DisplayHeightPx not preserved, got %v", label, tl.DisplayHeightPx)
|
||||
}
|
||||
case "text_editor_20250728":
|
||||
if tl.MaxCharacters == nil || *tl.MaxCharacters != maxChars {
|
||||
t.Errorf("%s: MaxCharacters not preserved, got %v", label, tl.MaxCharacters)
|
||||
}
|
||||
case "mcp_toolset":
|
||||
if tl.MCPServerName != "notion" {
|
||||
t.Errorf("%s: MCPServerName not preserved, got %q", label, tl.MCPServerName)
|
||||
}
|
||||
if len(tl.Configs) != 1 {
|
||||
t.Errorf("%s: Configs not preserved, got %v", label, tl.Configs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var tool schemas.ChatTool
|
||||
if err := sonic.Unmarshal([]byte(tc.raw), &tool); err != nil {
|
||||
t.Fatalf("unmarshal failed: %v", err)
|
||||
}
|
||||
if string(tool.Type) == "" {
|
||||
t.Errorf("Type should be preserved, got empty")
|
||||
}
|
||||
if tool.Name == "" {
|
||||
t.Errorf("Name should be preserved, got empty")
|
||||
}
|
||||
assertVariantFields("first decode", tool)
|
||||
|
||||
// Re-marshal and re-decode — all preserved fields should survive round trip.
|
||||
out, err := schemas.MarshalSorted(tool)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal failed: %v", err)
|
||||
}
|
||||
var tool2 schemas.ChatTool
|
||||
if err := sonic.Unmarshal(out, &tool2); err != nil {
|
||||
t.Fatalf("second unmarshal failed: %v\njson: %s", err, string(out))
|
||||
}
|
||||
if tool.Name != tool2.Name || tool.Type != tool2.Type {
|
||||
t.Errorf("round-trip mismatch\n in: %s\n out: %s", tc.raw, string(out))
|
||||
}
|
||||
assertVariantFields("round trip", tool2)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestToAnthropicChatRequest_ServerTools verifies every ChatTool server-tool
|
||||
// shape converts correctly through ToAnthropicChatRequest.
|
||||
func TestToAnthropicChatRequest_ServerTools(t *testing.T) {
|
||||
mk := func(rawTool string) *schemas.BifrostChatRequest {
|
||||
var tool schemas.ChatTool
|
||||
if err := sonic.Unmarshal([]byte(rawTool), &tool); err != nil {
|
||||
t.Fatalf("test setup: %v", err)
|
||||
}
|
||||
return &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: "claude-sonnet-4-6",
|
||||
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hi")}}},
|
||||
Params: &schemas.ChatParameters{Tools: []schemas.ChatTool{tool}},
|
||||
}
|
||||
}
|
||||
|
||||
type check struct {
|
||||
expectName string
|
||||
expectType AnthropicToolType
|
||||
expectWebSearch bool
|
||||
expectWebFetch bool
|
||||
expectComputer bool
|
||||
expectTextEditor bool
|
||||
expectMCPToolset bool
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
raw string
|
||||
want check
|
||||
}{
|
||||
{
|
||||
name: "web_search",
|
||||
raw: `{"type":"web_search_20260209","name":"web_search","max_uses":5}`,
|
||||
want: check{expectName: "web_search", expectType: "web_search_20260209", expectWebSearch: true},
|
||||
},
|
||||
{
|
||||
name: "web_fetch",
|
||||
raw: `{"type":"web_fetch_20260309","name":"web_fetch","max_uses":3,"use_cache":true}`,
|
||||
want: check{expectName: "web_fetch", expectType: "web_fetch_20260309", expectWebFetch: true},
|
||||
},
|
||||
{
|
||||
name: "computer_20251124",
|
||||
raw: `{"type":"computer_20251124","name":"computer","display_width_px":1280,"display_height_px":800}`,
|
||||
want: check{expectName: "computer", expectType: "computer_20251124", expectComputer: true},
|
||||
},
|
||||
{
|
||||
name: "text_editor_20250728",
|
||||
raw: `{"type":"text_editor_20250728","name":"str_replace_based_edit_tool","max_characters":16000}`,
|
||||
want: check{expectName: "str_replace_based_edit_tool", expectType: "text_editor_20250728", expectTextEditor: true},
|
||||
},
|
||||
{
|
||||
name: "bash_20250124",
|
||||
raw: `{"type":"bash_20250124","name":"bash"}`,
|
||||
want: check{expectName: "bash", expectType: "bash_20250124"},
|
||||
},
|
||||
{
|
||||
name: "mcp_toolset",
|
||||
raw: `{"type":"mcp_toolset","name":"notion","mcp_server_name":"notion"}`,
|
||||
want: check{expectMCPToolset: true},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := mk(tc.raw)
|
||||
out, err := ToAnthropicChatRequest(nil, req)
|
||||
if err != nil {
|
||||
t.Fatalf("conversion failed: %v", err)
|
||||
}
|
||||
if len(out.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d (raw: %s)", len(out.Tools), tc.raw)
|
||||
}
|
||||
at := out.Tools[0]
|
||||
if tc.want.expectMCPToolset {
|
||||
if at.MCPToolset == nil {
|
||||
t.Errorf("expected MCPToolset to be set")
|
||||
}
|
||||
return
|
||||
}
|
||||
if at.Name != tc.want.expectName {
|
||||
t.Errorf("Name: got %q want %q", at.Name, tc.want.expectName)
|
||||
}
|
||||
if at.Type == nil || *at.Type != tc.want.expectType {
|
||||
t.Errorf("Type: got %v want %q", at.Type, tc.want.expectType)
|
||||
}
|
||||
if tc.want.expectWebSearch && at.AnthropicToolWebSearch == nil {
|
||||
t.Errorf("expected AnthropicToolWebSearch populated")
|
||||
}
|
||||
if tc.want.expectWebFetch && at.AnthropicToolWebFetch == nil {
|
||||
t.Errorf("expected AnthropicToolWebFetch populated")
|
||||
}
|
||||
if tc.want.expectComputer && at.AnthropicToolComputerUse == nil {
|
||||
t.Errorf("expected AnthropicToolComputerUse populated")
|
||||
}
|
||||
if tc.want.expectTextEditor && at.AnthropicToolTextEditor == nil {
|
||||
t.Errorf("expected AnthropicToolTextEditor populated")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestToBifrostResponsesRequest_MCPToolsetPreservesAnthropicFlags verifies
|
||||
// that when an Anthropic request carries an mcp_toolset tool with the four
|
||||
// Anthropic-native flags (DeferLoading, AllowedCallers, InputExamples,
|
||||
// EagerInputStreaming), those flags survive the inbound conversion into the
|
||||
// neutral ResponsesTool on the mcp_servers merge path. Before the fix, the
|
||||
// merge path only applied MCP configs (allowlist/cache-control) and dropped
|
||||
// the flags because convertAnthropicToolToBifrost skips mcp_toolset entries.
|
||||
func TestToBifrostResponsesRequest_MCPToolsetPreservesAnthropicFlags(t *testing.T) {
|
||||
toolsetType := "mcp_toolset"
|
||||
_ = toolsetType // shape documentation only; AnthropicTool.Type is pointer-to-enum and left nil for mcp_toolset
|
||||
|
||||
req := &AnthropicMessageRequest{
|
||||
Model: "claude-sonnet-4-6",
|
||||
Tools: []AnthropicTool{
|
||||
{
|
||||
Name: "notion",
|
||||
DeferLoading: schemas.Ptr(true),
|
||||
AllowedCallers: []string{"direct", "agent"},
|
||||
EagerInputStreaming: schemas.Ptr(false),
|
||||
InputExamples: []AnthropicToolInputExample{
|
||||
{Input: json.RawMessage(`{"q":"hello"}`), Description: schemas.Ptr("basic")},
|
||||
},
|
||||
MCPToolset: &AnthropicMCPToolsetTool{
|
||||
Type: "mcp_toolset",
|
||||
MCPServerName: "notion",
|
||||
DefaultConfig: &AnthropicMCPToolsetConfig{Enabled: schemas.Ptr(true)},
|
||||
},
|
||||
},
|
||||
},
|
||||
MCPServers: []AnthropicMCPServerV2{
|
||||
{Type: "url", URL: "https://mcp.example.com", Name: "notion"},
|
||||
},
|
||||
}
|
||||
|
||||
got := req.ToBifrostResponsesRequest(nil)
|
||||
if got == nil || got.Params == nil {
|
||||
t.Fatalf("ToBifrostResponsesRequest returned nil params")
|
||||
}
|
||||
|
||||
// The mcp_toolset tool should have been dropped by convertAnthropicToolToBifrost
|
||||
// and re-created on the mcp_servers merge path — end result: exactly one tool,
|
||||
// of type mcp, carrying the Anthropic flags we set.
|
||||
if len(got.Params.Tools) != 1 {
|
||||
t.Fatalf("expected 1 mcp tool after merge, got %d", len(got.Params.Tools))
|
||||
}
|
||||
mcp := got.Params.Tools[0]
|
||||
if mcp.Type != schemas.ResponsesToolTypeMCP {
|
||||
t.Errorf("expected MCP tool, got type=%q", mcp.Type)
|
||||
}
|
||||
if mcp.DeferLoading == nil || !*mcp.DeferLoading {
|
||||
t.Errorf("DeferLoading dropped on mcp_toolset merge path")
|
||||
}
|
||||
if len(mcp.AllowedCallers) != 2 || mcp.AllowedCallers[0] != "direct" {
|
||||
t.Errorf("AllowedCallers dropped on mcp_toolset merge path, got %v", mcp.AllowedCallers)
|
||||
}
|
||||
if len(mcp.InputExamples) != 1 {
|
||||
t.Errorf("InputExamples dropped on mcp_toolset merge path, got len=%d", len(mcp.InputExamples))
|
||||
}
|
||||
if mcp.EagerInputStreaming == nil || *mcp.EagerInputStreaming {
|
||||
t.Errorf("EagerInputStreaming dropped on mcp_toolset merge path, got %v", mcp.EagerInputStreaming)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToAnthropicChatRequest_ServerTools_ReproUserBug is the exact shape
|
||||
// from the reported curl — web_search_20260209 with max_uses + allowed_callers.
|
||||
// Verifies the request reaches ToAnthropicChatRequest output with a populated
|
||||
// tools array (previously it was silently dropped).
|
||||
func TestToAnthropicChatRequest_ServerTools_ReproUserBug(t *testing.T) {
|
||||
raw := []byte(`{
|
||||
"model":"claude-sonnet-4-6",
|
||||
"messages":[{"role":"user","content":"What is the weather in SF?"}],
|
||||
"tools":[{"name":"web_search","type":"web_search_20260209","max_uses":5,"allowed_callers":["direct"]}]
|
||||
}`)
|
||||
// Unmarshal through the neutral schema the way the OpenAI endpoint does.
|
||||
var inner struct {
|
||||
Model string `json:"model"`
|
||||
Messages []json.RawMessage `json:"messages"`
|
||||
Tools []schemas.ChatTool `json:"tools"`
|
||||
}
|
||||
if err := sonic.Unmarshal(raw, &inner); err != nil {
|
||||
t.Fatalf("outer unmarshal: %v", err)
|
||||
}
|
||||
if len(inner.Tools) != 1 {
|
||||
t.Fatalf("setup: expected 1 tool in raw JSON, got %d", len(inner.Tools))
|
||||
}
|
||||
if inner.Tools[0].Name == "" {
|
||||
t.Errorf("Name lost at neutral-schema decode (was the bug). Got: %+v", inner.Tools[0])
|
||||
}
|
||||
if inner.Tools[0].MaxUses == nil {
|
||||
t.Errorf("MaxUses lost at neutral-schema decode (was the bug)")
|
||||
}
|
||||
|
||||
req := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: inner.Model,
|
||||
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hi")}}},
|
||||
Params: &schemas.ChatParameters{Tools: inner.Tools},
|
||||
}
|
||||
out, err := ToAnthropicChatRequest(nil, req)
|
||||
if err != nil {
|
||||
t.Fatalf("conversion failed: %v", err)
|
||||
}
|
||||
if len(out.Tools) != 1 {
|
||||
t.Fatalf("repro bug: expected 1 tool after conversion, got %d (tools array was empty — this was the bug)", len(out.Tools))
|
||||
}
|
||||
if out.Tools[0].Name != "web_search" {
|
||||
t.Errorf("tool Name: got %q, want %q", out.Tools[0].Name, "web_search")
|
||||
}
|
||||
if out.Tools[0].AnthropicToolWebSearch == nil ||
|
||||
out.Tools[0].AnthropicToolWebSearch.MaxUses == nil ||
|
||||
*out.Tools[0].AnthropicToolWebSearch.MaxUses != 5 {
|
||||
t.Errorf("tool max_uses lost: %+v", out.Tools[0])
|
||||
}
|
||||
}
|
||||
752
core/providers/anthropic/chat_test.go
Normal file
752
core/providers/anthropic/chat_test.go
Normal file
@@ -0,0 +1,752 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestToAnthropicChatRequest_PreservesPropertyOrder(t *testing.T) {
|
||||
params := &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("chain_of_thought", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "string"),
|
||||
schemas.KV("description", "Reasoning steps"),
|
||||
)),
|
||||
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "string"),
|
||||
schemas.KV("description", "The answer"),
|
||||
)),
|
||||
schemas.KV("citations", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "array"),
|
||||
)),
|
||||
schemas.KV("is_unanswered", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "boolean"),
|
||||
)),
|
||||
),
|
||||
Required: []string{"answer", "is_unanswered"},
|
||||
}
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
Input: []schemas.ChatMessage{{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
|
||||
}},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "AnswerResponseModel",
|
||||
Description: schemas.Ptr("Extract answer"),
|
||||
Parameters: params,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) == 0 {
|
||||
t.Fatal("expected at least one tool")
|
||||
}
|
||||
|
||||
inputSchema := result.Tools[0].InputSchema
|
||||
if inputSchema == nil {
|
||||
t.Fatal("expected InputSchema to be non-nil")
|
||||
}
|
||||
|
||||
// CoT: property order preserved
|
||||
keys := inputSchema.Properties.Keys()
|
||||
expected := []string{"chain_of_thought", "answer", "citations", "is_unanswered"}
|
||||
if len(keys) != len(expected) {
|
||||
t.Fatalf("expected %d properties, got %d: %v", len(expected), len(keys), keys)
|
||||
}
|
||||
for i, k := range expected {
|
||||
if keys[i] != k {
|
||||
t.Errorf("property %d: expected %q, got %q (full order: %v)", i, k, keys[i], keys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAnthropicChatRequest_CachingDeterminism(t *testing.T) {
|
||||
makeReq := func(props *schemas.OrderedMap) *schemas.BifrostChatRequest {
|
||||
return &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
Input: []schemas.ChatMessage{{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: new("test")},
|
||||
}},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "test",
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Version A: type before description
|
||||
propsA := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "string"),
|
||||
schemas.KV("description", "Step by step"),
|
||||
)),
|
||||
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "string"),
|
||||
schemas.KV("description", "Final answer"),
|
||||
)),
|
||||
)
|
||||
|
||||
// Version B: description before type
|
||||
propsB := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("description", "Step by step"),
|
||||
schemas.KV("type", "string"),
|
||||
)),
|
||||
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("description", "Final answer"),
|
||||
schemas.KV("type", "string"),
|
||||
)),
|
||||
)
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
resultA, err := ToAnthropicChatRequest(ctx, makeReq(propsA))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
resultB, err := ToAnthropicChatRequest(ctx, makeReq(propsB))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
jsonA, err := schemas.Marshal(resultA.Tools[0].InputSchema)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal params A: %v", err)
|
||||
}
|
||||
jsonB, err := schemas.Marshal(resultB.Tools[0].InputSchema)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal params B: %v", err)
|
||||
}
|
||||
|
||||
if string(jsonA) != string(jsonB) {
|
||||
t.Errorf("caching broken: same schema produced different JSON\nA: %s\nB: %s", jsonA, jsonB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAnthropicChatRequest_NestedProperties_Preserved(t *testing.T) {
|
||||
params := &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("output", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "object"),
|
||||
schemas.KV("properties", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("verdict", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
|
||||
schemas.KV("score", schemas.NewOrderedMapFromPairs(schemas.KV("type", "number"))),
|
||||
schemas.KV("explanation", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
|
||||
)),
|
||||
)),
|
||||
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
|
||||
),
|
||||
}
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
Input: []schemas.ChatMessage{{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
|
||||
}},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "nested_tool",
|
||||
Parameters: params,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) == 0 {
|
||||
t.Fatal("expected at least one tool")
|
||||
}
|
||||
inputSchema := result.Tools[0].InputSchema
|
||||
|
||||
// CoT: top-level property order preserved
|
||||
keys := inputSchema.Properties.Keys()
|
||||
if len(keys) != 2 || keys[0] != "output" || keys[1] != "reasoning" {
|
||||
t.Errorf("expected top-level property order [output, reasoning], got %v", keys)
|
||||
}
|
||||
|
||||
// CoT: nested property order preserved
|
||||
output, ok := inputSchema.Properties.Get("output")
|
||||
if !ok {
|
||||
t.Fatal("expected output property")
|
||||
}
|
||||
outputOM, ok := output.(*schemas.OrderedMap)
|
||||
if !ok {
|
||||
t.Fatalf("expected output to be *schemas.OrderedMap, got %T", output)
|
||||
}
|
||||
nestedProps, ok := outputOM.Get("properties")
|
||||
if !ok {
|
||||
t.Fatal("expected nested properties in output")
|
||||
}
|
||||
nestedPropsOM, ok := nestedProps.(*schemas.OrderedMap)
|
||||
if !ok {
|
||||
t.Fatalf("expected nested properties to be *schemas.OrderedMap, got %T", nestedProps)
|
||||
}
|
||||
nestedKeys := nestedPropsOM.Keys()
|
||||
if len(nestedKeys) != 3 || nestedKeys[0] != "verdict" || nestedKeys[1] != "score" || nestedKeys[2] != "explanation" {
|
||||
t.Errorf("expected nested property order [verdict, score, explanation], got %v", nestedKeys)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToAnthropicChatRequest_ToolInputKeyOrderPreservation verifies that tool_use input
|
||||
// arguments preserve the client's original key ordering after conversion to Anthropic format.
|
||||
// This is critical for prompt caching, which relies on exact byte-for-byte prefix matching.
|
||||
// The test uses multiple parallel tool calls in a single assistant message — each with
|
||||
// a different key ordering — matching real-world Claude Code usage patterns.
|
||||
func TestToAnthropicChatRequest_ToolInputKeyOrderPreservation(t *testing.T) {
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
|
||||
},
|
||||
{
|
||||
// Multiple parallel tool calls with different key orderings per block
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{
|
||||
{
|
||||
Index: 0,
|
||||
Type: schemas.Ptr("function"),
|
||||
ID: schemas.Ptr("toolu_vrtx_013t7gabfKz98BKpdwrnS6LP"),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("bash"),
|
||||
Arguments: `{"description":"Find references to auth_injector quickly","timeout":30000,"command":"grep -r \"auth_injector\" . --include=\"Makefile\" -l 2>/dev/null"}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 1,
|
||||
Type: schemas.Ptr("function"),
|
||||
ID: schemas.Ptr("toolu_vrtx_01K2kr3wi7M4RriLgE7Kq3vJ"),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("bash"),
|
||||
Arguments: `{"command":"git diff main...HEAD --stat","description":"Show diff of commits in branch"}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 2,
|
||||
Type: schemas.Ptr("function"),
|
||||
ID: schemas.Ptr("toolu_vrtx_01D1mMkcvpfqGrEhkcxUQpGc"),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("bash"),
|
||||
Arguments: `{"command":"git log main..HEAD --format=\"%H %s\" | head -20","description":"Show detailed commits in branch"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Collect all tool_use content blocks
|
||||
var toolUseBlocks []AnthropicContentBlock
|
||||
for _, msg := range result.Messages {
|
||||
for _, block := range msg.Content.ContentBlocks {
|
||||
if block.Type == AnthropicContentBlockTypeToolUse {
|
||||
toolUseBlocks = append(toolUseBlocks, block)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolUseBlocks) != 3 {
|
||||
t.Fatalf("expected 3 tool_use blocks, got %d", len(toolUseBlocks))
|
||||
}
|
||||
|
||||
// Block 0: keys should be description, timeout, command (NOT alphabetical)
|
||||
json0, _ := json.Marshal(toolUseBlocks[0].Input)
|
||||
s0 := string(json0)
|
||||
descIdx0 := strings.Index(s0, `"description"`)
|
||||
timeIdx0 := strings.Index(s0, `"timeout"`)
|
||||
cmdIdx0 := strings.Index(s0, `"command"`)
|
||||
if descIdx0 < 0 || timeIdx0 < 0 || cmdIdx0 < 0 {
|
||||
t.Fatalf("block 0: missing expected key(s) in: %s", s0)
|
||||
}
|
||||
if !(descIdx0 < timeIdx0 && timeIdx0 < cmdIdx0) {
|
||||
t.Errorf("block 0: key order not preserved, expected description < timeout < command in: %s", s0)
|
||||
}
|
||||
|
||||
// Block 1: keys should be command, description (NOT alphabetical)
|
||||
json1, _ := json.Marshal(toolUseBlocks[1].Input)
|
||||
s1 := string(json1)
|
||||
cmdIdx1 := strings.Index(s1, `"command"`)
|
||||
descIdx1 := strings.Index(s1, `"description"`)
|
||||
if cmdIdx1 < 0 || descIdx1 < 0 {
|
||||
t.Fatalf("block 1: missing expected key(s) in: %s", s1)
|
||||
}
|
||||
if !(cmdIdx1 < descIdx1) {
|
||||
t.Errorf("block 1: key order not preserved, expected command < description in: %s", s1)
|
||||
}
|
||||
|
||||
// Block 2: keys should be command, description (same as block 1)
|
||||
json2, _ := json.Marshal(toolUseBlocks[2].Input)
|
||||
s2 := string(json2)
|
||||
cmdIdx2 := strings.Index(s2, `"command"`)
|
||||
descIdx2 := strings.Index(s2, `"description"`)
|
||||
if cmdIdx2 < 0 || descIdx2 < 0 {
|
||||
t.Fatalf("block 2: missing expected key(s) in: %s", s2)
|
||||
}
|
||||
if !(cmdIdx2 < descIdx2) {
|
||||
t.Errorf("block 2: key order not preserved, expected command < description in: %s", s2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostChatResponse_MultipleTextBlocksWithThinking(t *testing.T) {
|
||||
thinkingText := "Let me reason step by step about this problem."
|
||||
textBlock1 := "The answer is 42."
|
||||
textBlock2 := "Here is why that is the case."
|
||||
signature := "sig_abc123"
|
||||
|
||||
response := &AnthropicMessageResponse{
|
||||
ID: "msg_test123",
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: "claude-opus-4-6-20250514",
|
||||
Content: []AnthropicContentBlock{
|
||||
{
|
||||
Type: AnthropicContentBlockTypeThinking,
|
||||
Thinking: &thinkingText,
|
||||
Signature: &signature,
|
||||
},
|
||||
{
|
||||
Type: AnthropicContentBlockTypeText,
|
||||
Text: &textBlock1,
|
||||
},
|
||||
{
|
||||
Type: AnthropicContentBlockTypeText,
|
||||
Text: &textBlock2,
|
||||
},
|
||||
},
|
||||
StopReason: "end_turn",
|
||||
Usage: &AnthropicUsage{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result := response.ToBifrostChatResponse(ctx)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
// With multiple text blocks, ToBifrostChatResponse preserves them as ContentBlocks
|
||||
// (only a single text block collapses to ContentStr — see chat.go:812-815).
|
||||
// Thinking flows through ReasoningDetails below, not ContentStr.
|
||||
choice := result.Choices[0]
|
||||
msg := choice.ChatNonStreamResponseChoice.Message
|
||||
if msg.Content.ContentStr != nil {
|
||||
t.Errorf("expected ContentStr to be nil with multiple text blocks, got %q", *msg.Content.ContentStr)
|
||||
}
|
||||
if len(msg.Content.ContentBlocks) != 2 {
|
||||
t.Fatalf("expected 2 content blocks (one per text block), got %d", len(msg.Content.ContentBlocks))
|
||||
}
|
||||
if msg.Content.ContentBlocks[0].Text == nil || *msg.Content.ContentBlocks[0].Text != textBlock1 {
|
||||
t.Errorf("block 0 text mismatch: got %v, want %q", msg.Content.ContentBlocks[0].Text, textBlock1)
|
||||
}
|
||||
if msg.Content.ContentBlocks[1].Text == nil || *msg.Content.ContentBlocks[1].Text != textBlock2 {
|
||||
t.Errorf("block 1 text mismatch: got %v, want %q", msg.Content.ContentBlocks[1].Text, textBlock2)
|
||||
}
|
||||
|
||||
// Thinking is surfaced via ReasoningDetails with the signature preserved
|
||||
// (see chat.go:798-807).
|
||||
if msg.ChatAssistantMessage == nil {
|
||||
t.Fatal("expected ChatAssistantMessage to be non-nil")
|
||||
}
|
||||
rd := msg.ChatAssistantMessage.ReasoningDetails
|
||||
if len(rd) != 1 {
|
||||
t.Fatalf("expected 1 reasoning details entry (the thinking block), got %d", len(rd))
|
||||
}
|
||||
if rd[0].Type != schemas.BifrostReasoningDetailsTypeText {
|
||||
t.Errorf("expected reasoning detail type %s, got %s", schemas.BifrostReasoningDetailsTypeText, rd[0].Type)
|
||||
}
|
||||
if rd[0].Signature == nil || *rd[0].Signature != signature {
|
||||
t.Error("expected thinking signature to be preserved on reasoning detail")
|
||||
}
|
||||
if rd[0].Text == nil || *rd[0].Text != thinkingText {
|
||||
t.Errorf("expected reasoning text to match thinking text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostChatResponse_SingleTextBlockNoThinking(t *testing.T) {
|
||||
// Verify existing behavior: single text block without thinking collapses to string
|
||||
text := "Simple response"
|
||||
response := &AnthropicMessageResponse{
|
||||
ID: "msg_simple",
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: "claude-sonnet-4-6-20250514",
|
||||
Content: []AnthropicContentBlock{
|
||||
{Type: AnthropicContentBlockTypeText, Text: &text},
|
||||
},
|
||||
StopReason: "end_turn",
|
||||
Usage: &AnthropicUsage{InputTokens: 10, OutputTokens: 5},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result := response.ToBifrostChatResponse(ctx)
|
||||
|
||||
msg := result.Choices[0].ChatNonStreamResponseChoice.Message
|
||||
if msg.Content.ContentStr == nil || *msg.Content.ContentStr != text {
|
||||
t.Error("expected ContentStr to be the text")
|
||||
}
|
||||
if msg.Content.ContentBlocks != nil {
|
||||
t.Error("expected ContentBlocks to be nil")
|
||||
}
|
||||
// No reasoning details for plain text
|
||||
if msg.ChatAssistantMessage != nil && len(msg.ChatAssistantMessage.ReasoningDetails) > 0 {
|
||||
t.Error("expected no reasoning details for single text block without thinking")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAnthropicChatRequest_BoundaryMismatchFallback(t *testing.T) {
|
||||
// If content was modified by the client, boundaries won't match — fall back to single text block
|
||||
signature := "sig_fallback"
|
||||
modifiedContent := "The user edited this content"
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: "claude-opus-4-6-20250514",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("Hi")},
|
||||
},
|
||||
{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: &modifiedContent},
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ReasoningDetails: []schemas.ChatReasoningDetails{
|
||||
{Index: 0, Type: schemas.BifrostReasoningDetailsTypeText, Text: &modifiedContent, Signature: &signature},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("Continue")},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var assistantMsg *AnthropicMessage
|
||||
for i := range result.Messages {
|
||||
if result.Messages[i].Role == "assistant" {
|
||||
assistantMsg = &result.Messages[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if assistantMsg == nil {
|
||||
t.Fatal("expected assistant message")
|
||||
}
|
||||
|
||||
// Should have thinking block (from reasoning_details with signature) + single text fallback
|
||||
blocks := assistantMsg.Content.ContentBlocks
|
||||
// First block: thinking (from reasoning_details, text is nil since it was cleared)
|
||||
// Plus: fallback single text block with the full modified content
|
||||
foundText := false
|
||||
for _, block := range blocks {
|
||||
if block.Type == AnthropicContentBlockTypeText {
|
||||
if block.Text != nil && *block.Text == modifiedContent {
|
||||
foundText = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundText {
|
||||
t.Error("expected fallback to single text block with full content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAnthropicChatRequest_NormalFlowUnchanged(t *testing.T) {
|
||||
// Verify that the normal multi-turn flow (reasoning_details with text + signature,
|
||||
// no bifrost.content_blocks) produces the same output as before.
|
||||
thinkingText := "I need to think about this carefully"
|
||||
signature := "sig_normal"
|
||||
responseText := "Here is my answer"
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: "claude-opus-4-6-20250514",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("What is 2+2?")},
|
||||
},
|
||||
{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: &responseText},
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ReasoningDetails: []schemas.ChatReasoningDetails{
|
||||
{
|
||||
Index: 0,
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: &thinkingText,
|
||||
Signature: &signature,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("Are you sure?")},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var assistantMsg *AnthropicMessage
|
||||
for i := range result.Messages {
|
||||
if result.Messages[i].Role == "assistant" {
|
||||
assistantMsg = &result.Messages[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if assistantMsg == nil {
|
||||
t.Fatal("expected assistant message")
|
||||
}
|
||||
|
||||
blocks := assistantMsg.Content.ContentBlocks
|
||||
if len(blocks) != 2 {
|
||||
t.Fatalf("expected 2 content blocks (thinking + text), got %d", len(blocks))
|
||||
}
|
||||
|
||||
// Block 0: thinking with original text and signature
|
||||
if blocks[0].Type != AnthropicContentBlockTypeThinking {
|
||||
t.Errorf("block 0: expected thinking, got %s", blocks[0].Type)
|
||||
}
|
||||
if blocks[0].Thinking == nil || *blocks[0].Thinking != thinkingText {
|
||||
t.Errorf("block 0: expected thinking text %q, got %v", thinkingText, blocks[0].Thinking)
|
||||
}
|
||||
if blocks[0].Signature == nil || *blocks[0].Signature != signature {
|
||||
t.Errorf("block 0: expected signature %q, got %v", signature, blocks[0].Signature)
|
||||
}
|
||||
|
||||
// Block 1: text with response
|
||||
if blocks[1].Type != AnthropicContentBlockTypeText {
|
||||
t.Errorf("block 1: expected text, got %s", blocks[1].Type)
|
||||
}
|
||||
if blocks[1].Text == nil || *blocks[1].Text != responseText {
|
||||
t.Errorf("block 1: expected text %q, got %v", responseText, blocks[1].Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAnthropicChatRequest_Opus47_StripsTemperatureTopPTopK(t *testing.T) {
|
||||
temp := 0.7
|
||||
topP := 0.9
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: "claude-opus-4-7-20260401",
|
||||
Input: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hi")}},
|
||||
},
|
||||
Params: &schemas.ChatParameters{
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
ExtraParams: map[string]interface{}{"top_k": 40},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Temperature != nil {
|
||||
t.Errorf("expected Temperature to be nil for Opus 4.7, got %v", result.Temperature)
|
||||
}
|
||||
if result.TopP != nil {
|
||||
t.Errorf("expected TopP to be nil for Opus 4.7, got %v", result.TopP)
|
||||
}
|
||||
if result.TopK != nil {
|
||||
t.Errorf("expected TopK to be nil for Opus 4.7, got %v", result.TopK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAnthropicChatRequest_NonOpus47_PreservesTemperature(t *testing.T) {
|
||||
temp := 0.7
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: "claude-opus-4-6-20250514",
|
||||
Input: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("hi")}},
|
||||
},
|
||||
Params: &schemas.ChatParameters{
|
||||
Temperature: &temp,
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Temperature == nil || *result.Temperature != temp {
|
||||
t.Errorf("expected Temperature %v, got %v", temp, result.Temperature)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAnthropicChatRequest_Opus47_ReasoningMaxTokens_AdaptiveOnly(t *testing.T) {
|
||||
maxTok := 2048
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: "claude-opus-4-7-20260401",
|
||||
Input: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("think")}},
|
||||
},
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: schemas.Ptr(8192),
|
||||
Reasoning: &schemas.ChatReasoning{MaxTokens: &maxTok},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Thinking == nil {
|
||||
t.Fatal("expected Thinking to be set")
|
||||
}
|
||||
if result.Thinking.Type != "adaptive" {
|
||||
t.Errorf("expected thinking type 'adaptive' for Opus 4.7, got %q", result.Thinking.Type)
|
||||
}
|
||||
if result.Thinking.BudgetTokens != nil {
|
||||
t.Errorf("expected BudgetTokens to be nil for Opus 4.7, got %v", result.Thinking.BudgetTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAnthropicChatRequest_NonOpus47_ReasoningMaxTokens_EnabledWithBudget(t *testing.T) {
|
||||
maxTok := 2048
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: "claude-opus-4-6-20250514",
|
||||
Input: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("think")}},
|
||||
},
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: schemas.Ptr(8192),
|
||||
Reasoning: &schemas.ChatReasoning{MaxTokens: &maxTok},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Thinking == nil {
|
||||
t.Fatal("expected Thinking to be set")
|
||||
}
|
||||
if result.Thinking.Type != "enabled" {
|
||||
t.Errorf("expected thinking type 'enabled' for Opus 4.6, got %q", result.Thinking.Type)
|
||||
}
|
||||
if result.Thinking.BudgetTokens == nil || *result.Thinking.BudgetTokens != maxTok {
|
||||
t.Errorf("expected BudgetTokens %d, got %v", maxTok, result.Thinking.BudgetTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAnthropicChatRequest_Opus47_ReasoningEffort_AdaptiveWithEffort(t *testing.T) {
|
||||
effort := "high"
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Anthropic,
|
||||
Model: "claude-opus-4-7-20260401",
|
||||
Input: []schemas.ChatMessage{
|
||||
{Role: schemas.ChatMessageRoleUser, Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("think")}},
|
||||
},
|
||||
Params: &schemas.ChatParameters{
|
||||
MaxCompletionTokens: schemas.Ptr(8192),
|
||||
Reasoning: &schemas.ChatReasoning{Effort: &effort},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result, err := ToAnthropicChatRequest(ctx, bifrostReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result.Thinking == nil {
|
||||
t.Fatal("expected Thinking to be set")
|
||||
}
|
||||
if result.Thinking.Type != "adaptive" {
|
||||
t.Errorf("expected thinking type 'adaptive' for Opus 4.7 effort-based, got %q", result.Thinking.Type)
|
||||
}
|
||||
if result.OutputConfig == nil || result.OutputConfig.Effort == nil {
|
||||
t.Error("expected OutputConfig.Effort to be set for Opus 4.7 effort-based reasoning")
|
||||
}
|
||||
}
|
||||
703
core/providers/anthropic/compaction_test.go
Normal file
703
core/providers/anthropic/compaction_test.go
Normal file
@@ -0,0 +1,703 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// --- isCompactionItem tests ---
|
||||
|
||||
func TestIsCompactionItem(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
item *schemas.ResponsesMessage
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil item",
|
||||
item: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nil type",
|
||||
item: &schemas.ResponsesMessage{
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{Type: schemas.ResponsesOutputMessageContentTypeCompaction},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "message type with compaction content block",
|
||||
item: &schemas.ResponsesMessage{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{
|
||||
Type: schemas.ResponsesOutputMessageContentTypeCompaction,
|
||||
ResponsesOutputMessageContentCompaction: &schemas.ResponsesOutputMessageContentCompaction{
|
||||
Summary: "Summary of conversation",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "message type with text content block",
|
||||
item: &schemas.ResponsesMessage{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{
|
||||
Type: schemas.ResponsesOutputMessageContentTypeText,
|
||||
Text: schemas.Ptr("Hello"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "function call type",
|
||||
item: &schemas.ResponsesMessage{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "message type with nil content",
|
||||
item: &schemas.ResponsesMessage{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Content: nil,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "message type with empty content blocks",
|
||||
item: &schemas.ResponsesMessage{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: []schemas.ResponsesMessageContentBlock{},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isCompactionItem(tt.item)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isCompactionItem() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Streaming: Anthropic → Bifrost (inbound) ---
|
||||
|
||||
func TestToBifrostResponsesStream_CompactionContentBlockStart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
state := &AnthropicResponsesStreamState{
|
||||
ContentIndexToOutputIndex: make(map[int]int),
|
||||
ContentIndexToBlockType: make(map[int]AnthropicContentBlockType),
|
||||
ToolArgumentBuffers: make(map[int]string),
|
||||
MCPCallOutputIndices: make(map[int]bool),
|
||||
ItemIDs: make(map[int]string),
|
||||
OutputItems: make(map[int]*schemas.ResponsesMessage),
|
||||
ReasoningSignatures: make(map[int]string),
|
||||
TextContentIndices: make(map[int]bool),
|
||||
ReasoningContentIndices: make(map[int]bool),
|
||||
CompactionContentIndices: make(map[int]*schemas.CacheControl),
|
||||
CurrentOutputIndex: 0,
|
||||
CreatedAt: 1234567890,
|
||||
HasEmittedCreated: true,
|
||||
HasEmittedInProgress: true,
|
||||
}
|
||||
|
||||
// content_block_start with compaction type should return nil (defers to delta)
|
||||
chunk := &AnthropicStreamEvent{
|
||||
Type: AnthropicStreamEventTypeContentBlockStart,
|
||||
Index: schemas.Ptr(0),
|
||||
ContentBlock: &AnthropicContentBlock{
|
||||
Type: AnthropicContentBlockTypeCompaction,
|
||||
CacheControl: &schemas.CacheControl{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
responses, err, isLast := chunk.ToBifrostResponsesStream(context.Background(), 0, state)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if isLast {
|
||||
t.Error("should not be last chunk")
|
||||
}
|
||||
if len(responses) != 0 {
|
||||
t.Errorf("expected 0 responses for compaction content_block_start, got %d", len(responses))
|
||||
}
|
||||
|
||||
// Verify state was tracked
|
||||
if _, exists := state.CompactionContentIndices[0]; !exists {
|
||||
t.Error("expected compaction to be tracked in CompactionContentIndices")
|
||||
}
|
||||
if blockType, exists := state.ContentIndexToBlockType[0]; !exists || blockType != AnthropicContentBlockTypeCompaction {
|
||||
t.Error("expected compaction block type tracked in ContentIndexToBlockType")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostResponsesStream_CompactionDelta(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
state := &AnthropicResponsesStreamState{
|
||||
ContentIndexToOutputIndex: map[int]int{0: 0},
|
||||
ContentIndexToBlockType: map[int]AnthropicContentBlockType{0: AnthropicContentBlockTypeCompaction},
|
||||
ToolArgumentBuffers: make(map[int]string),
|
||||
MCPCallOutputIndices: make(map[int]bool),
|
||||
ItemIDs: map[int]string{0: "cmp_0"},
|
||||
OutputItems: make(map[int]*schemas.ResponsesMessage),
|
||||
ReasoningSignatures: make(map[int]string),
|
||||
TextContentIndices: make(map[int]bool),
|
||||
ReasoningContentIndices: make(map[int]bool),
|
||||
CompactionContentIndices: map[int]*schemas.CacheControl{0: {Type: "ephemeral"}},
|
||||
CurrentOutputIndex: 1,
|
||||
CreatedAt: 1234567890,
|
||||
HasEmittedCreated: true,
|
||||
HasEmittedInProgress: true,
|
||||
}
|
||||
|
||||
summary := "The user asked about building a website. We discussed HTML, CSS, and JavaScript."
|
||||
chunk := &AnthropicStreamEvent{
|
||||
Type: AnthropicStreamEventTypeContentBlockDelta,
|
||||
Index: schemas.Ptr(0),
|
||||
Delta: &AnthropicStreamDelta{
|
||||
Type: AnthropicStreamDeltaTypeCompaction,
|
||||
Content: &summary,
|
||||
},
|
||||
}
|
||||
|
||||
responses, err, isLast := chunk.ToBifrostResponsesStream(context.Background(), 0, state)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if isLast {
|
||||
t.Error("should not be last chunk")
|
||||
}
|
||||
|
||||
// Should emit output_item.added and output_item.done
|
||||
if len(responses) != 2 {
|
||||
t.Fatalf("expected 2 responses for compaction delta, got %d", len(responses))
|
||||
}
|
||||
|
||||
// First: output_item.added
|
||||
added := responses[0]
|
||||
if added.Type != schemas.ResponsesStreamResponseTypeOutputItemAdded {
|
||||
t.Errorf("first response type = %v, want %v", added.Type, schemas.ResponsesStreamResponseTypeOutputItemAdded)
|
||||
}
|
||||
if added.Item == nil || added.Item.Content == nil || len(added.Item.Content.ContentBlocks) == 0 {
|
||||
t.Fatal("output_item.added should have content blocks")
|
||||
}
|
||||
block := added.Item.Content.ContentBlocks[0]
|
||||
if block.Type != schemas.ResponsesOutputMessageContentTypeCompaction {
|
||||
t.Errorf("content block type = %v, want compaction", block.Type)
|
||||
}
|
||||
if block.ResponsesOutputMessageContentCompaction == nil {
|
||||
t.Fatal("expected compaction content to be non-nil")
|
||||
}
|
||||
if block.ResponsesOutputMessageContentCompaction.Summary != summary {
|
||||
t.Errorf("summary = %q, want %q", block.ResponsesOutputMessageContentCompaction.Summary, summary)
|
||||
}
|
||||
// Cache control should be preserved from content_block_start
|
||||
if block.CacheControl == nil || block.CacheControl.Type != "ephemeral" {
|
||||
t.Error("expected cache control to be preserved")
|
||||
}
|
||||
|
||||
// Second: output_item.done
|
||||
done := responses[1]
|
||||
if done.Type != schemas.ResponsesStreamResponseTypeOutputItemDone {
|
||||
t.Errorf("second response type = %v, want %v", done.Type, schemas.ResponsesStreamResponseTypeOutputItemDone)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostResponsesStream_CompactionContentBlockStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
state := &AnthropicResponsesStreamState{
|
||||
ContentIndexToOutputIndex: map[int]int{0: 0},
|
||||
ContentIndexToBlockType: map[int]AnthropicContentBlockType{0: AnthropicContentBlockTypeCompaction},
|
||||
ToolArgumentBuffers: make(map[int]string),
|
||||
MCPCallOutputIndices: make(map[int]bool),
|
||||
ItemIDs: map[int]string{0: "cmp_0"},
|
||||
OutputItems: make(map[int]*schemas.ResponsesMessage),
|
||||
ReasoningSignatures: make(map[int]string),
|
||||
TextContentIndices: make(map[int]bool),
|
||||
ReasoningContentIndices: make(map[int]bool),
|
||||
CompactionContentIndices: make(map[int]*schemas.CacheControl),
|
||||
CurrentOutputIndex: 1,
|
||||
CreatedAt: 1234567890,
|
||||
HasEmittedCreated: true,
|
||||
HasEmittedInProgress: true,
|
||||
}
|
||||
|
||||
chunk := &AnthropicStreamEvent{
|
||||
Type: AnthropicStreamEventTypeContentBlockStop,
|
||||
Index: schemas.Ptr(0),
|
||||
}
|
||||
|
||||
responses, err, isLast := chunk.ToBifrostResponsesStream(context.Background(), 0, state)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if isLast {
|
||||
t.Error("should not be last chunk")
|
||||
}
|
||||
// content_block_stop for compaction should return nil (done was already emitted with delta)
|
||||
if len(responses) != 0 {
|
||||
t.Errorf("expected 0 responses for compaction content_block_stop, got %d", len(responses))
|
||||
}
|
||||
}
|
||||
|
||||
// --- Streaming: Bifrost → Anthropic (outbound, non-passthrough) ---
|
||||
|
||||
func TestToAnthropicResponsesStreamResponse_CompactionOutputItemAdded(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
summary := "Summary of the conversation about building a website"
|
||||
bifrostResp := &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
|
||||
OutputIndex: schemas.Ptr(0),
|
||||
Item: &schemas.ResponsesMessage{
|
||||
ID: schemas.Ptr("cmp_test123"),
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Status: schemas.Ptr("completed"),
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{
|
||||
Type: schemas.ResponsesOutputMessageContentTypeCompaction,
|
||||
ResponsesOutputMessageContentCompaction: &schemas.ResponsesOutputMessageContentCompaction{
|
||||
Summary: summary,
|
||||
},
|
||||
CacheControl: &schemas.CacheControl{Type: "ephemeral"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
|
||||
|
||||
// Should emit: content_block_start (compaction) + content_block_delta (compaction_delta)
|
||||
if len(events) < 2 {
|
||||
t.Fatalf("expected at least 2 events, got %d", len(events))
|
||||
}
|
||||
|
||||
// Event 1: content_block_start
|
||||
start := events[0]
|
||||
if start.Type != AnthropicStreamEventTypeContentBlockStart {
|
||||
t.Errorf("event[0] type = %v, want content_block_start", start.Type)
|
||||
}
|
||||
if start.ContentBlock == nil {
|
||||
t.Fatal("content_block_start should have ContentBlock")
|
||||
}
|
||||
if start.ContentBlock.Type != AnthropicContentBlockTypeCompaction {
|
||||
t.Errorf("ContentBlock.Type = %v, want compaction", start.ContentBlock.Type)
|
||||
}
|
||||
if start.ContentBlock.CacheControl == nil || start.ContentBlock.CacheControl.Type != "ephemeral" {
|
||||
t.Error("expected cache control to be preserved on content_block_start")
|
||||
}
|
||||
|
||||
// Event 2: content_block_delta with compaction_delta
|
||||
delta := events[1]
|
||||
if delta.Type != AnthropicStreamEventTypeContentBlockDelta {
|
||||
t.Errorf("event[1] type = %v, want content_block_delta", delta.Type)
|
||||
}
|
||||
if delta.Delta == nil {
|
||||
t.Fatal("content_block_delta should have Delta")
|
||||
}
|
||||
if delta.Delta.Type != AnthropicStreamDeltaTypeCompaction {
|
||||
t.Errorf("Delta.Type = %v, want compaction_delta", delta.Delta.Type)
|
||||
}
|
||||
if delta.Delta.Content == nil || *delta.Delta.Content != summary {
|
||||
t.Errorf("Delta.Content = %v, want %q", delta.Delta.Content, summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAnthropicResponsesStreamResponse_CompactionOutputItemDone(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
bifrostResp := &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeOutputItemDone,
|
||||
OutputIndex: schemas.Ptr(0),
|
||||
ItemID: schemas.Ptr("cmp_test123"),
|
||||
Item: &schemas.ResponsesMessage{
|
||||
ID: schemas.Ptr("cmp_test123"),
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Status: schemas.Ptr("completed"),
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{
|
||||
Type: schemas.ResponsesOutputMessageContentTypeCompaction,
|
||||
ResponsesOutputMessageContentCompaction: &schemas.ResponsesOutputMessageContentCompaction{
|
||||
Summary: "Summary text",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
|
||||
|
||||
// Should emit content_block_stop
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("expected 1 event for output_item.done, got %d", len(events))
|
||||
}
|
||||
|
||||
stop := events[0]
|
||||
if stop.Type != AnthropicStreamEventTypeContentBlockStop {
|
||||
t.Errorf("event type = %v, want content_block_stop", stop.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAnthropicResponsesStreamResponse_TextOutputItemAdded_NotAffectedByCompactionCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
// Regular text message should still emit content_block_start with type=text
|
||||
bifrostResp := &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
|
||||
OutputIndex: schemas.Ptr(0),
|
||||
Item: &schemas.ResponsesMessage{
|
||||
ID: schemas.Ptr("msg_test123"),
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage),
|
||||
Status: schemas.Ptr("in_progress"),
|
||||
Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant),
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentBlocks: []schemas.ResponsesMessageContentBlock{
|
||||
{
|
||||
Type: schemas.ResponsesOutputMessageContentTypeText,
|
||||
Text: schemas.Ptr(""),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
|
||||
if len(events) == 0 {
|
||||
t.Fatal("expected at least 1 event")
|
||||
}
|
||||
|
||||
start := events[0]
|
||||
if start.Type != AnthropicStreamEventTypeContentBlockStart {
|
||||
t.Errorf("event type = %v, want content_block_start", start.Type)
|
||||
}
|
||||
if start.ContentBlock == nil {
|
||||
t.Fatal("expected ContentBlock to be non-nil")
|
||||
}
|
||||
if start.ContentBlock.Type != AnthropicContentBlockTypeText {
|
||||
t.Errorf("ContentBlock.Type = %v, want text", start.ContentBlock.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Non-Streaming: stop_reason mapping ---
|
||||
|
||||
func TestToBifrostResponsesResponse_PreservesStopReason(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
stopReason AnthropicStopReason
|
||||
expectedStopReason string
|
||||
}{
|
||||
{
|
||||
name: "compaction stop reason",
|
||||
stopReason: AnthropicStopReasonCompaction,
|
||||
expectedStopReason: "compaction",
|
||||
},
|
||||
{
|
||||
name: "end_turn stop reason",
|
||||
stopReason: AnthropicStopReasonEndTurn,
|
||||
expectedStopReason: "end_turn",
|
||||
},
|
||||
{
|
||||
name: "tool_use stop reason",
|
||||
stopReason: AnthropicStopReasonToolUse,
|
||||
expectedStopReason: "tool_use",
|
||||
},
|
||||
{
|
||||
name: "max_tokens stop reason",
|
||||
stopReason: AnthropicStopReasonMaxTokens,
|
||||
expectedStopReason: "max_tokens",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
resp := &AnthropicMessageResponse{
|
||||
ID: "msg_test",
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: "claude-sonnet-4-6",
|
||||
StopReason: tt.stopReason,
|
||||
Content: []AnthropicContentBlock{
|
||||
{Type: AnthropicContentBlockTypeText, Text: schemas.Ptr("Hello")},
|
||||
},
|
||||
}
|
||||
|
||||
bifrostResp := resp.ToBifrostResponsesResponse(ctx)
|
||||
|
||||
if bifrostResp.StopReason == nil {
|
||||
t.Fatal("expected StopReason to be non-nil")
|
||||
}
|
||||
if *bifrostResp.StopReason != tt.expectedStopReason {
|
||||
t.Errorf("StopReason = %q, want %q", *bifrostResp.StopReason, tt.expectedStopReason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToBifrostResponsesResponse_EmptyStopReason(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
resp := &AnthropicMessageResponse{
|
||||
ID: "msg_test",
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: "claude-sonnet-4-6",
|
||||
Content: []AnthropicContentBlock{},
|
||||
}
|
||||
|
||||
bifrostResp := resp.ToBifrostResponsesResponse(ctx)
|
||||
|
||||
if bifrostResp.StopReason != nil {
|
||||
t.Errorf("expected nil StopReason for empty stop_reason, got %q", *bifrostResp.StopReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAnthropicResponsesResponse_StopReasonFromBifrost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
stopReason *string
|
||||
contentBlocks []schemas.ResponsesMessage
|
||||
expectedReason AnthropicStopReason
|
||||
}{
|
||||
{
|
||||
name: "compaction stop reason from bifrost",
|
||||
stopReason: schemas.Ptr("compaction"),
|
||||
expectedReason: AnthropicStopReasonCompaction,
|
||||
},
|
||||
{
|
||||
name: "end_turn mapped from stop",
|
||||
stopReason: schemas.Ptr("stop"),
|
||||
expectedReason: AnthropicStopReasonEndTurn,
|
||||
},
|
||||
{
|
||||
name: "tool_use mapped from tool_calls",
|
||||
stopReason: schemas.Ptr("tool_calls"),
|
||||
expectedReason: AnthropicStopReasonToolUse,
|
||||
},
|
||||
{
|
||||
name: "nil stop_reason defaults to end_turn",
|
||||
stopReason: nil,
|
||||
expectedReason: AnthropicStopReasonEndTurn,
|
||||
},
|
||||
{
|
||||
name: "nil stop_reason with tool_use content defaults to tool_use",
|
||||
stopReason: nil,
|
||||
contentBlocks: []schemas.ResponsesMessage{
|
||||
{
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr("call_123"),
|
||||
Name: schemas.Ptr("my_tool"),
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedReason: AnthropicStopReasonToolUse,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
bifrostResp := &schemas.BifrostResponsesResponse{
|
||||
ID: schemas.Ptr("resp_test"),
|
||||
Model: "claude-sonnet-4-6",
|
||||
StopReason: tt.stopReason,
|
||||
Output: tt.contentBlocks,
|
||||
}
|
||||
|
||||
result := ToAnthropicResponsesResponse(ctx, bifrostResp)
|
||||
|
||||
if result.StopReason != tt.expectedReason {
|
||||
t.Errorf("StopReason = %v, want %v", result.StopReason, tt.expectedReason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Non-Streaming: compaction content block round-trip ---
|
||||
|
||||
func TestCompactionContentBlock_NonStreamingRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
summary := "The user requested help building a web scraper using Python with BeautifulSoup."
|
||||
|
||||
// Simulate Anthropic response with compaction block
|
||||
anthropicResp := &AnthropicMessageResponse{
|
||||
ID: "msg_compaction_test",
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: "claude-opus-4-6",
|
||||
StopReason: AnthropicStopReasonCompaction,
|
||||
Content: []AnthropicContentBlock{
|
||||
{
|
||||
Type: AnthropicContentBlockTypeCompaction,
|
||||
Content: &AnthropicContent{
|
||||
ContentStr: &summary,
|
||||
},
|
||||
CacheControl: &schemas.CacheControl{Type: "ephemeral"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Step 1: Anthropic → Bifrost
|
||||
bifrostResp := anthropicResp.ToBifrostResponsesResponse(ctx)
|
||||
|
||||
if bifrostResp.StopReason == nil || *bifrostResp.StopReason != "compaction" {
|
||||
t.Fatalf("expected stop_reason='compaction', got %v", bifrostResp.StopReason)
|
||||
}
|
||||
if len(bifrostResp.Output) == 0 {
|
||||
t.Fatal("expected at least one output message")
|
||||
}
|
||||
|
||||
// Find the compaction block
|
||||
var foundCompaction bool
|
||||
for _, msg := range bifrostResp.Output {
|
||||
if msg.Content != nil {
|
||||
for _, block := range msg.Content.ContentBlocks {
|
||||
if block.Type == schemas.ResponsesOutputMessageContentTypeCompaction {
|
||||
foundCompaction = true
|
||||
if block.ResponsesOutputMessageContentCompaction == nil {
|
||||
t.Fatal("expected compaction content to be non-nil")
|
||||
}
|
||||
if block.ResponsesOutputMessageContentCompaction.Summary != summary {
|
||||
t.Errorf("summary = %q, want %q", block.ResponsesOutputMessageContentCompaction.Summary, summary)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundCompaction {
|
||||
t.Error("compaction block not found in Bifrost output")
|
||||
}
|
||||
|
||||
// Step 2: Bifrost → Anthropic
|
||||
result := ToAnthropicResponsesResponse(ctx, bifrostResp)
|
||||
|
||||
if result.StopReason != AnthropicStopReasonCompaction {
|
||||
t.Errorf("result StopReason = %v, want compaction", result.StopReason)
|
||||
}
|
||||
|
||||
// Find compaction content block in result
|
||||
var foundResultCompaction bool
|
||||
for _, block := range result.Content {
|
||||
if block.Type == AnthropicContentBlockTypeCompaction {
|
||||
foundResultCompaction = true
|
||||
if block.Content == nil || block.Content.ContentStr == nil {
|
||||
t.Fatal("expected compaction content string")
|
||||
}
|
||||
if *block.Content.ContentStr != summary {
|
||||
t.Errorf("result summary = %q, want %q", *block.Content.ContentStr, summary)
|
||||
}
|
||||
if block.CacheControl == nil || block.CacheControl.Type != "ephemeral" {
|
||||
t.Error("expected cache control to be preserved")
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundResultCompaction {
|
||||
t.Error("compaction block not found in Anthropic result")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Streaming: compaction stop_reason in response.completed ---
|
||||
|
||||
func TestToAnthropicResponsesStreamResponse_CompletedWithCompactionStopReason(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
bifrostResp := &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeCompleted,
|
||||
Response: &schemas.BifrostResponsesResponse{
|
||||
ID: schemas.Ptr("resp_test"),
|
||||
Model: "claude-opus-4-6",
|
||||
StopReason: schemas.Ptr("compaction"),
|
||||
Usage: &schemas.ResponsesResponseUsage{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 500,
|
||||
TotalTokens: 1500,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
|
||||
|
||||
// Should emit message_delta + message_stop
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("expected 2 events for response.completed, got %d", len(events))
|
||||
}
|
||||
|
||||
// message_delta should have stop_reason=compaction
|
||||
messageDelta := events[0]
|
||||
if messageDelta.Type != AnthropicStreamEventTypeMessageDelta {
|
||||
t.Errorf("event[0] type = %v, want message_delta", messageDelta.Type)
|
||||
}
|
||||
if messageDelta.Delta == nil || messageDelta.Delta.StopReason == nil {
|
||||
t.Fatal("expected Delta.StopReason in message_delta")
|
||||
}
|
||||
if *messageDelta.Delta.StopReason != AnthropicStopReasonCompaction {
|
||||
t.Errorf("StopReason = %v, want compaction", *messageDelta.Delta.StopReason)
|
||||
}
|
||||
|
||||
// message_stop
|
||||
messageStop := events[1]
|
||||
if messageStop.Type != AnthropicStreamEventTypeMessageStop {
|
||||
t.Errorf("event[1] type = %v, want message_stop", messageStop.Type)
|
||||
}
|
||||
}
|
||||
34
core/providers/anthropic/count_tokens.go
Normal file
34
core/providers/anthropic/count_tokens.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostCountTokensResponse converts an Anthropic count tokens response to Bifrost format
|
||||
func (resp *AnthropicCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
totalTokens := resp.InputTokens
|
||||
|
||||
bifrostResp := &schemas.BifrostCountTokensResponse{
|
||||
Model: model,
|
||||
InputTokens: resp.InputTokens,
|
||||
TotalTokens: &totalTokens,
|
||||
Object: "response.input_tokens",
|
||||
}
|
||||
|
||||
return bifrostResp
|
||||
}
|
||||
|
||||
// ToAnthropicCountTokensResponse converts a Bifrost count tokens response to Anthropic format.
|
||||
func ToAnthropicCountTokensResponse(bifrostResp *schemas.BifrostCountTokensResponse) *AnthropicCountTokensResponse {
|
||||
if bifrostResp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &AnthropicCountTokensResponse{
|
||||
InputTokens: bifrostResp.InputTokens,
|
||||
}
|
||||
}
|
||||
68
core/providers/anthropic/errors.go
Normal file
68
core/providers/anthropic/errors.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ToAnthropicChatCompletionError converts a BifrostError to AnthropicMessageError
|
||||
func ToAnthropicChatCompletionError(bifrostErr *schemas.BifrostError) *AnthropicMessageError {
|
||||
if bifrostErr == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Safely extract type and message from nested error
|
||||
errorType := "api_error"
|
||||
message := ""
|
||||
if bifrostErr.Error != nil {
|
||||
if bifrostErr.Error.Type != nil && *bifrostErr.Error.Type != "" {
|
||||
errorType = *bifrostErr.Error.Type
|
||||
}
|
||||
message = bifrostErr.Error.Message
|
||||
}
|
||||
|
||||
// Handle nested error fields with nil checks
|
||||
errorStruct := AnthropicMessageErrorStruct{
|
||||
Type: errorType,
|
||||
Message: message,
|
||||
}
|
||||
|
||||
return &AnthropicMessageError{
|
||||
Type: "error", // always "error" for Anthropic
|
||||
Error: errorStruct,
|
||||
}
|
||||
}
|
||||
|
||||
// ToAnthropicResponsesStreamError converts a BifrostError to Anthropic responses streaming error in SSE format
|
||||
func ToAnthropicResponsesStreamError(bifrostErr *schemas.BifrostError) string {
|
||||
if bifrostErr == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
anthropicErr := ToAnthropicChatCompletionError(bifrostErr)
|
||||
|
||||
// Marshal to JSON
|
||||
jsonData, err := providerUtils.MarshalSorted(anthropicErr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Format as Anthropic SSE error event
|
||||
return fmt.Sprintf("event: error\ndata: %s\n\n", jsonData)
|
||||
}
|
||||
|
||||
func parseAnthropicError(resp *fasthttp.Response) *schemas.BifrostError {
|
||||
var errorResp AnthropicError
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
|
||||
if errorResp.Error != nil {
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
bifrostErr.Error.Type = &errorResp.Error.Type
|
||||
bifrostErr.Error.Message = errorResp.Error.Message
|
||||
}
|
||||
return bifrostErr
|
||||
}
|
||||
96
core/providers/anthropic/errors_test.go
Normal file
96
core/providers/anthropic/errors_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestToAnthropicChatCompletionError(t *testing.T) {
|
||||
strPtr := func(s string) *string { return &s }
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input *schemas.BifrostError
|
||||
expectNil bool
|
||||
expectedType string
|
||||
}{
|
||||
{
|
||||
name: "nil BifrostError returns nil",
|
||||
input: nil,
|
||||
expectNil: true,
|
||||
},
|
||||
{
|
||||
name: "nil ErrorField.Type defaults to api_error",
|
||||
input: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Type: nil,
|
||||
Message: "connection failed",
|
||||
},
|
||||
},
|
||||
expectedType: "api_error",
|
||||
},
|
||||
{
|
||||
name: "empty string Type defaults to api_error",
|
||||
input: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Type: strPtr(""),
|
||||
Message: "rate limited",
|
||||
},
|
||||
},
|
||||
expectedType: "api_error",
|
||||
},
|
||||
{
|
||||
name: "valid Type is preserved",
|
||||
input: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Type: strPtr("rate_limit_error"),
|
||||
Message: "rate limited",
|
||||
},
|
||||
},
|
||||
expectedType: "rate_limit_error",
|
||||
},
|
||||
{
|
||||
name: "internal Type is preserved",
|
||||
input: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Type: strPtr("request_cancelled"),
|
||||
Message: "cancelled",
|
||||
},
|
||||
},
|
||||
expectedType: "request_cancelled",
|
||||
},
|
||||
{
|
||||
name: "nil Error field defaults to api_error",
|
||||
input: &schemas.BifrostError{
|
||||
Error: nil,
|
||||
},
|
||||
expectedType: "api_error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ToAnthropicChatCompletionError(tt.input)
|
||||
|
||||
if tt.expectNil {
|
||||
if result != nil {
|
||||
t.Fatalf("expected nil, got %+v", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
if result.Type != "error" {
|
||||
t.Errorf("expected top-level Type %q, got %q", "error", result.Type)
|
||||
}
|
||||
|
||||
if result.Error.Type != tt.expectedType {
|
||||
t.Errorf("expected error Type %q, got %q", tt.expectedType, result.Error.Type)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
71
core/providers/anthropic/files.go
Normal file
71
core/providers/anthropic/files.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToAnthropicFileUploadResponse converts a Bifrost file upload response to Anthropic format.
|
||||
func ToAnthropicFileUploadResponse(resp *schemas.BifrostFileUploadResponse) *AnthropicFileResponse {
|
||||
return &AnthropicFileResponse{
|
||||
ID: resp.ID,
|
||||
Type: resp.Object,
|
||||
Filename: resp.Filename,
|
||||
MimeType: "",
|
||||
SizeBytes: resp.Bytes,
|
||||
CreatedAt: formatAnthropicFileTimestamp(resp.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ToAnthropicFileListResponse converts a Bifrost file list response to Anthropic format.
|
||||
func ToAnthropicFileListResponse(resp *schemas.BifrostFileListResponse) *AnthropicFileListResponse {
|
||||
data := make([]AnthropicFileResponse, len(resp.Data))
|
||||
for i, file := range resp.Data {
|
||||
data[i] = AnthropicFileResponse{
|
||||
ID: file.ID,
|
||||
Type: file.Object,
|
||||
Filename: file.Filename,
|
||||
MimeType: "",
|
||||
SizeBytes: file.Bytes,
|
||||
CreatedAt: formatAnthropicFileTimestamp(file.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
return &AnthropicFileListResponse{
|
||||
Data: data,
|
||||
HasMore: resp.HasMore,
|
||||
}
|
||||
}
|
||||
|
||||
// ToAnthropicFileRetrieveResponse converts a Bifrost file retrieve response to Anthropic format.
|
||||
func ToAnthropicFileRetrieveResponse(resp *schemas.BifrostFileRetrieveResponse) *AnthropicFileResponse {
|
||||
return &AnthropicFileResponse{
|
||||
ID: resp.ID,
|
||||
Type: resp.Object,
|
||||
Filename: resp.Filename,
|
||||
MimeType: "", // Not supported in Bifrost responses
|
||||
SizeBytes: resp.Bytes,
|
||||
CreatedAt: formatAnthropicFileTimestamp(resp.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// ToAnthropicFileDeleteResponse converts a Bifrost file delete response to Anthropic format.
|
||||
func ToAnthropicFileDeleteResponse(resp *schemas.BifrostFileDeleteResponse) *AnthropicFileDeleteResponse {
|
||||
respType := "file"
|
||||
if resp.Deleted {
|
||||
respType = "file_deleted"
|
||||
}
|
||||
return &AnthropicFileDeleteResponse{
|
||||
ID: resp.ID,
|
||||
Type: respType,
|
||||
}
|
||||
}
|
||||
|
||||
// formatAnthropicFileTimestamp converts Unix timestamp to Anthropic ISO timestamp format.
|
||||
func formatAnthropicFileTimestamp(unixTime int64) string {
|
||||
if unixTime == 0 {
|
||||
return ""
|
||||
}
|
||||
return time.Unix(unixTime, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
108
core/providers/anthropic/models.go
Normal file
108
core/providers/anthropic/models.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.Data)),
|
||||
FirstID: response.FirstID,
|
||||
LastID: response.LastID,
|
||||
HasMore: schemas.Ptr(response.HasMore),
|
||||
}
|
||||
|
||||
// Map Anthropic's cursor-based pagination to Bifrost's token-based pagination.
|
||||
// If there are more results, set next_page_token to last_id for the next request.
|
||||
if response.HasMore && response.LastID != nil {
|
||||
bifrostResponse.NextPageToken = *response.LastID
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: providerKey,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.Data {
|
||||
for _, result := range pipeline.FilterModel(model.ID) {
|
||||
resolvedKey := strings.ToLower(result.ResolvedID)
|
||||
if included[resolvedKey] {
|
||||
continue
|
||||
}
|
||||
entry := schemas.Model{
|
||||
ID: string(providerKey) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(model.DisplayName),
|
||||
Created: schemas.Ptr(model.CreatedAt.Unix()),
|
||||
MaxInputTokens: model.MaxInputTokens,
|
||||
MaxOutputTokens: model.MaxTokens,
|
||||
ProviderExtra: model.Capabilities,
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
entry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, entry)
|
||||
included[resolvedKey] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
func ToAnthropicListModelsResponse(response *schemas.BifrostListModelsResponse) *AnthropicListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
anthropicResponse := &AnthropicListModelsResponse{
|
||||
Data: make([]AnthropicModel, 0, len(response.Data)),
|
||||
}
|
||||
if response.FirstID != nil {
|
||||
anthropicResponse.FirstID = response.FirstID
|
||||
}
|
||||
if response.LastID != nil {
|
||||
anthropicResponse.LastID = response.LastID
|
||||
}
|
||||
if response.HasMore != nil {
|
||||
anthropicResponse.HasMore = *response.HasMore
|
||||
}
|
||||
|
||||
for _, model := range response.Data {
|
||||
_, modelID := schemas.ParseModelString(model.ID, schemas.Anthropic)
|
||||
anthropicModel := AnthropicModel{
|
||||
ID: modelID,
|
||||
Type: "model",
|
||||
MaxInputTokens: model.MaxInputTokens,
|
||||
MaxTokens: model.MaxOutputTokens,
|
||||
Capabilities: model.ProviderExtra,
|
||||
}
|
||||
if model.Name != nil {
|
||||
anthropicModel.DisplayName = *model.Name
|
||||
}
|
||||
if model.Created != nil {
|
||||
anthropicModel.CreatedAt = time.Unix(*model.Created, 0)
|
||||
}
|
||||
anthropicResponse.Data = append(anthropicResponse.Data, anthropicModel)
|
||||
}
|
||||
|
||||
return anthropicResponse
|
||||
}
|
||||
52
core/providers/anthropic/payload_ordering_test.go
Normal file
52
core/providers/anthropic/payload_ordering_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPayloadOrdering_AnthropicMessageRequest(t *testing.T) {
|
||||
req := &AnthropicMessageRequest{
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: AnthropicContent{ContentStr: schemas.Ptr("hello")},
|
||||
},
|
||||
},
|
||||
Temperature: schemas.Ptr(0.7),
|
||||
Stream: schemas.Ptr(true),
|
||||
Tools: []AnthropicTool{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: schemas.Ptr("Get weather"),
|
||||
InputSchema: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("location", map[string]interface{}{"type": "string"}),
|
||||
),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
golden := `{"model":"claude-sonnet-4-20250514","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"temperature":0.7,"stream":true,"tools":[{"name":"get_weather","description":"Get weather","input_schema":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}]}`
|
||||
|
||||
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
|
||||
|
||||
// Determinism: 100 iterations must produce identical bytes
|
||||
for i := 0; i < 100; i++ {
|
||||
iter, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
|
||||
}
|
||||
}
|
||||
5900
core/providers/anthropic/responses.go
Normal file
5900
core/providers/anthropic/responses.go
Normal file
File diff suppressed because it is too large
Load Diff
137
core/providers/anthropic/text.go
Normal file
137
core/providers/anthropic/text.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToAnthropicTextCompletionRequest converts a Bifrost text completion request to Anthropic format
|
||||
func ToAnthropicTextCompletionRequest(bifrostReq *schemas.BifrostTextCompletionRequest) *AnthropicTextRequest {
|
||||
if bifrostReq == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
prompt := ""
|
||||
if bifrostReq.Input.PromptStr != nil {
|
||||
prompt = *bifrostReq.Input.PromptStr
|
||||
} else if len(bifrostReq.Input.PromptArray) > 0 {
|
||||
prompt = strings.Join(bifrostReq.Input.PromptArray, "\n\n")
|
||||
}
|
||||
|
||||
anthropicReq := &AnthropicTextRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Prompt: fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", prompt),
|
||||
MaxTokensToSample: providerUtils.GetMaxOutputTokensOrDefault(bifrostReq.Model, AnthropicDefaultMaxTokens),
|
||||
}
|
||||
|
||||
// Convert parameters
|
||||
if bifrostReq.Params != nil {
|
||||
if bifrostReq.Params.MaxTokens != nil {
|
||||
anthropicReq.MaxTokensToSample = *bifrostReq.Params.MaxTokens
|
||||
}
|
||||
anthropicReq.Temperature = bifrostReq.Params.Temperature
|
||||
anthropicReq.TopP = bifrostReq.Params.TopP
|
||||
anthropicReq.StopSequences = bifrostReq.Params.Stop
|
||||
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
anthropicReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok {
|
||||
delete(anthropicReq.ExtraParams, "top_k")
|
||||
anthropicReq.TopK = topK
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return anthropicReq
|
||||
}
|
||||
|
||||
// ToBifrostTextCompletionRequest converts an Anthropic text request back to Bifrost format
|
||||
func (req *AnthropicTextRequest) ToBifrostTextCompletionRequest(ctx *schemas.BifrostContext) *schemas.BifrostTextCompletionRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Anthropic))
|
||||
|
||||
bifrostReq := &schemas.BifrostTextCompletionRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: &schemas.TextCompletionInput{
|
||||
PromptStr: &req.Prompt,
|
||||
},
|
||||
Params: &schemas.TextCompletionParameters{
|
||||
MaxTokens: &req.MaxTokensToSample,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stop: req.StopSequences,
|
||||
},
|
||||
Fallbacks: schemas.ParseFallbacks(req.Fallbacks),
|
||||
}
|
||||
|
||||
// Add extra params if present
|
||||
if req.TopK != nil {
|
||||
bifrostReq.Params.ExtraParams = map[string]interface{}{
|
||||
"top_k": *req.TopK,
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToBifrostTextCompletionResponse converts an Anthropic text response back to Bifrost format
|
||||
func (response *AnthropicTextResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
return &schemas.BifrostTextCompletionResponse{
|
||||
ID: response.ID,
|
||||
Object: "text_completion",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
|
||||
Text: &response.Completion,
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: &schemas.BifrostLLMUsage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
},
|
||||
Model: response.Model,
|
||||
}
|
||||
}
|
||||
|
||||
// ToAnthropicTextCompletionResponse converts a BifrostResponse back to Anthropic text completion format
|
||||
func ToAnthropicTextCompletionResponse(bifrostResp *schemas.BifrostTextCompletionResponse) *AnthropicTextResponse {
|
||||
if bifrostResp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
anthropicResp := &AnthropicTextResponse{
|
||||
ID: bifrostResp.ID,
|
||||
Type: "completion",
|
||||
Model: bifrostResp.Model,
|
||||
}
|
||||
|
||||
// Convert choices to completion text
|
||||
if len(bifrostResp.Choices) > 0 {
|
||||
choice := bifrostResp.Choices[0] // Anthropic text API typically returns one choice
|
||||
|
||||
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
|
||||
anthropicResp.Completion = *choice.TextCompletionResponseChoice.Text
|
||||
}
|
||||
}
|
||||
|
||||
// Convert usage information
|
||||
if bifrostResp.Usage != nil {
|
||||
anthropicResp.Usage.InputTokens = bifrostResp.Usage.PromptTokens
|
||||
anthropicResp.Usage.OutputTokens = bifrostResp.Usage.CompletionTokens
|
||||
}
|
||||
|
||||
return anthropicResp
|
||||
}
|
||||
1635
core/providers/anthropic/types.go
Normal file
1635
core/providers/anthropic/types.go
Normal file
File diff suppressed because it is too large
Load Diff
2712
core/providers/anthropic/utils.go
Normal file
2712
core/providers/anthropic/utils.go
Normal file
File diff suppressed because it is too large
Load Diff
1984
core/providers/anthropic/utils_test.go
Normal file
1984
core/providers/anthropic/utils_test.go
Normal file
File diff suppressed because it is too large
Load Diff
138
core/providers/anthropic/validate_chat_tools_test.go
Normal file
138
core/providers/anthropic/validate_chat_tools_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TestValidateChatToolsForProvider locks in the partition:
|
||||
// function/custom tools always survive; server tools survive only when the
|
||||
// target provider's ProviderFeatures flag is true for that tool type.
|
||||
func TestValidateChatToolsForProvider(t *testing.T) {
|
||||
fnTool := schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{Name: "get_weather"},
|
||||
}
|
||||
serverTool := func(tpe, name string) schemas.ChatTool {
|
||||
return schemas.ChatTool{Type: schemas.ChatToolType(tpe), Name: name}
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
provider schemas.ModelProvider
|
||||
input []schemas.ChatTool
|
||||
wantKeep int
|
||||
wantDropped []string
|
||||
assertNotes string
|
||||
}{
|
||||
{
|
||||
name: "function tools always survive on any provider",
|
||||
provider: schemas.Bedrock,
|
||||
input: []schemas.ChatTool{fnTool, fnTool},
|
||||
wantKeep: 2,
|
||||
},
|
||||
{
|
||||
name: "bedrock drops web_search",
|
||||
provider: schemas.Bedrock,
|
||||
input: []schemas.ChatTool{serverTool("web_search_20260209", "web_search")},
|
||||
wantKeep: 0,
|
||||
wantDropped: []string{"web_search_20260209"},
|
||||
assertNotes: "Bedrock has WebSearch=false per Table 20 (AWS user guide beta-header list + Anthropic overview)",
|
||||
},
|
||||
{
|
||||
name: "bedrock drops web_fetch + code_execution + mcp_toolset",
|
||||
provider: schemas.Bedrock,
|
||||
input: []schemas.ChatTool{
|
||||
serverTool("web_fetch_20260309", "web_fetch"),
|
||||
serverTool("code_execution_20250825", "code_execution"),
|
||||
serverTool("mcp_toolset", "notion"),
|
||||
},
|
||||
wantKeep: 0,
|
||||
wantDropped: []string{"web_fetch_20260309", "code_execution_20250825", "mcp_toolset"},
|
||||
},
|
||||
{
|
||||
name: "bedrock keeps computer/bash/memory/text_editor/tool_search",
|
||||
provider: schemas.Bedrock,
|
||||
input: []schemas.ChatTool{
|
||||
serverTool("computer_20251124", "computer"),
|
||||
serverTool("bash_20250124", "bash"),
|
||||
serverTool("memory_20250818", "memory"),
|
||||
serverTool("text_editor_20250728", "str_replace_based_edit_tool"),
|
||||
serverTool("tool_search_tool_bm25", "tool_search_tool_bm25"),
|
||||
},
|
||||
wantKeep: 5,
|
||||
},
|
||||
{
|
||||
name: "bedrock partial drop mixes function + server tools",
|
||||
provider: schemas.Bedrock,
|
||||
input: []schemas.ChatTool{
|
||||
fnTool,
|
||||
serverTool("web_search_20260209", "web_search"),
|
||||
serverTool("bash_20250124", "bash"),
|
||||
},
|
||||
wantKeep: 2, // fnTool + bash
|
||||
wantDropped: []string{"web_search_20260209"},
|
||||
},
|
||||
{
|
||||
name: "vertex drops web_fetch",
|
||||
provider: schemas.Vertex,
|
||||
input: []schemas.ChatTool{serverTool("web_fetch_20260309", "web_fetch")},
|
||||
wantKeep: 0,
|
||||
wantDropped: []string{"web_fetch_20260309"},
|
||||
assertNotes: "Vertex has WebFetch=false per Table 20",
|
||||
},
|
||||
{
|
||||
name: "vertex drops mcp_toolset",
|
||||
provider: schemas.Vertex,
|
||||
input: []schemas.ChatTool{serverTool("mcp_toolset", "notion")},
|
||||
wantKeep: 0,
|
||||
wantDropped: []string{"mcp_toolset"},
|
||||
assertNotes: "Vertex has MCP=false per MCP-excl (explicit exclusion in Anthropic docs)",
|
||||
},
|
||||
{
|
||||
name: "anthropic keeps everything",
|
||||
provider: schemas.Anthropic,
|
||||
input: []schemas.ChatTool{
|
||||
serverTool("web_search_20260209", "web_search"),
|
||||
serverTool("web_fetch_20260309", "web_fetch"),
|
||||
serverTool("code_execution_20250825", "code_execution"),
|
||||
serverTool("mcp_toolset", "x"),
|
||||
serverTool("computer_20251124", "computer"),
|
||||
},
|
||||
wantKeep: 5,
|
||||
},
|
||||
{
|
||||
name: "unknown provider keeps everything (forward-compat)",
|
||||
provider: schemas.ModelProvider("custom-new-provider"),
|
||||
input: []schemas.ChatTool{serverTool("web_search_20260209", "web_search")},
|
||||
wantKeep: 1,
|
||||
},
|
||||
{
|
||||
name: "unknown tool type on known provider is kept (forward-compat)",
|
||||
provider: schemas.Bedrock,
|
||||
input: []schemas.ChatTool{serverTool("future_tool_20270101", "future")},
|
||||
wantKeep: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
keep, dropped := ValidateChatToolsForProvider(tc.input, tc.provider)
|
||||
if len(keep) != tc.wantKeep {
|
||||
t.Errorf("keep count: got %d, want %d (%s)", len(keep), tc.wantKeep, tc.assertNotes)
|
||||
}
|
||||
if len(dropped) != len(tc.wantDropped) {
|
||||
t.Errorf("dropped count: got %v, want %v", dropped, tc.wantDropped)
|
||||
}
|
||||
for i, d := range tc.wantDropped {
|
||||
if i >= len(dropped) {
|
||||
break
|
||||
}
|
||||
if dropped[i] != d {
|
||||
t.Errorf("dropped[%d]: got %q, want %q", i, dropped[i], d)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
270
core/providers/anthropic/websearch_test.go
Normal file
270
core/providers/anthropic/websearch_test.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TestWebSearch_OutputItemAdded_StoresID verifies that a WebSearch function_call
|
||||
// output_item.added event stores the item ID in the per-request stream state so that
|
||||
// subsequent argument deltas can be skipped.
|
||||
func TestWebSearch_OutputItemAdded_StoresID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const itemID = "toolu_ws_storesid_test"
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
bifrostResp := &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
|
||||
OutputIndex: schemas.Ptr(0),
|
||||
Item: &schemas.ResponsesMessage{
|
||||
ID: schemas.Ptr(itemID),
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr(itemID),
|
||||
Name: schemas.Ptr("WebSearch"),
|
||||
Arguments: schemas.Ptr(""),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
|
||||
|
||||
// Should emit content_block_start
|
||||
if len(events) == 0 {
|
||||
t.Fatal("expected at least one event")
|
||||
}
|
||||
if events[0].Type != AnthropicStreamEventTypeContentBlockStart {
|
||||
t.Errorf("event[0].Type = %v, want content_block_start", events[0].Type)
|
||||
}
|
||||
if events[0].ContentBlock == nil || events[0].ContentBlock.Input == nil {
|
||||
t.Fatal("expected ContentBlock with Input")
|
||||
}
|
||||
if string(events[0].ContentBlock.Input) != "{}" {
|
||||
t.Errorf("ContentBlock.Input = %s, want {}", events[0].ContentBlock.Input)
|
||||
}
|
||||
|
||||
// ID must now be tracked in per-request state
|
||||
state := getOrCreateAnthropicToResponsesStreamState(ctx)
|
||||
if !state.webSearchItemIDs[itemID] {
|
||||
t.Error("expected item ID to be stored in per-request stream state after output_item.added")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebSearch_FunctionCallArgumentsDelta_Skipped verifies that argument deltas
|
||||
// for a tracked WebSearch item are skipped (returning nil) regardless of the
|
||||
// user agent — the fix for the original bug where non-Claude Code clients lost
|
||||
// the query.
|
||||
func TestWebSearch_FunctionCallArgumentsDelta_Skipped(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const itemID = "toolu_ws_skip_test"
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
// Pre-seed per-request state as if output_item.added already fired
|
||||
state := getOrCreateAnthropicToResponsesStreamState(ctx)
|
||||
state.webSearchItemIDs = map[string]bool{itemID: true}
|
||||
|
||||
partial := `{"query": "world news"`
|
||||
bifrostResp := &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
|
||||
OutputIndex: schemas.Ptr(0),
|
||||
ItemID: schemas.Ptr(itemID),
|
||||
Delta: &partial,
|
||||
}
|
||||
|
||||
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
|
||||
|
||||
if len(events) != 0 {
|
||||
t.Errorf("expected deltas to be skipped (0 events), got %d", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebSearch_OutputItemDone_GeneratesSyntheticDeltas verifies that when
|
||||
// output_item.done fires for a tracked WebSearch item, synthetic input_json_delta
|
||||
// events carrying the full query are emitted, followed by content_block_stop.
|
||||
// This applies for ALL clients regardless of user agent.
|
||||
func TestWebSearch_OutputItemDone_GeneratesSyntheticDeltas(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const itemID = "toolu_ws_synth_test"
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
// Pre-seed per-request state as if output_item.added already fired
|
||||
state := getOrCreateAnthropicToResponsesStreamState(ctx)
|
||||
state.webSearchItemIDs = map[string]bool{itemID: true}
|
||||
|
||||
query := `{"query":"world news today"}`
|
||||
bifrostResp := &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeOutputItemDone,
|
||||
OutputIndex: schemas.Ptr(1),
|
||||
Item: &schemas.ResponsesMessage{
|
||||
ID: schemas.Ptr(itemID),
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr(itemID),
|
||||
Name: schemas.Ptr("WebSearch"),
|
||||
Arguments: &query,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
events := ToAnthropicResponsesStreamResponse(ctx, bifrostResp)
|
||||
|
||||
// Must have at least one input_json_delta and a final content_block_stop
|
||||
if len(events) < 2 {
|
||||
t.Fatalf("expected at least 2 events (deltas + stop), got %d", len(events))
|
||||
}
|
||||
|
||||
// All events except last must be input_json_delta
|
||||
for i, ev := range events[:len(events)-1] {
|
||||
if ev.Type != AnthropicStreamEventTypeContentBlockDelta {
|
||||
t.Errorf("event[%d].Type = %v, want content_block_delta", i, ev.Type)
|
||||
continue
|
||||
}
|
||||
if ev.Delta == nil || ev.Delta.Type != AnthropicStreamDeltaTypeInputJSON {
|
||||
t.Errorf("event[%d].Delta.Type = %v, want input_json", i, ev.Delta)
|
||||
}
|
||||
}
|
||||
|
||||
// Last event must be content_block_stop
|
||||
last := events[len(events)-1]
|
||||
if last.Type != AnthropicStreamEventTypeContentBlockStop {
|
||||
t.Errorf("last event.Type = %v, want content_block_stop", last.Type)
|
||||
}
|
||||
|
||||
// Reconstruct the accumulated JSON from the deltas
|
||||
var accumulated string
|
||||
for _, ev := range events[:len(events)-1] {
|
||||
if ev.Delta != nil && ev.Delta.PartialJSON != nil {
|
||||
accumulated += *ev.Delta.PartialJSON
|
||||
}
|
||||
}
|
||||
var got map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(accumulated), &got); err != nil {
|
||||
t.Fatalf("accumulated JSON invalid: %v — got %q", err, accumulated)
|
||||
}
|
||||
if got["query"] != "world news today" {
|
||||
t.Errorf("query = %v, want %q", got["query"], "world news today")
|
||||
}
|
||||
|
||||
// ID must have been cleaned up from per-request state
|
||||
if state.webSearchItemIDs[itemID] {
|
||||
t.Error("expected item ID to be removed from per-request stream state after output_item.done")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebSearch_FullFlow_AnyUserAgent is the regression test for the original bug.
|
||||
// It simulates the complete streaming sequence:
|
||||
//
|
||||
// output_item.added → FunctionCallArgumentsDelta (×N) → output_item.done
|
||||
//
|
||||
// and verifies that the client-facing Anthropic stream contains proper
|
||||
// input_json_delta events with the query, regardless of user agent.
|
||||
func TestWebSearch_FullFlow_AnyUserAgent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const itemID = "toolu_ws_fullflow_test"
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
|
||||
var allEvents []*AnthropicStreamEvent
|
||||
|
||||
// Step 1: output_item.added
|
||||
addedResp := &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeOutputItemAdded,
|
||||
OutputIndex: schemas.Ptr(0),
|
||||
Item: &schemas.ResponsesMessage{
|
||||
ID: schemas.Ptr(itemID),
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr(itemID),
|
||||
Name: schemas.Ptr("WebSearch"),
|
||||
Arguments: schemas.Ptr(""),
|
||||
},
|
||||
},
|
||||
}
|
||||
allEvents = append(allEvents, ToAnthropicResponsesStreamResponse(ctx, addedResp)...)
|
||||
|
||||
// Step 2: FunctionCallArgumentsDelta events (should be skipped)
|
||||
for _, partial := range []string{`{"query": "`, `latest AI`, `news"}`} {
|
||||
p := partial
|
||||
deltaResp := &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeFunctionCallArgumentsDelta,
|
||||
OutputIndex: schemas.Ptr(0),
|
||||
ItemID: schemas.Ptr(itemID),
|
||||
Delta: &p,
|
||||
}
|
||||
allEvents = append(allEvents, ToAnthropicResponsesStreamResponse(ctx, deltaResp)...)
|
||||
}
|
||||
|
||||
// Step 3: output_item.done with full accumulated arguments
|
||||
fullArgs := `{"query":"latest AI news"}`
|
||||
doneResp := &schemas.BifrostResponsesStreamResponse{
|
||||
Type: schemas.ResponsesStreamResponseTypeOutputItemDone,
|
||||
OutputIndex: schemas.Ptr(0),
|
||||
Item: &schemas.ResponsesMessage{
|
||||
ID: schemas.Ptr(itemID),
|
||||
Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall),
|
||||
ResponsesToolMessage: &schemas.ResponsesToolMessage{
|
||||
CallID: schemas.Ptr(itemID),
|
||||
Name: schemas.Ptr("WebSearch"),
|
||||
Arguments: &fullArgs,
|
||||
},
|
||||
},
|
||||
}
|
||||
allEvents = append(allEvents, ToAnthropicResponsesStreamResponse(ctx, doneResp)...)
|
||||
|
||||
// Verify the sequence:
|
||||
// [0] content_block_start (input:{})
|
||||
// [1..N-1] input_json_delta events
|
||||
// [N] content_block_stop
|
||||
if len(allEvents) < 3 {
|
||||
t.Fatalf("expected at least 3 events, got %d: %v", len(allEvents), allEvents)
|
||||
}
|
||||
|
||||
// First event: content_block_start with empty input
|
||||
if allEvents[0].Type != AnthropicStreamEventTypeContentBlockStart {
|
||||
t.Errorf("allEvents[0].Type = %v, want content_block_start", allEvents[0].Type)
|
||||
}
|
||||
|
||||
// Last event: content_block_stop
|
||||
last := allEvents[len(allEvents)-1]
|
||||
if last.Type != AnthropicStreamEventTypeContentBlockStop {
|
||||
t.Errorf("last event.Type = %v, want content_block_stop", last.Type)
|
||||
}
|
||||
|
||||
// Middle events: all input_json_delta
|
||||
for i, ev := range allEvents[1 : len(allEvents)-1] {
|
||||
if ev.Type != AnthropicStreamEventTypeContentBlockDelta {
|
||||
t.Errorf("allEvents[%d].Type = %v, want content_block_delta", i+1, ev.Type)
|
||||
}
|
||||
if ev.Delta == nil || ev.Delta.Type != AnthropicStreamDeltaTypeInputJSON {
|
||||
t.Errorf("allEvents[%d].Delta.Type = %v, want input_json", i+1, ev.Delta)
|
||||
}
|
||||
}
|
||||
|
||||
// Reconstruct query from synthetic deltas
|
||||
var accumulated string
|
||||
for _, ev := range allEvents[1 : len(allEvents)-1] {
|
||||
if ev.Delta != nil && ev.Delta.PartialJSON != nil {
|
||||
accumulated += *ev.Delta.PartialJSON
|
||||
}
|
||||
}
|
||||
var got map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(accumulated), &got); err != nil {
|
||||
t.Fatalf("reconstructed JSON is invalid: %v — got %q", err, accumulated)
|
||||
}
|
||||
if got["query"] != "latest AI news" {
|
||||
t.Errorf("reconstructed query = %v, want %q", got["query"], "latest AI news")
|
||||
}
|
||||
}
|
||||
2958
core/providers/azure/azure.go
Normal file
2958
core/providers/azure/azure.go
Normal file
File diff suppressed because it is too large
Load Diff
198
core/providers/azure/azure_caching_test.go
Normal file
198
core/providers/azure/azure_caching_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package azure_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TestAzure_OpenAIModel_CachingDeterminism verifies that Azure's delegation to
|
||||
// openai.ToOpenAIChatRequest() produces deterministic JSON for prompt caching.
|
||||
// Two schemas with the same properties but different structural key order within
|
||||
// property definitions must produce byte-identical JSON after normalization.
|
||||
func TestAzure_OpenAIModel_CachingDeterminism(t *testing.T) {
|
||||
makeReq := func(props *schemas.OrderedMap) *schemas.BifrostChatRequest {
|
||||
return &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Azure,
|
||||
Model: "gpt-4o",
|
||||
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser}},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "test",
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Version A: type before description
|
||||
propsA := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "string"),
|
||||
schemas.KV("description", "Step by step"),
|
||||
)),
|
||||
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "string"),
|
||||
schemas.KV("description", "Final answer"),
|
||||
)),
|
||||
)
|
||||
|
||||
// Version B: description before type (different structural order)
|
||||
propsB := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("reasoning", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("description", "Step by step"),
|
||||
schemas.KV("type", "string"),
|
||||
)),
|
||||
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("description", "Final answer"),
|
||||
schemas.KV("type", "string"),
|
||||
)),
|
||||
)
|
||||
|
||||
// Azure delegates OpenAI models to openai.ToOpenAIChatRequest()
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
resultA := openai.ToOpenAIChatRequest(ctx, makeReq(propsA))
|
||||
resultB := openai.ToOpenAIChatRequest(ctx, makeReq(propsB))
|
||||
|
||||
jsonA, err := schemas.Marshal(resultA.ChatParameters.Tools[0].Function.Parameters)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal params A: %v", err)
|
||||
}
|
||||
jsonB, err := schemas.Marshal(resultB.ChatParameters.Tools[0].Function.Parameters)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal params B: %v", err)
|
||||
}
|
||||
|
||||
// Caching: byte-identical JSON
|
||||
if string(jsonA) != string(jsonB) {
|
||||
t.Errorf("caching broken via Azure→OpenAI path: same schema produced different JSON\nA: %s\nB: %s", jsonA, jsonB)
|
||||
}
|
||||
|
||||
// CoT: property order preserved
|
||||
keys := resultA.ChatParameters.Tools[0].Function.Parameters.Properties.Keys()
|
||||
if len(keys) != 2 || keys[0] != "reasoning" || keys[1] != "answer" {
|
||||
t.Errorf("expected property order [reasoning, answer], got %v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAzure_OpenAIModel_PreservesPropertyOrder verifies that the Azure→OpenAI
|
||||
// delegation path preserves user-defined property ordering.
|
||||
func TestAzure_OpenAIModel_PreservesPropertyOrder(t *testing.T) {
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Azure,
|
||||
Model: "gpt-4o",
|
||||
Input: []schemas.ChatMessage{{Role: schemas.ChatMessageRoleUser}},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "AnswerResponseModel",
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("chain_of_thought", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
|
||||
schemas.KV("answer", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
|
||||
schemas.KV("citations", schemas.NewOrderedMapFromPairs(schemas.KV("type", "array"))),
|
||||
),
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result := openai.ToOpenAIChatRequest(ctx, bifrostReq)
|
||||
|
||||
keys := result.ChatParameters.Tools[0].Function.Parameters.Properties.Keys()
|
||||
if len(keys) != 3 || keys[0] != "chain_of_thought" || keys[1] != "answer" || keys[2] != "citations" {
|
||||
t.Errorf("expected property order [chain_of_thought, answer, citations], got %v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAzure_ToolInputKeyOrderPreservation verifies that tool call arguments
|
||||
// preserve their original key ordering through the Azure→OpenAI delegation path.
|
||||
// TestAzure_ToolInputKeyOrderPreservation verifies that Azure→OpenAI delegation
|
||||
// preserves the original key ordering of tool call arguments for prompt caching.
|
||||
// Tests multiple parallel tool calls with different key orderings per block.
|
||||
func TestAzure_ToolInputKeyOrderPreservation(t *testing.T) {
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Azure,
|
||||
Model: "gpt-4o",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
|
||||
},
|
||||
{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{
|
||||
{
|
||||
Index: 0,
|
||||
Type: schemas.Ptr("function"),
|
||||
ID: schemas.Ptr("toolu_001"),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("bash"),
|
||||
Arguments: `{"description":"Find references quickly","timeout":30000,"command":"grep -r auth_injector ."}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 1,
|
||||
Type: schemas.Ptr("function"),
|
||||
ID: schemas.Ptr("toolu_002"),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("bash"),
|
||||
Arguments: `{"command":"git diff main...HEAD --stat","description":"Show diff"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result := openai.ToOpenAIChatRequest(ctx, bifrostReq)
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
// Collect tool call arguments from assistant message
|
||||
var argsList []string
|
||||
for _, msg := range result.Messages {
|
||||
if msg.OpenAIChatAssistantMessage != nil {
|
||||
for _, tc := range msg.OpenAIChatAssistantMessage.ToolCalls {
|
||||
argsList = append(argsList, tc.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(argsList) != 2 {
|
||||
t.Fatalf("expected 2 tool call arguments, got %d", len(argsList))
|
||||
}
|
||||
|
||||
// OpenAI path passes Arguments through as strings — verify key order is preserved
|
||||
// Block 0: keys should be description, timeout, command
|
||||
s0 := argsList[0]
|
||||
if !(strings.Index(s0, "description") < strings.Index(s0, "timeout") &&
|
||||
strings.Index(s0, "timeout") < strings.Index(s0, "command")) {
|
||||
t.Errorf("block 0: key order not preserved, expected description < timeout < command in: %s", s0)
|
||||
}
|
||||
|
||||
// Block 1: keys should be command, description
|
||||
s1 := argsList[1]
|
||||
if !(strings.Index(s1, "command") < strings.Index(s1, "description")) {
|
||||
t.Errorf("block 1: key order not preserved, expected command < description in: %s", s1)
|
||||
}
|
||||
}
|
||||
92
core/providers/azure/azure_test.go
Normal file
92
core/providers/azure/azure_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package azure_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestAzure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if strings.TrimSpace(os.Getenv("AZURE_API_KEY")) == "" {
|
||||
t.Skip("Skipping Azure tests because AZURE_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.Azure,
|
||||
ChatModel: "gpt-4o",
|
||||
PromptCachingModel: "gpt-4o",
|
||||
VisionModel: "gpt-4o",
|
||||
ChatAudioModel: "gpt-4o-mini-audio-preview",
|
||||
Fallbacks: []schemas.Fallback{
|
||||
{Provider: schemas.Azure, Model: "gpt-4o"},
|
||||
},
|
||||
TextModel: "", // Azure doesn't support text completion in newer models
|
||||
EmbeddingModel: "text-embedding-ada-002",
|
||||
ReasoningModel: "claude-opus-4-5",
|
||||
SpeechSynthesisModel: "gpt-4o-mini-tts",
|
||||
TranscriptionModel: "whisper",
|
||||
ImageGenerationModel: "gpt-image-1",
|
||||
ImageEditModel: "gpt-image-1",
|
||||
VideoGenerationModel: "sora-2",
|
||||
PassthroughModel: "gpt-4o",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: false, // Not supported
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: true,
|
||||
MultipleToolCallsStreaming: true,
|
||||
End2EndToolCalling: true,
|
||||
AutomaticFunctionCall: true,
|
||||
ImageURL: true,
|
||||
ImageBase64: true,
|
||||
MultipleImages: true,
|
||||
CompleteEnd2End: true,
|
||||
Embedding: true,
|
||||
ListModels: true,
|
||||
Reasoning: true,
|
||||
ChatAudio: false,
|
||||
Transcription: false, // Disabled for azure because of 3 calls/minute quota
|
||||
TranscriptionStream: false, // Not properly supported yet by Azure
|
||||
SpeechSynthesis: false, // Disabled for azure because of 3 calls/minute quota
|
||||
SpeechSynthesisStream: false, // Disabled for azure because of 3 calls/minute quota
|
||||
StructuredOutputs: true, // Structured outputs with nullable enum support
|
||||
PromptCaching: true,
|
||||
ImageGeneration: false, // Skipped for Azure
|
||||
ImageGenerationStream: false, // Skipped for Azure
|
||||
ImageEdit: false, // Model not deployed on Azure endpoint
|
||||
ImageEditStream: false, // Model not deployed on Azure endpoint
|
||||
ImageVariation: false, // Not supported by Azure
|
||||
VideoGeneration: false, // disabled for now because of long running operations
|
||||
VideoDownload: false,
|
||||
VideoRetrieve: false,
|
||||
VideoRemix: false,
|
||||
VideoList: false,
|
||||
VideoDelete: false,
|
||||
InterleavedThinking: true,
|
||||
PassthroughAPI: true,
|
||||
EagerInputStreaming: true, // fine-grained-tool-streaming-2025-05-14 (Beta on Azure Foundry)
|
||||
ServerToolsViaOpenAIEndpoint: true, // web_search / web_fetch / code_execution on Azure per Table 20
|
||||
},
|
||||
DisableParallelFor: []string{"Transcription"}, // Azure Whisper has 3 calls/minute quota
|
||||
}
|
||||
|
||||
t.Run("AzureTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
122
core/providers/azure/files.go
Normal file
122
core/providers/azure/files.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// setAzureAuth sets the Azure authentication header on the request for OpenAI models.
|
||||
// It handles authentication in order of priority:
|
||||
// 1. Service Principal (client ID/secret/tenant ID) - uses Bearer token
|
||||
// 2. Context token - uses Bearer token
|
||||
// 3. API key - uses api-key header
|
||||
// 4. DefaultAzureCredential auto-detection (managed identity, workload identity, env vars, CLI)
|
||||
func (provider *AzureProvider) setAzureAuth(ctx context.Context, req *fasthttp.Request, key schemas.Key) *schemas.BifrostError {
|
||||
// Service Principal authentication
|
||||
if key.AzureKeyConfig != nil && key.AzureKeyConfig.ClientID != nil &&
|
||||
key.AzureKeyConfig.ClientSecret != nil && key.AzureKeyConfig.TenantID != nil && key.AzureKeyConfig.ClientID.GetValue() != "" && key.AzureKeyConfig.ClientSecret.GetValue() != "" && key.AzureKeyConfig.TenantID.GetValue() != "" {
|
||||
cred, err := provider.getOrCreateAuth(key.AzureKeyConfig.TenantID.GetValue(), key.AzureKeyConfig.ClientID.GetValue(), key.AzureKeyConfig.ClientSecret.GetValue())
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to get or create Azure authentication", err)
|
||||
}
|
||||
|
||||
scopes := getAzureScopes(key.AzureKeyConfig.Scopes)
|
||||
|
||||
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{
|
||||
Scopes: scopes,
|
||||
})
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to get Azure access token", err)
|
||||
}
|
||||
|
||||
if token.Token == "" {
|
||||
return providerUtils.NewBifrostOperationError("azure access token is empty", fmt.Errorf("token is empty"))
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token))
|
||||
req.Header.Del("api-key")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Context token authentication
|
||||
if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok && authToken != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken))
|
||||
req.Header.Del("api-key")
|
||||
return nil
|
||||
}
|
||||
|
||||
// API key authentication
|
||||
value := key.Value.GetValue()
|
||||
if value != "" {
|
||||
req.Header.Del("Authorization")
|
||||
req.Header.Set("api-key", value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// No explicit credentials - attempt DefaultAzureCredential auto-detection.
|
||||
scopes := getAzureScopes(nil)
|
||||
if key.AzureKeyConfig != nil {
|
||||
scopes = getAzureScopes(key.AzureKeyConfig.Scopes)
|
||||
}
|
||||
|
||||
cred, err := provider.getOrCreateDefaultAzureCredential()
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential unavailable", err)
|
||||
}
|
||||
|
||||
token, err := cred.GetToken(ctx, policy.TokenRequestOptions{Scopes: scopes})
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential failed to get token", err)
|
||||
}
|
||||
|
||||
if token.Token == "" {
|
||||
return providerUtils.NewBifrostOperationError("no credentials provided and DefaultAzureCredential returned empty token", fmt.Errorf("token is empty"))
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Token))
|
||||
req.Header.Del("api-key")
|
||||
return nil
|
||||
}
|
||||
|
||||
// AzureFileResponse represents an Azure file response (same as OpenAI).
|
||||
type AzureFileResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Bytes int64 `json:"bytes"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Filename string `json:"filename"`
|
||||
Purpose schemas.FilePurpose `json:"purpose"`
|
||||
Status string `json:"status,omitempty"`
|
||||
StatusDetails *string `json:"status_details,omitempty"`
|
||||
}
|
||||
|
||||
// ToBifrostFileUploadResponse converts Azure file response to Bifrost response.
|
||||
func (r *AzureFileResponse) ToBifrostFileUploadResponse(providerName schemas.ModelProvider, latency time.Duration, sendBackRawResponse bool, rawResponse interface{}) *schemas.BifrostFileUploadResponse {
|
||||
resp := &schemas.BifrostFileUploadResponse{
|
||||
ID: r.ID,
|
||||
Object: r.Object,
|
||||
Bytes: r.Bytes,
|
||||
CreatedAt: r.CreatedAt,
|
||||
Filename: r.Filename,
|
||||
Purpose: r.Purpose,
|
||||
Status: openai.ToBifrostFileStatus(r.Status),
|
||||
StatusDetails: r.StatusDetails,
|
||||
StorageBackend: schemas.FileStorageAPI,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Latency: latency.Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
if sendBackRawResponse {
|
||||
resp.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
51
core/providers/azure/models.go
Normal file
51
core/providers/azure/models.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.Data)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: schemas.Azure,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.Data {
|
||||
for _, result := range pipeline.FilterModel(model.ID) {
|
||||
entry := schemas.Model{
|
||||
ID: string(schemas.Azure) + "/" + result.ResolvedID,
|
||||
Created: schemas.Ptr(model.CreatedAt),
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
entry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, entry)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
36
core/providers/azure/types.go
Normal file
36
core/providers/azure/types.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package azure
|
||||
|
||||
// AzureAPIVersionDefault is the default Azure API version to use when not specified.
|
||||
const AzureAPIVersionDefault = "2024-10-21"
|
||||
const AzureAPIVersionPreview = "preview"
|
||||
const AzureAPIVersionImageEditDefault = "2025-04-01-preview"
|
||||
const AzureAnthropicAPIVersionDefault = "2023-06-01"
|
||||
|
||||
type AzureModelCapabilities struct {
|
||||
FineTune bool `json:"fine_tune"`
|
||||
Inference bool `json:"inference"`
|
||||
Completion bool `json:"completion"`
|
||||
ChatCompletion bool `json:"chat_completion"`
|
||||
Embeddings bool `json:"embeddings"`
|
||||
}
|
||||
|
||||
type AzureModelDeprecation struct {
|
||||
FineTune int64 `json:"fine_tune,omitempty"`
|
||||
Inference int64 `json:"inference,omitempty"`
|
||||
}
|
||||
|
||||
type AzureModel struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
FineTune string `json:"fine_tune,omitempty"`
|
||||
Capabilities AzureModelCapabilities `json:"capabilities,omitempty"`
|
||||
LifecycleStatus string `json:"lifecycle_status"`
|
||||
Deprecation *AzureModelDeprecation `json:"deprecation,omitempty"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
type AzureListModelsResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []AzureModel `json:"data"`
|
||||
}
|
||||
94
core/providers/azure/utils.go
Normal file
94
core/providers/azure/utils.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/anthropic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool) ([]byte, *schemas.BifrostError) {
|
||||
// Large payload mode: body streams directly from the LP reader — skip all body building
|
||||
// (matches CheckContextAndGetRequestBody guard).
|
||||
if providerUtils.IsLargePayloadPassthroughEnabled(ctx) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var jsonBody []byte
|
||||
var err error
|
||||
|
||||
// Check if raw request body should be used
|
||||
if useRawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && useRawBody {
|
||||
jsonBody = request.GetRawRequestBody()
|
||||
|
||||
// Add max_tokens if not present (using sjson to preserve key order for prompt caching)
|
||||
if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") {
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens))
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
// Replace model with deployment
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
// Delete fallbacks field
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
// Add stream if streaming
|
||||
if isStreaming {
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Convert request to Anthropic format
|
||||
request.Model = deployment
|
||||
reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request)
|
||||
if convErr != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr)
|
||||
}
|
||||
if reqBody == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil)
|
||||
}
|
||||
|
||||
if isStreaming {
|
||||
reqBody.Stream = schemas.Ptr(true)
|
||||
}
|
||||
|
||||
// Add provider-aware beta headers for Azure
|
||||
anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Azure)
|
||||
|
||||
// Marshal struct to JSON bytes, preserving field order
|
||||
jsonBody, err = providerUtils.MarshalSorted(reqBody)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, fmt.Errorf("failed to marshal request body: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return jsonBody, nil
|
||||
}
|
||||
|
||||
// getCleanedScopes returns cleaned scopes or default scope if none are valid.
|
||||
// It filters out empty/whitespace-only strings and returns the default scope if no valid scopes remain.
|
||||
func getAzureScopes(configuredScopes []string) []string {
|
||||
scopes := []string{DefaultAzureScope}
|
||||
if len(configuredScopes) > 0 {
|
||||
cleaned := make([]string, 0, len(configuredScopes))
|
||||
for _, s := range configuredScopes {
|
||||
if strings.TrimSpace(s) != "" {
|
||||
cleaned = append(cleaned, strings.TrimSpace(s))
|
||||
}
|
||||
}
|
||||
if len(cleaned) > 0 {
|
||||
scopes = cleaned
|
||||
}
|
||||
}
|
||||
return scopes
|
||||
}
|
||||
417
core/providers/bedrock/batch.go
Normal file
417
core/providers/bedrock/batch.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// BedrockBatchJobRequest represents a request to create a batch inference job.
|
||||
type BedrockBatchJobRequest struct {
|
||||
JobName string `json:"jobName"`
|
||||
ModelID *string `json:"modelId"`
|
||||
RoleArn string `json:"roleArn"`
|
||||
InputDataConfig BedrockInputDataConfig `json:"inputDataConfig"`
|
||||
OutputDataConfig BedrockOutputDataConfig `json:"outputDataConfig"`
|
||||
TimeoutDurationInHours int `json:"timeoutDurationInHours,omitempty"`
|
||||
Tags []BedrockTag `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockInputDataConfig represents the input configuration for a batch job.
|
||||
type BedrockInputDataConfig struct {
|
||||
S3InputDataConfig BedrockS3InputDataConfig `json:"s3InputDataConfig"`
|
||||
}
|
||||
|
||||
// BedrockS3InputDataConfig represents S3 input configuration.
|
||||
type BedrockS3InputDataConfig struct {
|
||||
S3Uri string `json:"s3Uri"`
|
||||
S3InputFormat string `json:"s3InputFormat,omitempty"` // "JSONL"
|
||||
Endpoint *string `json:"endpoint,omitempty"`
|
||||
FileID *string `json:"file_id,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockOutputDataConfig represents the output configuration for a batch job.
|
||||
type BedrockOutputDataConfig struct {
|
||||
S3OutputDataConfig BedrockS3OutputDataConfig `json:"s3OutputDataConfig"`
|
||||
}
|
||||
|
||||
// BedrockS3OutputDataConfig represents S3 output configuration.
|
||||
type BedrockS3OutputDataConfig struct {
|
||||
S3Uri string `json:"s3Uri"`
|
||||
}
|
||||
|
||||
// BedrockTag represents a tag for a batch job.
|
||||
type BedrockTag struct {
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// BedrockBatchJobResponse represents a batch job response.
|
||||
type BedrockBatchJobResponse struct {
|
||||
JobArn string `json:"jobArn"`
|
||||
Status string `json:"status"`
|
||||
JobName string `json:"jobName,omitempty"`
|
||||
ModelID string `json:"modelId,omitempty"`
|
||||
RoleArn string `json:"roleArn,omitempty"`
|
||||
InputDataConfig *BedrockInputDataConfig `json:"inputDataConfig,omitempty"`
|
||||
OutputDataConfig *BedrockOutputDataConfig `json:"outputDataConfig,omitempty"`
|
||||
VpcConfig *BedrockVpcConfig `json:"vpcConfig,omitempty"`
|
||||
SubmitTime *time.Time `json:"submitTime,omitempty"`
|
||||
LastModifiedTime *time.Time `json:"lastModifiedTime,omitempty"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
ClientRequestToken string `json:"clientRequestToken,omitempty"`
|
||||
JobExpirationTime *time.Time `json:"jobExpirationTime,omitempty"`
|
||||
TimeoutDurationInHours int `json:"timeoutDurationInHours,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchJobListResponse represents a list of batch jobs.
|
||||
type BedrockBatchJobListResponse struct {
|
||||
InvocationJobSummaries []BedrockBatchJobSummary `json:"invocationJobSummaries"`
|
||||
NextToken *string `json:"nextToken,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchJobSummary represents a summary of a batch job.
|
||||
type BedrockBatchJobSummary struct {
|
||||
JobArn string `json:"jobArn"`
|
||||
JobName string `json:"jobName"`
|
||||
ModelID string `json:"modelId"`
|
||||
Status string `json:"status"`
|
||||
SubmitTime *time.Time `json:"submitTime,omitempty"`
|
||||
LastModifiedTime *time.Time `json:"lastModifiedTime,omitempty"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchResultRecord represents a single result record in Bedrock batch output JSONL.
|
||||
type BedrockBatchResultRecord struct {
|
||||
RecordID string `json:"recordId"`
|
||||
ModelOutput json.RawMessage `json:"modelOutput,omitempty"`
|
||||
Error *BedrockBatchError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchError represents an error in batch processing.
|
||||
type BedrockBatchError struct {
|
||||
ErrorCode int `json:"errorCode,omitempty"`
|
||||
ErrorMessage string `json:"errorMessage,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchListRequest represents a request to list batch jobs.
|
||||
type BedrockBatchListRequest struct {
|
||||
MaxResults int `json:"maxResults,omitempty"`
|
||||
NextToken *string `json:"nextToken,omitempty"`
|
||||
StatusEquals string `json:"statusEquals,omitempty"`
|
||||
NameContains string `json:"nameContains,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchRetrieveRequest represents a request to retrieve a batch job.
|
||||
type BedrockBatchRetrieveRequest struct {
|
||||
JobIdentifier string `json:"jobIdentifier"`
|
||||
}
|
||||
|
||||
// BedrockBatchCancelRequest represents a request to cancel/stop a batch job.
|
||||
type BedrockBatchCancelRequest struct {
|
||||
JobIdentifier string `json:"jobIdentifier"`
|
||||
}
|
||||
|
||||
// BedrockBatchCancelResponse represents the response from stopping a batch job.
|
||||
type BedrockBatchCancelResponse struct {
|
||||
JobArn string `json:"jobArn"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// ToBifrostBatchStatus converts Bedrock status to Bifrost status.
|
||||
func ToBifrostBatchStatus(status string) schemas.BatchStatus {
|
||||
switch status {
|
||||
case "Submitted", "Validating":
|
||||
return schemas.BatchStatusValidating
|
||||
case "InProgress":
|
||||
return schemas.BatchStatusInProgress
|
||||
case "Completed":
|
||||
return schemas.BatchStatusCompleted
|
||||
case "Failed", "PartiallyCompleted":
|
||||
return schemas.BatchStatusFailed
|
||||
case "Stopping":
|
||||
return schemas.BatchStatusCancelling
|
||||
case "Stopped":
|
||||
return schemas.BatchStatusCancelled
|
||||
case "Expired":
|
||||
return schemas.BatchStatusExpired
|
||||
case "Scheduled":
|
||||
return schemas.BatchStatusValidating
|
||||
default:
|
||||
return schemas.BatchStatus(status)
|
||||
}
|
||||
}
|
||||
|
||||
// parseBatchResultsJSONL parses JSONL content from Bedrock batch output into Bifrost format.
|
||||
// Returns the parsed results and any parse errors encountered.
|
||||
func parseBatchResultsJSONL(content []byte, provider *BedrockProvider) ([]schemas.BatchResultItem, []schemas.BatchError) {
|
||||
var results []schemas.BatchResultItem
|
||||
|
||||
parseResult := providerUtils.ParseJSONL(content, func(line []byte) error {
|
||||
var bedrockResult BedrockBatchResultRecord
|
||||
if err := sonic.Unmarshal(line, &bedrockResult); err != nil {
|
||||
provider.logger.Warn(fmt.Sprintf("failed to parse batch result line: %v", err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert Bedrock format to Bifrost format
|
||||
resultItem := schemas.BatchResultItem{
|
||||
CustomID: bedrockResult.RecordID,
|
||||
}
|
||||
|
||||
if bedrockResult.ModelOutput != nil {
|
||||
var bodyMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(bedrockResult.ModelOutput, &bodyMap); err == nil {
|
||||
resultItem.Response = &schemas.BatchResultResponse{
|
||||
StatusCode: 200,
|
||||
Body: bodyMap,
|
||||
}
|
||||
} else {
|
||||
resultItem.Error = &schemas.BatchResultError{
|
||||
Code: "parse_error",
|
||||
Message: fmt.Sprintf("failed to parse model output: %v", err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bedrockResult.Error != nil {
|
||||
resultItem.Error = &schemas.BatchResultError{
|
||||
Code: fmt.Sprintf("%d", bedrockResult.Error.ErrorCode),
|
||||
Message: bedrockResult.Error.ErrorMessage,
|
||||
}
|
||||
// Set status code to indicate error if there's an error
|
||||
if resultItem.Response == nil {
|
||||
resultItem.Response = &schemas.BatchResultResponse{
|
||||
StatusCode: bedrockResult.Error.ErrorCode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, resultItem)
|
||||
return nil
|
||||
})
|
||||
|
||||
return results, parseResult.Errors
|
||||
}
|
||||
|
||||
// ToBedrockBatchJobResponse converts a Bifrost batch create response to Bedrock format.
|
||||
func ToBedrockBatchJobResponse(resp *schemas.BifrostBatchCreateResponse) *BedrockBatchJobResponse {
|
||||
// Here if the provider is not Bedrock - then we create a dummy arn and string using the batch ID
|
||||
if resp.ExtraFields.Provider != schemas.Bedrock {
|
||||
return &BedrockBatchJobResponse{
|
||||
JobArn: fmt.Sprintf("arn:aws:bedrock:us-east-1:444444444444:batch:%s", resp.ID),
|
||||
Status: toBedrockBatchStatus(resp.Status),
|
||||
}
|
||||
}
|
||||
// For bedrock, we go as is
|
||||
result := &BedrockBatchJobResponse{
|
||||
JobArn: resp.ID,
|
||||
Status: toBedrockBatchStatus(resp.Status),
|
||||
}
|
||||
|
||||
if resp.Metadata != nil {
|
||||
if jobName, ok := resp.Metadata["job_name"]; ok {
|
||||
result.JobName = jobName
|
||||
}
|
||||
}
|
||||
|
||||
if resp.CreatedAt > 0 {
|
||||
t := time.Unix(resp.CreatedAt, 0)
|
||||
result.SubmitTime = &t
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ToBedrockBatchJobListResponse converts a Bifrost batch list response to Bedrock format.
|
||||
func ToBedrockBatchJobListResponse(resp *schemas.BifrostBatchListResponse) *BedrockBatchJobListResponse {
|
||||
result := &BedrockBatchJobListResponse{
|
||||
InvocationJobSummaries: make([]BedrockBatchJobSummary, len(resp.Data)),
|
||||
}
|
||||
|
||||
for i, batch := range resp.Data {
|
||||
summary := BedrockBatchJobSummary{
|
||||
JobArn: batch.ID,
|
||||
Status: toBedrockBatchStatus(batch.Status),
|
||||
}
|
||||
|
||||
if batch.Metadata != nil {
|
||||
if jobName, ok := batch.Metadata["job_name"]; ok {
|
||||
summary.JobName = jobName
|
||||
}
|
||||
if modelId, ok := batch.Metadata["model_id"]; ok {
|
||||
summary.ModelID = modelId
|
||||
}
|
||||
}
|
||||
|
||||
if batch.CreatedAt > 0 {
|
||||
t := time.Unix(batch.CreatedAt, 0)
|
||||
summary.SubmitTime = &t
|
||||
}
|
||||
|
||||
if batch.CompletedAt != nil && *batch.CompletedAt > 0 {
|
||||
t := time.Unix(*batch.CompletedAt, 0)
|
||||
summary.EndTime = &t
|
||||
}
|
||||
|
||||
result.InvocationJobSummaries[i] = summary
|
||||
}
|
||||
|
||||
if resp.LastID != nil {
|
||||
result.NextToken = resp.LastID
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ToBedrockBatchJobRetrieveResponse converts a Bifrost batch retrieve response to Bedrock format.
|
||||
func ToBedrockBatchJobRetrieveResponse(resp *schemas.BifrostBatchRetrieveResponse) *BedrockBatchJobResponse {
|
||||
result := &BedrockBatchJobResponse{
|
||||
JobArn: resp.ID,
|
||||
Status: toBedrockBatchStatus(resp.Status),
|
||||
}
|
||||
|
||||
if resp.Metadata != nil {
|
||||
if jobName, ok := resp.Metadata["job_name"]; ok {
|
||||
result.JobName = jobName
|
||||
}
|
||||
}
|
||||
|
||||
if resp.CreatedAt > 0 {
|
||||
t := time.Unix(resp.CreatedAt, 0)
|
||||
result.SubmitTime = &t
|
||||
}
|
||||
|
||||
if resp.CompletedAt != nil && *resp.CompletedAt > 0 {
|
||||
t := time.Unix(*resp.CompletedAt, 0)
|
||||
result.EndTime = &t
|
||||
}
|
||||
|
||||
if resp.InputFileID != "" {
|
||||
result.InputDataConfig = &BedrockInputDataConfig{
|
||||
S3InputDataConfig: BedrockS3InputDataConfig{
|
||||
S3Uri: resp.InputFileID,
|
||||
S3InputFormat: "JSONL",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if resp.OutputFileID != nil && *resp.OutputFileID != "" {
|
||||
result.OutputDataConfig = &BedrockOutputDataConfig{
|
||||
S3OutputDataConfig: BedrockS3OutputDataConfig{
|
||||
S3Uri: *resp.OutputFileID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// toBedrockBatchStatus converts Bifrost batch status to Bedrock status.
|
||||
func toBedrockBatchStatus(status schemas.BatchStatus) string {
|
||||
switch status {
|
||||
case schemas.BatchStatusValidating:
|
||||
return "Validating"
|
||||
case schemas.BatchStatusInProgress:
|
||||
return "InProgress"
|
||||
case schemas.BatchStatusCompleted:
|
||||
fallthrough
|
||||
case schemas.BatchStatusEnded:
|
||||
return "Completed"
|
||||
case schemas.BatchStatusFailed:
|
||||
return "Failed"
|
||||
case schemas.BatchStatusCancelling:
|
||||
return "Stopping"
|
||||
case schemas.BatchStatusCancelled:
|
||||
return "Stopped"
|
||||
case schemas.BatchStatusExpired:
|
||||
return "Expired"
|
||||
default:
|
||||
return string(status)
|
||||
}
|
||||
}
|
||||
|
||||
// ToBifrostBatchListRequest converts a Bedrock batch list request to Bifrost format.
|
||||
func ToBifrostBatchListRequest(req *BedrockBatchListRequest, provider schemas.ModelProvider) *schemas.BifrostBatchListRequest {
|
||||
result := &schemas.BifrostBatchListRequest{
|
||||
Provider: provider,
|
||||
Limit: req.MaxResults,
|
||||
}
|
||||
|
||||
if req.NextToken != nil {
|
||||
result.PageToken = req.NextToken
|
||||
}
|
||||
|
||||
if req.StatusEquals != "" || req.NameContains != "" {
|
||||
result.ExtraParams = make(map[string]interface{})
|
||||
if req.StatusEquals != "" {
|
||||
result.ExtraParams["statusEquals"] = req.StatusEquals
|
||||
}
|
||||
if req.NameContains != "" {
|
||||
result.ExtraParams["nameContains"] = req.NameContains
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ToBifrostBatchRetrieveRequest converts a Bedrock batch retrieve request to Bifrost format.
|
||||
func ToBifrostBatchRetrieveRequest(req *BedrockBatchRetrieveRequest, provider schemas.ModelProvider) *schemas.BifrostBatchRetrieveRequest {
|
||||
return &schemas.BifrostBatchRetrieveRequest{
|
||||
Provider: provider,
|
||||
BatchID: req.JobIdentifier,
|
||||
}
|
||||
}
|
||||
|
||||
// ToBifrostBatchCancelRequest converts a Bedrock batch cancel request to Bifrost format.
|
||||
func ToBifrostBatchCancelRequest(req *BedrockBatchCancelRequest, provider schemas.ModelProvider) *schemas.BifrostBatchCancelRequest {
|
||||
return &schemas.BifrostBatchCancelRequest{
|
||||
Provider: provider,
|
||||
BatchID: req.JobIdentifier,
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockBatchCancelResponse converts a Bifrost batch cancel response to Bedrock format.
|
||||
func ToBedrockBatchCancelResponse(resp *schemas.BifrostBatchCancelResponse) *BedrockBatchCancelResponse {
|
||||
return &BedrockBatchCancelResponse{
|
||||
JobArn: resp.ID,
|
||||
Status: toBedrockBatchStatus(resp.Status),
|
||||
}
|
||||
}
|
||||
|
||||
// splitJSONL splits JSONL content into individual lines.
|
||||
func splitJSONL(data []byte) [][]byte {
|
||||
var lines [][]byte
|
||||
start := 0
|
||||
for i, b := range data {
|
||||
if b == '\n' {
|
||||
if i > start {
|
||||
lines = append(lines, data[start:i])
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(data) {
|
||||
lines = append(lines, data[start:])
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
// BedrockVpcConfig represents VPC configuration for a batch job.
|
||||
type BedrockVpcConfig struct {
|
||||
SecurityGroupIds []string `json:"securityGroupIds,omitempty"`
|
||||
SubnetIds []string `json:"subnetIds,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockBatchManifest represents the manifest.json.out file structure from S3.
|
||||
type BedrockBatchManifest struct {
|
||||
TotalRecordCount int `json:"totalRecordCount"`
|
||||
ProcessedRecordCount int `json:"processedRecordCount"`
|
||||
ErrorRecordCount int `json:"errorRecordCount"`
|
||||
}
|
||||
3597
core/providers/bedrock/bedrock.go
Normal file
3597
core/providers/bedrock/bedrock.go
Normal file
File diff suppressed because it is too large
Load Diff
4326
core/providers/bedrock/bedrock_test.go
Normal file
4326
core/providers/bedrock/bedrock_test.go
Normal file
File diff suppressed because it is too large
Load Diff
449
core/providers/bedrock/chat.go
Normal file
449
core/providers/bedrock/chat.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBedrockChatCompletionRequest converts a Bifrost request to Bedrock Converse API format
|
||||
func ToBedrockChatCompletionRequest(ctx *schemas.BifrostContext, bifrostReq *schemas.BifrostChatRequest) (*BedrockConverseRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrost request is nil")
|
||||
}
|
||||
|
||||
if bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("only chat completion requests are supported for Bedrock Converse API")
|
||||
}
|
||||
|
||||
bedrockReq := &BedrockConverseRequest{
|
||||
ModelID: bifrostReq.Model,
|
||||
}
|
||||
|
||||
// Convert messages and system messages
|
||||
messages, systemMessages, err := convertMessages(bifrostReq.Input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
bedrockReq.Messages = messages
|
||||
if len(systemMessages) > 0 {
|
||||
bedrockReq.System = systemMessages
|
||||
}
|
||||
|
||||
// Convert parameters and configurations
|
||||
if err := convertChatParameters(ctx, bifrostReq, bedrockReq); err != nil {
|
||||
return nil, fmt.Errorf("failed to convert chat parameters: %w", err)
|
||||
}
|
||||
|
||||
// Ensure tool config is present when needed
|
||||
ensureChatToolConfigForConversation(bifrostReq, bedrockReq)
|
||||
|
||||
return bedrockReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostChatResponse converts a Bedrock Converse API response to Bifrost format
|
||||
func (response *BedrockConverseResponse) ToBifrostChatResponse(ctx context.Context, model string) (*schemas.BifrostChatResponse, error) {
|
||||
if response == nil {
|
||||
return nil, fmt.Errorf("bedrock response is nil")
|
||||
}
|
||||
|
||||
// Convert content blocks and tool calls
|
||||
var contentStr *string
|
||||
var contentBlocks []schemas.ChatContentBlock
|
||||
var toolCalls []schemas.ChatAssistantMessageToolCall
|
||||
var reasoningDetails []schemas.ChatReasoningDetails
|
||||
var reasoningText string
|
||||
|
||||
if response.Output.Message != nil {
|
||||
for _, contentBlock := range response.Output.Message.Content {
|
||||
// Handle text content
|
||||
if contentBlock.Text != nil && *contentBlock.Text != "" {
|
||||
chatContentBlock := schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: contentBlock.Text,
|
||||
}
|
||||
contentBlocks = append(contentBlocks, chatContentBlock)
|
||||
}
|
||||
|
||||
if contentBlock.ToolUse != nil {
|
||||
// Check if this is the structured output tool
|
||||
if structuredOutputToolName, ok := ctx.Value(schemas.BifrostContextKeyStructuredOutputToolName).(string); ok && contentBlock.ToolUse.Name == structuredOutputToolName {
|
||||
// This is structured output - set contentStr and skip adding to toolCalls
|
||||
if contentBlock.ToolUse.Input != nil {
|
||||
jsonStr := string(contentBlock.ToolUse.Input)
|
||||
contentStr = &jsonStr
|
||||
}
|
||||
continue // Skip adding to toolCalls
|
||||
}
|
||||
|
||||
// Regular tool call processing
|
||||
var arguments string
|
||||
if contentBlock.ToolUse.Input != nil {
|
||||
arguments = string(contentBlock.ToolUse.Input)
|
||||
} else {
|
||||
arguments = "{}"
|
||||
}
|
||||
|
||||
toolUseID := contentBlock.ToolUse.ToolUseID
|
||||
toolUseName := contentBlock.ToolUse.Name
|
||||
|
||||
toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{
|
||||
Index: uint16(len(toolCalls)),
|
||||
Type: schemas.Ptr("function"),
|
||||
ID: &toolUseID,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: &toolUseName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Handle reasoning content
|
||||
if contentBlock.ReasoningContent != nil {
|
||||
if contentBlock.ReasoningContent.ReasoningText == nil {
|
||||
continue
|
||||
}
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: contentBlock.ReasoningContent.ReasoningText.Text,
|
||||
Signature: contentBlock.ReasoningContent.ReasoningText.Signature,
|
||||
})
|
||||
if contentBlock.ReasoningContent.ReasoningText.Text != nil {
|
||||
reasoningText += *contentBlock.ReasoningContent.ReasoningText.Text + "\n"
|
||||
}
|
||||
}
|
||||
|
||||
// Handle document content
|
||||
if contentBlock.Document != nil {
|
||||
fileBlock := schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeFile,
|
||||
File: &schemas.ChatInputFile{},
|
||||
}
|
||||
|
||||
// Set filename from document name
|
||||
if contentBlock.Document.Name != "" {
|
||||
fileBlock.File.Filename = &contentBlock.Document.Name
|
||||
}
|
||||
|
||||
// Set file type based on format
|
||||
if contentBlock.Document.Format != "" {
|
||||
var fileType string
|
||||
switch contentBlock.Document.Format {
|
||||
case "pdf":
|
||||
fileType = "application/pdf"
|
||||
case "txt":
|
||||
fileType = "text/plain"
|
||||
case "md":
|
||||
fileType = "text/markdown"
|
||||
case "html":
|
||||
fileType = "text/html"
|
||||
case "csv":
|
||||
fileType = "text/csv"
|
||||
case "doc":
|
||||
fileType = "application/msword"
|
||||
case "docx":
|
||||
fileType = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
case "xls":
|
||||
fileType = "application/vnd.ms-excel"
|
||||
case "xlsx":
|
||||
fileType = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
default:
|
||||
fileType = "application/pdf"
|
||||
}
|
||||
fileBlock.File.FileType = &fileType
|
||||
}
|
||||
|
||||
// Convert document source data
|
||||
if contentBlock.Document.Source != nil {
|
||||
if contentBlock.Document.Source.Bytes != nil {
|
||||
fileBlock.File.FileData = contentBlock.Document.Source.Bytes
|
||||
} else if contentBlock.Document.Source.Text != nil {
|
||||
fileBlock.File.FileData = contentBlock.Document.Source.Text
|
||||
}
|
||||
}
|
||||
|
||||
contentBlocks = append(contentBlocks, fileBlock)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText {
|
||||
contentStr = contentBlocks[0].Text
|
||||
contentBlocks = nil
|
||||
}
|
||||
|
||||
// Create the message content
|
||||
messageContent := schemas.ChatMessageContent{
|
||||
ContentStr: contentStr,
|
||||
ContentBlocks: contentBlocks,
|
||||
}
|
||||
|
||||
// Create assistant message if we have tool calls
|
||||
var assistantMessage *schemas.ChatAssistantMessage
|
||||
if len(toolCalls) > 0 {
|
||||
assistantMessage = &schemas.ChatAssistantMessage{
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
}
|
||||
if len(reasoningDetails) > 0 {
|
||||
if assistantMessage == nil {
|
||||
assistantMessage = &schemas.ChatAssistantMessage{}
|
||||
}
|
||||
assistantMessage.ReasoningDetails = reasoningDetails
|
||||
if reasoningText != "" {
|
||||
assistantMessage.Reasoning = new(reasoningText)
|
||||
}
|
||||
}
|
||||
|
||||
// Create the response choice
|
||||
choices := []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &messageContent,
|
||||
ChatAssistantMessage: assistantMessage,
|
||||
},
|
||||
},
|
||||
FinishReason: schemas.Ptr(convertBedrockStopReason(response.StopReason)),
|
||||
},
|
||||
}
|
||||
var usage *schemas.BifrostLLMUsage
|
||||
if response.Usage != nil {
|
||||
// Convert usage information
|
||||
usage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
}
|
||||
// Handle cached tokens if present
|
||||
if response.Usage.CacheReadInputTokens > 0 {
|
||||
if usage.PromptTokensDetails == nil {
|
||||
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{}
|
||||
}
|
||||
usage.PromptTokensDetails.CachedReadTokens = response.Usage.CacheReadInputTokens
|
||||
usage.PromptTokens = usage.PromptTokens + response.Usage.CacheReadInputTokens
|
||||
}
|
||||
if response.Usage.CacheWriteInputTokens > 0 {
|
||||
if usage.PromptTokensDetails == nil {
|
||||
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{}
|
||||
}
|
||||
usage.PromptTokensDetails.CachedWriteTokens = response.Usage.CacheWriteInputTokens
|
||||
usage.PromptTokens = usage.PromptTokens + response.Usage.CacheWriteInputTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Create the final Bifrost response
|
||||
bifrostResponse := &schemas.BifrostChatResponse{
|
||||
ID: uuid.New().String(),
|
||||
Model: model,
|
||||
Object: "chat.completion",
|
||||
Choices: choices,
|
||||
Usage: usage,
|
||||
Created: int(time.Now().Unix()),
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
},
|
||||
}
|
||||
|
||||
if response.ServiceTier != nil && response.ServiceTier.Type != "" {
|
||||
bifrostResponse.ServiceTier = &response.ServiceTier.Type
|
||||
}
|
||||
|
||||
return bifrostResponse, nil
|
||||
}
|
||||
|
||||
// BedrockStreamState tracks per-stream tool call index state.
|
||||
type BedrockStreamState struct {
|
||||
nextToolCallIndex int
|
||||
contentBlockToToolCallIdx map[int]int
|
||||
}
|
||||
|
||||
// NewBedrockStreamState returns initialised stream state for one streaming response.
|
||||
func NewBedrockStreamState() *BedrockStreamState {
|
||||
return &BedrockStreamState{
|
||||
contentBlockToToolCallIdx: make(map[int]int),
|
||||
}
|
||||
}
|
||||
|
||||
func (chunk *BedrockStreamEvent) ToBifrostChatCompletionStream(state *BedrockStreamState) (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) {
|
||||
if state == nil {
|
||||
state = NewBedrockStreamState()
|
||||
} else if state.contentBlockToToolCallIdx == nil {
|
||||
state.contentBlockToToolCallIdx = make(map[int]int)
|
||||
}
|
||||
|
||||
// event with metrics/usage is the last and with stop reason is the second last
|
||||
switch {
|
||||
case chunk.Role != nil:
|
||||
// Send empty response to signal start
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Role: chunk.Role,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
|
||||
case chunk.Start != nil && chunk.Start.ToolUse != nil:
|
||||
toolUseStart := chunk.Start.ToolUse
|
||||
|
||||
toolCallIdx := 0
|
||||
if chunk.ContentBlockIndex != nil {
|
||||
toolCallIdx = state.nextToolCallIndex
|
||||
state.contentBlockToToolCallIdx[*chunk.ContentBlockIndex] = toolCallIdx
|
||||
state.nextToolCallIndex++
|
||||
}
|
||||
|
||||
// Create tool call structure for start event
|
||||
var toolCall schemas.ChatAssistantMessageToolCall
|
||||
toolCall.Index = uint16(toolCallIdx)
|
||||
toolCall.ID = schemas.Ptr(toolUseStart.ToolUseID)
|
||||
toolCall.Type = schemas.Ptr("function")
|
||||
toolCall.Function.Name = schemas.Ptr(toolUseStart.Name)
|
||||
toolCall.Function.Arguments = "" // Start with empty arguments
|
||||
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
|
||||
case chunk.Delta != nil:
|
||||
switch {
|
||||
case chunk.Delta.Text != nil:
|
||||
// Handle text delta
|
||||
text := *chunk.Delta.Text
|
||||
if text != "" {
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Content: &text,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
}
|
||||
|
||||
case chunk.Delta.ToolUse != nil:
|
||||
// Handle tool use delta
|
||||
toolUseDelta := chunk.Delta.ToolUse
|
||||
|
||||
toolCallIdx := 0
|
||||
if chunk.ContentBlockIndex != nil {
|
||||
toolCallIdx = state.contentBlockToToolCallIdx[*chunk.ContentBlockIndex]
|
||||
}
|
||||
|
||||
// Create tool call structure
|
||||
var toolCall schemas.ChatAssistantMessageToolCall
|
||||
toolCall.Index = uint16(toolCallIdx)
|
||||
toolCall.Type = schemas.Ptr("function")
|
||||
|
||||
// For streaming, we need to accumulate tool use data
|
||||
// This is a simplified approach - in practice, you'd need to track tool calls across chunks
|
||||
toolCall.Function.Arguments = toolUseDelta.Input
|
||||
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
|
||||
case chunk.Delta.ReasoningContent != nil:
|
||||
// Handle reasoning content delta
|
||||
reasoningContentDelta := chunk.Delta.ReasoningContent
|
||||
|
||||
// Only construct and return a response when either Text or Signature is set
|
||||
if (reasoningContentDelta.Text == nil || *reasoningContentDelta.Text == "") && reasoningContentDelta.Signature == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
var streamResponse *schemas.BifrostChatResponse
|
||||
if reasoningContentDelta.Text != nil && *reasoningContentDelta.Text != "" {
|
||||
streamResponse = &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Reasoning: reasoningContentDelta.Text,
|
||||
ReasoningDetails: []schemas.ChatReasoningDetails{
|
||||
{
|
||||
Index: 0,
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: reasoningContentDelta.Text,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
} else if reasoningContentDelta.Signature != nil {
|
||||
streamResponse = &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
ReasoningDetails: []schemas.ChatReasoningDetails{
|
||||
{
|
||||
Index: 0,
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Signature: reasoningContentDelta.Signature,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, false
|
||||
}
|
||||
477
core/providers/bedrock/convert_tool_config_test.go
Normal file
477
core/providers/bedrock/convert_tool_config_test.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TestConvertToolConfig_DropsServerToolsOnBedrock locks in the bug fix from
|
||||
// the user-reported repro: sending `web_search_20260209` via the OpenAI-
|
||||
// compatible /v1/chat/completions endpoint to Bedrock was producing a
|
||||
// malformed ToolConfig that Bedrock rejected with 400 "The provided request
|
||||
// is not valid". The fix strips unsupported server tools before the
|
||||
// conversion loop so the outbound request is valid.
|
||||
func TestConvertToolConfig_DropsServerToolsOnBedrock(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: schemas.Ptr("Get weather by city"),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
// Server tool — Bedrock doesn't support web_search per Table 20.
|
||||
// Should be stripped silently.
|
||||
Type: schemas.ChatToolType("web_search_20260209"),
|
||||
Name: "web_search",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params)
|
||||
if cfg == nil {
|
||||
t.Fatalf("expected ToolConfig, got nil (function tool should have survived)")
|
||||
}
|
||||
if len(cfg.Tools) != 1 {
|
||||
t.Fatalf("expected exactly 1 tool (function), got %d: %+v", len(cfg.Tools), cfg.Tools)
|
||||
}
|
||||
if cfg.Tools[0].ToolSpec == nil || cfg.Tools[0].ToolSpec.Name != "get_weather" {
|
||||
t.Errorf("expected function tool 'get_weather' to survive, got %+v", cfg.Tools[0])
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToolConfig_ReturnsNilWhenAllDropped locks in the empty-slice
|
||||
// guard. Bedrock's Converse API rejects `"toolConfig": {"tools": []}` with a
|
||||
// 400; when every tool is unsupported and gets stripped, convertToolConfig
|
||||
// must return nil so no ToolConfig ships at all.
|
||||
func TestConvertToolConfig_ReturnsNilWhenAllDropped(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("web_search_20260209"),
|
||||
Name: "web_search",
|
||||
},
|
||||
{
|
||||
Type: schemas.ChatToolType("web_fetch_20260309"),
|
||||
Name: "web_fetch",
|
||||
},
|
||||
{
|
||||
Type: schemas.ChatToolType("code_execution_20250825"),
|
||||
Name: "code_execution",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params)
|
||||
if cfg != nil {
|
||||
t.Fatalf("expected nil ToolConfig (all tools unsupported on Bedrock), got %+v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertToolConfig_KeepsBedrockSupportedServerTools — locks in that
|
||||
// Bedrock-supported server tools (bash, memory, text_editor, computer,
|
||||
// tool_search) do NOT appear in Converse's typed toolConfig.tools slot —
|
||||
// they must be tunneled via additionalModelRequestFields (exercised in
|
||||
// TestCollectBedrockServerTools_*). If the only tool is a server tool,
|
||||
// toolConfig is nil so we don't ship {"toolConfig": {"tools": []}}.
|
||||
func TestConvertToolConfig_KeepsBedrockSupportedServerTools(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("bash_20250124"),
|
||||
Name: "bash",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := convertToolConfig("global.anthropic.claude-sonnet-4-6", params)
|
||||
if cfg != nil {
|
||||
t.Fatalf("expected nil toolConfig (server tools flow via additionalModelRequestFields, not toolSpec), got %+v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectBedrockServerTools_BashOnly — bash is Bedrock-supported per the
|
||||
// B-header list; the helper must emit it as a native-JSON tool entry with no
|
||||
// derived beta header (bash has no high-confidence 1:1 beta-header mapping;
|
||||
// callers rely on extra-headers for that).
|
||||
func TestCollectBedrockServerTools_BashOnly(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("bash_20250124"),
|
||||
Name: "bash",
|
||||
},
|
||||
},
|
||||
}
|
||||
tools, betas := collectBedrockServerTools(params)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 server tool, got %d", len(tools))
|
||||
}
|
||||
got := string(tools[0])
|
||||
if !strings.Contains(got, `"type":"bash_20250124"`) || !strings.Contains(got, `"name":"bash"`) {
|
||||
t.Errorf("expected native Anthropic bash shape, got %s", got)
|
||||
}
|
||||
if len(betas) != 0 {
|
||||
t.Errorf("expected no derived beta headers for bash (no 1:1 mapping), got %v", betas)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectBedrockServerTools_ComputerDerivesBeta — computer_YYYYMMDD must
|
||||
// derive computer-use-YYYY-MM-DD as the beta header, gated through
|
||||
// FilterBetaHeadersForProvider(Bedrock) which keeps computer-use-* headers.
|
||||
func TestCollectBedrockServerTools_ComputerDerivesBeta(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("computer_20251124"),
|
||||
Name: "computer",
|
||||
DisplayWidthPx: schemas.Ptr(1280),
|
||||
DisplayHeightPx: schemas.Ptr(800),
|
||||
},
|
||||
},
|
||||
}
|
||||
tools, betas := collectBedrockServerTools(params)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 server tool, got %d", len(tools))
|
||||
}
|
||||
if !strings.Contains(string(tools[0]), `"display_width_px":1280`) {
|
||||
t.Errorf("expected computer variant fields to flow through, got %s", string(tools[0]))
|
||||
}
|
||||
if len(betas) != 1 || betas[0] != "computer-use-2025-11-24" {
|
||||
t.Errorf("expected [computer-use-2025-11-24], got %v", betas)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectBedrockServerTools_MemoryDerivesContextManagement — memory
|
||||
// activates via the context-management-2025-06-27 bundle on Bedrock (cite:
|
||||
// anthropic/types.go:179).
|
||||
func TestCollectBedrockServerTools_MemoryDerivesContextManagement(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("memory_20250818"),
|
||||
Name: "memory",
|
||||
},
|
||||
},
|
||||
}
|
||||
_, betas := collectBedrockServerTools(params)
|
||||
if len(betas) != 1 || betas[0] != "context-management-2025-06-27" {
|
||||
t.Errorf("expected [context-management-2025-06-27], got %v", betas)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectBedrockServerTools_StripsUnsupported — web_search isn't in
|
||||
// Bedrock's ProviderFeatures (WebSearch=false), so ValidateChatToolsForProvider
|
||||
// drops it and the helper must emit nothing.
|
||||
func TestCollectBedrockServerTools_StripsUnsupported(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("web_search_20260209"),
|
||||
Name: "web_search",
|
||||
},
|
||||
},
|
||||
}
|
||||
tools, betas := collectBedrockServerTools(params)
|
||||
if len(tools) != 0 {
|
||||
t.Errorf("expected no server tools (web_search unsupported on Bedrock), got %d", len(tools))
|
||||
}
|
||||
if len(betas) != 0 {
|
||||
t.Errorf("expected no betas when all tools filtered, got %v", betas)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCollectBedrockServerTools_FunctionToolsIgnored — function/custom tools
|
||||
// go through convertToolConfig, not this helper.
|
||||
func TestCollectBedrockServerTools_FunctionToolsIgnored(t *testing.T) {
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
tools, betas := collectBedrockServerTools(params)
|
||||
if len(tools) != 0 || len(betas) != 0 {
|
||||
t.Errorf("function tools should not flow through server-tool helper, got tools=%d betas=%v", len(tools), betas)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildBedrockServerToolChoice_PinnedServerTool — caller pins a kept
|
||||
// server tool (computer) by name. Converse's typed toolConfig.toolChoice path
|
||||
// can't carry this because toolConfig.tools doesn't include server tools; the
|
||||
// existing reconciliation silently drops the pin. The tunneled path must
|
||||
// emit {"type":"tool","name":"computer"} into additionalModelRequestFields.
|
||||
func TestBuildBedrockServerToolChoice_PinnedServerTool(t *testing.T) {
|
||||
computer := schemas.ChatTool{
|
||||
Type: schemas.ChatToolType("computer_20251124"),
|
||||
Name: "computer",
|
||||
DisplayWidthPx: schemas.Ptr(1280),
|
||||
}
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{computer},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
|
||||
Type: schemas.ChatToolChoiceTypeFunction,
|
||||
Function: &schemas.ChatToolChoiceFunction{Name: "computer"},
|
||||
},
|
||||
},
|
||||
}
|
||||
choice, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{computer})
|
||||
if !ok {
|
||||
t.Fatalf("expected tunneled tool_choice for pinned server tool, got (nil, false)")
|
||||
}
|
||||
got := string(choice)
|
||||
if !strings.Contains(got, `"type":"tool"`) || !strings.Contains(got, `"name":"computer"`) {
|
||||
t.Errorf("expected Anthropic-native {type:tool,name:computer}, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildBedrockServerToolChoice_PinnedFunctionTool_NotTunneled — function
|
||||
// tool pins stay on Converse's typed path (toolConfig.toolChoice.tool). The
|
||||
// helper must not double-emit.
|
||||
func TestBuildBedrockServerToolChoice_PinnedFunctionTool_NotTunneled(t *testing.T) {
|
||||
fn := schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: &schemas.ToolFunctionParameters{Type: "object"},
|
||||
},
|
||||
}
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{fn},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
|
||||
Type: schemas.ChatToolChoiceTypeFunction,
|
||||
Function: &schemas.ChatToolChoiceFunction{Name: "get_weather"},
|
||||
},
|
||||
},
|
||||
}
|
||||
if _, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{fn}); ok {
|
||||
t.Errorf("expected no tunneling for function-tool pin (typed Converse path handles it)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildBedrockServerToolChoice_AnyWithOnlyServerTools — tool_choice:any
|
||||
// with only server tools: convertToolConfig returns nil (bedrockTools empty),
|
||||
// so the typed any-contract is lost. The tunneled path must emit
|
||||
// {"type":"any"} to preserve the forcing semantics.
|
||||
func TestBuildBedrockServerToolChoice_AnyWithOnlyServerTools(t *testing.T) {
|
||||
bash := schemas.ChatTool{
|
||||
Type: schemas.ChatToolType("bash_20250124"),
|
||||
Name: "bash",
|
||||
}
|
||||
anyStr := string(schemas.ChatToolChoiceTypeAny)
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{bash},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: &anyStr,
|
||||
},
|
||||
}
|
||||
choice, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{bash})
|
||||
if !ok {
|
||||
t.Fatalf("expected tunneled any-contract when only server tools are present, got (nil, false)")
|
||||
}
|
||||
got := string(choice)
|
||||
if !strings.Contains(got, `"type":"any"`) {
|
||||
t.Errorf("expected {type:any}, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildBedrockServerToolChoice_AnyWithFunctionTool_NotTunneled — when at
|
||||
// least one function/custom tool is present, Converse's typed
|
||||
// toolConfig.toolChoice.any carries the any-contract. Don't double-emit.
|
||||
func TestBuildBedrockServerToolChoice_AnyWithFunctionTool_NotTunneled(t *testing.T) {
|
||||
fn := schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: &schemas.ToolFunctionParameters{Type: "object"},
|
||||
},
|
||||
}
|
||||
bash := schemas.ChatTool{
|
||||
Type: schemas.ChatToolType("bash_20250124"),
|
||||
Name: "bash",
|
||||
}
|
||||
anyStr := string(schemas.ChatToolChoiceTypeAny)
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{fn, bash},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: &anyStr,
|
||||
},
|
||||
}
|
||||
if _, ok := buildBedrockServerToolChoice(params, []schemas.ChatTool{fn, bash}); ok {
|
||||
t.Errorf("expected no tunneling when function/custom tool is present (typed Converse path handles any)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildBedrockServerToolChoice_UnsupportedServerToolPin_NotTunneled — the
|
||||
// caller pins web_search, which ValidateChatToolsForProvider strips on
|
||||
// Bedrock. The pin name is absent from the filtered set; the helper must not
|
||||
// fabricate a tunneled tool_choice for a tool that isn't in the request.
|
||||
func TestBuildBedrockServerToolChoice_UnsupportedServerToolPin_NotTunneled(t *testing.T) {
|
||||
// The caller's original request had web_search, but it's been stripped.
|
||||
// We pass the filtered slice (empty for the server-tool axis) to mimic
|
||||
// the convertChatParameters call path.
|
||||
params := &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{{Type: schemas.ChatToolType("web_search_20260209"), Name: "web_search"}},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
|
||||
Type: schemas.ChatToolChoiceTypeFunction,
|
||||
Function: &schemas.ChatToolChoiceFunction{Name: "web_search"},
|
||||
},
|
||||
},
|
||||
}
|
||||
// Filtered (post-ValidateChatToolsForProvider(Bedrock)) — web_search is dropped.
|
||||
filtered := []schemas.ChatTool{}
|
||||
if _, ok := buildBedrockServerToolChoice(params, filtered); ok {
|
||||
t.Errorf("expected no tunneling when pinned name was stripped by provider validation")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertChatParameters_PinnedServerToolE2E — end-to-end verification
|
||||
// that convertChatParameters composes convertToolConfig +
|
||||
// collectBedrockServerTools + buildBedrockServerToolChoice such that a
|
||||
// request pinning a kept server tool produces:
|
||||
// - AdditionalModelRequestFields.tools containing the server tool
|
||||
// - AdditionalModelRequestFields.tool_choice with Anthropic-native shape
|
||||
// - ToolConfig nil (no function tools → Converse's typed path is inactive)
|
||||
func TestConvertChatParameters_PinnedServerToolE2E(t *testing.T) {
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Model: "global.anthropic.claude-sonnet-4-6",
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("computer_20251124"),
|
||||
Name: "computer",
|
||||
DisplayWidthPx: schemas.Ptr(1280),
|
||||
},
|
||||
},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
|
||||
Type: schemas.ChatToolChoiceTypeFunction,
|
||||
Function: &schemas.ChatToolChoiceFunction{Name: "computer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bedrockReq := &BedrockConverseRequest{}
|
||||
if err := convertChatParameters(nil, bifrostReq, bedrockReq); err != nil {
|
||||
t.Fatalf("convertChatParameters failed: %v", err)
|
||||
}
|
||||
if bedrockReq.ToolConfig != nil {
|
||||
t.Errorf("expected nil ToolConfig (no function/custom tools), got %+v", bedrockReq.ToolConfig)
|
||||
}
|
||||
if bedrockReq.AdditionalModelRequestFields == nil {
|
||||
t.Fatalf("expected AdditionalModelRequestFields to carry server-tool payload, got nil")
|
||||
}
|
||||
tools, ok := bedrockReq.AdditionalModelRequestFields.Get("tools")
|
||||
if !ok {
|
||||
t.Errorf("expected additionalModelRequestFields.tools to be set for server tool")
|
||||
} else if toolsSlice, castOK := tools.([]json.RawMessage); !castOK || len(toolsSlice) != 1 {
|
||||
t.Errorf("expected 1 server tool in additionalModelRequestFields.tools, got %+v", tools)
|
||||
}
|
||||
choice, ok := bedrockReq.AdditionalModelRequestFields.Get("tool_choice")
|
||||
if !ok {
|
||||
t.Fatalf("expected additionalModelRequestFields.tool_choice to carry pinned server-tool contract")
|
||||
}
|
||||
choiceRaw, castOK := choice.(json.RawMessage)
|
||||
if !castOK {
|
||||
t.Fatalf("expected tool_choice value to be json.RawMessage, got %T", choice)
|
||||
}
|
||||
got := string(choiceRaw)
|
||||
if !strings.Contains(got, `"type":"tool"`) || !strings.Contains(got, `"name":"computer"`) {
|
||||
t.Errorf("expected {type:tool,name:computer}, got %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertChatParameters_ResponseFormatWithPinnedServerTool_NoConflictingChoice
|
||||
// locks in the fix for the "two conflicting tool-choice directives" hazard:
|
||||
// when response_format forces the synthetic bf_so_* tool via
|
||||
// ToolConfig.ToolChoice, the tunneled additionalModelRequestFields.tool_choice
|
||||
// (which would pin a server tool) must be suppressed so Bedrock doesn't
|
||||
// receive both pins in the same Converse call. Uses a Nova model since
|
||||
// Anthropic models route response_format through native output_config.format
|
||||
// (no synthetic tool), so the conflict only surfaces on non-Anthropic
|
||||
// Bedrock targets.
|
||||
func TestConvertChatParameters_ResponseFormatWithPinnedServerTool_NoConflictingChoice(t *testing.T) {
|
||||
responseFormat := any(map[string]any{
|
||||
"type": "json_schema",
|
||||
"json_schema": map[string]any{
|
||||
"name": "classification",
|
||||
"schema": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"topic": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"topic"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Model: "amazon.nova-pro-v1:0",
|
||||
Params: &schemas.ChatParameters{
|
||||
ResponseFormat: &responseFormat,
|
||||
Tools: []schemas.ChatTool{
|
||||
{
|
||||
Type: schemas.ChatToolType("bash_20250124"),
|
||||
Name: "bash",
|
||||
},
|
||||
},
|
||||
ToolChoice: &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{
|
||||
Type: schemas.ChatToolChoiceTypeFunction,
|
||||
Function: &schemas.ChatToolChoiceFunction{Name: "bash"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bedrockReq := &BedrockConverseRequest{}
|
||||
if err := convertChatParameters(ctx, bifrostReq, bedrockReq); err != nil {
|
||||
t.Fatalf("convertChatParameters failed: %v", err)
|
||||
}
|
||||
|
||||
// Synthetic bf_so_* tool must be injected and pinned via Converse's typed path.
|
||||
if bedrockReq.ToolConfig == nil {
|
||||
t.Fatalf("expected ToolConfig with synthetic bf_so_* tool, got nil")
|
||||
}
|
||||
if bedrockReq.ToolConfig.ToolChoice == nil || bedrockReq.ToolConfig.ToolChoice.Tool == nil {
|
||||
t.Fatalf("expected ToolConfig.ToolChoice.Tool to pin synthetic structured-output tool, got %+v", bedrockReq.ToolConfig.ToolChoice)
|
||||
}
|
||||
if !strings.HasPrefix(bedrockReq.ToolConfig.ToolChoice.Tool.Name, "bf_so_") {
|
||||
t.Errorf("expected ToolConfig.ToolChoice.Tool.Name to start with bf_so_, got %q", bedrockReq.ToolConfig.ToolChoice.Tool.Name)
|
||||
}
|
||||
|
||||
// Server tool must still be tunneled so the model has it available.
|
||||
if bedrockReq.AdditionalModelRequestFields == nil {
|
||||
t.Fatalf("expected AdditionalModelRequestFields to carry tunneled server-tool payload, got nil")
|
||||
}
|
||||
if _, ok := bedrockReq.AdditionalModelRequestFields.Get("tools"); !ok {
|
||||
t.Errorf("expected additionalModelRequestFields.tools to still carry bash server tool")
|
||||
}
|
||||
|
||||
// Guarded field: tunneled tool_choice MUST be absent because response_format
|
||||
// forces the synthetic tool. Two tool-choice directives in the same request
|
||||
// would let Bedrock pick one and silently violate the structured-output contract.
|
||||
if _, ok := bedrockReq.AdditionalModelRequestFields.Get("tool_choice"); ok {
|
||||
t.Errorf("expected NO additionalModelRequestFields.tool_choice when response_format pins bf_so_* (conflict hazard)")
|
||||
}
|
||||
}
|
||||
57
core/providers/bedrock/count_tokens.go
Normal file
57
core/providers/bedrock/count_tokens.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const estimatedBytesPerToken = 4
|
||||
|
||||
// ToBifrostCountTokensResponse converts a Bedrock count tokens response to Bifrost format
|
||||
func (resp *BedrockCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
totalTokens := resp.InputTokens
|
||||
|
||||
return &schemas.BifrostCountTokensResponse{
|
||||
Model: model,
|
||||
InputTokens: resp.InputTokens,
|
||||
TotalTokens: &totalTokens,
|
||||
Object: "response.input_tokens",
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockCountTokensResponse converts a Bifrost count tokens response to Bedrock native format
|
||||
func ToBedrockCountTokensResponse(resp *schemas.BifrostCountTokensResponse) *BedrockCountTokensResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &BedrockCountTokensResponse{
|
||||
InputTokens: resp.InputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// isCountTokensUnsupported checks whether a BifrostError indicates that the
|
||||
// Bedrock model does not support the count-tokens operation.
|
||||
func isCountTokensUnsupported(err *schemas.BifrostError) bool {
|
||||
if err == nil || err.Error == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(err.Error.Message), "doesn't support counting tokens")
|
||||
}
|
||||
|
||||
// estimateTokenCount returns a rough token count derived from the byte length
|
||||
// of the serialized request body. Claude's tokenizer averages ~4 bytes per
|
||||
// token on mixed content; this intentionally rounds up so that context-window
|
||||
// management decisions stay on the conservative side.
|
||||
func estimateTokenCount(requestBody []byte) int {
|
||||
n := len(requestBody)
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
return (n + estimatedBytesPerToken - 1) / estimatedBytesPerToken
|
||||
}
|
||||
105
core/providers/bedrock/count_tokens_test.go
Normal file
105
core/providers/bedrock/count_tokens_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsCountTokensUnsupported(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *schemas.BifrostError
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nil error field",
|
||||
err: &schemas.BifrostError{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "matching bedrock error message",
|
||||
err: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "The provided model doesn't support counting tokens.",
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "matching message with different casing",
|
||||
err: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "the provided model DOESN'T SUPPORT COUNTING TOKENS.",
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "unrelated error message",
|
||||
err: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "access denied",
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.expected, isCountTokensUnsupported(tc.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEstimateTokenCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte{},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "nil input",
|
||||
input: nil,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "exact multiple of 4",
|
||||
input: make([]byte, 100),
|
||||
expected: 25,
|
||||
},
|
||||
{
|
||||
name: "rounds up",
|
||||
input: make([]byte, 101),
|
||||
expected: 26,
|
||||
},
|
||||
{
|
||||
name: "single byte",
|
||||
input: []byte("x"),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "realistic json body",
|
||||
input: []byte(`{"messages":[{"role":"user","content":"Hello, how are you today?"}],"model":"us.anthropic.claude-sonnet-4-6"}`),
|
||||
expected: 28,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.expected, estimateTokenCount(tc.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
271
core/providers/bedrock/embedding.go
Normal file
271
core/providers/bedrock/embedding.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBedrockTitanEmbeddingRequest converts a Bifrost embedding request to Bedrock Titan format
|
||||
func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockTitanEmbeddingRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrost embedding request is nil")
|
||||
}
|
||||
|
||||
// Validate that only single text input is provided for Titan models
|
||||
if bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0 {
|
||||
return nil, fmt.Errorf("no input text provided for embedding")
|
||||
}
|
||||
|
||||
titanReq := &BedrockTitanEmbeddingRequest{}
|
||||
|
||||
// Set input text
|
||||
if bifrostReq.Input.Text != nil {
|
||||
titanReq.InputText = *bifrostReq.Input.Text
|
||||
} else if len(bifrostReq.Input.Texts) > 0 {
|
||||
var embeddingText string
|
||||
for _, text := range bifrostReq.Input.Texts {
|
||||
embeddingText += text + " \n"
|
||||
}
|
||||
titanReq.InputText = embeddingText
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
titanReq.Dimensions = bifrostReq.Params.Dimensions
|
||||
if normalize, ok := bifrostReq.Params.ExtraParams["normalize"]; ok {
|
||||
if b, ok := normalize.(bool); ok {
|
||||
titanReq.Normalize = &b
|
||||
}
|
||||
}
|
||||
// Forward remaining extra params (excluding normalize which is now a first-class field)
|
||||
if len(bifrostReq.Params.ExtraParams) > 0 {
|
||||
extra := make(map[string]interface{})
|
||||
for k, v := range bifrostReq.Params.ExtraParams {
|
||||
if k != "normalize" {
|
||||
extra[k] = v
|
||||
}
|
||||
}
|
||||
if len(extra) > 0 {
|
||||
titanReq.ExtraParams = extra
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return titanReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingResponse converts a Bedrock Titan embedding response to Bifrost format
|
||||
func (response *BedrockTitanEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: []schemas.EmbeddingData{
|
||||
{
|
||||
Index: 0,
|
||||
Object: "embedding",
|
||||
Embedding: schemas.EmbeddingStruct{
|
||||
EmbeddingArray: response.Embedding,
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: &schemas.BifrostLLMUsage{
|
||||
PromptTokens: response.InputTextTokenCount,
|
||||
TotalTokens: response.InputTextTokenCount,
|
||||
},
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
// ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format.
|
||||
// Unlike the direct Cohere API, Bedrock does not accept a "model" field in the request body.
|
||||
func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockCohereEmbeddingRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrost embedding request is nil")
|
||||
}
|
||||
if bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0) {
|
||||
return nil, fmt.Errorf("no input provided for embedding")
|
||||
}
|
||||
|
||||
req := &BedrockCohereEmbeddingRequest{}
|
||||
|
||||
// Map texts
|
||||
if bifrostReq.Input.Text != nil {
|
||||
req.Texts = []string{*bifrostReq.Input.Text}
|
||||
} else if len(bifrostReq.Input.Texts) > 0 {
|
||||
req.Texts = bifrostReq.Input.Texts
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
extra := make(map[string]interface{}, len(bifrostReq.Params.ExtraParams))
|
||||
for k, v := range bifrostReq.Params.ExtraParams {
|
||||
extra[k] = v
|
||||
}
|
||||
|
||||
if v, ok := extra["input_type"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
req.InputType = s
|
||||
delete(extra, "input_type")
|
||||
}
|
||||
}
|
||||
if v, ok := extra["truncate"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
req.Truncate = &s
|
||||
delete(extra, "truncate")
|
||||
}
|
||||
}
|
||||
if v, ok := extra["embedding_types"]; ok {
|
||||
if ss, ok := v.([]string); ok {
|
||||
req.EmbeddingTypes = ss
|
||||
delete(extra, "embedding_types")
|
||||
}
|
||||
}
|
||||
if v, ok := extra["images"]; ok {
|
||||
if ss, ok := v.([]string); ok {
|
||||
req.Images = ss
|
||||
delete(extra, "images")
|
||||
}
|
||||
}
|
||||
if v, ok := extra["inputs"]; ok {
|
||||
if inputs, ok := v.([]BedrockCohereEmbeddingInput); ok {
|
||||
req.Inputs = inputs
|
||||
delete(extra, "inputs")
|
||||
}
|
||||
}
|
||||
if v, ok := extra["max_tokens"]; ok {
|
||||
switch n := v.(type) {
|
||||
case int:
|
||||
req.MaxTokens = &n
|
||||
delete(extra, "max_tokens")
|
||||
case float64:
|
||||
i := int(n)
|
||||
req.MaxTokens = &i
|
||||
delete(extra, "max_tokens")
|
||||
}
|
||||
}
|
||||
if bifrostReq.Params.Dimensions != nil {
|
||||
req.OutputDimension = bifrostReq.Params.Dimensions
|
||||
}
|
||||
if len(extra) > 0 {
|
||||
req.ExtraParams = extra
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// DetermineEmbeddingModelType determines the embedding model type from the model name
|
||||
func DetermineEmbeddingModelType(model string) (string, error) {
|
||||
switch {
|
||||
case strings.Contains(model, "amazon.titan-embed-text"):
|
||||
return "titan", nil
|
||||
case strings.Contains(model, "cohere.embed"):
|
||||
return "cohere", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported embedding model: %s", model)
|
||||
}
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingResponse converts a BedrockCohereEmbeddingResponse to Bifrost format.
|
||||
// Bedrock returns embeddings as a raw [][]float32 when response_type is "embeddings_floats"
|
||||
// (the default, when no embedding_types are requested), and as a typed object when
|
||||
// response_type is "embeddings_by_type".
|
||||
func (r *BedrockCohereEmbeddingResponse) ToBifrostEmbeddingResponse() (*schemas.BifrostEmbeddingResponse, error) {
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("nil Bedrock Cohere embedding response")
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostEmbeddingResponse{Object: "list"}
|
||||
|
||||
switch r.ResponseType {
|
||||
case "embeddings_by_type":
|
||||
// Object form: {"float": [[...]], "int8": [[...]], "uint8": [[...]], "binary": [[...]], "ubinary": [[...]], "base64": [...]}
|
||||
var typed struct {
|
||||
Float [][]float32 `json:"float"`
|
||||
Base64 []string `json:"base64"`
|
||||
Int8 [][]int8 `json:"int8"`
|
||||
Uint8 [][]int32 `json:"uint8"` // int32 avoids []byte→base64 JSON issue
|
||||
Binary [][]int8 `json:"binary"`
|
||||
Ubinary [][]int32 `json:"ubinary"` // int32 avoids []byte→base64 JSON issue
|
||||
}
|
||||
if err := json.Unmarshal(r.Embeddings, &typed); err != nil {
|
||||
return nil, fmt.Errorf("error parsing embeddings_by_type: %w", err)
|
||||
}
|
||||
if typed.Float != nil {
|
||||
for i, emb := range typed.Float {
|
||||
float64Emb := make([]float64, len(emb))
|
||||
for j, v := range emb {
|
||||
float64Emb[j] = float64(v)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb},
|
||||
})
|
||||
}
|
||||
}
|
||||
if typed.Base64 != nil {
|
||||
for i, emb := range typed.Base64 {
|
||||
e := emb
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingStr: &e},
|
||||
})
|
||||
}
|
||||
}
|
||||
for i, emb := range typed.Int8 {
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb},
|
||||
})
|
||||
}
|
||||
for i, emb := range typed.Binary {
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb},
|
||||
})
|
||||
}
|
||||
for i, emb := range typed.Uint8 {
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb},
|
||||
})
|
||||
}
|
||||
for i, emb := range typed.Ubinary {
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb},
|
||||
})
|
||||
}
|
||||
|
||||
default:
|
||||
// Default / "embeddings_floats": raw array form [[...], [...]]
|
||||
var floats [][]float32
|
||||
if err := json.Unmarshal(r.Embeddings, &floats); err != nil {
|
||||
return nil, fmt.Errorf("error parsing embeddings_floats: %w", err)
|
||||
}
|
||||
for i, emb := range floats {
|
||||
float64Emb := make([]float64, len(emb))
|
||||
for j, v := range emb {
|
||||
float64Emb[j] = float64(v)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse, nil
|
||||
}
|
||||
114
core/providers/bedrock/embedding_test.go
Normal file
114
core/providers/bedrock/embedding_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToBedrockCohereEmbeddingRequest(t *testing.T) {
|
||||
t.Run("returns error for nil request", func(t *testing.T) {
|
||||
req, err := ToBedrockCohereEmbeddingRequest(nil)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, req)
|
||||
assert.Contains(t, err.Error(), "nil")
|
||||
})
|
||||
|
||||
t.Run("returns error for missing input", func(t *testing.T) {
|
||||
req, err := ToBedrockCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, req)
|
||||
assert.Contains(t, err.Error(), "no input")
|
||||
})
|
||||
|
||||
t.Run("returns error for non-nil but empty input", func(t *testing.T) {
|
||||
req, err := ToBedrockCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{
|
||||
Input: &schemas.EmbeddingInput{},
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, req)
|
||||
assert.Contains(t, err.Error(), "no input")
|
||||
})
|
||||
|
||||
t.Run("single text strips model and extracts typed params", func(t *testing.T) {
|
||||
text := "hello"
|
||||
truncate := "RIGHT"
|
||||
dimensions := 512
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Model: "cohere.embed-english-v3",
|
||||
Input: &schemas.EmbeddingInput{Text: &text},
|
||||
Params: &schemas.EmbeddingParameters{
|
||||
Dimensions: &dimensions,
|
||||
ExtraParams: map[string]interface{}{
|
||||
"input_type": "search_query",
|
||||
"embedding_types": []string{"float"},
|
||||
"truncate": truncate,
|
||||
"max_tokens": float64(128),
|
||||
"trace_id": "req-123",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req, err := ToBedrockCohereEmbeddingRequest(bifrostReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req)
|
||||
assert.Equal(t, "search_query", req.InputType)
|
||||
assert.Equal(t, []string{"hello"}, req.Texts)
|
||||
assert.Equal(t, []string{"float"}, req.EmbeddingTypes)
|
||||
assert.Equal(t, &dimensions, req.OutputDimension)
|
||||
assert.Equal(t, 128, *req.MaxTokens)
|
||||
require.NotNil(t, req.Truncate)
|
||||
assert.Equal(t, truncate, *req.Truncate)
|
||||
assert.Equal(t, map[string]interface{}{"trace_id": "req-123"}, req.ExtraParams)
|
||||
})
|
||||
|
||||
t.Run("multiple texts preserve bedrock body shape", func(t *testing.T) {
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Model: "cohere.embed-multilingual-v3",
|
||||
Input: &schemas.EmbeddingInput{Texts: []string{"hello", "world"}},
|
||||
Params: &schemas.EmbeddingParameters{
|
||||
ExtraParams: map[string]interface{}{
|
||||
"input_type": "search_document",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req, err := ToBedrockCohereEmbeddingRequest(bifrostReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"hello", "world"}, req.Texts)
|
||||
assert.Equal(t, "search_document", req.InputType)
|
||||
})
|
||||
}
|
||||
|
||||
func TestToBedrockCohereEmbeddingRequestBodyOmitsModel(t *testing.T) {
|
||||
text := "hello"
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Model: "cohere.embed-english-v3",
|
||||
Input: &schemas.EmbeddingInput{Text: &text},
|
||||
Params: &schemas.EmbeddingParameters{
|
||||
ExtraParams: map[string]interface{}{
|
||||
"input_type": "search_document",
|
||||
"embedding_types": []string{"float"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
||||
context.Background(),
|
||||
bifrostReq,
|
||||
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
||||
return ToBedrockCohereEmbeddingRequest(bifrostReq)
|
||||
},
|
||||
)
|
||||
require.Nil(t, bifrostErr)
|
||||
assert.NotContains(t, string(wireBody), `"model"`)
|
||||
assert.JSONEq(t, `{
|
||||
"input_type": "search_document",
|
||||
"texts": ["hello"],
|
||||
"embedding_types": ["float"]
|
||||
}`, string(wireBody))
|
||||
}
|
||||
34
core/providers/bedrock/errors.go
Normal file
34
core/providers/bedrock/errors.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func parseBedrockHTTPError(statusCode int, headers http.Header, body []byte) *schemas.BifrostError {
|
||||
fastResp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseResponse(fastResp)
|
||||
|
||||
fastResp.SetStatusCode(statusCode)
|
||||
for k, values := range headers {
|
||||
for _, value := range values {
|
||||
fastResp.Header.Add(k, value)
|
||||
}
|
||||
}
|
||||
fastResp.SetBody(body)
|
||||
|
||||
var errorResp BedrockError
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(fastResp, &errorResp)
|
||||
if errorResp.Message != "" {
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
bifrostErr.Error.Message = errorResp.Message
|
||||
bifrostErr.Error.Code = errorResp.Code
|
||||
}
|
||||
|
||||
return bifrostErr
|
||||
}
|
||||
276
core/providers/bedrock/files.go
Normal file
276
core/providers/bedrock/files.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// escapeS3KeyForURL escapes each segment of an S3 key path individually.
|
||||
// This prevents signature and URL parsing failures with special characters.
|
||||
// We can't use url.PathEscape on the full key as it escapes "/" to "%2F",
|
||||
// but we need each segment properly escaped per RFC 3986 for AWS SigV4 signing.
|
||||
func escapeS3KeyForURL(key string) string {
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(key, "/")
|
||||
for i, p := range parts {
|
||||
parts[i] = url.PathEscape(p)
|
||||
}
|
||||
return strings.Join(parts, "/")
|
||||
}
|
||||
|
||||
// parseS3URI parses an S3 URI (s3://bucket/key or bucket-name) and returns bucket name and key.
|
||||
func parseS3URI(uri string) (bucket, key string) {
|
||||
if strings.HasPrefix(uri, "s3://") {
|
||||
uri = strings.TrimPrefix(uri, "s3://")
|
||||
parts := strings.SplitN(uri, "/", 2)
|
||||
bucket = parts[0]
|
||||
if len(parts) > 1 {
|
||||
key = parts[1]
|
||||
}
|
||||
} else {
|
||||
// Assume it's just a bucket name
|
||||
bucket = uri
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// S3ListObjectsResponse represents S3 ListObjectsV2 response.
|
||||
type S3ListObjectsResponse struct {
|
||||
Contents []S3Object `json:"contents"`
|
||||
IsTruncated bool `json:"isTruncated"`
|
||||
NextContinuationToken string `json:"nextContinuationToken,omitempty"`
|
||||
}
|
||||
|
||||
// S3Object represents an S3 object in list response.
|
||||
type S3Object struct {
|
||||
Key string `json:"key"`
|
||||
Size int64 `json:"size"`
|
||||
LastModified time.Time `json:"lastModified"`
|
||||
}
|
||||
|
||||
// parseS3ListResponse parses S3 ListObjectsV2 XML response.
|
||||
func parseS3ListResponse(body []byte, resp *S3ListObjectsResponse) error {
|
||||
// S3 returns XML, so we need to parse it
|
||||
// Try JSON first (some S3-compatible services return JSON)
|
||||
if err := sonic.Unmarshal(body, resp); err == nil && len(resp.Contents) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse XML using simple string matching for key fields
|
||||
// This is a lightweight approach that doesn't require encoding/xml
|
||||
bodyStr := string(body)
|
||||
|
||||
// Parse IsTruncated
|
||||
if strings.Contains(bodyStr, "<IsTruncated>true</IsTruncated>") {
|
||||
resp.IsTruncated = true
|
||||
}
|
||||
|
||||
// Parse NextContinuationToken
|
||||
if start := strings.Index(bodyStr, "<NextContinuationToken>"); start >= 0 {
|
||||
start += len("<NextContinuationToken>")
|
||||
if end := strings.Index(bodyStr[start:], "</NextContinuationToken>"); end >= 0 {
|
||||
resp.NextContinuationToken = bodyStr[start : start+end]
|
||||
}
|
||||
}
|
||||
|
||||
// Parse Contents
|
||||
contents := bodyStr
|
||||
for {
|
||||
start := strings.Index(contents, "<Contents>")
|
||||
if start < 0 {
|
||||
break
|
||||
}
|
||||
end := strings.Index(contents[start:], "</Contents>")
|
||||
if end < 0 {
|
||||
break
|
||||
}
|
||||
|
||||
contentBlock := contents[start : start+end+len("</Contents>")]
|
||||
contents = contents[start+end+len("</Contents>"):]
|
||||
|
||||
obj := S3Object{}
|
||||
|
||||
// Parse Key
|
||||
if keyStart := strings.Index(contentBlock, "<Key>"); keyStart >= 0 {
|
||||
keyStart += len("<Key>")
|
||||
if keyEnd := strings.Index(contentBlock[keyStart:], "</Key>"); keyEnd >= 0 {
|
||||
obj.Key = html.UnescapeString(contentBlock[keyStart : keyStart+keyEnd])
|
||||
}
|
||||
}
|
||||
|
||||
// Parse Size
|
||||
if sizeStart := strings.Index(contentBlock, "<Size>"); sizeStart >= 0 {
|
||||
sizeStart += len("<Size>")
|
||||
if sizeEnd := strings.Index(contentBlock[sizeStart:], "</Size>"); sizeEnd >= 0 {
|
||||
sizeStr := contentBlock[sizeStart : sizeStart+sizeEnd]
|
||||
fmt.Sscanf(sizeStr, "%d", &obj.Size)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse LastModified
|
||||
if lmStart := strings.Index(contentBlock, "<LastModified>"); lmStart >= 0 {
|
||||
lmStart += len("<LastModified>")
|
||||
if lmEnd := strings.Index(contentBlock[lmStart:], "</LastModified>"); lmEnd >= 0 {
|
||||
lmStr := contentBlock[lmStart : lmStart+lmEnd]
|
||||
if t, err := time.Parse(time.RFC3339Nano, lmStr); err == nil {
|
||||
obj.LastModified = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if obj.Key != "" {
|
||||
resp.Contents = append(resp.Contents, obj)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ==================== BEDROCK FILE TYPE CONVERTERS ====================
|
||||
|
||||
// ToBedrockFileUploadResponse converts a Bifrost file upload response to Bedrock format.
|
||||
func ToBedrockFileUploadResponse(resp *schemas.BifrostFileUploadResponse) *BedrockFileUploadResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse S3 URI to get bucket and key
|
||||
bucket, key := parseS3URI(resp.ID)
|
||||
|
||||
return &BedrockFileUploadResponse{
|
||||
S3Uri: resp.ID,
|
||||
Bucket: bucket,
|
||||
Key: key,
|
||||
SizeBytes: resp.Bytes,
|
||||
ContentType: "application/jsonl",
|
||||
CreatedAt: resp.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockFileListResponse converts a Bifrost file list response to Bedrock format.
|
||||
func ToBedrockFileListResponse(resp *schemas.BifrostFileListResponse) *BedrockFileListResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
files := make([]BedrockFileInfo, len(resp.Data))
|
||||
for i, f := range resp.Data {
|
||||
_, key := parseS3URI(f.ID)
|
||||
files[i] = BedrockFileInfo{
|
||||
S3Uri: f.ID,
|
||||
Key: key,
|
||||
SizeBytes: f.Bytes,
|
||||
LastModified: f.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
return &BedrockFileListResponse{
|
||||
Files: files,
|
||||
IsTruncated: resp.HasMore,
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockFileRetrieveResponse converts a Bifrost file retrieve response to Bedrock format.
|
||||
func ToBedrockFileRetrieveResponse(resp *schemas.BifrostFileRetrieveResponse) *BedrockFileRetrieveResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, key := parseS3URI(resp.ID)
|
||||
|
||||
return &BedrockFileRetrieveResponse{
|
||||
S3Uri: resp.ID,
|
||||
Key: key,
|
||||
SizeBytes: resp.Bytes,
|
||||
LastModified: resp.CreatedAt,
|
||||
ContentType: "application/jsonl",
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockFileDeleteResponse converts a Bifrost file delete response to Bedrock format.
|
||||
func ToBedrockFileDeleteResponse(resp *schemas.BifrostFileDeleteResponse) *BedrockFileDeleteResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &BedrockFileDeleteResponse{
|
||||
S3Uri: resp.ID,
|
||||
Deleted: resp.Deleted,
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockFileContentResponse converts a Bifrost file content response to Bedrock format.
|
||||
func ToBedrockFileContentResponse(resp *schemas.BifrostFileContentResponse) *BedrockFileContentResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &BedrockFileContentResponse{
|
||||
S3Uri: resp.FileID,
|
||||
Content: resp.Content,
|
||||
ContentType: resp.ContentType,
|
||||
SizeBytes: int64(len(resp.Content)),
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== S3 API XML FORMATTERS ====================
|
||||
|
||||
// ToS3ListObjectsV2XML converts a Bifrost file list response to S3 ListObjectsV2 XML format.
|
||||
func ToS3ListObjectsV2XML(resp *schemas.BifrostFileListResponse, bucket, prefix string, maxKeys int) []byte {
|
||||
if resp == nil {
|
||||
return []byte(`<?xml version="1.0" encoding="UTF-8"?><ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/"></ListBucketResult>`)
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(`<?xml version="1.0" encoding="UTF-8"?>`)
|
||||
sb.WriteString(`<ListBucketResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">`)
|
||||
sb.WriteString(fmt.Sprintf("<Name>%s</Name>", bucket))
|
||||
sb.WriteString(fmt.Sprintf("<Prefix>%s</Prefix>", prefix))
|
||||
sb.WriteString(fmt.Sprintf("<KeyCount>%d</KeyCount>", len(resp.Data)))
|
||||
sb.WriteString(fmt.Sprintf("<MaxKeys>%d</MaxKeys>", maxKeys))
|
||||
if resp.HasMore {
|
||||
sb.WriteString("<IsTruncated>true</IsTruncated>")
|
||||
if resp.After != nil && *resp.After != "" {
|
||||
sb.WriteString(fmt.Sprintf("<NextContinuationToken>%s</NextContinuationToken>", *resp.After))
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("<IsTruncated>false</IsTruncated>")
|
||||
}
|
||||
|
||||
for _, f := range resp.Data {
|
||||
// Extract key from S3 URI
|
||||
_, key := parseS3URI(f.ID)
|
||||
sb.WriteString("<Contents>")
|
||||
sb.WriteString(fmt.Sprintf("<Key>%s</Key>", key))
|
||||
sb.WriteString(fmt.Sprintf("<Size>%d</Size>", f.Bytes))
|
||||
if f.CreatedAt > 0 {
|
||||
sb.WriteString(fmt.Sprintf("<LastModified>%s</LastModified>", time.Unix(f.CreatedAt, 0).UTC().Format(time.RFC3339)))
|
||||
}
|
||||
sb.WriteString("<StorageClass>STANDARD</StorageClass>")
|
||||
sb.WriteString("</Contents>")
|
||||
}
|
||||
|
||||
sb.WriteString("</ListBucketResult>")
|
||||
return []byte(sb.String())
|
||||
}
|
||||
|
||||
// ToS3ErrorXML converts an error to S3 error XML format.
|
||||
func ToS3ErrorXML(code, message, resource, requestID string) []byte {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(`<?xml version="1.0" encoding="UTF-8"?>`)
|
||||
sb.WriteString("<Error>")
|
||||
sb.WriteString(fmt.Sprintf("<Code>%s</Code>", code))
|
||||
sb.WriteString(fmt.Sprintf("<Message>%s</Message>", message))
|
||||
sb.WriteString(fmt.Sprintf("<Resource>%s</Resource>", resource))
|
||||
sb.WriteString(fmt.Sprintf("<RequestId>%s</RequestId>", requestID))
|
||||
sb.WriteString("</Error>")
|
||||
return []byte(sb.String())
|
||||
}
|
||||
742
core/providers/bedrock/images.go
Normal file
742
core/providers/bedrock/images.go
Normal file
@@ -0,0 +1,742 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// mapQualityToBedrock maps quality values to Bedrock format:
|
||||
// - "low" and "medium" -> "standard"
|
||||
// - "high" -> "premium"
|
||||
// - "standard" and "premium" (case-insensitive) -> pass through as lowercase ("standard"/"premium")
|
||||
func mapQualityToBedrock(quality *string) *string {
|
||||
if quality == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
qualityLower := strings.ToLower(strings.TrimSpace(*quality))
|
||||
|
||||
switch qualityLower {
|
||||
case "low", "medium":
|
||||
return schemas.Ptr("standard")
|
||||
case "high":
|
||||
return schemas.Ptr("premium")
|
||||
case "standard":
|
||||
return schemas.Ptr("standard")
|
||||
case "premium":
|
||||
return schemas.Ptr("premium")
|
||||
default:
|
||||
return quality
|
||||
}
|
||||
}
|
||||
|
||||
// isStabilityAIModel returns true if the model is a Stability AI model (contains "stability.")
|
||||
func isStabilityAIModel(model string) bool {
|
||||
return strings.Contains(strings.ToLower(model), "stability.")
|
||||
}
|
||||
|
||||
// isPromptOnlyImageGenerationModel returns true for image generation models that use a flat
|
||||
// {"prompt": "..."} payload (no taskType field). Covers Vertex Imagen and similar models.
|
||||
// Stability AI is excluded here — it's handled separately because it also supports image edit.
|
||||
func isPromptOnlyImageGenerationModel(model string) bool {
|
||||
m := strings.ToLower(model)
|
||||
return strings.Contains(m, "image")
|
||||
}
|
||||
|
||||
// ToStabilityAIImageGenerationRequest converts a Bifrost image generation request to the Stability AI
|
||||
// flat request format used by Bedrock (stability.stable-image-* models).
|
||||
func ToStabilityAIImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*StabilityAIImageGenerationRequest, error) {
|
||||
if request == nil {
|
||||
return nil, fmt.Errorf("request is nil")
|
||||
}
|
||||
if request.Input == nil {
|
||||
return nil, fmt.Errorf("request input is required")
|
||||
}
|
||||
|
||||
req := &StabilityAIImageGenerationRequest{
|
||||
Prompt: request.Input.Prompt,
|
||||
}
|
||||
|
||||
if request.Params != nil {
|
||||
if request.Params.AspectRatio != nil {
|
||||
req.AspectRatio = request.Params.AspectRatio
|
||||
}
|
||||
if request.Params.OutputFormat != nil {
|
||||
req.OutputFormat = request.Params.OutputFormat
|
||||
}
|
||||
if request.Params.Seed != nil {
|
||||
req.Seed = request.Params.Seed
|
||||
}
|
||||
if request.Params.NegativePrompt != nil {
|
||||
req.NegativePrompt = request.Params.NegativePrompt
|
||||
}
|
||||
if request.Params.ExtraParams != nil {
|
||||
// aspect_ratio may also arrive via ExtraParams if not in knownFields; skip if already set
|
||||
if req.AspectRatio == nil {
|
||||
if ar, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["aspect_ratio"]); ok {
|
||||
delete(request.Params.ExtraParams, "aspect_ratio")
|
||||
req.AspectRatio = ar
|
||||
}
|
||||
}
|
||||
req.ExtraParams = request.Params.ExtraParams
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// ToBedrockImageGenerationRequest converts a Bifrost image generation request to a Bedrock image generation request
|
||||
func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*BedrockImageGenerationRequest, error) {
|
||||
if request == nil {
|
||||
return nil, fmt.Errorf("request is nil")
|
||||
}
|
||||
|
||||
if request.Input == nil {
|
||||
return nil, fmt.Errorf("request input is required")
|
||||
}
|
||||
|
||||
bedrockReq := &BedrockImageGenerationRequest{
|
||||
TaskType: schemas.Ptr(TaskTypeTextImage),
|
||||
TextToImageParams: &BedrockTextToImageParams{
|
||||
Text: request.Input.Prompt,
|
||||
},
|
||||
ImageGenerationConfig: &ImageGenerationConfig{},
|
||||
}
|
||||
|
||||
if request.Params != nil {
|
||||
if request.Params.N != nil {
|
||||
bedrockReq.ImageGenerationConfig.NumberOfImages = request.Params.N
|
||||
}
|
||||
if request.Params.NegativePrompt != nil {
|
||||
bedrockReq.TextToImageParams.NegativeText = request.Params.NegativePrompt
|
||||
}
|
||||
if request.Params.Seed != nil {
|
||||
bedrockReq.ImageGenerationConfig.Seed = request.Params.Seed
|
||||
}
|
||||
if request.Params.Quality != nil {
|
||||
bedrockReq.ImageGenerationConfig.Quality = mapQualityToBedrock(request.Params.Quality)
|
||||
}
|
||||
if request.Params.Style != nil {
|
||||
bedrockReq.TextToImageParams.Style = request.Params.Style
|
||||
}
|
||||
if request.Params.Size != nil && strings.TrimSpace(strings.ToLower(*request.Params.Size)) != "auto" {
|
||||
|
||||
size := strings.Split(strings.TrimSpace(strings.ToLower(*request.Params.Size)), "x")
|
||||
if len(size) != 2 {
|
||||
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *request.Params.Size)
|
||||
}
|
||||
|
||||
width, err := strconv.Atoi(size[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid width in size %q: %w", *request.Params.Size, err)
|
||||
}
|
||||
|
||||
height, err := strconv.Atoi(size[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid height in size %q: %w", *request.Params.Size, err)
|
||||
}
|
||||
|
||||
bedrockReq.ImageGenerationConfig.Width = schemas.Ptr(width)
|
||||
bedrockReq.ImageGenerationConfig.Height = schemas.Ptr(height)
|
||||
}
|
||||
if request.Params.ExtraParams != nil {
|
||||
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["cfgScale"]); ok {
|
||||
delete(request.Params.ExtraParams, "cfgScale")
|
||||
bedrockReq.ImageGenerationConfig.CfgScale = cfgScale
|
||||
}
|
||||
bedrockReq.ExtraParams = request.Params.ExtraParams
|
||||
}
|
||||
}
|
||||
|
||||
return bedrockReq, nil
|
||||
}
|
||||
|
||||
// ToStabilityAIImageGenerationResponse converts a BifrostImageGenerationResponse back to
|
||||
// the native Bedrock invoke API response format used by Stability AI models.
|
||||
// Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas.
|
||||
func ToStabilityAIImageGenerationResponse(response *schemas.BifrostImageGenerationResponse) (*BedrockImageGenerationResponse, error) {
|
||||
if response == nil {
|
||||
return nil, fmt.Errorf("response is nil")
|
||||
}
|
||||
result := &BedrockImageGenerationResponse{}
|
||||
for _, d := range response.Data {
|
||||
result.Images = append(result.Images, d.B64JSON)
|
||||
}
|
||||
if response.ImageGenerationResponseParameters != nil {
|
||||
result.FinishReasons = response.ImageGenerationResponseParameters.FinishReasons
|
||||
result.Seeds = response.ImageGenerationResponseParameters.Seeds
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ToBedrockImageVariationRequest converts a Bifrost image variation request to a Bedrock image variation request
|
||||
func ToBedrockImageVariationRequest(request *schemas.BifrostImageVariationRequest) (*BedrockImageVariationRequest, error) {
|
||||
if request == nil {
|
||||
return nil, fmt.Errorf("request is nil")
|
||||
}
|
||||
|
||||
if request.Input == nil || request.Input.Image.Image == nil || len(request.Input.Image.Image) == 0 {
|
||||
return nil, fmt.Errorf("request.Input.Image is required")
|
||||
}
|
||||
|
||||
bedrockReq := &BedrockImageVariationRequest{
|
||||
TaskType: schemas.Ptr(TaskTypeImageVariation),
|
||||
ImageVariationParams: &BedrockImageVariationParams{
|
||||
Images: []string{},
|
||||
},
|
||||
ImageGenerationConfig: &ImageGenerationConfig{},
|
||||
}
|
||||
|
||||
// Convert all images to base64 strings
|
||||
// Primary image from Input.Image
|
||||
imageBase64 := base64.StdEncoding.EncodeToString(request.Input.Image.Image)
|
||||
bedrockReq.ImageVariationParams.Images = append(bedrockReq.ImageVariationParams.Images, imageBase64)
|
||||
|
||||
// Additional images from ExtraParams (stored as [][]byte)
|
||||
if request.Params != nil && request.Params.ExtraParams != nil {
|
||||
if additionalImages, ok := request.Params.ExtraParams["images"]; ok {
|
||||
delete(request.Params.ExtraParams, "images")
|
||||
// Handle array of byte arrays (stored by HTTP handler)
|
||||
if imagesArray, ok := additionalImages.([][]byte); ok {
|
||||
for _, imgBytes := range imagesArray {
|
||||
if len(imgBytes) > 0 {
|
||||
additionalBase64 := base64.StdEncoding.EncodeToString(imgBytes)
|
||||
bedrockReq.ImageVariationParams.Images = append(bedrockReq.ImageVariationParams.Images, additionalBase64)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract optional fields from ExtraParams
|
||||
if prompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["prompt"]); ok {
|
||||
delete(request.Params.ExtraParams, "prompt")
|
||||
bedrockReq.ImageVariationParams.Text = prompt
|
||||
}
|
||||
if negativeText, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["negativeText"]); ok {
|
||||
delete(request.Params.ExtraParams, "negativeText")
|
||||
bedrockReq.ImageVariationParams.NegativeText = negativeText
|
||||
}
|
||||
|
||||
if similarityStrength, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["similarityStrength"]); ok {
|
||||
delete(request.Params.ExtraParams, "similarityStrength")
|
||||
// Validate similarityStrength range (0.2 to 1.0)
|
||||
if *similarityStrength < 0.2 || *similarityStrength > 1.0 {
|
||||
return nil, fmt.Errorf("similarityStrength must be between 0.2 and 1.0, got %f", *similarityStrength)
|
||||
}
|
||||
bedrockReq.ImageVariationParams.SimilarityStrength = similarityStrength
|
||||
}
|
||||
bedrockReq.ExtraParams = request.Params.ExtraParams
|
||||
}
|
||||
|
||||
// Map standard params to ImageGenerationConfig
|
||||
if request.Params != nil {
|
||||
if request.Params.N != nil {
|
||||
bedrockReq.ImageGenerationConfig.NumberOfImages = request.Params.N
|
||||
}
|
||||
|
||||
if request.Params.Size != nil && strings.TrimSpace(strings.ToLower(*request.Params.Size)) != "auto" {
|
||||
size := strings.Split(strings.TrimSpace(strings.ToLower(*request.Params.Size)), "x")
|
||||
if len(size) != 2 {
|
||||
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *request.Params.Size)
|
||||
}
|
||||
|
||||
width, err := strconv.Atoi(size[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid width in size %q: %w", *request.Params.Size, err)
|
||||
}
|
||||
|
||||
height, err := strconv.Atoi(size[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid height in size %q: %w", *request.Params.Size, err)
|
||||
}
|
||||
|
||||
bedrockReq.ImageGenerationConfig.Width = schemas.Ptr(width)
|
||||
bedrockReq.ImageGenerationConfig.Height = schemas.Ptr(height)
|
||||
}
|
||||
|
||||
// Extract quality and cfgScale from ExtraParams
|
||||
if request.Params.ExtraParams != nil {
|
||||
if quality, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["quality"]); ok {
|
||||
bedrockReq.ImageGenerationConfig.Quality = mapQualityToBedrock(quality)
|
||||
}
|
||||
|
||||
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["cfgScale"]); ok {
|
||||
bedrockReq.ImageGenerationConfig.CfgScale = cfgScale
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bedrockReq, nil
|
||||
}
|
||||
|
||||
// ToBedrockImageEditRequest converts a Bifrost image edit request to a Bedrock image edit request
|
||||
func ToBedrockImageEditRequest(request *schemas.BifrostImageEditRequest) (*BedrockImageEditRequest, error) {
|
||||
// Validate request
|
||||
if request == nil || request.Input == nil {
|
||||
return nil, fmt.Errorf("request or input is nil")
|
||||
}
|
||||
|
||||
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
|
||||
return nil, fmt.Errorf("at least one image is required")
|
||||
}
|
||||
|
||||
// Validate and extract type (required)
|
||||
if request.Params == nil || request.Params.Type == nil {
|
||||
return nil, fmt.Errorf("type field is required (must be inpainting, outpainting, or background_removal)")
|
||||
}
|
||||
|
||||
editType := strings.ToLower(*request.Params.Type)
|
||||
|
||||
// Convert first image to base64
|
||||
imageBase64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
|
||||
|
||||
bedrockReq := &BedrockImageEditRequest{}
|
||||
|
||||
switch editType {
|
||||
case "inpainting":
|
||||
bedrockReq.TaskType = schemas.Ptr(TaskTypeInpainting)
|
||||
bedrockReq.InPaintingParams = buildInPaintingParams(imageBase64, request)
|
||||
bedrockReq.ImageGenerationConfig = buildImageGenerationConfig(request.Params)
|
||||
|
||||
case "outpainting":
|
||||
bedrockReq.TaskType = schemas.Ptr(TaskTypeOutpainting)
|
||||
bedrockReq.OutPaintingParams = buildOutPaintingParams(imageBase64, request)
|
||||
bedrockReq.ImageGenerationConfig = buildImageGenerationConfig(request.Params)
|
||||
|
||||
case "background_removal":
|
||||
bedrockReq.TaskType = schemas.Ptr(TaskTypeBackgroundRemoval)
|
||||
bedrockReq.BackgroundRemovalParams = &BedrockBackgroundRemovalParams{
|
||||
Image: imageBase64,
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type for Bedrock: %s (must be inpainting, outpainting, or background_removal)", editType)
|
||||
}
|
||||
|
||||
bedrockReq.ExtraParams = request.Params.ExtraParams
|
||||
return bedrockReq, nil
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func buildInPaintingParams(imageBase64 string, request *schemas.BifrostImageEditRequest) *BedrockInPaintingParams {
|
||||
params := &BedrockInPaintingParams{
|
||||
Image: imageBase64,
|
||||
Text: request.Input.Prompt,
|
||||
}
|
||||
|
||||
if request.Params.NegativePrompt != nil {
|
||||
params.NegativeText = request.Params.NegativePrompt
|
||||
}
|
||||
|
||||
if request.Params.ExtraParams != nil {
|
||||
if maskPrompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["mask_prompt"]); ok {
|
||||
delete(request.Params.ExtraParams, "mask_prompt")
|
||||
params.MaskPrompt = maskPrompt
|
||||
}
|
||||
if returnMask, ok := schemas.SafeExtractBoolPointer(request.Params.ExtraParams["return_mask"]); ok {
|
||||
delete(request.Params.ExtraParams, "return_mask")
|
||||
params.ReturnMask = returnMask
|
||||
}
|
||||
}
|
||||
|
||||
// Convert mask to base64 if present
|
||||
if len(request.Params.Mask) > 0 {
|
||||
maskBase64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
|
||||
params.MaskImage = &maskBase64
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
func buildOutPaintingParams(imageBase64 string, request *schemas.BifrostImageEditRequest) *BedrockOutPaintingParams {
|
||||
params := &BedrockOutPaintingParams{
|
||||
Text: request.Input.Prompt,
|
||||
Image: imageBase64,
|
||||
}
|
||||
|
||||
if request.Params.NegativePrompt != nil {
|
||||
params.NegativeText = request.Params.NegativePrompt
|
||||
}
|
||||
|
||||
if request.Params.ExtraParams != nil {
|
||||
if maskPrompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["mask_prompt"]); ok {
|
||||
delete(request.Params.ExtraParams, "mask_prompt")
|
||||
params.MaskPrompt = maskPrompt
|
||||
}
|
||||
if returnMask, ok := schemas.SafeExtractBoolPointer(request.Params.ExtraParams["return_mask"]); ok {
|
||||
delete(request.Params.ExtraParams, "return_mask")
|
||||
params.ReturnMask = returnMask
|
||||
}
|
||||
if outPaintingMode, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["outpainting_mode"]); ok {
|
||||
// Validate mode
|
||||
mode := strings.ToUpper(*outPaintingMode)
|
||||
if mode == "DEFAULT" || mode == "PRECISE" {
|
||||
delete(request.Params.ExtraParams, "outpainting_mode")
|
||||
params.OutPaintingMode = &mode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert mask to base64 if present
|
||||
if len(request.Params.Mask) > 0 {
|
||||
maskBase64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
|
||||
params.MaskImage = &maskBase64
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
func buildImageGenerationConfig(params *schemas.ImageEditParameters) *ImageGenerationConfig {
|
||||
config := &ImageGenerationConfig{}
|
||||
|
||||
if params.N != nil {
|
||||
config.NumberOfImages = params.N
|
||||
}
|
||||
|
||||
// Parse size (reuse logic from image generation)
|
||||
if params.Size != nil && strings.TrimSpace(strings.ToLower(*params.Size)) != "auto" {
|
||||
size := strings.Split(strings.TrimSpace(strings.ToLower(*params.Size)), "x")
|
||||
if len(size) == 2 {
|
||||
width, err := strconv.Atoi(size[0])
|
||||
if err == nil {
|
||||
height, err := strconv.Atoi(size[1])
|
||||
if err == nil {
|
||||
config.Width = schemas.Ptr(width)
|
||||
config.Height = schemas.Ptr(height)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if params.Quality != nil {
|
||||
config.Quality = mapQualityToBedrock(params.Quality)
|
||||
}
|
||||
|
||||
if params.Seed != nil {
|
||||
config.Seed = params.Seed
|
||||
}
|
||||
|
||||
if params.ExtraParams != nil {
|
||||
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(params.ExtraParams["cfgScale"]); ok {
|
||||
delete(params.ExtraParams, "cfgScale")
|
||||
config.CfgScale = cfgScale
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// getStabilityAITaskTypeFromParams maps the generic BifrostImageEditParameters.Type value
|
||||
// to a Stability AI task type string. Returns "" if the value is not a recognized Stability AI task type.
|
||||
func getStabilityAITaskTypeFromParams(t string) string {
|
||||
switch strings.ToLower(t) {
|
||||
case "inpainting", "inpaint":
|
||||
return "inpaint"
|
||||
case "outpainting", "outpaint":
|
||||
return "outpaint"
|
||||
case "background_removal", "remove_background":
|
||||
return "remove-bg"
|
||||
case "erase_object":
|
||||
return "erase-object"
|
||||
case "upscale_fast":
|
||||
return "upscale-fast"
|
||||
case "upscale_creative":
|
||||
return "upscale-creative"
|
||||
case "upscale_conservative":
|
||||
return "upscale-conservative"
|
||||
case "recolor":
|
||||
return "recolor"
|
||||
case "search_replace":
|
||||
return "search-replace"
|
||||
case "control_sketch":
|
||||
return "control-sketch"
|
||||
case "control_structure":
|
||||
return "control-structure"
|
||||
case "style_guide":
|
||||
return "style-guide"
|
||||
case "style_transfer":
|
||||
return "style-transfer"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// getStabilityAIEditTaskType infers the Stability AI edit task from the model name.
|
||||
// Returns an error if the model name does not match any known pattern.
|
||||
func getStabilityAIEditTaskType(model string) (string, error) {
|
||||
m := strings.ToLower(model)
|
||||
switch {
|
||||
case strings.Contains(m, "stable-creative-upscale"):
|
||||
return "upscale-creative", nil
|
||||
case strings.Contains(m, "stable-conservative-upscale"):
|
||||
return "upscale-conservative", nil
|
||||
case strings.Contains(m, "stable-fast-upscale"):
|
||||
return "upscale-fast", nil
|
||||
case strings.Contains(m, "stable-image-inpaint"):
|
||||
return "inpaint", nil
|
||||
case strings.Contains(m, "stable-outpaint"):
|
||||
return "outpaint", nil
|
||||
case strings.Contains(m, "stable-image-search-recolor"):
|
||||
return "recolor", nil
|
||||
case strings.Contains(m, "stable-image-search-replace"):
|
||||
return "search-replace", nil
|
||||
case strings.Contains(m, "stable-image-erase-object"):
|
||||
return "erase-object", nil
|
||||
case strings.Contains(m, "stable-image-remove-background"):
|
||||
return "remove-bg", nil
|
||||
case strings.Contains(m, "stable-image-control-sketch"):
|
||||
return "control-sketch", nil
|
||||
case strings.Contains(m, "stable-image-control-structure"):
|
||||
return "control-structure", nil
|
||||
case strings.Contains(m, "stable-image-style-guide"):
|
||||
return "style-guide", nil
|
||||
case strings.Contains(m, "stable-style-transfer"):
|
||||
return "style-transfer", nil
|
||||
default:
|
||||
return "", fmt.Errorf("cannot determine task type from stability ai model name %q", model)
|
||||
}
|
||||
}
|
||||
|
||||
// ToStabilityAIImageEditRequest converts a Bifrost image edit request to the Stability AI flat request
|
||||
// format used by Bedrock edit models. Only fields valid for the detected task type are populated.
|
||||
// deployment is the resolved model identifier (after applying any deployment alias mapping); it is
|
||||
// used for task-type inference so that alias-mapped models route correctly.
|
||||
func ToStabilityAIImageEditRequest(request *schemas.BifrostImageEditRequest, deployment string) (*StabilityAIImageEditRequest, error) {
|
||||
if request == nil || request.Input == nil {
|
||||
return nil, fmt.Errorf("request or input is nil")
|
||||
}
|
||||
|
||||
var taskType string
|
||||
if request.Params != nil && request.Params.Type != nil {
|
||||
taskType = getStabilityAITaskTypeFromParams(*request.Params.Type)
|
||||
}
|
||||
if taskType == "" {
|
||||
var err error
|
||||
taskType, err = getStabilityAIEditTaskType(deployment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
req := &StabilityAIImageEditRequest{}
|
||||
|
||||
// Image sourcing
|
||||
if taskType == "style-transfer" {
|
||||
if len(request.Input.Images) != 2 {
|
||||
return nil, fmt.Errorf("style-transfer requires exactly two images: init_image and style_image")
|
||||
}
|
||||
if len(request.Input.Images[0].Image) == 0 || len(request.Input.Images[1].Image) == 0 {
|
||||
return nil, fmt.Errorf("style-transfer requires non-empty init_image and style_image")
|
||||
}
|
||||
initB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
|
||||
styleB64 := base64.StdEncoding.EncodeToString(request.Input.Images[1].Image)
|
||||
req.InitImage = &initB64
|
||||
req.StyleImage = &styleB64
|
||||
} else {
|
||||
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
|
||||
return nil, fmt.Errorf("at least one image is required")
|
||||
}
|
||||
imageB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
|
||||
req.Image = &imageB64
|
||||
}
|
||||
|
||||
// Common fields populated based on task allowlist
|
||||
prompt := request.Input.Prompt
|
||||
switch taskType {
|
||||
case "inpaint", "recolor", "search-replace", "control-sketch", "control-structure",
|
||||
"style-guide", "upscale-creative", "upscale-conservative", "outpaint", "style-transfer":
|
||||
req.Prompt = &prompt
|
||||
}
|
||||
|
||||
// Negative prompt
|
||||
if request.Params != nil && request.Params.NegativePrompt != nil {
|
||||
switch taskType {
|
||||
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
|
||||
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
|
||||
req.NegativePrompt = request.Params.NegativePrompt
|
||||
}
|
||||
}
|
||||
|
||||
// Seed
|
||||
if request.Params != nil && request.Params.Seed != nil {
|
||||
switch taskType {
|
||||
case "inpaint", "outpaint", "recolor", "search-replace", "erase-object", "control-sketch",
|
||||
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
|
||||
req.Seed = request.Params.Seed
|
||||
}
|
||||
}
|
||||
|
||||
// Mask (from Params.Mask bytes)
|
||||
if request.Params != nil && len(request.Params.Mask) > 0 {
|
||||
switch taskType {
|
||||
case "inpaint", "erase-object":
|
||||
maskB64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
|
||||
req.Mask = &maskB64
|
||||
}
|
||||
}
|
||||
|
||||
// ExtraParams
|
||||
if request.Params != nil {
|
||||
// Typed OutputFormat takes priority over ExtraParams
|
||||
if request.Params.OutputFormat != nil {
|
||||
req.OutputFormat = request.Params.OutputFormat
|
||||
}
|
||||
|
||||
if request.Params.ExtraParams != nil {
|
||||
ep := make(map[string]interface{}, len(request.Params.ExtraParams))
|
||||
for k, v := range request.Params.ExtraParams {
|
||||
ep[k] = v
|
||||
}
|
||||
|
||||
// output_format — all tasks (fallback if not already set by typed field)
|
||||
if req.OutputFormat == nil {
|
||||
if v, ok := schemas.SafeExtractStringPointer(ep["output_format"]); ok {
|
||||
delete(ep, "output_format")
|
||||
req.OutputFormat = v
|
||||
}
|
||||
}
|
||||
|
||||
// style_preset
|
||||
switch taskType {
|
||||
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
|
||||
"control-structure", "style-guide", "upscale-creative":
|
||||
if v, ok := schemas.SafeExtractStringPointer(ep["style_preset"]); ok {
|
||||
delete(ep, "style_preset")
|
||||
req.StylePreset = v
|
||||
}
|
||||
}
|
||||
|
||||
// grow_mask
|
||||
switch taskType {
|
||||
case "inpaint", "recolor", "search-replace", "erase-object":
|
||||
if v, ok := schemas.SafeExtractIntPointer(ep["grow_mask"]); ok {
|
||||
delete(ep, "grow_mask")
|
||||
req.GrowMask = v
|
||||
}
|
||||
}
|
||||
|
||||
// outpaint directional fields
|
||||
if taskType == "outpaint" {
|
||||
if v, ok := schemas.SafeExtractIntPointer(ep["left"]); ok {
|
||||
delete(ep, "left")
|
||||
req.Left = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(ep["right"]); ok {
|
||||
delete(ep, "right")
|
||||
req.Right = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(ep["up"]); ok {
|
||||
delete(ep, "up")
|
||||
req.Up = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractIntPointer(ep["down"]); ok {
|
||||
delete(ep, "down")
|
||||
req.Down = v
|
||||
}
|
||||
}
|
||||
|
||||
// creativity
|
||||
switch taskType {
|
||||
case "upscale-creative", "upscale-conservative", "outpaint":
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(ep["creativity"]); ok {
|
||||
delete(ep, "creativity")
|
||||
req.Creativity = v
|
||||
}
|
||||
}
|
||||
|
||||
// select_prompt (recolor)
|
||||
if taskType == "recolor" {
|
||||
if v, ok := schemas.SafeExtractStringPointer(ep["select_prompt"]); ok {
|
||||
delete(ep, "select_prompt")
|
||||
req.SelectPrompt = v
|
||||
}
|
||||
}
|
||||
|
||||
// search_prompt (search-replace)
|
||||
if taskType == "search-replace" {
|
||||
if v, ok := schemas.SafeExtractStringPointer(ep["search_prompt"]); ok {
|
||||
delete(ep, "search_prompt")
|
||||
req.SearchPrompt = v
|
||||
}
|
||||
}
|
||||
|
||||
// control_strength
|
||||
switch taskType {
|
||||
case "control-sketch", "control-structure":
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(ep["control_strength"]); ok {
|
||||
delete(ep, "control_strength")
|
||||
req.ControlStrength = v
|
||||
}
|
||||
}
|
||||
|
||||
// style-guide fields
|
||||
if taskType == "style-guide" {
|
||||
if v, ok := schemas.SafeExtractStringPointer(ep["aspect_ratio"]); ok {
|
||||
delete(ep, "aspect_ratio")
|
||||
req.AspectRatio = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(ep["fidelity"]); ok {
|
||||
delete(ep, "fidelity")
|
||||
req.Fidelity = v
|
||||
}
|
||||
}
|
||||
|
||||
// style-transfer fields
|
||||
if taskType == "style-transfer" {
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(ep["style_strength"]); ok {
|
||||
delete(ep, "style_strength")
|
||||
req.StyleStrength = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(ep["composition_fidelity"]); ok {
|
||||
delete(ep, "composition_fidelity")
|
||||
req.CompositionFidelity = v
|
||||
}
|
||||
if v, ok := schemas.SafeExtractFloat64Pointer(ep["change_strength"]); ok {
|
||||
delete(ep, "change_strength")
|
||||
req.ChangeStrength = v
|
||||
}
|
||||
}
|
||||
|
||||
req.ExtraParams = ep
|
||||
}
|
||||
}
|
||||
|
||||
// Validate required per-task fields
|
||||
if taskType == "recolor" && (req.SelectPrompt == nil || *req.SelectPrompt == "") {
|
||||
return nil, fmt.Errorf("select_prompt is required for stability ai recolor task")
|
||||
}
|
||||
if taskType == "search-replace" && (req.SearchPrompt == nil || *req.SearchPrompt == "") {
|
||||
return nil, fmt.Errorf("search_prompt is required for stability ai search-replace task")
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// ToBifrostImageGenerationResponse converts a Bedrock image generation response to a Bifrost image generation response
|
||||
func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) *schemas.BifrostImageGenerationResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostImageGenerationResponse{}
|
||||
|
||||
if len(response.FinishReasons) > 0 || len(response.Seeds) > 0 {
|
||||
bifrostResponse.ImageGenerationResponseParameters = &schemas.ImageGenerationResponseParameters{
|
||||
FinishReasons: append([]*string(nil), response.FinishReasons...),
|
||||
Seeds: append([]int(nil), response.Seeds...),
|
||||
}
|
||||
}
|
||||
|
||||
for index, image := range response.Images {
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, schemas.ImageData{
|
||||
B64JSON: image,
|
||||
Index: index,
|
||||
})
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
1442
core/providers/bedrock/invoke.go
Normal file
1442
core/providers/bedrock/invoke.go
Normal file
File diff suppressed because it is too large
Load Diff
130
core/providers/bedrock/models.go
Normal file
130
core/providers/bedrock/models.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// BedrockRerankRequest is the Bedrock Agent Runtime rerank request body.
|
||||
type BedrockRerankRequest struct {
|
||||
Queries []BedrockRerankQuery `json:"queries"`
|
||||
Sources []BedrockRerankSource `json:"sources"`
|
||||
RerankingConfiguration BedrockRerankingConfiguration `json:"rerankingConfiguration"`
|
||||
}
|
||||
|
||||
// GetExtraParams implements RequestBodyWithExtraParams.
|
||||
func (*BedrockRerankRequest) GetExtraParams() map[string]interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
bedrockRerankQueryTypeText = "TEXT"
|
||||
bedrockRerankSourceTypeInline = "INLINE"
|
||||
bedrockRerankInlineDocumentTypeText = "TEXT"
|
||||
bedrockRerankConfigurationTypeBedrock = "BEDROCK_RERANKING_MODEL"
|
||||
)
|
||||
|
||||
type BedrockRerankQuery struct {
|
||||
Type string `json:"type"`
|
||||
TextQuery BedrockRerankTextRef `json:"textQuery"`
|
||||
}
|
||||
|
||||
type BedrockRerankSource struct {
|
||||
Type string `json:"type"`
|
||||
InlineDocumentSource BedrockRerankInlineSource `json:"inlineDocumentSource"`
|
||||
}
|
||||
|
||||
type BedrockRerankInlineSource struct {
|
||||
Type string `json:"type"`
|
||||
TextDocument BedrockRerankTextValue `json:"textDocument"`
|
||||
}
|
||||
|
||||
type BedrockRerankTextRef struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type BedrockRerankTextValue struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type BedrockRerankingConfiguration struct {
|
||||
Type string `json:"type"`
|
||||
BedrockRerankingConfiguration BedrockRerankingModelConfiguration `json:"bedrockRerankingConfiguration"`
|
||||
}
|
||||
|
||||
type BedrockRerankingModelConfiguration struct {
|
||||
ModelConfiguration BedrockRerankModelConfiguration `json:"modelConfiguration"`
|
||||
NumberOfResults *int `json:"numberOfResults,omitempty"`
|
||||
}
|
||||
|
||||
type BedrockRerankModelConfiguration struct {
|
||||
ModelARN string `json:"modelArn"`
|
||||
AdditionalModelRequestFields map[string]interface{} `json:"additionalModelRequestFields,omitempty"`
|
||||
}
|
||||
|
||||
// BedrockRerankResponse is the Bedrock Agent Runtime rerank response body.
|
||||
type BedrockRerankResponse struct {
|
||||
Results []BedrockRerankResult `json:"results"`
|
||||
NextToken *string `json:"nextToken,omitempty"`
|
||||
}
|
||||
|
||||
type BedrockRerankResult struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevanceScore"`
|
||||
Document *BedrockRerankResponseDocument `json:"document,omitempty"`
|
||||
}
|
||||
|
||||
type BedrockRerankResponseDocument struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
TextDocument *BedrockRerankTextValue `json:"textDocument,omitempty"`
|
||||
}
|
||||
|
||||
func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.ModelSummaries)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: providerKey,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.ModelSummaries {
|
||||
for _, result := range pipeline.FilterModel(model.ModelID) {
|
||||
modelEntry := schemas.Model{
|
||||
ID: string(providerKey) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(model.ModelName),
|
||||
OwnedBy: schemas.Ptr(model.ProviderName),
|
||||
Architecture: &schemas.Architecture{
|
||||
InputModalities: model.InputModalities,
|
||||
OutputModalities: model.OutputModalities,
|
||||
},
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
modelEntry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, modelEntry)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
55
core/providers/bedrock/payload_ordering_test.go
Normal file
55
core/providers/bedrock/payload_ordering_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPayloadOrdering_BedrockConverseRequest(t *testing.T) {
|
||||
req := &BedrockConverseRequest{
|
||||
Messages: []BedrockMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []BedrockContentBlock{
|
||||
{Text: schemas.Ptr("hello")},
|
||||
},
|
||||
},
|
||||
},
|
||||
InferenceConfig: &BedrockInferenceConfig{
|
||||
Temperature: schemas.Ptr(0.7),
|
||||
MaxTokens: schemas.Ptr(1024),
|
||||
},
|
||||
ToolConfig: &BedrockToolConfig{
|
||||
Tools: []BedrockTool{
|
||||
{
|
||||
ToolSpec: &BedrockToolSpec{
|
||||
Name: "get_weather",
|
||||
Description: schemas.Ptr("Get weather"),
|
||||
InputSchema: BedrockToolInputSchema{
|
||||
JSON: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
golden := `{"messages":[{"role":"user","content":[{"text":"hello"}]}],"inferenceConfig":{"maxTokens":1024,"temperature":0.7},"toolConfig":{"tools":[{"toolSpec":{"name":"get_weather","description":"Get weather","inputSchema":{"json":{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}}}}]}}`
|
||||
|
||||
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
|
||||
|
||||
// Determinism: 100 iterations must produce identical bytes
|
||||
for i := 0; i < 100; i++ {
|
||||
iter, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
|
||||
}
|
||||
}
|
||||
168
core/providers/bedrock/rerank.go
Normal file
168
core/providers/bedrock/rerank.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBedrockRerankRequest converts a Bifrost rerank request into Bedrock Agent Runtime format.
|
||||
func ToBedrockRerankRequest(bifrostReq *schemas.BifrostRerankRequest, modelARN string) (*BedrockRerankRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrost rerank request is nil")
|
||||
}
|
||||
if strings.TrimSpace(modelARN) == "" {
|
||||
return nil, fmt.Errorf("bedrock rerank model ARN is empty")
|
||||
}
|
||||
if len(bifrostReq.Documents) == 0 {
|
||||
return nil, fmt.Errorf("documents are required for rerank request")
|
||||
}
|
||||
|
||||
bedrockReq := &BedrockRerankRequest{
|
||||
Queries: []BedrockRerankQuery{
|
||||
{
|
||||
Type: bedrockRerankQueryTypeText,
|
||||
TextQuery: BedrockRerankTextRef{
|
||||
Text: bifrostReq.Query,
|
||||
},
|
||||
},
|
||||
},
|
||||
Sources: make([]BedrockRerankSource, len(bifrostReq.Documents)),
|
||||
RerankingConfiguration: BedrockRerankingConfiguration{
|
||||
Type: bedrockRerankConfigurationTypeBedrock,
|
||||
BedrockRerankingConfiguration: BedrockRerankingModelConfiguration{
|
||||
ModelConfiguration: BedrockRerankModelConfiguration{
|
||||
ModelARN: modelARN,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, doc := range bifrostReq.Documents {
|
||||
bedrockReq.Sources[i] = BedrockRerankSource{
|
||||
Type: bedrockRerankSourceTypeInline,
|
||||
InlineDocumentSource: BedrockRerankInlineSource{
|
||||
Type: bedrockRerankInlineDocumentTypeText,
|
||||
TextDocument: BedrockRerankTextValue{
|
||||
Text: doc.Text,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if bifrostReq.Params == nil {
|
||||
return bedrockReq, nil
|
||||
}
|
||||
|
||||
if bifrostReq.Params.TopN != nil {
|
||||
topN := *bifrostReq.Params.TopN
|
||||
if topN < 1 {
|
||||
return nil, fmt.Errorf("top_n must be at least 1")
|
||||
}
|
||||
if topN > len(bifrostReq.Documents) {
|
||||
topN = len(bifrostReq.Documents)
|
||||
}
|
||||
bedrockReq.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults = schemas.Ptr(topN)
|
||||
}
|
||||
|
||||
additionalFields := make(map[string]interface{})
|
||||
if bifrostReq.Params.MaxTokensPerDoc != nil {
|
||||
additionalFields["max_tokens_per_doc"] = *bifrostReq.Params.MaxTokensPerDoc
|
||||
}
|
||||
if bifrostReq.Params.Priority != nil {
|
||||
additionalFields["priority"] = *bifrostReq.Params.Priority
|
||||
}
|
||||
for k, v := range bifrostReq.Params.ExtraParams {
|
||||
additionalFields[k] = v
|
||||
}
|
||||
if len(additionalFields) > 0 {
|
||||
bedrockReq.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields = additionalFields
|
||||
}
|
||||
|
||||
return bedrockReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostRerankResponse converts a Bedrock rerank response into Bifrost format.
|
||||
func (response *BedrockRerankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) *schemas.BifrostRerankResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostRerankResponse{
|
||||
Results: make([]schemas.RerankResult, 0, len(response.Results)),
|
||||
}
|
||||
|
||||
for _, result := range response.Results {
|
||||
rerankResult := schemas.RerankResult{
|
||||
Index: result.Index,
|
||||
RelevanceScore: result.RelevanceScore,
|
||||
}
|
||||
if result.Document != nil && result.Document.TextDocument != nil {
|
||||
rerankResult.Document = &schemas.RerankDocument{
|
||||
Text: result.Document.TextDocument.Text,
|
||||
}
|
||||
}
|
||||
bifrostResponse.Results = append(bifrostResponse.Results, rerankResult)
|
||||
}
|
||||
|
||||
sort.SliceStable(bifrostResponse.Results, func(i, j int) bool {
|
||||
if bifrostResponse.Results[i].RelevanceScore == bifrostResponse.Results[j].RelevanceScore {
|
||||
return bifrostResponse.Results[i].Index < bifrostResponse.Results[j].Index
|
||||
}
|
||||
return bifrostResponse.Results[i].RelevanceScore > bifrostResponse.Results[j].RelevanceScore
|
||||
})
|
||||
|
||||
if returnDocuments {
|
||||
for i := range bifrostResponse.Results {
|
||||
resultIndex := bifrostResponse.Results[i].Index
|
||||
if resultIndex >= 0 && resultIndex < len(documents) {
|
||||
bifrostResponse.Results[i].Document = schemas.Ptr(documents[resultIndex])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
// ToBifrostRerankRequest converts a Bedrock Agent Runtime rerank request to Bifrost format.
|
||||
func (req *BedrockRerankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelARN := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.ModelARN
|
||||
provider, model := schemas.ParseModelString(modelARN, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock))
|
||||
|
||||
bifrostReq := &schemas.BifrostRerankRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Params: &schemas.RerankParameters{},
|
||||
}
|
||||
|
||||
// Extract query from the first query entry
|
||||
if len(req.Queries) > 0 {
|
||||
bifrostReq.Query = req.Queries[0].TextQuery.Text
|
||||
}
|
||||
|
||||
// Convert sources to documents
|
||||
for _, source := range req.Sources {
|
||||
bifrostReq.Documents = append(bifrostReq.Documents, schemas.RerankDocument{
|
||||
Text: source.InlineDocumentSource.TextDocument.Text,
|
||||
})
|
||||
}
|
||||
|
||||
// Extract TopN from NumberOfResults
|
||||
if req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults != nil {
|
||||
bifrostReq.Params.TopN = req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults
|
||||
}
|
||||
|
||||
// Pass AdditionalModelRequestFields as ExtraParams
|
||||
if fields := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields; len(fields) > 0 {
|
||||
bifrostReq.Params.ExtraParams = fields
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
230
core/providers/bedrock/rerank_test.go
Normal file
230
core/providers/bedrock/rerank_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToBedrockRerankRequest(t *testing.T) {
|
||||
topN := 10
|
||||
maxTokensPerDoc := 512
|
||||
priority := 3
|
||||
|
||||
req, err := ToBedrockRerankRequest(&schemas.BifrostRerankRequest{
|
||||
Model: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
||||
Query: "capital of france",
|
||||
Documents: []schemas.RerankDocument{
|
||||
{Text: "Paris is the capital of France."},
|
||||
{Text: "Berlin is the capital of Germany."},
|
||||
},
|
||||
Params: &schemas.RerankParameters{
|
||||
TopN: schemas.Ptr(topN),
|
||||
MaxTokensPerDoc: schemas.Ptr(maxTokensPerDoc),
|
||||
Priority: schemas.Ptr(priority),
|
||||
ExtraParams: map[string]interface{}{
|
||||
"truncate": "END",
|
||||
},
|
||||
},
|
||||
}, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req)
|
||||
|
||||
require.Len(t, req.Queries, 1)
|
||||
assert.Equal(t, "TEXT", req.Queries[0].Type)
|
||||
assert.Equal(t, "capital of france", req.Queries[0].TextQuery.Text)
|
||||
require.Len(t, req.Sources, 2)
|
||||
|
||||
require.NotNil(t, req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults)
|
||||
assert.Equal(t, 2, *req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults, "top_n must be clamped to source count")
|
||||
|
||||
fields := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields
|
||||
require.NotNil(t, fields)
|
||||
assert.Equal(t, maxTokensPerDoc, fields["max_tokens_per_doc"])
|
||||
assert.Equal(t, priority, fields["priority"])
|
||||
assert.Equal(t, "END", fields["truncate"])
|
||||
}
|
||||
|
||||
func TestBedrockRerankResponseToBifrostRerankResponse(t *testing.T) {
|
||||
response := (&BedrockRerankResponse{
|
||||
Results: []BedrockRerankResult{
|
||||
{
|
||||
Index: 2,
|
||||
RelevanceScore: 0.21,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "doc-2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 1,
|
||||
RelevanceScore: 0.95,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "doc-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 0,
|
||||
RelevanceScore: 0.95,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "doc-0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostRerankResponse(nil, false)
|
||||
|
||||
require.NotNil(t, response)
|
||||
require.Len(t, response.Results, 3)
|
||||
|
||||
assert.Equal(t, 0, response.Results[0].Index)
|
||||
assert.Equal(t, 1, response.Results[1].Index)
|
||||
assert.Equal(t, 2, response.Results[2].Index)
|
||||
assert.Equal(t, "doc-0", response.Results[0].Document.Text)
|
||||
assert.Equal(t, "doc-1", response.Results[1].Document.Text)
|
||||
}
|
||||
|
||||
func TestBedrockRerankResponseToBifrostRerankResponseReturnDocuments(t *testing.T) {
|
||||
requestDocs := []schemas.RerankDocument{
|
||||
{Text: "request-doc-0"},
|
||||
{Text: "request-doc-1"},
|
||||
{Text: "request-doc-2"},
|
||||
}
|
||||
|
||||
response := (&BedrockRerankResponse{
|
||||
Results: []BedrockRerankResult{
|
||||
{
|
||||
Index: 2,
|
||||
RelevanceScore: 0.21,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 1,
|
||||
RelevanceScore: 0.95,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 0,
|
||||
RelevanceScore: 0.95,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostRerankResponse(requestDocs, true)
|
||||
|
||||
require.NotNil(t, response)
|
||||
require.Len(t, response.Results, 3)
|
||||
require.NotNil(t, response.Results[0].Document)
|
||||
require.NotNil(t, response.Results[1].Document)
|
||||
require.NotNil(t, response.Results[2].Document)
|
||||
|
||||
assert.Equal(t, 0, response.Results[0].Index)
|
||||
assert.Equal(t, 1, response.Results[1].Index)
|
||||
assert.Equal(t, 2, response.Results[2].Index)
|
||||
assert.Equal(t, "request-doc-0", response.Results[0].Document.Text)
|
||||
assert.Equal(t, "request-doc-1", response.Results[1].Document.Text)
|
||||
assert.Equal(t, "request-doc-2", response.Results[2].Document.Text)
|
||||
}
|
||||
|
||||
func TestBedrockRerankRequestToBifrostRerankRequest(t *testing.T) {
|
||||
topN := 3
|
||||
bedrockReq := &BedrockRerankRequest{
|
||||
Queries: []BedrockRerankQuery{
|
||||
{
|
||||
Type: bedrockRerankQueryTypeText,
|
||||
TextQuery: BedrockRerankTextRef{Text: "capital of france"},
|
||||
},
|
||||
},
|
||||
Sources: []BedrockRerankSource{
|
||||
{
|
||||
Type: bedrockRerankSourceTypeInline,
|
||||
InlineDocumentSource: BedrockRerankInlineSource{
|
||||
Type: bedrockRerankInlineDocumentTypeText,
|
||||
TextDocument: BedrockRerankTextValue{Text: "Paris is the capital of France."},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: bedrockRerankSourceTypeInline,
|
||||
InlineDocumentSource: BedrockRerankInlineSource{
|
||||
Type: bedrockRerankInlineDocumentTypeText,
|
||||
TextDocument: BedrockRerankTextValue{Text: "Berlin is the capital of Germany."},
|
||||
},
|
||||
},
|
||||
},
|
||||
RerankingConfiguration: BedrockRerankingConfiguration{
|
||||
Type: bedrockRerankConfigurationTypeBedrock,
|
||||
BedrockRerankingConfiguration: BedrockRerankingModelConfiguration{
|
||||
NumberOfResults: &topN,
|
||||
ModelConfiguration: BedrockRerankModelConfiguration{
|
||||
ModelARN: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
||||
AdditionalModelRequestFields: map[string]interface{}{
|
||||
"truncate": "END",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
result := bedrockReq.ToBifrostRerankRequest(bifrostCtx)
|
||||
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, schemas.Bedrock, result.Provider)
|
||||
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", result.Model)
|
||||
assert.Equal(t, "capital of france", result.Query)
|
||||
require.Len(t, result.Documents, 2)
|
||||
assert.Equal(t, "Paris is the capital of France.", result.Documents[0].Text)
|
||||
assert.Equal(t, "Berlin is the capital of Germany.", result.Documents[1].Text)
|
||||
require.NotNil(t, result.Params)
|
||||
require.NotNil(t, result.Params.TopN)
|
||||
assert.Equal(t, 3, *result.Params.TopN)
|
||||
require.NotNil(t, result.Params.ExtraParams)
|
||||
assert.Equal(t, "END", result.Params.ExtraParams["truncate"])
|
||||
}
|
||||
|
||||
func TestBedrockRerankRequestToBifrostRerankRequestNil(t *testing.T) {
|
||||
var req *BedrockRerankRequest
|
||||
assert.Nil(t, req.ToBifrostRerankRequest(nil))
|
||||
}
|
||||
|
||||
func TestResolveBedrockDeployment(t *testing.T) {
|
||||
key := schemas.Key{
|
||||
Aliases: schemas.KeyAliases{
|
||||
"cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
||||
},
|
||||
}
|
||||
|
||||
deployment := key.Aliases.Resolve("cohere-rerank")
|
||||
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", deployment)
|
||||
assert.Equal(t, "cohere.rerank-v3-5:0", key.Aliases.Resolve("cohere.rerank-v3-5:0"))
|
||||
assert.Equal(t, "", key.Aliases.Resolve(""))
|
||||
}
|
||||
|
||||
func TestBedrockRerankRequiresARNModelIdentifier(t *testing.T) {
|
||||
provider := &BedrockProvider{}
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
key := schemas.Key{
|
||||
Aliases: schemas.KeyAliases{
|
||||
"cohere-rerank": "cohere.rerank-v3-5:0",
|
||||
},
|
||||
}
|
||||
|
||||
response, bifrostErr := provider.Rerank(ctx, key, &schemas.BifrostRerankRequest{
|
||||
Model: "cohere-rerank",
|
||||
Query: "capital of france",
|
||||
Documents: []schemas.RerankDocument{
|
||||
{Text: "Paris is the capital of France."},
|
||||
},
|
||||
})
|
||||
|
||||
require.Nil(t, response)
|
||||
require.NotNil(t, bifrostErr)
|
||||
require.NotNil(t, bifrostErr.Error)
|
||||
assert.Contains(t, bifrostErr.Error.Message, "requires an ARN")
|
||||
}
|
||||
3394
core/providers/bedrock/responses.go
Normal file
3394
core/providers/bedrock/responses.go
Normal file
File diff suppressed because it is too large
Load Diff
130
core/providers/bedrock/s3.go
Normal file
130
core/providers/bedrock/s3.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// uploadToS3 uploads content to an S3 bucket using the provided credentials.
|
||||
func uploadToS3(
|
||||
ctx context.Context,
|
||||
accessKey, secretKey string,
|
||||
sessionToken *string,
|
||||
region string,
|
||||
bucket, key string,
|
||||
content []byte,
|
||||
) *schemas.BifrostError {
|
||||
// Create AWS config with credentials
|
||||
var cfg aws.Config
|
||||
var err error
|
||||
|
||||
if accessKey != "" && secretKey != "" {
|
||||
// Use provided credentials
|
||||
var creds aws.CredentialsProvider
|
||||
if sessionToken != nil && *sessionToken != "" {
|
||||
creds = credentials.NewStaticCredentialsProvider(accessKey, secretKey, *sessionToken)
|
||||
} else {
|
||||
creds = credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")
|
||||
}
|
||||
|
||||
cfg, err = config.LoadDefaultConfig(ctx,
|
||||
config.WithRegion(region),
|
||||
config.WithCredentialsProvider(creds),
|
||||
)
|
||||
} else {
|
||||
// Use default credentials chain (IAM role, env vars, etc.)
|
||||
cfg, err = config.LoadDefaultConfig(ctx, config.WithRegion(region))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to load aws config for s3", err)
|
||||
}
|
||||
|
||||
// Create S3 client
|
||||
client := s3.NewFromConfig(cfg)
|
||||
|
||||
// Upload the content
|
||||
_, err = client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: aws.String(bucket),
|
||||
Key: aws.String(key),
|
||||
Body: bytes.NewReader(content),
|
||||
ContentType: aws.String("application/jsonl"),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError(fmt.Sprintf("failed to upload to s3: %s/%s", bucket, key), err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateBatchInputS3Key generates a unique S3 key for batch input files.
|
||||
func generateBatchInputS3Key(jobName string) string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
return fmt.Sprintf("bifrost-batch-input/%s-%d.jsonl", jobName, timestamp)
|
||||
}
|
||||
|
||||
// deriveInputS3URIFromOutput derives an input S3 URI from the output S3 URI.
|
||||
// It uses the same bucket but with a different path for input files.
|
||||
func deriveInputS3URIFromOutput(outputS3URI, inputKey string) string {
|
||||
bucket, _ := parseS3URI(outputS3URI)
|
||||
return fmt.Sprintf("s3://%s/%s", bucket, inputKey)
|
||||
}
|
||||
|
||||
// ConvertBedrockRequestsToJSONL converts batch request items to JSONL format for Bedrock.
|
||||
// Bedrock uses a specific format for batch inference requests.
|
||||
func ConvertBedrockRequestsToJSONL(requests []schemas.BatchRequestItem, modelID *string) ([]byte, error) {
|
||||
// Model ID is required for Bedrock batch JSONL conversion
|
||||
if modelID == nil || *modelID == "" {
|
||||
return nil, fmt.Errorf("modelID is required for Bedrock batch JSONL conversion")
|
||||
}
|
||||
// Initialize the buffer
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Iterate over the requests
|
||||
for _, req := range requests {
|
||||
// Build the Bedrock batch request format
|
||||
bedrockReq := map[string]interface{}{
|
||||
"recordId": req.CustomID,
|
||||
"modelInput": map[string]interface{}{
|
||||
"modelId": *modelID,
|
||||
},
|
||||
}
|
||||
|
||||
// If the request has a body, use it as the model input parameters
|
||||
if req.Body != nil {
|
||||
modelInput := bedrockReq["modelInput"].(map[string]interface{})
|
||||
for k, v := range req.Body {
|
||||
if k != "model" { // Don't override modelId
|
||||
modelInput[k] = v
|
||||
}
|
||||
}
|
||||
} else if req.Params != nil {
|
||||
modelInput := bedrockReq["modelInput"].(map[string]interface{})
|
||||
for k, v := range req.Params {
|
||||
if k != "model" {
|
||||
modelInput[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal the request as a JSON line
|
||||
line, err := providerUtils.MarshalSorted(bedrockReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal batch request item %s: %w", req.CustomID, err)
|
||||
}
|
||||
buf.Write(line)
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
433
core/providers/bedrock/signer.go
Normal file
433
core/providers/bedrock/signer.go
Normal file
@@ -0,0 +1,433 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/smithy-go/encoding/httpbinding"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
signingAlgorithm = "AWS4-HMAC-SHA256"
|
||||
amzDateKey = "X-Amz-Date"
|
||||
amzSecurityToken = "X-Amz-Security-Token"
|
||||
timeFormat = "20060102T150405Z"
|
||||
shortTimeFormat = "20060102"
|
||||
)
|
||||
|
||||
// Headers to ignore during signing
|
||||
var ignoredHeaders = map[string]struct{}{
|
||||
"authorization": {},
|
||||
"user-agent": {},
|
||||
"x-amzn-trace-id": {},
|
||||
"expect": {},
|
||||
"transfer-encoding": {},
|
||||
}
|
||||
|
||||
// signingKeyCache caches derived signing keys to avoid recomputation
|
||||
type signingKeyCache struct {
|
||||
cache map[string]cachedKey
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type cachedKey struct {
|
||||
key []byte
|
||||
date string // YYYYMMDD format
|
||||
accessKey string
|
||||
}
|
||||
|
||||
var keyCache = &signingKeyCache{
|
||||
cache: make(map[string]cachedKey),
|
||||
}
|
||||
|
||||
// hmacSHA256 computes HMAC-SHA256
|
||||
func hmacSHA256(key, data []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// deriveSigningKey derives the AWS signing key
|
||||
func deriveSigningKey(secret, dateStamp, region, service string) []byte {
|
||||
kDate := hmacSHA256([]byte("AWS4"+secret), []byte(dateStamp))
|
||||
kRegion := hmacSHA256(kDate, []byte(region))
|
||||
kService := hmacSHA256(kRegion, []byte(service))
|
||||
kSigning := hmacSHA256(kService, []byte("aws4_request"))
|
||||
return kSigning
|
||||
}
|
||||
|
||||
// getSigningKey retrieves or computes the signing key with caching
|
||||
func getSigningKey(accessKey, secretKey, dateStamp, region, service string) []byte {
|
||||
cacheKey := fmt.Sprintf("%s/%s/%s/%s", accessKey, dateStamp, region, service)
|
||||
|
||||
keyCache.mu.RLock()
|
||||
if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp {
|
||||
keyCache.mu.RUnlock()
|
||||
return cached.key
|
||||
}
|
||||
keyCache.mu.RUnlock()
|
||||
|
||||
keyCache.mu.Lock()
|
||||
defer keyCache.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if cached, ok := keyCache.cache[cacheKey]; ok && cached.accessKey == accessKey && cached.date == dateStamp {
|
||||
return cached.key
|
||||
}
|
||||
|
||||
key := deriveSigningKey(secretKey, dateStamp, region, service)
|
||||
keyCache.cache[cacheKey] = cachedKey{
|
||||
key: key,
|
||||
date: dateStamp,
|
||||
accessKey: accessKey,
|
||||
}
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
// stripExcessSpaces removes excess spaces from a string
|
||||
func stripExcessSpaces(str string) string {
|
||||
str = strings.TrimSpace(str)
|
||||
if !strings.Contains(str, " ") {
|
||||
return str
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.Grow(len(str))
|
||||
prevWasSpace := false
|
||||
|
||||
for _, ch := range str {
|
||||
if ch == ' ' {
|
||||
if !prevWasSpace {
|
||||
result.WriteRune(ch)
|
||||
}
|
||||
prevWasSpace = true
|
||||
} else {
|
||||
result.WriteRune(ch)
|
||||
prevWasSpace = false
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// percentEncodeRFC3986 encodes a string per RFC 3986
|
||||
// Keep unreserved characters (A-Z, a-z, 0-9, -, _, ., ~) as-is
|
||||
// Percent-encode everything else as %HH using uppercase hex
|
||||
func percentEncodeRFC3986(s string) string {
|
||||
var result strings.Builder
|
||||
result.Grow(len(s))
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
b := s[i]
|
||||
// RFC 3986 unreserved characters
|
||||
if (b >= 'A' && b <= 'Z') ||
|
||||
(b >= 'a' && b <= 'z') ||
|
||||
(b >= '0' && b <= '9') ||
|
||||
b == '-' || b == '_' || b == '.' || b == '~' {
|
||||
result.WriteByte(b)
|
||||
} else {
|
||||
// Percent-encode with uppercase hex
|
||||
result.WriteByte('%')
|
||||
result.WriteByte(uppercaseHex(b >> 4))
|
||||
result.WriteByte(uppercaseHex(b & 0x0F))
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// uppercaseHex returns the uppercase hex character for a nibble (0-15)
|
||||
func uppercaseHex(b byte) byte {
|
||||
if b < 10 {
|
||||
return '0' + b
|
||||
}
|
||||
return 'A' + (b - 10)
|
||||
}
|
||||
|
||||
// percentDecode decodes percent-encoded sequences in a string without treating + as space
|
||||
// This differs from url.QueryUnescape which uses form encoding (+ becomes space)
|
||||
func percentDecode(s string) string {
|
||||
// Quick check if there are any percent signs
|
||||
if !strings.Contains(s, "%") {
|
||||
return s
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
result.Grow(len(s))
|
||||
|
||||
for i := 0; i < len(s); {
|
||||
if s[i] == '%' && i+2 < len(s) {
|
||||
// Try to decode the hex sequence
|
||||
if h1 := unhex(s[i+1]); h1 >= 0 {
|
||||
if h2 := unhex(s[i+2]); h2 >= 0 {
|
||||
result.WriteByte(byte(h1<<4 | h2))
|
||||
i += 3
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
result.WriteByte(s[i])
|
||||
i++
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// unhex converts a hex character to its value, or -1 if not a hex char
|
||||
func unhex(c byte) int {
|
||||
switch {
|
||||
case '0' <= c && c <= '9':
|
||||
return int(c - '0')
|
||||
case 'a' <= c && c <= 'f':
|
||||
return int(c - 'a' + 10)
|
||||
case 'A' <= c && c <= 'F':
|
||||
return int(c - 'A' + 10)
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// queryPair represents a query parameter name-value pair
|
||||
type queryPair struct {
|
||||
encodedName string
|
||||
encodedValue string
|
||||
}
|
||||
|
||||
// buildCanonicalQueryString builds a canonical query string per AWS SigV4 spec
|
||||
// using proper RFC 3986 percent-encoding
|
||||
func buildCanonicalQueryString(queryString string) string {
|
||||
if queryString == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Split the raw query string on '&' into pairs
|
||||
rawPairs := strings.Split(queryString, "&")
|
||||
pairs := make([]queryPair, 0, len(rawPairs))
|
||||
|
||||
for _, rawPair := range rawPairs {
|
||||
if rawPair == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Split on the first '=' to get name and value
|
||||
var name, value string
|
||||
if idx := strings.IndexByte(rawPair, '='); idx >= 0 {
|
||||
name = rawPair[:idx]
|
||||
value = rawPair[idx+1:]
|
||||
} else {
|
||||
// No '=' means name only, empty value
|
||||
name = rawPair
|
||||
value = ""
|
||||
}
|
||||
|
||||
// Decode percent-encoded sequences first to normalize (handles already-encoded values)
|
||||
// then encode per RFC 3986 to ensure consistent encoding
|
||||
// Note: We use percentDecode instead of url.QueryUnescape because the latter
|
||||
// treats + as space (form encoding), but we need + to encode as %2B
|
||||
decodedName := percentDecode(name)
|
||||
decodedValue := percentDecode(value)
|
||||
|
||||
// Percent-encode name and value per RFC 3986
|
||||
encodedName := percentEncodeRFC3986(decodedName)
|
||||
encodedValue := percentEncodeRFC3986(decodedValue)
|
||||
|
||||
pairs = append(pairs, queryPair{
|
||||
encodedName: encodedName,
|
||||
encodedValue: encodedValue,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort pairs lexicographically by encoded name, then by encoded value
|
||||
sort.Slice(pairs, func(i, j int) bool {
|
||||
if pairs[i].encodedName != pairs[j].encodedName {
|
||||
return pairs[i].encodedName < pairs[j].encodedName
|
||||
}
|
||||
return pairs[i].encodedValue < pairs[j].encodedValue
|
||||
})
|
||||
|
||||
// Join encoded pairs with '&'
|
||||
var result strings.Builder
|
||||
for i, pair := range pairs {
|
||||
if i > 0 {
|
||||
result.WriteByte('&')
|
||||
}
|
||||
result.WriteString(pair.encodedName)
|
||||
result.WriteByte('=')
|
||||
result.WriteString(pair.encodedValue)
|
||||
}
|
||||
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// signAWSRequestFastHTTP signs a fasthttp request using AWS Signature Version 4
|
||||
// This is a native implementation that avoids allocating http.Request
|
||||
func signAWSRequestFastHTTP(
|
||||
ctx context.Context,
|
||||
req *fasthttp.Request,
|
||||
body []byte,
|
||||
accessKey, secretKey string,
|
||||
sessionToken *string,
|
||||
region, service string,
|
||||
) *schemas.BifrostError {
|
||||
// Get AWS credentials if not provided
|
||||
if accessKey == "" && secretKey == "" {
|
||||
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to load aws config", err)
|
||||
}
|
||||
creds, err := cfg.Credentials.Retrieve(ctx)
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to retrieve aws credentials", err)
|
||||
}
|
||||
accessKey = creds.AccessKeyID
|
||||
secretKey = creds.SecretAccessKey
|
||||
if creds.SessionToken != "" {
|
||||
st := creds.SessionToken
|
||||
sessionToken = &st
|
||||
}
|
||||
}
|
||||
|
||||
// Get current time
|
||||
now := time.Now().UTC()
|
||||
amzDate := now.Format(timeFormat)
|
||||
dateStamp := now.Format(shortTimeFormat)
|
||||
|
||||
// Parse URI
|
||||
uri := req.URI()
|
||||
host := string(uri.Host())
|
||||
path := string(uri.Path())
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
queryString := string(uri.QueryString())
|
||||
|
||||
// Escape path for canonical URI (Bedrock doesn't disable escaping)
|
||||
canonicalURI := httpbinding.EscapePath(path, false)
|
||||
|
||||
// Calculate payload hash
|
||||
hash := sha256.Sum256(body)
|
||||
payloadHash := hex.EncodeToString(hash[:])
|
||||
|
||||
// Set required headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set(amzDateKey, amzDate)
|
||||
if sessionToken != nil && *sessionToken != "" {
|
||||
req.Header.Set(amzSecurityToken, *sessionToken)
|
||||
}
|
||||
|
||||
// Build canonical headers
|
||||
var headerNames []string
|
||||
headerMap := make(map[string][]string)
|
||||
|
||||
// Always include host
|
||||
headerNames = append(headerNames, "host")
|
||||
headerMap["host"] = []string{host}
|
||||
|
||||
// Include content-length if body is present
|
||||
if cl := req.Header.ContentLength(); cl >= 0 {
|
||||
headerNames = append(headerNames, "content-length")
|
||||
headerMap["content-length"] = []string{strconv.Itoa(cl)}
|
||||
}
|
||||
|
||||
// Collect other headers
|
||||
for key, value := range req.Header.All() {
|
||||
keyStr := strings.ToLower(string(key))
|
||||
|
||||
// Skip ignored headers
|
||||
if _, ignore := ignoredHeaders[keyStr]; ignore {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if already handled
|
||||
if keyStr == "host" || keyStr == "content-length" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := headerMap[keyStr]; !exists {
|
||||
headerNames = append(headerNames, keyStr)
|
||||
}
|
||||
headerMap[keyStr] = append(headerMap[keyStr], string(value))
|
||||
}
|
||||
|
||||
// Sort header names
|
||||
sort.Strings(headerNames)
|
||||
|
||||
// Build canonical headers string
|
||||
var canonicalHeaders strings.Builder
|
||||
for _, name := range headerNames {
|
||||
canonicalHeaders.WriteString(name)
|
||||
canonicalHeaders.WriteRune(':')
|
||||
|
||||
values := headerMap[name]
|
||||
for i, v := range values {
|
||||
cleanedValue := stripExcessSpaces(v)
|
||||
canonicalHeaders.WriteString(cleanedValue)
|
||||
if i < len(values)-1 {
|
||||
canonicalHeaders.WriteRune(',')
|
||||
}
|
||||
}
|
||||
canonicalHeaders.WriteRune('\n')
|
||||
}
|
||||
|
||||
signedHeaders := strings.Join(headerNames, ";")
|
||||
|
||||
// Build canonical query string using RFC 3986 encoding
|
||||
canonicalQueryString := buildCanonicalQueryString(queryString)
|
||||
|
||||
// Build canonical request
|
||||
canonicalRequest := strings.Join([]string{
|
||||
string(req.Header.Method()),
|
||||
canonicalURI,
|
||||
canonicalQueryString,
|
||||
canonicalHeaders.String(),
|
||||
signedHeaders,
|
||||
payloadHash,
|
||||
}, "\n")
|
||||
|
||||
// Build credential scope
|
||||
credentialScope := strings.Join([]string{
|
||||
dateStamp,
|
||||
region,
|
||||
service,
|
||||
"aws4_request",
|
||||
}, "/")
|
||||
|
||||
// Build string to sign
|
||||
canonicalRequestHash := sha256.Sum256([]byte(canonicalRequest))
|
||||
stringToSign := strings.Join([]string{
|
||||
signingAlgorithm,
|
||||
amzDate,
|
||||
credentialScope,
|
||||
hex.EncodeToString(canonicalRequestHash[:]),
|
||||
}, "\n")
|
||||
|
||||
// Calculate signature
|
||||
signingKey := getSigningKey(accessKey, secretKey, dateStamp, region, service)
|
||||
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
|
||||
|
||||
// Build authorization header
|
||||
authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
signingAlgorithm,
|
||||
accessKey,
|
||||
credentialScope,
|
||||
signedHeaders,
|
||||
signature,
|
||||
)
|
||||
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
|
||||
return nil
|
||||
}
|
||||
229
core/providers/bedrock/text.go
Normal file
229
core/providers/bedrock/text.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/anthropic"
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBedrockTextCompletionRequest converts a Bifrost text completion request to Bedrock format
|
||||
func ToBedrockTextCompletionRequest(bifrostReq *schemas.BifrostTextCompletionRequest) *BedrockTextCompletionRequest {
|
||||
if bifrostReq == nil || (bifrostReq.Input.PromptStr == nil && len(bifrostReq.Input.PromptArray) == 0) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract the raw prompt from bifrostReq
|
||||
prompt := ""
|
||||
if bifrostReq.Input != nil {
|
||||
if bifrostReq.Input.PromptStr != nil {
|
||||
prompt = *bifrostReq.Input.PromptStr
|
||||
} else if len(bifrostReq.Input.PromptArray) > 0 && bifrostReq.Input.PromptArray != nil {
|
||||
prompt = strings.Join(bifrostReq.Input.PromptArray, "\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
bedrockReq := &BedrockTextCompletionRequest{
|
||||
Prompt: prompt,
|
||||
}
|
||||
|
||||
// Apply parameters
|
||||
if bifrostReq.Params != nil {
|
||||
bedrockReq.Temperature = bifrostReq.Params.Temperature
|
||||
bedrockReq.TopP = bifrostReq.Params.TopP
|
||||
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
bedrockReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok {
|
||||
delete(bedrockReq.ExtraParams, "top_k")
|
||||
bedrockReq.TopK = topK
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply model-specific formatting and field naming
|
||||
if strings.Contains(bifrostReq.Model, "anthropic.") || strings.Contains(bifrostReq.Model, "claude") {
|
||||
// For Claude models, wrap the prompt in Anthropic format and use Anthropic field names
|
||||
anthropicReq := anthropic.ToAnthropicTextCompletionRequest(bifrostReq)
|
||||
bedrockReq.Prompt = anthropicReq.Prompt
|
||||
bedrockReq.MaxTokensToSample = &anthropicReq.MaxTokensToSample
|
||||
bedrockReq.StopSequences = anthropicReq.StopSequences
|
||||
} else {
|
||||
// For other models, use standard field names with raw prompt
|
||||
if bifrostReq.Params != nil {
|
||||
bedrockReq.MaxTokens = bifrostReq.Params.MaxTokens
|
||||
bedrockReq.Stop = bifrostReq.Params.Stop
|
||||
}
|
||||
}
|
||||
|
||||
return bedrockReq
|
||||
}
|
||||
|
||||
// ToBifrostTextCompletionRequest converts a Bedrock text completion request to Bifrost format
|
||||
func (request *BedrockTextCompletionRequest) ToBifrostTextCompletionRequest(ctx *schemas.BifrostContext) *schemas.BifrostTextCompletionRequest {
|
||||
if request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
prompt := request.Prompt
|
||||
// Fallback for Claude 3 Messages API
|
||||
if prompt == "" && len(request.Messages) > 0 {
|
||||
var parts []string
|
||||
for _, msg := range request.Messages {
|
||||
for _, content := range msg.Content {
|
||||
if content.Text != nil {
|
||||
parts = append(parts, *content.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
prompt = strings.Join(parts, "\n\n")
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(request.ModelID, utils.CheckAndSetDefaultProvider(ctx, schemas.Bedrock))
|
||||
|
||||
bifrostReq := &schemas.BifrostTextCompletionRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: &schemas.TextCompletionInput{
|
||||
PromptStr: &prompt,
|
||||
},
|
||||
Params: &schemas.TextCompletionParameters{
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
},
|
||||
}
|
||||
|
||||
if request.MaxTokens != nil {
|
||||
bifrostReq.Params.MaxTokens = request.MaxTokens
|
||||
} else if request.MaxTokensToSample != nil {
|
||||
bifrostReq.Params.MaxTokens = request.MaxTokensToSample
|
||||
}
|
||||
|
||||
if len(request.Stop) > 0 {
|
||||
bifrostReq.Params.Stop = request.Stop
|
||||
} else if len(request.StopSequences) > 0 {
|
||||
bifrostReq.Params.Stop = request.StopSequences
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToBifrostTextCompletionResponse converts a Bedrock Anthropic text response to Bifrost format
|
||||
func (response *BedrockAnthropicTextResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &schemas.BifrostTextCompletionResponse{
|
||||
Object: "text_completion",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
|
||||
Text: &response.Completion,
|
||||
},
|
||||
FinishReason: &response.StopReason,
|
||||
},
|
||||
},
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ToBifrostTextCompletionResponse converts a Bedrock Mistral text response to Bifrost format
|
||||
func (response *BedrockMistralTextResponse) ToBifrostTextCompletionResponse() *schemas.BifrostTextCompletionResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var choices []schemas.BifrostResponseChoice
|
||||
for i, output := range response.Outputs {
|
||||
choices = append(choices, schemas.BifrostResponseChoice{
|
||||
Index: i,
|
||||
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
|
||||
Text: &output.Text,
|
||||
},
|
||||
FinishReason: &output.StopReason,
|
||||
})
|
||||
}
|
||||
|
||||
return &schemas.BifrostTextCompletionResponse{
|
||||
Object: "text_completion",
|
||||
Choices: choices,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ToBedrockTextCompletionResponse converts a BifrostTextCompletionResponse back to Bedrock text completion format
|
||||
// Returns either *BedrockAnthropicTextResponse or *BedrockMistralTextResponse based on the model
|
||||
func ToBedrockTextCompletionResponse(bifrostResp *schemas.BifrostTextCompletionResponse) interface{} {
|
||||
if bifrostResp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Determine response format based on resolved model identity.
|
||||
// Use ResolvedModelUsed (actual provider ID) for accurate family detection,
|
||||
// falling back to bifrostResp.Model, then OriginalModelRequested as a last resort.
|
||||
model := bifrostResp.Model
|
||||
if bifrostResp.ExtraFields.ResolvedModelUsed != "" {
|
||||
model = bifrostResp.ExtraFields.ResolvedModelUsed
|
||||
} else if model == "" && bifrostResp.ExtraFields.OriginalModelRequested != "" {
|
||||
model = bifrostResp.ExtraFields.OriginalModelRequested
|
||||
}
|
||||
|
||||
if strings.Contains(model, "anthropic.") || strings.Contains(model, "claude") {
|
||||
// Convert to Anthropic format
|
||||
bedrockResp := &BedrockAnthropicTextResponse{}
|
||||
|
||||
// Convert choices to completion text
|
||||
if len(bifrostResp.Choices) > 0 {
|
||||
choice := bifrostResp.Choices[0] // Anthropic text API typically returns one choice
|
||||
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
|
||||
bedrockResp.Completion = *choice.TextCompletionResponseChoice.Text
|
||||
}
|
||||
if choice.FinishReason != nil {
|
||||
bedrockResp.StopReason = *choice.FinishReason
|
||||
}
|
||||
}
|
||||
|
||||
return bedrockResp
|
||||
} else if strings.Contains(model, "mistral.") {
|
||||
// Convert to Mistral format
|
||||
bedrockResp := &BedrockMistralTextResponse{}
|
||||
|
||||
// Convert choices to outputs
|
||||
for _, choice := range bifrostResp.Choices {
|
||||
var output struct {
|
||||
Text string `json:"text"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
}
|
||||
|
||||
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
|
||||
output.Text = *choice.TextCompletionResponseChoice.Text
|
||||
}
|
||||
if choice.FinishReason != nil {
|
||||
output.StopReason = *choice.FinishReason
|
||||
}
|
||||
|
||||
bedrockResp.Outputs = append(bedrockResp.Outputs, output)
|
||||
}
|
||||
|
||||
return bedrockResp
|
||||
}
|
||||
|
||||
// Default to Anthropic format if model type cannot be determined
|
||||
bedrockResp := &BedrockAnthropicTextResponse{}
|
||||
if len(bifrostResp.Choices) > 0 {
|
||||
choice := bifrostResp.Choices[0]
|
||||
if choice.TextCompletionResponseChoice != nil && choice.TextCompletionResponseChoice.Text != nil {
|
||||
bedrockResp.Completion = *choice.TextCompletionResponseChoice.Text
|
||||
}
|
||||
if choice.FinishReason != nil {
|
||||
bedrockResp.StopReason = *choice.FinishReason
|
||||
}
|
||||
}
|
||||
|
||||
return bedrockResp
|
||||
}
|
||||
714
core/providers/bedrock/transport_test.go
Normal file
714
core/providers/bedrock/transport_test.go
Normal file
@@ -0,0 +1,714 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// redirectTransport is an http.RoundTripper that rewrites every request's
|
||||
// host/scheme to a fixed target URL, used to redirect provider requests to a
|
||||
// local httptest.Server without modifying provider code.
|
||||
type redirectTransport struct {
|
||||
target *url.URL
|
||||
transport http.RoundTripper
|
||||
}
|
||||
|
||||
func (r *redirectTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
cloned := req.Clone(req.Context())
|
||||
cloned.URL.Scheme = r.target.Scheme
|
||||
cloned.URL.Host = r.target.Host
|
||||
cloned.Host = r.target.Host
|
||||
return r.transport.RoundTrip(cloned)
|
||||
}
|
||||
|
||||
// noopLogger is a no-op schemas.Logger for use in tests.
|
||||
type noopLogger struct{}
|
||||
|
||||
func (noopLogger) Debug(string, ...any) {}
|
||||
func (noopLogger) Info(string, ...any) {}
|
||||
func (noopLogger) Warn(string, ...any) {}
|
||||
func (noopLogger) Error(string, ...any) {}
|
||||
func (noopLogger) Fatal(string, ...any) {}
|
||||
func (noopLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (noopLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (noopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
// newTestProviderWithServer returns a BedrockProvider whose HTTP client is
|
||||
// redirected to the given httptest.Server.
|
||||
func newTestProviderWithServer(t *testing.T, ts *httptest.Server) *BedrockProvider {
|
||||
t.Helper()
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 5,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
provider, err := NewBedrockProvider(config, noopLogger{})
|
||||
require.NoError(t, err)
|
||||
|
||||
targetURL, err := url.Parse(ts.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
redirect := &redirectTransport{
|
||||
target: targetURL,
|
||||
transport: ts.Client().Transport,
|
||||
}
|
||||
provider.client = &http.Client{
|
||||
Transport: redirect,
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
// Streaming paths use streamingClient (no Timeout); redirect it to the
|
||||
// test server too, otherwise Bedrock streaming tests would hit the real
|
||||
// AWS endpoint.
|
||||
provider.streamingClient = &http.Client{Transport: redirect}
|
||||
return provider
|
||||
}
|
||||
|
||||
// testBedrockKey returns a minimal Key with a bearer value so makeStreamingRequest
|
||||
// skips IAM signing and proceeds to the HTTP call.
|
||||
func testBedrockKey() schemas.Key {
|
||||
region := schemas.NewEnvVar("us-east-1")
|
||||
return schemas.Key{
|
||||
Value: *schemas.NewEnvVar("test-api-key"),
|
||||
BedrockKeyConfig: &schemas.BedrockKeyConfig{
|
||||
Region: region,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// testBedrockCtx returns a BifrostContext suitable for unit tests.
|
||||
func testBedrockCtx() *schemas.BifrostContext {
|
||||
return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
}
|
||||
|
||||
// noopPostHookRunner is a PostHookRunner that passes through results unchanged.
|
||||
func noopPostHookRunner(_ *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
||||
return result, err
|
||||
}
|
||||
|
||||
// testChatRequest returns a minimal BifrostChatRequest for streaming tests.
|
||||
func testChatRequest() *schemas.BifrostChatRequest {
|
||||
content := "hello"
|
||||
return &schemas.BifrostChatRequest{
|
||||
Model: "anthropic.claude-sonnet-4-5",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: &content},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeStreamingRequest_StaleConnection_IsRetryable verifies that when the
|
||||
// HTTP server closes the connection before sending a response (simulating a
|
||||
// stale HTTP/2 connection), makeStreamingRequest returns a BifrostError with
|
||||
// IsBifrostError:false so the retry gate in executeRequestWithRetries retries.
|
||||
func TestMakeStreamingRequest_StaleConnection_IsRetryable(t *testing.T) {
|
||||
// Server that immediately closes the connection without sending anything.
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "hijack not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
conn, _, _ := hj.Hijack()
|
||||
conn.Close() // close without writing any response
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
ctx := testBedrockCtx()
|
||||
key := testBedrockKey()
|
||||
|
||||
_, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream")
|
||||
|
||||
require.NotNil(t, bifrostErr, "expected error when server closes connection")
|
||||
assert.False(t, bifrostErr.IsBifrostError,
|
||||
"stale-connection error must be IsBifrostError:false so the retry gate can retry it")
|
||||
require.NotNil(t, bifrostErr.Error)
|
||||
// Either ErrProviderNetworkError (net.OpError) or ErrProviderDoRequest (EOF/connection-reset)
|
||||
// are both retryable — the key invariant is IsBifrostError:false.
|
||||
assert.Contains(t, []string{schemas.ErrProviderNetworkError, schemas.ErrProviderDoRequest}, bifrostErr.Error.Message,
|
||||
"stale-connection error must use a retryable error message")
|
||||
}
|
||||
|
||||
// TestChatCompletionStream_StaleConnection_ChunkIsRetryable verifies that when
|
||||
// the server returns HTTP 200 but closes the body immediately (simulating a
|
||||
// stale connection mid-stream before any EventStream data arrives), the first
|
||||
// chunk received from the stream channel carries a BifrostError with
|
||||
// IsBifrostError:false so that CheckFirstStreamChunkForError + the retry gate
|
||||
// can transparently retry the request.
|
||||
func TestChatCompletionStream_StaleConnection_ChunkIsRetryable(t *testing.T) {
|
||||
// Server: returns 200 with the correct content-type but closes body immediately.
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
conn, _, _ := hj.Hijack()
|
||||
conn.Close() // close without any EventStream bytes
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
ctx := testBedrockCtx()
|
||||
key := testBedrockKey()
|
||||
|
||||
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
|
||||
|
||||
if bifrostErr != nil {
|
||||
// Error surfaced synchronously (e.g. connection refused before HTTP 200).
|
||||
assert.False(t, bifrostErr.IsBifrostError,
|
||||
"pre-stream network error must be IsBifrostError:false")
|
||||
return
|
||||
}
|
||||
|
||||
// Error surfaced as the first stream chunk.
|
||||
require.NotNil(t, streamChan)
|
||||
chunk, ok := <-streamChan
|
||||
require.True(t, ok, "channel must not be empty")
|
||||
require.NotNil(t, chunk)
|
||||
require.NotNil(t, chunk.BifrostError, "expected an error chunk from the stream")
|
||||
|
||||
assert.False(t, chunk.BifrostError.IsBifrostError,
|
||||
"stream transport error must be IsBifrostError:false so the retry gate can retry it")
|
||||
require.NotNil(t, chunk.BifrostError.Error)
|
||||
assert.Equal(t, schemas.ErrProviderNetworkError, chunk.BifrostError.Error.Message,
|
||||
"stream transport error must use ErrProviderNetworkError message")
|
||||
|
||||
// Drain any remaining chunks.
|
||||
for range streamChan {
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatCompletionStream_NetOpError_ChunkIsRetryable verifies the specific
|
||||
// "use of closed network connection" *net.OpError scenario from issue #2424:
|
||||
// a successful HTTP connection that is then closed server-side produces a
|
||||
// *net.OpError during EventStream decoding, which must arrive as a retryable
|
||||
// IsBifrostError:false chunk.
|
||||
func TestChatCompletionStream_NetOpError_ChunkIsRetryable(t *testing.T) {
|
||||
// Server: returns 200 + correct headers, writes a truncated EventStream
|
||||
// prelude (not a valid frame), then forcibly resets the TCP connection —
|
||||
// producing a *net.OpError on the client's read side.
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
// Write a partial EventStream frame header (3 bytes, not a valid frame).
|
||||
_, _ = w.Write([]byte{0x00, 0x00, 0x00})
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
conn, _, _ := hj.Hijack()
|
||||
// RST instead of FIN — guarantees a *net.OpError on the client read.
|
||||
if tc, ok := conn.(*net.TCPConn); ok {
|
||||
_ = tc.SetLinger(0)
|
||||
}
|
||||
conn.Close()
|
||||
}))
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
ctx := testBedrockCtx()
|
||||
key := testBedrockKey()
|
||||
|
||||
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
|
||||
if bifrostErr != nil {
|
||||
assert.False(t, bifrostErr.IsBifrostError,
|
||||
"pre-stream network error must be IsBifrostError:false")
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, streamChan)
|
||||
|
||||
// Collect chunks until we find an error chunk (may not be the very first
|
||||
// if the OS buffers the partial write, but it must appear before close).
|
||||
var errChunk *schemas.BifrostStreamChunk
|
||||
for chunk := range streamChan {
|
||||
if chunk != nil && chunk.BifrostError != nil {
|
||||
errChunk = chunk
|
||||
break
|
||||
}
|
||||
}
|
||||
// Drain remaining.
|
||||
for range streamChan {
|
||||
}
|
||||
|
||||
require.NotNil(t, errChunk, "expected an error chunk from the stream")
|
||||
assert.False(t, errChunk.BifrostError.IsBifrostError,
|
||||
"net.OpError during EventStream decoding must be IsBifrostError:false so the retry gate can retry it")
|
||||
require.NotNil(t, errChunk.BifrostError.Error)
|
||||
assert.Equal(t, schemas.ErrProviderNetworkError, errChunk.BifrostError.Error.Message,
|
||||
"net.OpError during EventStream decoding must use ErrProviderNetworkError message")
|
||||
}
|
||||
|
||||
// writeEventStreamException encodes a well-formed AWS EventStream exception
|
||||
// frame with the given exception type and message into w.
|
||||
// The frame format is: prelude (total_len + headers_len + CRC) + headers + payload + message_CRC.
|
||||
// We use the AWS SDK's eventstream.Encoder so the binary framing is correct.
|
||||
func writeEventStreamException(t *testing.T, w io.Writer, excType, msg string) {
|
||||
t.Helper()
|
||||
enc := eventstream.NewEncoder()
|
||||
payload, err := json.Marshal(map[string]string{"message": msg})
|
||||
require.NoError(t, err, "failed to marshal exception payload")
|
||||
headers := eventstream.Headers{
|
||||
{Name: ":message-type", Value: eventstream.StringValue("exception")},
|
||||
{Name: ":exception-type", Value: eventstream.StringValue(excType)},
|
||||
{Name: ":content-type", Value: eventstream.StringValue("application/json")},
|
||||
}
|
||||
err = enc.Encode(w, eventstream.Message{Headers: headers, Payload: payload})
|
||||
require.NoError(t, err, "failed to encode EventStream exception frame")
|
||||
}
|
||||
|
||||
// TestChatCompletionStream_RetryableException_ChunkIsRetryable verifies that
|
||||
// when AWS Bedrock sends a retryable exception (serviceUnavailableException,
|
||||
// throttlingException, etc.) through the EventStream, the resulting error chunk
|
||||
// has IsBifrostError:false and the correct HTTP StatusCode so that the retry
|
||||
// gate in executeRequestWithRetries can retry the request.
|
||||
func TestChatCompletionStream_RetryableException_ChunkIsRetryable(t *testing.T) {
|
||||
tests := []struct {
|
||||
excType string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"serviceUnavailableException", 503},
|
||||
{"throttlingException", 429},
|
||||
{"modelNotReadyException", 503},
|
||||
{"internalServerException", 500},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.excType, func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
writeEventStreamException(t, w, tc.excType, "service is unavailable, please retry")
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
ctx := testBedrockCtx()
|
||||
key := testBedrockKey()
|
||||
|
||||
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
|
||||
require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk")
|
||||
|
||||
require.NotNil(t, streamChan)
|
||||
|
||||
var errChunk *schemas.BifrostStreamChunk
|
||||
for chunk := range streamChan {
|
||||
if chunk != nil && chunk.BifrostError != nil {
|
||||
errChunk = chunk
|
||||
break
|
||||
}
|
||||
}
|
||||
for range streamChan {
|
||||
}
|
||||
|
||||
require.NotNil(t, errChunk, "expected error chunk for %s", tc.excType)
|
||||
assert.False(t, errChunk.BifrostError.IsBifrostError,
|
||||
"%s must be IsBifrostError:false so retry gate can retry it", tc.excType)
|
||||
require.NotNil(t, errChunk.BifrostError.StatusCode,
|
||||
"%s must carry a StatusCode for the retryableStatusCodes gate", tc.excType)
|
||||
assert.Equal(t, tc.expectedStatus, *errChunk.BifrostError.StatusCode,
|
||||
"%s must map to HTTP %d", tc.excType, tc.expectedStatus)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestChatCompletionStream_NonRetryableException_IsTerminal verifies that
|
||||
// non-retryable exception types (e.g. validationException, accessDeniedException)
|
||||
// continue to use ProcessAndSendError (IsBifrostError:true) and are NOT retried.
|
||||
func TestChatCompletionStream_NonRetryableException_IsTerminal(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
writeEventStreamException(t, w, "validationException", "input validation failed")
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
ctx := testBedrockCtx()
|
||||
key := testBedrockKey()
|
||||
|
||||
streamChan, bifrostErr := provider.ChatCompletionStream(ctx, noopPostHookRunner, nil, key, testChatRequest())
|
||||
require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk")
|
||||
|
||||
require.NotNil(t, streamChan)
|
||||
|
||||
var errChunk *schemas.BifrostStreamChunk
|
||||
for chunk := range streamChan {
|
||||
if chunk != nil && chunk.BifrostError != nil {
|
||||
errChunk = chunk
|
||||
break
|
||||
}
|
||||
}
|
||||
for range streamChan {
|
||||
}
|
||||
|
||||
require.NotNil(t, errChunk, "expected error chunk for validationException")
|
||||
assert.True(t, errChunk.BifrostError.IsBifrostError,
|
||||
"non-retryable validationException must remain IsBifrostError:true")
|
||||
}
|
||||
|
||||
// testTextCompletionRequest returns a minimal BifrostTextCompletionRequest for streaming tests.
|
||||
func testTextCompletionRequest() *schemas.BifrostTextCompletionRequest {
|
||||
prompt := "hello"
|
||||
return &schemas.BifrostTextCompletionRequest{
|
||||
Model: "anthropic.claude-sonnet-4-5",
|
||||
Input: &schemas.TextCompletionInput{PromptStr: &prompt},
|
||||
}
|
||||
}
|
||||
|
||||
// testResponsesRequest returns a minimal BifrostResponsesRequest for streaming tests.
|
||||
func testResponsesRequest() *schemas.BifrostResponsesRequest {
|
||||
msgType := schemas.ResponsesMessageType("message")
|
||||
roleUser := schemas.ResponsesMessageRoleType("user")
|
||||
content := "hello"
|
||||
return &schemas.BifrostResponsesRequest{
|
||||
Model: "anthropic.claude-sonnet-4-5",
|
||||
Input: []schemas.ResponsesMessage{
|
||||
{
|
||||
Type: &msgType,
|
||||
Role: &roleUser,
|
||||
Content: &schemas.ResponsesMessageContent{ContentStr: &content},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// assertRetryableExceptionChunk is the shared assertion helper for all three
|
||||
// streaming-method retryable-exception tests.
|
||||
func assertRetryableExceptionChunk(t *testing.T, streamChan chan *schemas.BifrostStreamChunk, bifrostErr *schemas.BifrostError, excType string, expectedStatus int) {
|
||||
t.Helper()
|
||||
require.Nil(t, bifrostErr, "expected EventStream exception to surface as a stream chunk, not a pre-stream error")
|
||||
require.NotNil(t, streamChan)
|
||||
|
||||
var errChunk *schemas.BifrostStreamChunk
|
||||
for chunk := range streamChan {
|
||||
if chunk != nil && chunk.BifrostError != nil {
|
||||
errChunk = chunk
|
||||
break
|
||||
}
|
||||
}
|
||||
for range streamChan {
|
||||
}
|
||||
|
||||
require.NotNil(t, errChunk, "expected error chunk for %s", excType)
|
||||
assert.False(t, errChunk.BifrostError.IsBifrostError,
|
||||
"%s must be IsBifrostError:false so retry gate can retry it", excType)
|
||||
require.NotNil(t, errChunk.BifrostError.StatusCode,
|
||||
"%s must carry a StatusCode for the retryableStatusCodes gate", excType)
|
||||
assert.Equal(t, expectedStatus, *errChunk.BifrostError.StatusCode,
|
||||
"%s must map to HTTP %d", excType, expectedStatus)
|
||||
}
|
||||
|
||||
// TestTextCompletionStream_RetryableException_ChunkIsRetryable mirrors the
|
||||
// ChatCompletionStream test for the TextCompletionStream path, which has
|
||||
// slightly different payload-parsing logic (extra BedrockError JSON unmarshal).
|
||||
func TestTextCompletionStream_RetryableException_ChunkIsRetryable(t *testing.T) {
|
||||
tests := []struct {
|
||||
excType string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"serviceUnavailableException", 503},
|
||||
{"throttlingException", 429},
|
||||
{"modelNotReadyException", 503},
|
||||
{"internalServerException", 500},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.excType, func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
writeEventStreamException(t, w, tc.excType, "service is unavailable, please retry")
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
streamChan, bifrostErr := provider.TextCompletionStream(testBedrockCtx(), noopPostHookRunner, nil, testBedrockKey(), testTextCompletionRequest())
|
||||
assertRetryableExceptionChunk(t, streamChan, bifrostErr, tc.excType, tc.expectedStatus)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResponsesStream_RetryableException_ChunkIsRetryable mirrors the
|
||||
// ChatCompletionStream test for the ResponsesStream path.
|
||||
func TestResponsesStream_RetryableException_ChunkIsRetryable(t *testing.T) {
|
||||
tests := []struct {
|
||||
excType string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"serviceUnavailableException", 503},
|
||||
{"throttlingException", 429},
|
||||
{"modelNotReadyException", 503},
|
||||
{"internalServerException", 500},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.excType, func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/vnd.amazon.eventstream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
writeEventStreamException(t, w, tc.excType, "service is unavailable, please retry")
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := newTestProviderWithServer(t, ts)
|
||||
streamChan, bifrostErr := provider.ResponsesStream(testBedrockCtx(), noopPostHookRunner, nil, testBedrockKey(), testResponsesRequest())
|
||||
assertRetryableExceptionChunk(t, streamChan, bifrostErr, tc.excType, tc.expectedStatus)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateTestCACert(t *testing.T) string {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "testca"},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
|
||||
IsCA: true,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
return string(certPEM)
|
||||
}
|
||||
|
||||
func TestBedrockTransportHTTP2Config(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
MaxConnsPerHost: 5000,
|
||||
EnforceHTTP2: true,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, provider)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok, "transport should be *http.Transport")
|
||||
|
||||
assert.Equal(t, 5000, transport.MaxConnsPerHost)
|
||||
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
|
||||
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConns)
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
|
||||
func TestBedrockTransportCustomMaxConns(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
MaxConnsPerHost: 50,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, 50, transport.MaxConnsPerHost)
|
||||
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
|
||||
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConns)
|
||||
}
|
||||
|
||||
func TestBedrockTransportDefaultMaxConns(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
// MaxConnsPerHost left as 0 — should default to 5000
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
assert.Equal(t, schemas.DefaultMaxConnsPerHost, config.NetworkConfig.MaxConnsPerHost)
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, schemas.DefaultMaxConnsPerHost, transport.MaxConnsPerHost)
|
||||
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConnsPerHost)
|
||||
assert.Equal(t, schemas.DefaultMaxIdleConnsPerHost, transport.MaxIdleConns)
|
||||
}
|
||||
|
||||
func TestBedrockTransportTLSInsecureSkipVerify(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
InsecureSkipVerify: true,
|
||||
EnforceHTTP2: true,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, transport.TLSClientConfig)
|
||||
assert.True(t, transport.TLSClientConfig.InsecureSkipVerify)
|
||||
assert.Equal(t, uint16(tls.VersionTLS12), transport.TLSClientConfig.MinVersion)
|
||||
// ForceAttemptHTTP2 should still be true even with custom TLS config
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
|
||||
func TestBedrockTransportTLSCACert(t *testing.T) {
|
||||
testCACert := generateTestCACert(t)
|
||||
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
CACertPEM: testCACert,
|
||||
EnforceHTTP2: true,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, transport.TLSClientConfig)
|
||||
assert.NotNil(t, transport.TLSClientConfig.RootCAs)
|
||||
assert.Equal(t, uint16(tls.VersionTLS12), transport.TLSClientConfig.MinVersion)
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
|
||||
func TestBedrockTransportDefaultTLS(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
// No TLS settings — should use system defaults
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
// No custom TLS config should be set
|
||||
assert.Nil(t, transport.TLSClientConfig)
|
||||
// EnforceHTTP2 not set — ForceAttemptHTTP2 should be false
|
||||
assert.False(t, transport.ForceAttemptHTTP2)
|
||||
}
|
||||
|
||||
func TestBedrockTransportEnforceHTTP2(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
EnforceHTTP2: true,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
assert.True(t, transport.ForceAttemptHTTP2)
|
||||
// TLSNextProto should NOT be set when HTTP/2 is enforced, allowing ALPN negotiation
|
||||
assert.Nil(t, transport.TLSNextProto)
|
||||
}
|
||||
|
||||
func TestBedrockTransportEnforceHTTP2Disabled(t *testing.T) {
|
||||
config := &schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
EnforceHTTP2: false,
|
||||
},
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
provider, err := NewBedrockProvider(config, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
transport, ok := provider.client.Transport.(*http.Transport)
|
||||
require.True(t, ok)
|
||||
assert.False(t, transport.ForceAttemptHTTP2)
|
||||
// TLSNextProto must be set to empty map to truly disable HTTP/2 ALPN negotiation
|
||||
assert.NotNil(t, transport.TLSNextProto)
|
||||
assert.Empty(t, transport.TLSNextProto)
|
||||
}
|
||||
1130
core/providers/bedrock/types.go
Normal file
1130
core/providers/bedrock/types.go
Normal file
File diff suppressed because it is too large
Load Diff
1843
core/providers/bedrock/utils.go
Normal file
1843
core/providers/bedrock/utils.go
Normal file
File diff suppressed because it is too large
Load Diff
405
core/providers/cerebras/cerebras.go
Normal file
405
core/providers/cerebras/cerebras.go
Normal file
@@ -0,0 +1,405 @@
|
||||
// Package cerebras implements the Cerebras LLM provider.
|
||||
package cerebras
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// CerebrasProvider implements the Provider interface for Cerebras's API.
|
||||
type CerebrasProvider struct {
|
||||
logger schemas.Logger // Logger for provider operations
|
||||
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
|
||||
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
|
||||
networkConfig schemas.NetworkConfig // Network configuration including extra headers
|
||||
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
|
||||
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
|
||||
}
|
||||
|
||||
// NewCerebrasProvider creates a new Cerebras provider instance.
|
||||
// It initializes the HTTP client with the provided configuration and sets up response pools.
|
||||
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
|
||||
func NewCerebrasProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*CerebrasProvider, error) {
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
|
||||
client := &fasthttp.Client{
|
||||
ReadTimeout: requestTimeout,
|
||||
WriteTimeout: requestTimeout,
|
||||
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
|
||||
MaxIdleConnDuration: 30 * time.Second,
|
||||
MaxConnWaitTimeout: requestTimeout,
|
||||
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
|
||||
ConnPoolStrategy: fasthttp.FIFO,
|
||||
}
|
||||
|
||||
// Configure proxy and retry policy
|
||||
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
|
||||
client = providerUtils.ConfigureDialer(client)
|
||||
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
|
||||
streamingClient := providerUtils.BuildStreamingClient(client)
|
||||
// Set default BaseURL if not provided
|
||||
if config.NetworkConfig.BaseURL == "" {
|
||||
config.NetworkConfig.BaseURL = "https://api.cerebras.ai"
|
||||
}
|
||||
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
|
||||
|
||||
return &CerebrasProvider{
|
||||
logger: logger,
|
||||
client: client,
|
||||
streamingClient: streamingClient,
|
||||
networkConfig: config.NetworkConfig,
|
||||
sendBackRawRequest: config.SendBackRawRequest,
|
||||
sendBackRawResponse: config.SendBackRawResponse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetProviderKey returns the provider identifier for Cerebras.
|
||||
func (provider *CerebrasProvider) GetProviderKey() schemas.ModelProvider {
|
||||
return schemas.Cerebras
|
||||
}
|
||||
|
||||
// ListModels performs a list models request to Cerebras's API.
|
||||
func (provider *CerebrasProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIListModelsRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
request,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"),
|
||||
keys,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
provider.GetProviderKey(),
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
)
|
||||
}
|
||||
|
||||
// TextCompletion performs a text completion request to Cerebras's API.
|
||||
// It formats the request, sends it to Cerebras, and processes the response.
|
||||
// Returns a BifrostResponse containing the completion results or an error if the request fails.
|
||||
func (provider *CerebrasProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAITextCompletionRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
provider.GetProviderKey(),
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// TextCompletionStream performs a streaming text completion request to Cerebras's API.
|
||||
// It formats the request, sends it to Cerebras, and processes the response.
|
||||
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
|
||||
func (provider *CerebrasProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
var authHeader map[string]string
|
||||
if key.Value.GetValue() != "" {
|
||||
authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()}
|
||||
}
|
||||
// Use shared OpenAI-compatible streaming logic
|
||||
return openai.HandleOpenAITextCompletionStreaming(
|
||||
ctx,
|
||||
provider.streamingClient,
|
||||
provider.networkConfig.BaseURL+"/v1/completions",
|
||||
request,
|
||||
authHeader,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
nil,
|
||||
postHookRunner,
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
postHookSpanFinalizer,
|
||||
)
|
||||
}
|
||||
|
||||
// ChatCompletion performs a chat completion request to the Cerebras API.
|
||||
func (provider *CerebrasProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIChatCompletionRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// ChatCompletionStream performs a streaming chat completion request to the Cerebras API.
|
||||
// It supports real-time streaming of responses using Server-Sent Events (SSE).
|
||||
// Uses Cerebras's OpenAI-compatible streaming format.
|
||||
// Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails.
|
||||
func (provider *CerebrasProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
var authHeader map[string]string
|
||||
if key.Value.GetValue() != "" {
|
||||
authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()}
|
||||
}
|
||||
// Use shared OpenAI-compatible streaming logic
|
||||
return openai.HandleOpenAIChatCompletionStreaming(
|
||||
ctx,
|
||||
provider.streamingClient,
|
||||
provider.networkConfig.BaseURL+"/v1/chat/completions",
|
||||
request,
|
||||
authHeader,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
schemas.Cerebras,
|
||||
postHookRunner,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
postHookSpanFinalizer,
|
||||
)
|
||||
}
|
||||
|
||||
func (provider *CerebrasProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := chatResponse.ToBifrostResponsesResponse()
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ResponsesStream performs a streaming responses request to the Cerebras API.
|
||||
func (provider *CerebrasProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
|
||||
return provider.ChatCompletionStream(
|
||||
ctx,
|
||||
postHookRunner,
|
||||
postHookSpanFinalizer,
|
||||
key,
|
||||
request.ToChatRequest(),
|
||||
)
|
||||
}
|
||||
|
||||
// Embedding is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Speech is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Rerank is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// OCR is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// SpeechStream is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Transcription is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// TranscriptionStream is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGeneration is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGenerationStream is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEdit is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEditStream is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageVariation is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoGeneration is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRetrieve is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDownload is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDelete is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoList is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRemix is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileUpload is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileList is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileRetrieve is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileDelete is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileContent is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCreate is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchList is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchRetrieve is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCancel is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchDelete is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchResults is not supported by Cerebras provider.
|
||||
func (provider *CerebrasProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// CountTokens is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerCreate is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerList is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerRetrieve is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerDelete is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileCreate is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileList is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileRetrieve is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileContent is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileDelete is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Passthrough is not supported by the Cerebras provider.
|
||||
func (provider *CerebrasProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
func (provider *CerebrasProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
60
core/providers/cerebras/cerebras_test.go
Normal file
60
core/providers/cerebras/cerebras_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package cerebras_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestCerebras(t *testing.T) {
|
||||
t.Parallel()
|
||||
if strings.TrimSpace(os.Getenv("CEREBRAS_API_KEY")) == "" {
|
||||
t.Skip("Skipping Cerebras tests because CEREBRAS_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.Cerebras,
|
||||
ChatModel: "llama3.1-8b",
|
||||
Fallbacks: []schemas.Fallback{
|
||||
{Provider: schemas.Cerebras, Model: "llama3.1-8b"},
|
||||
{Provider: schemas.Cerebras, Model: "gpt-oss-120b"},
|
||||
},
|
||||
TextModel: "llama3.1-8b",
|
||||
EmbeddingModel: "", // Cerebras doesn't support embedding
|
||||
ReasoningModel: "gpt-oss-120b",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: true,
|
||||
TextCompletionStream: true,
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: false, // llama3.1-8b doesn't reliably produce parallel tool calls
|
||||
End2EndToolCalling: true,
|
||||
AutomaticFunctionCall: true,
|
||||
ImageURL: false,
|
||||
ImageBase64: false,
|
||||
MultipleImages: false,
|
||||
CompleteEnd2End: true,
|
||||
Embedding: false,
|
||||
ListModels: true,
|
||||
Reasoning: true,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("CerebrasTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
722
core/providers/cohere/chat.go
Normal file
722
core/providers/cohere/chat.go
Normal file
@@ -0,0 +1,722 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/anthropic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToCohereChatCompletionRequest converts a Bifrost request to Cohere v2 format
|
||||
func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*CohereChatRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("bifrost request is nil")
|
||||
}
|
||||
|
||||
messages := bifrostReq.Input
|
||||
cohereReq := &CohereChatRequest{
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
|
||||
// Convert messages to Cohere v2 format
|
||||
var cohereMessages []CohereMessage
|
||||
for _, msg := range messages {
|
||||
cohereMsg := CohereMessage{
|
||||
Role: string(msg.Role),
|
||||
}
|
||||
|
||||
// Convert content
|
||||
if msg.Content != nil && msg.Content.ContentStr != nil {
|
||||
cohereMsg.Content = NewStringContent(*msg.Content.ContentStr)
|
||||
} else if msg.Content != nil && msg.Content.ContentBlocks != nil {
|
||||
var contentBlocks []CohereContentBlock
|
||||
for _, block := range msg.Content.ContentBlocks {
|
||||
if block.Text != nil {
|
||||
contentBlocks = append(contentBlocks, CohereContentBlock{
|
||||
Type: CohereContentBlockTypeText,
|
||||
Text: block.Text,
|
||||
})
|
||||
} else if block.ImageURLStruct != nil {
|
||||
contentBlocks = append(contentBlocks, CohereContentBlock{
|
||||
Type: CohereContentBlockTypeImage,
|
||||
ImageURL: &CohereImageURL{
|
||||
URL: block.ImageURLStruct.URL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(contentBlocks) > 0 {
|
||||
cohereMsg.Content = NewBlocksContent(contentBlocks)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tool calls for assistant messages
|
||||
if msg.ChatAssistantMessage != nil && msg.ChatAssistantMessage.ToolCalls != nil {
|
||||
var toolCalls []CohereToolCall
|
||||
for _, toolCall := range msg.ChatAssistantMessage.ToolCalls {
|
||||
// Safely extract function name and arguments
|
||||
var functionName *string
|
||||
var functionArguments string
|
||||
|
||||
if toolCall.Function.Name != nil {
|
||||
functionName = toolCall.Function.Name
|
||||
} else {
|
||||
// Use empty string if Name is nil
|
||||
functionName = schemas.Ptr("")
|
||||
}
|
||||
|
||||
// Arguments is a string, not a pointer, so it's safe to access directly
|
||||
// Default to "{}" if empty to ensure the field is always present.
|
||||
if toolCall.Function.Arguments == "" {
|
||||
functionArguments = "{}"
|
||||
} else {
|
||||
functionArguments = toolCall.Function.Arguments
|
||||
}
|
||||
|
||||
cohereToolCall := CohereToolCall{
|
||||
ID: toolCall.ID,
|
||||
Type: "function",
|
||||
Function: &CohereFunction{
|
||||
Name: functionName,
|
||||
Arguments: functionArguments,
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, cohereToolCall)
|
||||
}
|
||||
cohereMsg.ToolCalls = toolCalls
|
||||
}
|
||||
|
||||
// Convert tool messages
|
||||
if msg.ChatToolMessage != nil && msg.ChatToolMessage.ToolCallID != nil {
|
||||
cohereMsg.ToolCallID = msg.ChatToolMessage.ToolCallID
|
||||
}
|
||||
|
||||
cohereMessages = append(cohereMessages, cohereMsg)
|
||||
}
|
||||
|
||||
cohereReq.Messages = cohereMessages
|
||||
|
||||
// Convert parameters
|
||||
if bifrostReq.Params != nil {
|
||||
cohereReq.MaxTokens = bifrostReq.Params.MaxCompletionTokens
|
||||
cohereReq.Temperature = bifrostReq.Params.Temperature
|
||||
cohereReq.P = bifrostReq.Params.TopP
|
||||
cohereReq.StopSequences = bifrostReq.Params.Stop
|
||||
cohereReq.FrequencyPenalty = bifrostReq.Params.FrequencyPenalty
|
||||
cohereReq.PresencePenalty = bifrostReq.Params.PresencePenalty
|
||||
|
||||
// Convert reasoning
|
||||
if bifrostReq.Params.Reasoning != nil {
|
||||
if bifrostReq.Params.Reasoning.MaxTokens != nil {
|
||||
thinking := &CohereThinking{
|
||||
Type: ThinkingTypeEnabled,
|
||||
}
|
||||
if *bifrostReq.Params.Reasoning.MaxTokens == -1 {
|
||||
// cohere does not support dynamic reasoning budget like gemini
|
||||
// setting it to minimum reasoning budget
|
||||
thinking.TokenBudget = schemas.Ptr(anthropic.MinimumReasoningMaxTokens)
|
||||
} else {
|
||||
thinking.TokenBudget = bifrostReq.Params.Reasoning.MaxTokens
|
||||
}
|
||||
cohereReq.Thinking = thinking
|
||||
} else if bifrostReq.Params.Reasoning.Effort != nil {
|
||||
if *bifrostReq.Params.Reasoning.Effort != "none" {
|
||||
maxCompletionTokens := providerUtils.GetMaxOutputTokensOrDefault(bifrostReq.Model, DefaultCompletionMaxTokens)
|
||||
if bifrostReq.Params.MaxCompletionTokens != nil {
|
||||
maxCompletionTokens = *bifrostReq.Params.MaxCompletionTokens
|
||||
}
|
||||
budgetTokens, err := providerUtils.GetBudgetTokensFromReasoningEffort(*bifrostReq.Params.Reasoning.Effort, MinimumReasoningMaxTokens, maxCompletionTokens)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cohereReq.Thinking = &CohereThinking{
|
||||
Type: ThinkingTypeEnabled,
|
||||
TokenBudget: schemas.Ptr(budgetTokens), // Max tokens for reasoning
|
||||
}
|
||||
} else {
|
||||
cohereReq.Thinking = &CohereThinking{
|
||||
Type: ThinkingTypeDisabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert response format
|
||||
if bifrostReq.Params.ResponseFormat != nil {
|
||||
cohereReq.ResponseFormat = convertResponseFormatToCohere(bifrostReq.Params.ResponseFormat)
|
||||
}
|
||||
|
||||
// Convert extra params
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
// Handle thinking parameter
|
||||
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if thinkingParam, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "thinking"); ok {
|
||||
if thinkingMap, ok := thinkingParam.(map[string]interface{}); ok {
|
||||
thinking := &CohereThinking{}
|
||||
if typeStr, ok := schemas.SafeExtractString(thinkingMap["type"]); ok {
|
||||
delete(thinkingMap, "type")
|
||||
thinking.Type = CohereThinkingType(typeStr)
|
||||
}
|
||||
if tokenBudget, ok := schemas.SafeExtractIntPointer(thinkingMap["token_budget"]); ok {
|
||||
delete(thinkingMap, "token_budget")
|
||||
thinking.TokenBudget = tokenBudget
|
||||
}
|
||||
cohereReq.Thinking = thinking
|
||||
cohereReq.ExtraParams["thinking"] = thinkingMap
|
||||
}
|
||||
}
|
||||
|
||||
// Handle other Cohere-specific extra params
|
||||
if safetyMode, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["safety_mode"]); ok {
|
||||
delete(cohereReq.ExtraParams, "safety_mode")
|
||||
cohereReq.SafetyMode = safetyMode
|
||||
}
|
||||
|
||||
if logProbs, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["log_probs"]); ok {
|
||||
delete(cohereReq.ExtraParams, "log_probs")
|
||||
cohereReq.LogProbs = logProbs
|
||||
}
|
||||
|
||||
if strictToolChoice, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["strict_tool_choice"]); ok {
|
||||
delete(cohereReq.ExtraParams, "strict_tool_choice")
|
||||
cohereReq.StrictToolChoice = strictToolChoice
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tools to Cohere-specific format (without "strict" field)
|
||||
if bifrostReq.Params.Tools != nil {
|
||||
cohereTools := make([]CohereChatRequestTool, len(bifrostReq.Params.Tools))
|
||||
for i, tool := range bifrostReq.Params.Tools {
|
||||
cohereTools[i] = CohereChatRequestTool{
|
||||
Type: string(tool.Type),
|
||||
}
|
||||
if tool.Function != nil {
|
||||
cohereTools[i].Function = CohereChatRequestFunction{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
Parameters: tool.Function.Parameters, // Convert to map
|
||||
// Note: No "strict" field - Cohere doesn't support it
|
||||
}
|
||||
}
|
||||
}
|
||||
cohereReq.Tools = cohereTools
|
||||
}
|
||||
|
||||
// Convert tool choice
|
||||
if bifrostReq.Params.ToolChoice != nil {
|
||||
toolChoice := bifrostReq.Params.ToolChoice
|
||||
|
||||
if toolChoice.ChatToolChoiceStr != nil {
|
||||
switch schemas.ChatToolChoiceType(*toolChoice.ChatToolChoiceStr) {
|
||||
case schemas.ChatToolChoiceTypeNone:
|
||||
toolChoice := ToolChoiceNone
|
||||
cohereReq.ToolChoice = &toolChoice
|
||||
default:
|
||||
toolChoice := ToolChoiceRequired
|
||||
cohereReq.ToolChoice = &toolChoice
|
||||
}
|
||||
} else if toolChoice.ChatToolChoiceStruct != nil {
|
||||
switch toolChoice.ChatToolChoiceStruct.Type {
|
||||
case schemas.ChatToolChoiceTypeFunction:
|
||||
toolChoice := ToolChoiceRequired
|
||||
cohereReq.ToolChoice = &toolChoice
|
||||
default:
|
||||
toolChoice := ToolChoiceAuto
|
||||
cohereReq.ToolChoice = &toolChoice
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cohereReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostChatRequest converts a Cohere v2 chat request to Bifrost format
|
||||
func (req *CohereChatRequest) ToBifrostChatRequest(ctx *schemas.BifrostContext) *schemas.BifrostChatRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Params: &schemas.ChatParameters{},
|
||||
}
|
||||
// Convert messages
|
||||
if req.Messages != nil {
|
||||
bifrostMessages := make([]schemas.ChatMessage, len(req.Messages))
|
||||
for i, message := range req.Messages {
|
||||
bifrostMessages[i] = *message.ToBifrostChatMessage()
|
||||
}
|
||||
bifrostReq.Input = bifrostMessages
|
||||
}
|
||||
// Convert parameters
|
||||
if req.MaxTokens != nil {
|
||||
bifrostReq.Params.MaxCompletionTokens = req.MaxTokens
|
||||
}
|
||||
if req.Temperature != nil {
|
||||
bifrostReq.Params.Temperature = req.Temperature
|
||||
}
|
||||
if req.P != nil {
|
||||
bifrostReq.Params.TopP = req.P
|
||||
}
|
||||
if req.StopSequences != nil {
|
||||
bifrostReq.Params.Stop = req.StopSequences
|
||||
}
|
||||
if req.FrequencyPenalty != nil {
|
||||
bifrostReq.Params.FrequencyPenalty = req.FrequencyPenalty
|
||||
}
|
||||
if req.PresencePenalty != nil {
|
||||
bifrostReq.Params.PresencePenalty = req.PresencePenalty
|
||||
}
|
||||
|
||||
// Convert reasoning
|
||||
if req.Thinking != nil {
|
||||
if req.Thinking.Type == ThinkingTypeDisabled {
|
||||
bifrostReq.Params.Reasoning = &schemas.ChatReasoning{
|
||||
Effort: schemas.Ptr("none"),
|
||||
}
|
||||
} else {
|
||||
bifrostReq.Params.Reasoning = &schemas.ChatReasoning{
|
||||
Effort: schemas.Ptr("auto"),
|
||||
}
|
||||
if req.Thinking.TokenBudget != nil {
|
||||
bifrostReq.Params.Reasoning.MaxTokens = req.Thinking.TokenBudget
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.ResponseFormat != nil {
|
||||
bifrostReq.Params.ResponseFormat = convertCohereResponseFormatToBifrost(req.ResponseFormat)
|
||||
}
|
||||
|
||||
// Convert tools
|
||||
if req.Tools != nil {
|
||||
bifrostTools := make([]schemas.ChatTool, len(req.Tools))
|
||||
for i, tool := range req.Tools {
|
||||
bifrostTools[i] = schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
Parameters: convertInterfaceToToolFunctionParameters(tool.Function.Parameters),
|
||||
},
|
||||
}
|
||||
}
|
||||
bifrostReq.Params.Tools = bifrostTools
|
||||
}
|
||||
|
||||
// Convert tool choice
|
||||
if req.ToolChoice != nil {
|
||||
switch *req.ToolChoice {
|
||||
case ToolChoiceNone:
|
||||
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeNone)),
|
||||
}
|
||||
case ToolChoiceRequired:
|
||||
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeRequired)),
|
||||
}
|
||||
case ToolChoiceAuto:
|
||||
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeAny)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert extra params
|
||||
extraParams := make(map[string]interface{})
|
||||
if req.SafetyMode != nil {
|
||||
extraParams["safety_mode"] = *req.SafetyMode
|
||||
}
|
||||
if req.LogProbs != nil {
|
||||
extraParams["log_probs"] = *req.LogProbs
|
||||
}
|
||||
if req.StrictToolChoice != nil {
|
||||
extraParams["strict_tool_choice"] = *req.StrictToolChoice
|
||||
}
|
||||
if req.Thinking != nil {
|
||||
thinkingMap := map[string]interface{}{
|
||||
"type": string(req.Thinking.Type),
|
||||
}
|
||||
if req.Thinking.TokenBudget != nil {
|
||||
thinkingMap["token_budget"] = *req.Thinking.TokenBudget
|
||||
}
|
||||
extraParams["thinking"] = thinkingMap
|
||||
}
|
||||
if len(extraParams) > 0 {
|
||||
bifrostReq.Params.ExtraParams = extraParams
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToBifrostChatResponse converts a Cohere v2 response to Bifrost format
|
||||
func (response *CohereChatResponse) ToBifrostChatResponse(model string) *schemas.BifrostChatResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostChatResponse{
|
||||
ID: response.ID,
|
||||
Model: model,
|
||||
Object: "chat.completion",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{},
|
||||
},
|
||||
},
|
||||
Created: int(time.Now().Unix()),
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
},
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
if response.Message != nil {
|
||||
bifrostMessage := response.Message.ToBifrostChatMessage()
|
||||
bifrostResponse.Choices[0].ChatNonStreamResponseChoice.Message = bifrostMessage
|
||||
}
|
||||
|
||||
// Convert finish reason
|
||||
if response.FinishReason != nil {
|
||||
finishReason := ConvertCohereFinishReasonToBifrost(*response.FinishReason)
|
||||
bifrostResponse.Choices[0].FinishReason = schemas.Ptr(finishReason)
|
||||
}
|
||||
|
||||
// Convert usage information
|
||||
if response.Usage != nil {
|
||||
usage := &schemas.BifrostLLMUsage{}
|
||||
|
||||
if response.Usage.Tokens != nil {
|
||||
if response.Usage.Tokens.InputTokens != nil {
|
||||
usage.PromptTokens = *response.Usage.Tokens.InputTokens
|
||||
}
|
||||
if response.Usage.Tokens.OutputTokens != nil {
|
||||
usage.CompletionTokens = *response.Usage.Tokens.OutputTokens
|
||||
}
|
||||
if response.Usage.CachedTokens != nil {
|
||||
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
|
||||
CachedReadTokens: *response.Usage.CachedTokens,
|
||||
}
|
||||
}
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
|
||||
bifrostResponse.Usage = usage
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
func (chunk *CohereStreamEvent) ToBifrostChatCompletionStream() (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) {
|
||||
switch chunk.Type {
|
||||
case StreamEventMessageStart:
|
||||
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.Role != nil {
|
||||
// Create streaming response for this delta
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Role: chunk.Delta.Message.Role,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
}
|
||||
|
||||
case StreamEventContentDelta:
|
||||
if chunk.Delta != nil &&
|
||||
chunk.Delta.Message != nil &&
|
||||
chunk.Delta.Message.Content != nil &&
|
||||
chunk.Delta.Message.Content.CohereStreamContentObject != nil {
|
||||
if chunk.Delta.Message.Content.CohereStreamContentObject.Text != nil {
|
||||
// Try to cast content to CohereStreamContent
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Content: chunk.Delta.Message.Content.CohereStreamContentObject.Text,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
} else if chunk.Delta.Message.Content.CohereStreamContentObject.Thinking != nil {
|
||||
thinkingText := *chunk.Delta.Message.Content.CohereStreamContentObject.Thinking
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Reasoning: schemas.Ptr(thinkingText),
|
||||
ReasoningDetails: []schemas.ChatReasoningDetails{
|
||||
{
|
||||
Index: 0,
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: schemas.Ptr(thinkingText),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
case StreamEventToolPlanDelta:
|
||||
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolPlan != nil {
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Reasoning: chunk.Delta.Message.ToolPlan,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
}
|
||||
|
||||
case StreamEventContentStart:
|
||||
// Content start event - just continue, actual content comes in content-delta
|
||||
return nil, nil, false
|
||||
|
||||
case StreamEventToolCallStart, StreamEventToolCallDelta:
|
||||
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolCalls != nil && chunk.Delta.Message.ToolCalls.CohereToolCallObject != nil {
|
||||
// Handle single tool call object (tool-call-start/delta events)
|
||||
cohereToolCall := chunk.Delta.Message.ToolCalls.CohereToolCallObject
|
||||
toolCall := schemas.ChatAssistantMessageToolCall{}
|
||||
|
||||
if chunk.Index != nil {
|
||||
toolCall.Index = uint16(*chunk.Index)
|
||||
}
|
||||
|
||||
if cohereToolCall.ID != nil {
|
||||
toolCall.ID = cohereToolCall.ID
|
||||
}
|
||||
|
||||
if cohereToolCall.Function != nil {
|
||||
if cohereToolCall.Function.Name != nil {
|
||||
toolCall.Function.Name = cohereToolCall.Function.Name
|
||||
}
|
||||
toolCall.Function.Arguments = cohereToolCall.Function.Arguments
|
||||
}
|
||||
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
}
|
||||
|
||||
case StreamEventToolCallEnd:
|
||||
return nil, nil, false
|
||||
|
||||
case StreamEventContentEnd:
|
||||
return nil, nil, false
|
||||
|
||||
case StreamEventMessageEnd:
|
||||
if chunk.Delta != nil {
|
||||
var finishReason string
|
||||
usage := &schemas.BifrostLLMUsage{}
|
||||
// Set finish reason
|
||||
if chunk.Delta.FinishReason != nil {
|
||||
finishReason = ConvertCohereFinishReasonToBifrost(*chunk.Delta.FinishReason)
|
||||
}
|
||||
|
||||
// Set usage information
|
||||
if chunk.Delta.Usage != nil {
|
||||
if chunk.Delta.Usage.Tokens != nil {
|
||||
if chunk.Delta.Usage.Tokens.InputTokens != nil {
|
||||
usage.PromptTokens = *chunk.Delta.Usage.Tokens.InputTokens
|
||||
}
|
||||
if chunk.Delta.Usage.Tokens.OutputTokens != nil {
|
||||
usage.CompletionTokens = *chunk.Delta.Usage.Tokens.OutputTokens
|
||||
}
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: &finishReason,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{},
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: usage,
|
||||
}
|
||||
|
||||
return streamResponse, nil, true
|
||||
}
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
func (cm *CohereMessage) ToBifrostChatMessage() *schemas.ChatMessage {
|
||||
if cm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var content *string
|
||||
var contentBlocks []schemas.ChatContentBlock
|
||||
var toolCalls []schemas.ChatAssistantMessageToolCall
|
||||
var reasoningDetails []schemas.ChatReasoningDetails
|
||||
var reasoningText string
|
||||
|
||||
// Convert message content
|
||||
if cm.Content != nil {
|
||||
if cm.Content.IsString() ||
|
||||
(cm.Content.IsBlocks() &&
|
||||
len(cm.Content.GetBlocks()) == 1 &&
|
||||
cm.Content.GetBlocks()[0].Type == CohereContentBlockTypeText) {
|
||||
if cm.Content.IsString() {
|
||||
content = cm.Content.GetString()
|
||||
} else {
|
||||
content = cm.Content.GetBlocks()[0].Text
|
||||
}
|
||||
} else if cm.Content.IsBlocks() {
|
||||
for _, block := range cm.Content.GetBlocks() {
|
||||
if block.Type == CohereContentBlockTypeText && block.Text != nil {
|
||||
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: block.Text,
|
||||
})
|
||||
} else if block.Type == CohereContentBlockTypeImage && block.ImageURL != nil {
|
||||
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeImage,
|
||||
ImageURLStruct: &schemas.ChatInputImage{
|
||||
URL: block.ImageURL.URL,
|
||||
},
|
||||
})
|
||||
} else if block.Type == CohereContentBlockTypeThinking && block.Thinking != nil {
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: block.Thinking,
|
||||
})
|
||||
if len(reasoningText) > 0 {
|
||||
reasoningText += "\n"
|
||||
}
|
||||
reasoningText += *block.Thinking
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText {
|
||||
content = contentBlocks[0].Text
|
||||
contentBlocks = nil
|
||||
}
|
||||
|
||||
// Create the message content
|
||||
messageContent := &schemas.ChatMessageContent{
|
||||
ContentStr: content,
|
||||
ContentBlocks: contentBlocks,
|
||||
}
|
||||
|
||||
// Convert tool calls
|
||||
if cm.ToolCalls != nil {
|
||||
for _, toolCall := range cm.ToolCalls {
|
||||
// Check if Function is nil to avoid nil pointer dereference
|
||||
if toolCall.Function == nil {
|
||||
// Skip this tool call if Function is nil
|
||||
continue
|
||||
}
|
||||
|
||||
// Safely extract function name and arguments
|
||||
var functionName *string
|
||||
var functionArguments string
|
||||
|
||||
if toolCall.Function.Name != nil {
|
||||
functionName = toolCall.Function.Name
|
||||
} else {
|
||||
// Use empty string if Name is nil
|
||||
functionName = schemas.Ptr("")
|
||||
}
|
||||
|
||||
// Arguments is a string, not a pointer, so it's safe to access directly
|
||||
functionArguments = toolCall.Function.Arguments
|
||||
|
||||
bifrostToolCall := schemas.ChatAssistantMessageToolCall{
|
||||
Index: uint16(len(toolCalls)),
|
||||
ID: toolCall.ID,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: functionName,
|
||||
Arguments: functionArguments,
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, bifrostToolCall)
|
||||
}
|
||||
}
|
||||
|
||||
// Create assistant message if we have tool calls
|
||||
var assistantMessage *schemas.ChatAssistantMessage
|
||||
if len(toolCalls) > 0 {
|
||||
assistantMessage = &schemas.ChatAssistantMessage{
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
if len(reasoningDetails) > 0 {
|
||||
if assistantMessage == nil {
|
||||
assistantMessage = &schemas.ChatAssistantMessage{}
|
||||
}
|
||||
assistantMessage.ReasoningDetails = reasoningDetails
|
||||
assistantMessage.Reasoning = schemas.Ptr(reasoningText)
|
||||
}
|
||||
|
||||
bifrostMessage := &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRole(cm.Role),
|
||||
Content: messageContent,
|
||||
ChatAssistantMessage: assistantMessage,
|
||||
}
|
||||
|
||||
if cm.Role == "tool" {
|
||||
bifrostMessage.ChatToolMessage = &schemas.ChatToolMessage{
|
||||
ToolCallID: cm.ToolCallID,
|
||||
}
|
||||
}
|
||||
return bifrostMessage
|
||||
}
|
||||
1266
core/providers/cohere/cohere.go
Normal file
1266
core/providers/cohere/cohere.go
Normal file
File diff suppressed because it is too large
Load Diff
62
core/providers/cohere/cohere_test.go
Normal file
62
core/providers/cohere/cohere_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package cohere_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestCohere(t *testing.T) {
|
||||
t.Parallel()
|
||||
if strings.TrimSpace(os.Getenv("COHERE_API_KEY")) == "" {
|
||||
t.Skip("Skipping Cohere tests because COHERE_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.Cohere,
|
||||
ChatModel: "command-a-03-2025",
|
||||
VisionModel: "command-a-vision-07-2025", // Cohere's latest vision model
|
||||
TextModel: "", // Cohere focuses on chat
|
||||
EmbeddingModel: "embed-v4.0",
|
||||
RerankModel: "rerank-v3.5",
|
||||
ReasoningModel: "command-a-reasoning-08-2025",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: false, // Not typical for Cohere
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: true,
|
||||
MultipleToolCallsStreaming: true,
|
||||
End2EndToolCalling: true,
|
||||
AutomaticFunctionCall: true, // May not support automatic
|
||||
ImageURL: false, // Supported by c4ai-aya-vision-8b model
|
||||
ImageBase64: true, // Supported by c4ai-aya-vision-8b model
|
||||
MultipleImages: false, // Supported by c4ai-aya-vision-8b model
|
||||
FileBase64: false, // Not supported
|
||||
FileURL: false, // Not supported
|
||||
CompleteEnd2End: false,
|
||||
Embedding: true,
|
||||
Rerank: true,
|
||||
Reasoning: true,
|
||||
ListModels: true,
|
||||
CountTokens: true,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("CohereTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
125
core/providers/cohere/count_tokens.go
Normal file
125
core/providers/cohere/count_tokens.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostResponsesRequest converts a Cohere count tokens request to Bifrost format.
|
||||
func (req *CohereCountTokensRequest) ToBifrostResponsesRequest(ctx *schemas.BifrostContext) *schemas.BifrostResponsesRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
|
||||
|
||||
userRole := schemas.ResponsesInputMessageRoleUser
|
||||
return &schemas.BifrostResponsesRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: &userRole,
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: &req.Text,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ToCohereCountTokensRequest converts a Bifrost count tokens request to Cohere's tokenize payload.
|
||||
func ToCohereCountTokensRequest(bifrostReq *schemas.BifrostResponsesRequest) (*CohereCountTokensRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("count tokens input is not provided")
|
||||
}
|
||||
|
||||
text := buildCohereCountTokensText(bifrostReq.Input)
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return nil, fmt.Errorf("count tokens text is empty after conversion")
|
||||
}
|
||||
runeCount := utf8.RuneCountInString(trimmed)
|
||||
if runeCount < cohereTokenizeMinTextLength || runeCount > cohereTokenizeMaxTextLength {
|
||||
return nil, fmt.Errorf("count tokens text length must be between %d and %d characters", cohereTokenizeMinTextLength, cohereTokenizeMaxTextLength)
|
||||
}
|
||||
|
||||
cohereReq := &CohereCountTokensRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Text: trimmed,
|
||||
}
|
||||
if bifrostReq.Params != nil {
|
||||
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
return cohereReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostCountTokensResponse converts a Cohere tokenize response to Bifrost format.
|
||||
func (resp *CohereCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
inputTokens := len(resp.Tokens)
|
||||
if inputTokens == 0 && len(resp.TokenStrings) > 0 {
|
||||
inputTokens = len(resp.TokenStrings)
|
||||
}
|
||||
totalTokens := inputTokens
|
||||
|
||||
return &schemas.BifrostCountTokensResponse{
|
||||
Model: model,
|
||||
InputTokens: inputTokens,
|
||||
TotalTokens: &totalTokens,
|
||||
TokenStrings: resp.TokenStrings,
|
||||
Tokens: resp.Tokens,
|
||||
Object: "response.input_tokens",
|
||||
}
|
||||
}
|
||||
|
||||
// buildCohereCountTokensText flattens Responses messages into a plain text payload for tokenization.
|
||||
func buildCohereCountTokensText(messages []schemas.ResponsesMessage) string {
|
||||
var parts []string
|
||||
|
||||
for _, msg := range messages {
|
||||
var contentParts []string
|
||||
|
||||
if msg.Content != nil {
|
||||
if msg.Content.ContentStr != nil {
|
||||
contentParts = append(contentParts, *msg.Content.ContentStr)
|
||||
}
|
||||
for _, block := range msg.Content.ContentBlocks {
|
||||
if block.Text != nil {
|
||||
contentParts = append(contentParts, *block.Text)
|
||||
}
|
||||
if block.ResponsesOutputMessageContentRefusal != nil && block.ResponsesOutputMessageContentRefusal.Refusal != "" {
|
||||
contentParts = append(contentParts, block.ResponsesOutputMessageContentRefusal.Refusal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if msg.ResponsesReasoning != nil {
|
||||
for _, summary := range msg.ResponsesReasoning.Summary {
|
||||
if summary.Text != "" {
|
||||
contentParts = append(contentParts, summary.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(contentParts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
parts = append(parts, strings.Join(contentParts, "\n"))
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strings.Join(parts, "\n"))
|
||||
}
|
||||
190
core/providers/cohere/embedding.go
Normal file
190
core/providers/cohere/embedding.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToCohereEmbeddingRequest converts a Bifrost embedding request to Cohere format
|
||||
func ToCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *CohereEmbeddingRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
|
||||
return nil
|
||||
}
|
||||
|
||||
embeddingInput := bifrostReq.Input
|
||||
cohereReq := &CohereEmbeddingRequest{
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
|
||||
texts := []string{}
|
||||
if embeddingInput.Text != nil {
|
||||
texts = append(texts, *embeddingInput.Text)
|
||||
} else {
|
||||
texts = embeddingInput.Texts
|
||||
}
|
||||
|
||||
// Convert texts from Bifrost format
|
||||
if len(texts) > 0 {
|
||||
cohereReq.Texts = texts
|
||||
}
|
||||
|
||||
// Set default input type if not specified in extra params
|
||||
cohereReq.InputType = "search_document" // Default value
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
cohereReq.OutputDimension = bifrostReq.Params.Dimensions
|
||||
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
if maxTokens, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["max_tokens"]); ok {
|
||||
delete(cohereReq.ExtraParams, "max_tokens")
|
||||
cohereReq.MaxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle extra params
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
|
||||
// Input type
|
||||
if inputType, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["input_type"]); ok {
|
||||
delete(cohereReq.ExtraParams, "input_type")
|
||||
cohereReq.InputType = inputType
|
||||
}
|
||||
|
||||
// Embedding types
|
||||
if embeddingTypes, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["embedding_types"]); ok {
|
||||
if len(embeddingTypes) > 0 {
|
||||
delete(cohereReq.ExtraParams, "embedding_types")
|
||||
cohereReq.EmbeddingTypes = embeddingTypes
|
||||
}
|
||||
}
|
||||
|
||||
// Truncate
|
||||
if truncate, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["truncate"]); ok {
|
||||
delete(cohereReq.ExtraParams, "truncate")
|
||||
cohereReq.Truncate = truncate
|
||||
}
|
||||
}
|
||||
|
||||
return cohereReq
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingRequest converts a Cohere embedding request to Bifrost format
|
||||
func (req *CohereEmbeddingRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
|
||||
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: &schemas.EmbeddingInput{},
|
||||
Params: &schemas.EmbeddingParameters{},
|
||||
}
|
||||
|
||||
// Convert texts
|
||||
if len(req.Texts) > 0 {
|
||||
if len(req.Texts) == 1 {
|
||||
bifrostReq.Input.Text = &req.Texts[0]
|
||||
} else {
|
||||
bifrostReq.Input.Texts = req.Texts
|
||||
}
|
||||
}
|
||||
|
||||
// Convert parameters
|
||||
if req.OutputDimension != nil {
|
||||
bifrostReq.Params.Dimensions = req.OutputDimension
|
||||
}
|
||||
|
||||
// Convert extra params
|
||||
extraParams := make(map[string]interface{})
|
||||
if req.InputType != "" {
|
||||
extraParams["input_type"] = req.InputType
|
||||
}
|
||||
if req.EmbeddingTypes != nil {
|
||||
extraParams["embedding_types"] = req.EmbeddingTypes
|
||||
}
|
||||
if req.Truncate != nil {
|
||||
extraParams["truncate"] = *req.Truncate
|
||||
}
|
||||
if req.MaxTokens != nil {
|
||||
extraParams["max_tokens"] = *req.MaxTokens
|
||||
}
|
||||
if len(extraParams) > 0 {
|
||||
bifrostReq.Params.ExtraParams = extraParams
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingResponse converts a Cohere embedding response to Bifrost format
|
||||
func (response *CohereEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostEmbeddingResponse{
|
||||
Object: "list",
|
||||
}
|
||||
|
||||
// Convert embeddings data
|
||||
if response.Embeddings != nil {
|
||||
var bifrostEmbeddings []schemas.EmbeddingData
|
||||
|
||||
// Handle different embedding types - prioritize float embeddings
|
||||
if response.Embeddings.Float != nil {
|
||||
for i, embedding := range response.Embeddings.Float {
|
||||
bifrostEmbedding := schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{
|
||||
EmbeddingArray: embedding,
|
||||
},
|
||||
}
|
||||
bifrostEmbeddings = append(bifrostEmbeddings, bifrostEmbedding)
|
||||
}
|
||||
} else if response.Embeddings.Base64 != nil {
|
||||
// Handle base64 embeddings as strings
|
||||
for i, embedding := range response.Embeddings.Base64 {
|
||||
bifrostEmbedding := schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{
|
||||
EmbeddingStr: &embedding,
|
||||
},
|
||||
}
|
||||
bifrostEmbeddings = append(bifrostEmbeddings, bifrostEmbedding)
|
||||
}
|
||||
}
|
||||
// Note: Int8, Uint8, Binary, Ubinary types would need special handling
|
||||
// depending on how Bifrost wants to represent them
|
||||
|
||||
bifrostResponse.Data = bifrostEmbeddings
|
||||
}
|
||||
|
||||
// Convert usage information
|
||||
if response.Meta != nil {
|
||||
if response.Meta.Tokens != nil {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{}
|
||||
if response.Meta.Tokens.InputTokens != nil {
|
||||
bifrostResponse.Usage.PromptTokens = int(*response.Meta.Tokens.InputTokens)
|
||||
}
|
||||
if response.Meta.Tokens.OutputTokens != nil {
|
||||
bifrostResponse.Usage.CompletionTokens = int(*response.Meta.Tokens.OutputTokens)
|
||||
}
|
||||
bifrostResponse.Usage.TotalTokens = bifrostResponse.Usage.PromptTokens + bifrostResponse.Usage.CompletionTokens
|
||||
} else if response.Meta.BilledUnits != nil {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{}
|
||||
if response.Meta.BilledUnits.InputTokens != nil {
|
||||
bifrostResponse.Usage.PromptTokens = int(*response.Meta.BilledUnits.InputTokens)
|
||||
}
|
||||
if response.Meta.BilledUnits.OutputTokens != nil {
|
||||
bifrostResponse.Usage.CompletionTokens = int(*response.Meta.BilledUnits.OutputTokens)
|
||||
}
|
||||
bifrostResponse.Usage.TotalTokens = bifrostResponse.Usage.PromptTokens + bifrostResponse.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
90
core/providers/cohere/embedding_test.go
Normal file
90
core/providers/cohere/embedding_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToCohereEmbeddingRequest(t *testing.T) {
|
||||
t.Run("returns nil for missing input", func(t *testing.T) {
|
||||
assert.Nil(t, ToCohereEmbeddingRequest(nil))
|
||||
assert.Nil(t, ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{}))
|
||||
assert.Nil(t, ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{
|
||||
Input: &schemas.EmbeddingInput{},
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("single text keeps model in direct cohere body", func(t *testing.T) {
|
||||
text := "hello"
|
||||
truncate := "END"
|
||||
dimensions := 1024
|
||||
maxTokens := 256
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Model: "embed-v4.0",
|
||||
Input: &schemas.EmbeddingInput{Text: &text},
|
||||
Params: &schemas.EmbeddingParameters{
|
||||
Dimensions: &dimensions,
|
||||
ExtraParams: map[string]interface{}{
|
||||
"input_type": "classification",
|
||||
"embedding_types": []string{"float", "int8"},
|
||||
"truncate": truncate,
|
||||
"max_tokens": maxTokens,
|
||||
"priority": "high",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := ToCohereEmbeddingRequest(bifrostReq)
|
||||
require.NotNil(t, req)
|
||||
assert.Equal(t, "embed-v4.0", req.Model)
|
||||
assert.Equal(t, "classification", req.InputType)
|
||||
assert.Equal(t, []string{"hello"}, req.Texts)
|
||||
assert.Equal(t, []string{"float", "int8"}, req.EmbeddingTypes)
|
||||
assert.Equal(t, &dimensions, req.OutputDimension)
|
||||
assert.Equal(t, &maxTokens, req.MaxTokens)
|
||||
require.NotNil(t, req.Truncate)
|
||||
assert.Equal(t, truncate, *req.Truncate)
|
||||
assert.Equal(t, map[string]interface{}{"priority": "high"}, req.ExtraParams)
|
||||
})
|
||||
|
||||
t.Run("multiple texts use default input type", func(t *testing.T) {
|
||||
req := ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{
|
||||
Model: "embed-english-v3.0",
|
||||
Input: &schemas.EmbeddingInput{Texts: []string{"hello", "world"}},
|
||||
})
|
||||
|
||||
require.NotNil(t, req)
|
||||
assert.Equal(t, "embed-english-v3.0", req.Model)
|
||||
assert.Equal(t, "search_document", req.InputType)
|
||||
assert.Equal(t, []string{"hello", "world"}, req.Texts)
|
||||
assert.Nil(t, req.ExtraParams)
|
||||
})
|
||||
}
|
||||
|
||||
func TestToCohereEmbeddingRequestBodyIncludesModelForDirectCohere(t *testing.T) {
|
||||
text := "hello"
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Model: "embed-v4.0",
|
||||
Input: &schemas.EmbeddingInput{Text: &text},
|
||||
}
|
||||
|
||||
wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
||||
context.Background(),
|
||||
bifrostReq,
|
||||
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
||||
return ToCohereEmbeddingRequest(bifrostReq), nil
|
||||
},
|
||||
schemas.Cohere,
|
||||
)
|
||||
require.Nil(t, bifrostErr)
|
||||
assert.JSONEq(t, `{
|
||||
"model": "embed-v4.0",
|
||||
"input_type": "search_document",
|
||||
"texts": ["hello"]
|
||||
}`, string(wireBody))
|
||||
}
|
||||
21
core/providers/cohere/errors.go
Normal file
21
core/providers/cohere/errors.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func parseCohereError(resp *fasthttp.Response) *schemas.BifrostError {
|
||||
var errorResp CohereError
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
|
||||
bifrostErr.Type = &errorResp.Type
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
bifrostErr.Error.Message = errorResp.Message
|
||||
if errorResp.Code != nil {
|
||||
bifrostErr.Error.Code = errorResp.Code
|
||||
}
|
||||
return bifrostErr
|
||||
}
|
||||
92
core/providers/cohere/models.go
Normal file
92
core/providers/cohere/models.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// CohereRerankRequest represents a Cohere rerank API request.
|
||||
type CohereRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
MaxTokensPerDoc *int `json:"max_tokens_per_doc,omitempty"`
|
||||
Priority *int `json:"priority,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetExtraParams returns extra parameters for the rerank request.
|
||||
func (r *CohereRerankRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
// CohereRerankResult represents a single result from Cohere rerank.
|
||||
type CohereRerankResult struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
Document json.RawMessage `json:"document,omitempty"`
|
||||
}
|
||||
|
||||
// CohereRerankResponse represents a Cohere rerank API response.
|
||||
type CohereRerankResponse struct {
|
||||
ID string `json:"id"`
|
||||
Results []CohereRerankResult `json:"results"`
|
||||
Meta *CohereRerankMeta `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// CohereRerankMeta represents metadata in Cohere rerank response.
|
||||
type CohereRerankMeta struct {
|
||||
APIVersion *CohereEmbeddingAPIVersion `json:"api_version,omitempty"`
|
||||
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"`
|
||||
Tokens *CohereTokenUsage `json:"tokens,omitempty"`
|
||||
}
|
||||
|
||||
func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.Models)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: providerKey,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.Models {
|
||||
// Cohere uses model.Name as the model identifier
|
||||
for _, result := range pipeline.FilterModel(model.Name) {
|
||||
entry := schemas.Model{
|
||||
ID: string(providerKey) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(model.Name),
|
||||
ContextLength: schemas.Ptr(int(model.ContextLength)),
|
||||
SupportedMethods: model.Endpoints,
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
entry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, entry)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
209
core/providers/cohere/rerank.go
Normal file
209
core/providers/cohere/rerank.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ToCohereRerankRequest converts a Bifrost rerank request to Cohere format
|
||||
func ToCohereRerankRequest(bifrostReq *schemas.BifrostRerankRequest) *CohereRerankRequest {
|
||||
if bifrostReq == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cohereReq := &CohereRerankRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Query: bifrostReq.Query,
|
||||
}
|
||||
|
||||
// Cohere v2 expects documents as a list of strings.
|
||||
documents := make([]string, len(bifrostReq.Documents))
|
||||
for i, doc := range bifrostReq.Documents {
|
||||
documents[i] = formatCohereRerankDocument(doc)
|
||||
}
|
||||
cohereReq.Documents = documents
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
cohereReq.TopN = bifrostReq.Params.TopN
|
||||
cohereReq.MaxTokensPerDoc = bifrostReq.Params.MaxTokensPerDoc
|
||||
cohereReq.Priority = bifrostReq.Params.Priority
|
||||
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
return cohereReq
|
||||
}
|
||||
|
||||
// ToBifrostRerankRequest converts a Cohere rerank request to Bifrost format
|
||||
func (req *CohereRerankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
|
||||
|
||||
bifrostReq := &schemas.BifrostRerankRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Query: req.Query,
|
||||
Params: &schemas.RerankParameters{},
|
||||
}
|
||||
|
||||
// Convert documents
|
||||
for _, doc := range req.Documents {
|
||||
bifrostReq.Documents = append(bifrostReq.Documents, schemas.RerankDocument{
|
||||
Text: doc,
|
||||
})
|
||||
}
|
||||
|
||||
if req.TopN != nil {
|
||||
bifrostReq.Params.TopN = req.TopN
|
||||
}
|
||||
if req.MaxTokensPerDoc != nil {
|
||||
bifrostReq.Params.MaxTokensPerDoc = req.MaxTokensPerDoc
|
||||
}
|
||||
if req.Priority != nil {
|
||||
bifrostReq.Params.Priority = req.Priority
|
||||
}
|
||||
if req.ExtraParams != nil {
|
||||
bifrostReq.Params.ExtraParams = req.ExtraParams
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToBifrostRerankResponse converts a Cohere rerank response to Bifrost format.
|
||||
func (response *CohereRerankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) *schemas.BifrostRerankResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostRerankResponse{
|
||||
ID: response.ID,
|
||||
}
|
||||
|
||||
// Convert results
|
||||
for _, result := range response.Results {
|
||||
rerankResult := schemas.RerankResult{
|
||||
Index: result.Index,
|
||||
RelevanceScore: result.RelevanceScore,
|
||||
}
|
||||
|
||||
// Convert document if present
|
||||
if len(result.Document) > 0 {
|
||||
var docMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(result.Document, &docMap); err == nil {
|
||||
doc := &schemas.RerankDocument{}
|
||||
populated := false
|
||||
if text, ok := docMap["text"].(string); ok {
|
||||
doc.Text = text
|
||||
populated = true
|
||||
}
|
||||
if id, ok := docMap["id"].(string); ok {
|
||||
doc.ID = &id
|
||||
populated = true
|
||||
}
|
||||
// Collect metadata: unwrap "metadata"/"meta" keys to avoid nesting
|
||||
meta := make(map[string]interface{})
|
||||
if rawMeta, ok := docMap["metadata"].(map[string]interface{}); ok {
|
||||
for k, v := range rawMeta {
|
||||
meta[k] = v
|
||||
}
|
||||
} else if rawMeta, ok := docMap["meta"].(map[string]interface{}); ok {
|
||||
for k, v := range rawMeta {
|
||||
meta[k] = v
|
||||
}
|
||||
}
|
||||
for k, v := range docMap {
|
||||
if k != "text" && k != "id" && k != "metadata" && k != "meta" {
|
||||
meta[k] = v
|
||||
}
|
||||
}
|
||||
if len(meta) > 0 {
|
||||
doc.Meta = meta
|
||||
populated = true
|
||||
}
|
||||
if populated {
|
||||
rerankResult.Document = doc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Results = append(bifrostResponse.Results, rerankResult)
|
||||
}
|
||||
sort.SliceStable(bifrostResponse.Results, func(i, j int) bool {
|
||||
if bifrostResponse.Results[i].RelevanceScore == bifrostResponse.Results[j].RelevanceScore {
|
||||
return bifrostResponse.Results[i].Index < bifrostResponse.Results[j].Index
|
||||
}
|
||||
return bifrostResponse.Results[i].RelevanceScore > bifrostResponse.Results[j].RelevanceScore
|
||||
})
|
||||
if returnDocuments {
|
||||
for i := range bifrostResponse.Results {
|
||||
resultIndex := bifrostResponse.Results[i].Index
|
||||
if resultIndex >= 0 && resultIndex < len(documents) {
|
||||
bifrostResponse.Results[i].Document = schemas.Ptr(documents[resultIndex])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert usage information
|
||||
if response.Meta != nil {
|
||||
promptTokens := 0
|
||||
completionTokens := 0
|
||||
hasTokenUsage := false
|
||||
if response.Meta.Tokens != nil {
|
||||
if response.Meta.Tokens.InputTokens != nil {
|
||||
promptTokens = int(*response.Meta.Tokens.InputTokens)
|
||||
hasTokenUsage = true
|
||||
}
|
||||
if response.Meta.Tokens.OutputTokens != nil {
|
||||
completionTokens = int(*response.Meta.Tokens.OutputTokens)
|
||||
hasTokenUsage = true
|
||||
}
|
||||
} else if response.Meta.BilledUnits != nil {
|
||||
if response.Meta.BilledUnits.InputTokens != nil {
|
||||
promptTokens = int(*response.Meta.BilledUnits.InputTokens)
|
||||
hasTokenUsage = true
|
||||
}
|
||||
if response.Meta.BilledUnits.OutputTokens != nil {
|
||||
completionTokens = int(*response.Meta.BilledUnits.OutputTokens)
|
||||
hasTokenUsage = true
|
||||
}
|
||||
}
|
||||
if hasTokenUsage {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
func formatCohereRerankDocument(doc schemas.RerankDocument) string {
|
||||
if doc.ID == nil && len(doc.Meta) == 0 {
|
||||
return doc.Text
|
||||
}
|
||||
|
||||
// Keep metadata/id available by encoding a structured string document.
|
||||
documentPayload := map[string]interface{}{
|
||||
"text": doc.Text,
|
||||
}
|
||||
if doc.ID != nil {
|
||||
documentPayload["id"] = *doc.ID
|
||||
}
|
||||
if len(doc.Meta) > 0 {
|
||||
documentPayload["metadata"] = doc.Meta
|
||||
}
|
||||
|
||||
encoded, err := yaml.Marshal(documentPayload)
|
||||
if err != nil {
|
||||
return doc.Text
|
||||
}
|
||||
return string(encoded)
|
||||
}
|
||||
72
core/providers/cohere/rerank_test.go
Normal file
72
core/providers/cohere/rerank_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCohereRerankResponseToBifrostRerankResponse(t *testing.T) {
|
||||
response := (&CohereRerankResponse{
|
||||
ID: "rerank-response-id",
|
||||
Results: []CohereRerankResult{
|
||||
{
|
||||
Index: 1,
|
||||
RelevanceScore: 0.62,
|
||||
Document: json.RawMessage(`{"text":"provider-doc-1","id":"doc-1","topic":"geography"}`),
|
||||
},
|
||||
{
|
||||
Index: 0,
|
||||
RelevanceScore: 0.91,
|
||||
Document: json.RawMessage(`{"text":"provider-doc-0"}`),
|
||||
},
|
||||
},
|
||||
}).ToBifrostRerankResponse(nil, false)
|
||||
|
||||
require.NotNil(t, response)
|
||||
assert.Equal(t, "rerank-response-id", response.ID)
|
||||
require.Len(t, response.Results, 2)
|
||||
assert.Equal(t, 0, response.Results[0].Index)
|
||||
assert.Equal(t, 1, response.Results[1].Index)
|
||||
require.NotNil(t, response.Results[0].Document)
|
||||
require.NotNil(t, response.Results[1].Document)
|
||||
assert.Equal(t, "provider-doc-0", response.Results[0].Document.Text)
|
||||
assert.Equal(t, "provider-doc-1", response.Results[1].Document.Text)
|
||||
require.NotNil(t, response.Results[1].Document.ID)
|
||||
assert.Equal(t, "doc-1", *response.Results[1].Document.ID)
|
||||
assert.Equal(t, "geography", response.Results[1].Document.Meta["topic"])
|
||||
}
|
||||
|
||||
func TestCohereRerankResponseToBifrostRerankResponseReturnDocuments(t *testing.T) {
|
||||
requestDocs := []schemas.RerankDocument{
|
||||
{Text: "request-doc-0"},
|
||||
{Text: "request-doc-1"},
|
||||
}
|
||||
|
||||
response := (&CohereRerankResponse{
|
||||
Results: []CohereRerankResult{
|
||||
{
|
||||
Index: 1,
|
||||
RelevanceScore: 0.62,
|
||||
Document: json.RawMessage(`{"text":"provider-doc-1"}`),
|
||||
},
|
||||
{
|
||||
Index: 0,
|
||||
RelevanceScore: 0.91,
|
||||
Document: json.RawMessage(`{"text":"provider-doc-0"}`),
|
||||
},
|
||||
},
|
||||
}).ToBifrostRerankResponse(requestDocs, true)
|
||||
|
||||
require.NotNil(t, response)
|
||||
require.Len(t, response.Results, 2)
|
||||
require.NotNil(t, response.Results[0].Document)
|
||||
require.NotNil(t, response.Results[1].Document)
|
||||
assert.Equal(t, 0, response.Results[0].Index)
|
||||
assert.Equal(t, 1, response.Results[1].Index)
|
||||
assert.Equal(t, "request-doc-0", response.Results[0].Document.Text)
|
||||
assert.Equal(t, "request-doc-1", response.Results[1].Document.Text)
|
||||
}
|
||||
1726
core/providers/cohere/responses.go
Normal file
1726
core/providers/cohere/responses.go
Normal file
File diff suppressed because it is too large
Load Diff
616
core/providers/cohere/types.go
Normal file
616
core/providers/cohere/types.go
Normal file
@@ -0,0 +1,616 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const (
|
||||
MinimumReasoningMaxTokens = 1
|
||||
DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default
|
||||
)
|
||||
|
||||
// Limits for tokenize input api call https://docs.cohere.com/reference/tokenize#request
|
||||
const (
|
||||
cohereTokenizeMinTextLength = 1
|
||||
cohereTokenizeMaxTextLength = 65536
|
||||
)
|
||||
|
||||
// ==================== REQUEST TYPES ====================
|
||||
|
||||
// CohereChatRequest represents a Cohere chat completion request
|
||||
type CohereChatRequest struct {
|
||||
Model string `json:"model"` // Required: Model to use for chat completion
|
||||
Messages []CohereMessage `json:"messages"` // Required: Array of message objects
|
||||
Tools []CohereChatRequestTool `json:"tools,omitempty"` // Optional: Tools available for the model
|
||||
ToolChoice *CohereToolChoice `json:"tool_choice,omitempty"` // Optional: Tool choice configuration
|
||||
Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature
|
||||
P *float64 `json:"p,omitempty"` // Optional: Top-p sampling
|
||||
K *int `json:"k,omitempty"` // Optional: Top-k sampling
|
||||
MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate
|
||||
StopSequences []string `json:"stop_sequences,omitempty"` // Optional: Stop sequences
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty
|
||||
Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming
|
||||
SafetyMode *string `json:"safety_mode,omitempty"` // Optional: Safety mode
|
||||
LogProbs *bool `json:"log_probs,omitempty"` // Optional: Log probabilities
|
||||
StrictToolChoice *bool `json:"strict_tool_choice,omitempty"` // Optional: Strict tool choice
|
||||
Thinking *CohereThinking `json:"thinking,omitempty"` // Optional: Reasoning configuration
|
||||
ResponseFormat *CohereResponseFormat `json:"response_format,omitempty"` // Optional: Format for the response
|
||||
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
|
||||
}
|
||||
|
||||
// IsStreamingRequested implements the StreamingRequest interface
|
||||
func (r *CohereChatRequest) IsStreamingRequested() bool {
|
||||
return r.Stream != nil && *r.Stream
|
||||
}
|
||||
|
||||
func (r *CohereChatRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
type CohereChatRequestTool struct {
|
||||
Type string `json:"type"` // always "function"
|
||||
Function CohereChatRequestFunction `json:"function"`
|
||||
}
|
||||
|
||||
type CohereChatRequestFunction struct {
|
||||
Name string `json:"name"` // Function name
|
||||
Parameters interface{} `json:"parameters,omitempty"` // Function parameters (JSON string)
|
||||
Description *string `json:"description,omitempty"` // Optional: Function description
|
||||
}
|
||||
|
||||
// CohereMessage represents a message in Cohere format
|
||||
type CohereMessage struct {
|
||||
Role string `json:"role"` // Required: Message role (system, user, assistant, tool)
|
||||
Content *CohereMessageContent `json:"content,omitempty"` // Optional: Message content (string or array of content blocks)
|
||||
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"` // Optional: Tool calls (for assistant messages)
|
||||
ToolCallID *string `json:"tool_call_id,omitempty"` // Optional: Tool call ID (for tool messages)
|
||||
ToolPlan *string `json:"tool_plan,omitempty"` // Optional: Chain-of-thought style reflection (assistant only)
|
||||
}
|
||||
|
||||
// CohereMessageContent represents flexible content that can be string or content blocks
|
||||
type CohereMessageContent struct {
|
||||
// Use custom marshaling to handle string or []CohereContentBlock
|
||||
StringContent *string `json:"-"`
|
||||
BlocksContent []CohereContentBlock `json:"-"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for CohereMessageContent
|
||||
func (c *CohereMessageContent) MarshalJSON() ([]byte, error) {
|
||||
if c.StringContent != nil {
|
||||
return providerUtils.MarshalSorted(*c.StringContent)
|
||||
}
|
||||
if c.BlocksContent != nil {
|
||||
return providerUtils.MarshalSorted(c.BlocksContent)
|
||||
}
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for CohereMessageContent
|
||||
func (c *CohereMessageContent) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as string first
|
||||
var str string
|
||||
if err := sonic.Unmarshal(data, &str); err == nil {
|
||||
c.StringContent = &str
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as content blocks array
|
||||
var blocks []CohereContentBlock
|
||||
if err := sonic.Unmarshal(data, &blocks); err == nil {
|
||||
c.BlocksContent = blocks
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("content must be either string or array of content blocks")
|
||||
}
|
||||
|
||||
// Helper methods for CohereMessageContent
|
||||
|
||||
// NewStringContent creates a CohereMessageContent with string content
|
||||
func NewStringContent(content string) *CohereMessageContent {
|
||||
return &CohereMessageContent{
|
||||
StringContent: &content,
|
||||
}
|
||||
}
|
||||
|
||||
// NewBlocksContent creates a CohereMessageContent with content blocks
|
||||
func NewBlocksContent(blocks []CohereContentBlock) *CohereMessageContent {
|
||||
return &CohereMessageContent{
|
||||
BlocksContent: blocks,
|
||||
}
|
||||
}
|
||||
|
||||
// IsString returns true if content is a string
|
||||
func (c *CohereMessageContent) IsString() bool {
|
||||
return c.StringContent != nil
|
||||
}
|
||||
|
||||
// IsBlocks returns true if content is content blocks
|
||||
func (c *CohereMessageContent) IsBlocks() bool {
|
||||
return c.BlocksContent != nil
|
||||
}
|
||||
|
||||
// GetString returns the string content (nil if not string)
|
||||
func (c *CohereMessageContent) GetString() *string {
|
||||
return c.StringContent
|
||||
}
|
||||
|
||||
// GetBlocks returns the content blocks (nil if not blocks)
|
||||
func (c *CohereMessageContent) GetBlocks() []CohereContentBlock {
|
||||
return c.BlocksContent
|
||||
}
|
||||
|
||||
type CohereContentBlockType string
|
||||
|
||||
const (
|
||||
CohereContentBlockTypeText CohereContentBlockType = "text"
|
||||
CohereContentBlockTypeImage CohereContentBlockType = "image_url"
|
||||
CohereContentBlockTypeThinking CohereContentBlockType = "thinking"
|
||||
CohereContentBlockTypeDocument CohereContentBlockType = "document"
|
||||
)
|
||||
|
||||
// CohereContentBlock represents a content block in Cohere format
|
||||
// This is a union type that can be text, image_url, thinking, or document
|
||||
type CohereContentBlock struct {
|
||||
Type CohereContentBlockType `json:"type"` // Required: Content block type
|
||||
|
||||
// Text content block
|
||||
Text *string `json:"text,omitempty"`
|
||||
|
||||
// Image URL content block
|
||||
ImageURL *CohereImageURL `json:"image_url,omitempty"`
|
||||
|
||||
// Thinking content block (assistant only)
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
|
||||
// Document content block (tool messages)
|
||||
Document *CohereDocument `json:"document,omitempty"`
|
||||
}
|
||||
|
||||
// CohereImageURL represents an image URL content block
|
||||
type CohereImageURL struct {
|
||||
URL string `json:"url"` // Required: Image URL
|
||||
}
|
||||
|
||||
// CohereDocument represents a document content block
|
||||
type CohereDocument struct {
|
||||
Data schemas.OrderedMap `json:"data"` // Required: Document data as key-value pairs
|
||||
ID *string `json:"id,omitempty"` // Optional: Document ID for citations
|
||||
}
|
||||
|
||||
// CohereThinking represents reasoning configuration
|
||||
type CohereThinking struct {
|
||||
Type CohereThinkingType `json:"type"` // Required: Reasoning type (enabled, disabled)
|
||||
TokenBudget *int `json:"token_budget,omitempty"` // Optional: Maximum thinking tokens (>=1)
|
||||
}
|
||||
|
||||
// CohereThinkingType represents the type of reasoning
|
||||
type CohereThinkingType string
|
||||
|
||||
const (
|
||||
ThinkingTypeEnabled CohereThinkingType = "enabled"
|
||||
ThinkingTypeDisabled CohereThinkingType = "disabled"
|
||||
)
|
||||
|
||||
// CohereResponseFormat represents the response format configuration for Cohere chat requests
|
||||
type CohereResponseFormat struct {
|
||||
Type CohereResponseFormatType `json:"type"` // Required: Response format type
|
||||
JSONSchema *interface{} `json:"schema,omitempty"` // Optional: JSON schema for structured output (not used when type is "text")
|
||||
}
|
||||
|
||||
// CohereResponseFormatType represents the type of response format
|
||||
type CohereResponseFormatType string
|
||||
|
||||
const (
|
||||
ResponseFormatTypeText CohereResponseFormatType = "text"
|
||||
ResponseFormatTypeJSONObject CohereResponseFormatType = "json_object"
|
||||
)
|
||||
|
||||
// CohereToolChoice represents tool choice configuration
|
||||
type CohereToolChoice string
|
||||
|
||||
const (
|
||||
ToolChoiceRequired CohereToolChoice = "REQUIRED"
|
||||
ToolChoiceNone CohereToolChoice = "NONE"
|
||||
ToolChoiceAuto CohereToolChoice = "AUTO"
|
||||
)
|
||||
|
||||
// CohereToolCall represents a tool call in Cohere format
|
||||
type CohereToolCall struct {
|
||||
ID *string `json:"id,omitempty"` // Optional: Tool call ID
|
||||
Type string `json:"type"` // Required: Tool call type (must be "function")
|
||||
Function *CohereFunction `json:"function"` // Required: Function call details
|
||||
}
|
||||
|
||||
// CohereFunction represents a function call
|
||||
type CohereFunction struct {
|
||||
Name *string `json:"name,omitempty"` // Optional: Function name
|
||||
Arguments string `json:"arguments,omitempty"` // Optional: Function arguments (JSON string)
|
||||
}
|
||||
|
||||
// CohereParameterDefinition represents a parameter definition for a Cohere tool.
|
||||
// It defines the type, description, and whether the parameter is required.
|
||||
type CohereParameterDefinition struct {
|
||||
Type string `json:"type"` // Type of the parameter
|
||||
Description *string `json:"description,omitempty"` // Optional description of the parameter
|
||||
Required bool `json:"required"` // Whether the parameter is required
|
||||
}
|
||||
|
||||
// CohereTool represents a tool definition for the Cohere API.
|
||||
// It includes the tool's name, description, and parameter definitions.
|
||||
type CohereTool struct {
|
||||
Name string `json:"name"` // Name of the tool
|
||||
Description string `json:"description"` // Description of the tool
|
||||
ParameterDefinitions map[string]CohereParameterDefinition `json:"parameter_definitions"` // Definitions of the tool's parameters
|
||||
}
|
||||
|
||||
// CohereCountTokensRequest represents a Cohere tokenize request
|
||||
type CohereCountTokensRequest struct {
|
||||
Model string `json:"model"` // Required: Model whose tokenizer should be used
|
||||
Text string `json:"text"` // Required: Text to tokenize (1-65536 chars)
|
||||
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
|
||||
}
|
||||
|
||||
func (r *CohereCountTokensRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
// CohereEmbeddingRequest represents a Cohere embedding request
|
||||
type CohereEmbeddingRequest struct {
|
||||
Model string `json:"model"` // Required: ID of embedding model
|
||||
InputType string `json:"input_type"` // Required: Type of input for v3+ models
|
||||
Texts []string `json:"texts,omitempty"` // Optional: Array of strings to embed (max 96)
|
||||
Images []string `json:"images,omitempty"` // Optional: Array of image data URIs (max 1)
|
||||
Inputs []CohereEmbeddingInput `json:"inputs,omitempty"` // Optional: Array of mixed text/image inputs (max 96)
|
||||
MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Max tokens to embed per input
|
||||
OutputDimension *int `json:"output_dimension,omitempty"` // Optional: Embedding dimensions (256, 512, 1024, 1536)
|
||||
EmbeddingTypes []string `json:"embedding_types,omitempty"` // Optional: Types of embeddings to return
|
||||
Truncate *string `json:"truncate,omitempty"` // Optional: How to handle long inputs
|
||||
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
|
||||
}
|
||||
|
||||
func (r *CohereEmbeddingRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
// CohereEmbeddingInput represents a mixed text/image input
|
||||
type CohereEmbeddingInput struct {
|
||||
Content []CohereContentBlock `json:"content"` // Required: Array of content blocks (reuses chat content blocks)
|
||||
}
|
||||
|
||||
// CohereEmbeddingResponse represents a Cohere embedding response
|
||||
type CohereEmbeddingResponse struct {
|
||||
ID string `json:"id"` // Response ID
|
||||
Embeddings *CohereEmbeddingData `json:"embeddings,omitempty"` // Embedding data object
|
||||
ResponseType *string `json:"response_type,omitempty"` // Response type (embeddings_floats, embeddings_by_type)
|
||||
Texts []string `json:"texts,omitempty"` // Original text entries
|
||||
Images []CohereEmbeddingImageInfo `json:"images,omitempty"` // Original image entries
|
||||
Meta *CohereEmbeddingMeta `json:"meta,omitempty"` // Response metadata
|
||||
}
|
||||
|
||||
// CohereEmbeddingData represents the embeddings object with different types
|
||||
type CohereEmbeddingData struct {
|
||||
Float [][]float64 `json:"float,omitempty"` // Float embeddings
|
||||
Int8 [][]int8 `json:"int8,omitempty"` // Int8 embeddings
|
||||
Uint8 [][]uint8 `json:"uint8,omitempty"` // Uint8 embeddings
|
||||
Binary [][]int8 `json:"binary,omitempty"` // Binary embeddings
|
||||
Ubinary [][]uint8 `json:"ubinary,omitempty"` // Unsigned binary embeddings
|
||||
Base64 []string `json:"base64,omitempty"` // Base64 embeddings
|
||||
}
|
||||
|
||||
// CohereEmbeddingImageInfo represents image information in the response
|
||||
type CohereEmbeddingImageInfo struct {
|
||||
Width int64 `json:"width"` // Width in pixels
|
||||
Height int64 `json:"height"` // Height in pixels
|
||||
Format string `json:"format"` // Image format
|
||||
BitDepth int64 `json:"bit_depth"` // Bit depth
|
||||
}
|
||||
|
||||
// CohereEmbeddingMeta represents metadata in embedding response
|
||||
type CohereEmbeddingMeta struct {
|
||||
APIVersion *CohereEmbeddingAPIVersion `json:"api_version,omitempty"` // API version info
|
||||
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"` // Billing information
|
||||
Tokens *CohereTokenUsage `json:"tokens,omitempty"` // Token usage
|
||||
Warnings []string `json:"warnings,omitempty"` // Any warnings
|
||||
}
|
||||
|
||||
// CohereEmbeddingAPIVersion represents API version information
|
||||
type CohereEmbeddingAPIVersion struct {
|
||||
Version *string `json:"version,omitempty"` // API version
|
||||
IsDeprecated *bool `json:"is_deprecated,omitempty"` // Deprecation status
|
||||
IsExperimental *bool `json:"is_experimental,omitempty"` // Experimental status
|
||||
}
|
||||
|
||||
// ==================== RESPONSE TYPES ====================
|
||||
|
||||
// CohereCountTokensResponse represents the response from the tokenize endpoint
|
||||
type CohereCountTokensResponse struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
TokenStrings []string `json:"token_strings,omitempty"`
|
||||
Meta *CohereTokenizeMeta `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// CohereTokenizeMeta captures metadata returned by the tokenize endpoint
|
||||
type CohereTokenizeMeta struct {
|
||||
APIVersion *CohereTokenizeAPIVersion `json:"api_version,omitempty"`
|
||||
}
|
||||
|
||||
// CohereTokenizeAPIVersion describes API version metadata
|
||||
type CohereTokenizeAPIVersion struct {
|
||||
Version *string `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
// CohereChatResponse represents a Cohere chat completion response
|
||||
type CohereChatResponse struct {
|
||||
ID string `json:"id"` // Unique identifier for the generated reply
|
||||
FinishReason *CohereFinishReason `json:"finish_reason,omitempty"` // Reason for completion
|
||||
Message *CohereMessage `json:"message,omitempty"` // Generated message from assistant
|
||||
Usage *CohereUsage `json:"usage,omitempty"` // Token usage information
|
||||
LogProbs []CohereLogProb `json:"logprobs,omitempty"` // Log probabilities (if requested)
|
||||
}
|
||||
|
||||
// CohereFinishReason represents the reason a chat request has finished
|
||||
type CohereFinishReason string
|
||||
|
||||
const (
|
||||
FinishReasonComplete CohereFinishReason = "COMPLETE" // Model finished sending complete message
|
||||
FinishReasonStopSequence CohereFinishReason = "STOP_SEQUENCE" // Stop sequence was reached
|
||||
FinishReasonMaxTokens CohereFinishReason = "MAX_TOKENS" // Max tokens exceeded
|
||||
FinishReasonToolCall CohereFinishReason = "TOOL_CALL" // Model generated tool call
|
||||
FinishReasonError CohereFinishReason = "ERROR" // Generation failed due to internal error
|
||||
FinishReasonTimeout CohereFinishReason = "TIMEOUT" // Timeout
|
||||
)
|
||||
|
||||
// CohereUsage represents token usage information
|
||||
type CohereUsage struct {
|
||||
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"` // Billed usage information
|
||||
Tokens *CohereTokenUsage `json:"tokens,omitempty"` // Token usage details
|
||||
CachedTokens *int `json:"cached_tokens,omitempty"` // Cached tokens
|
||||
}
|
||||
|
||||
// CohereBilledUnits represents billed usage information
|
||||
type CohereBilledUnits struct {
|
||||
InputTokens *int `json:"input_tokens,omitempty"` // Number of billed input tokens
|
||||
OutputTokens *int `json:"output_tokens,omitempty"` // Number of billed output tokens
|
||||
SearchUnits *int `json:"search_units,omitempty"` // Number of billed search units
|
||||
Classifications *int `json:"classifications,omitempty"` // Number of billed classification units
|
||||
}
|
||||
|
||||
// CohereTokenUsage represents detailed token usage
|
||||
type CohereTokenUsage struct {
|
||||
InputTokens *int `json:"input_tokens"` // Number of input tokens used
|
||||
OutputTokens *int `json:"output_tokens"` // Number of output tokens produced
|
||||
}
|
||||
|
||||
// CohereLogProb represents log probability information
|
||||
type CohereLogProb struct {
|
||||
TokenIDs []int `json:"token_ids"` // Token IDs of each token in text chunk
|
||||
Text *string `json:"text,omitempty"` // Text chunk for log probabilities
|
||||
LogProbs []float64 `json:"logprobs,omitempty"` // Log probability of each token
|
||||
}
|
||||
|
||||
type CohereCitationType string
|
||||
|
||||
const (
|
||||
CitationTypeTextContent CohereCitationType = "TEXT_CONTENT"
|
||||
CitationTypeThinkingContent CohereCitationType = "THINKING_CONTENT"
|
||||
CitationTypePlan CohereCitationType = "PLAN"
|
||||
)
|
||||
|
||||
type CohereSourceType string
|
||||
|
||||
const (
|
||||
SourceTypeTool CohereSourceType = "tool"
|
||||
SourceTypeDocument CohereSourceType = "document"
|
||||
)
|
||||
|
||||
// CohereCitation represents a citation in the response
|
||||
type CohereCitation struct {
|
||||
Start int `json:"start"` // Start position of cited text
|
||||
End int `json:"end"` // End position of cited text
|
||||
Text string `json:"text"` // Cited text
|
||||
Sources []CohereSource `json:"sources,omitempty"` // Citation sources
|
||||
ContentIndex int `json:"content_index"` // Content index of the citation
|
||||
Type CohereCitationType `json:"type"` // Type of citation
|
||||
}
|
||||
|
||||
// CohereSource represents a citation source
|
||||
type CohereSource struct {
|
||||
Type CohereSourceType `json:"type"` // Source type ("tool" or "document")
|
||||
ID *string `json:"id,omitempty"` // Source ID (nullable)
|
||||
ToolOutput *json.RawMessage `json:"tool_output,omitempty"` // Tool output (for tool sources, json.RawMessage preserves key ordering)
|
||||
Document *json.RawMessage `json:"document,omitempty"` // Document data (for document sources, json.RawMessage preserves key ordering)
|
||||
}
|
||||
|
||||
// ==================== STREAMING TYPES ====================
|
||||
|
||||
// CohereStreamEventType represents the type of streaming event
|
||||
type CohereStreamEventType string
|
||||
|
||||
const (
|
||||
StreamEventMessageStart CohereStreamEventType = "message-start"
|
||||
StreamEventContentStart CohereStreamEventType = "content-start"
|
||||
StreamEventContentDelta CohereStreamEventType = "content-delta"
|
||||
StreamEventContentEnd CohereStreamEventType = "content-end"
|
||||
StreamEventToolPlanDelta CohereStreamEventType = "tool-plan-delta"
|
||||
StreamEventToolCallStart CohereStreamEventType = "tool-call-start"
|
||||
StreamEventToolCallDelta CohereStreamEventType = "tool-call-delta"
|
||||
StreamEventToolCallEnd CohereStreamEventType = "tool-call-end"
|
||||
StreamEventCitationStart CohereStreamEventType = "citation-start"
|
||||
StreamEventCitationEnd CohereStreamEventType = "citation-end"
|
||||
StreamEventMessageEnd CohereStreamEventType = "message-end"
|
||||
StreamEventDebug CohereStreamEventType = "debug"
|
||||
)
|
||||
|
||||
// CohereStreamEvent represents a unified streaming event from Cohere API
|
||||
type CohereStreamEvent struct {
|
||||
Type CohereStreamEventType `json:"type"`
|
||||
ID *string `json:"id,omitempty"` // For message-start
|
||||
Index *int `json:"index,omitempty"` // For indexed events
|
||||
Delta *CohereStreamDelta `json:"delta,omitempty"`
|
||||
}
|
||||
|
||||
// CohereStreamDelta represents the delta content in streaming events
|
||||
type CohereStreamDelta struct {
|
||||
Message *CohereStreamMessage `json:"message,omitempty"`
|
||||
FinishReason *CohereFinishReason `json:"finish_reason,omitempty"`
|
||||
Usage *CohereUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type CohereStreamToolCallStruct struct {
|
||||
CohereToolCallObject *CohereToolCall
|
||||
CohereToolCallArray []CohereToolCall
|
||||
}
|
||||
|
||||
// JSON marshaling for CohereStreamToolCall
|
||||
func (c *CohereStreamToolCallStruct) MarshalJSON() ([]byte, error) {
|
||||
if c.CohereToolCallObject != nil {
|
||||
return providerUtils.MarshalSorted(c.CohereToolCallObject)
|
||||
}
|
||||
if c.CohereToolCallArray != nil {
|
||||
return providerUtils.MarshalSorted(c.CohereToolCallArray)
|
||||
}
|
||||
return providerUtils.MarshalSorted(nil)
|
||||
}
|
||||
|
||||
func (c *CohereStreamToolCallStruct) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
// Try to unmarshal as array first
|
||||
var toolCallArray []CohereToolCall
|
||||
if err := sonic.Unmarshal(data, &toolCallArray); err == nil {
|
||||
c.CohereToolCallArray = toolCallArray
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as single object
|
||||
var toolCallObject CohereToolCall
|
||||
if err := sonic.Unmarshal(data, &toolCallObject); err == nil {
|
||||
c.CohereToolCallObject = &toolCallObject
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("tool_calls field is neither array nor object")
|
||||
}
|
||||
|
||||
type CohereStreamContentStruct struct {
|
||||
CohereStreamContentObject *CohereStreamContent
|
||||
CohereStreamContentArray []CohereStreamContent
|
||||
}
|
||||
|
||||
func (c *CohereStreamContentStruct) MarshalJSON() ([]byte, error) {
|
||||
if c.CohereStreamContentObject != nil {
|
||||
return providerUtils.MarshalSorted(c.CohereStreamContentObject)
|
||||
}
|
||||
if c.CohereStreamContentArray != nil {
|
||||
return providerUtils.MarshalSorted(c.CohereStreamContentArray)
|
||||
}
|
||||
return providerUtils.MarshalSorted(nil)
|
||||
}
|
||||
|
||||
func (c *CohereStreamContentStruct) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
// Try to unmarshal as array first
|
||||
var contentArray []CohereStreamContent
|
||||
if err := sonic.Unmarshal(data, &contentArray); err == nil {
|
||||
c.CohereStreamContentArray = contentArray
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as single object
|
||||
var contentObject CohereStreamContent
|
||||
if err := sonic.Unmarshal(data, &contentObject); err == nil {
|
||||
c.CohereStreamContentObject = &contentObject
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("content field is neither array nor object")
|
||||
}
|
||||
|
||||
type CohereStreamCitationStruct struct {
|
||||
CohereStreamCitationObject *CohereCitation
|
||||
CohereStreamCitationArray []CohereCitation
|
||||
}
|
||||
|
||||
func (c *CohereStreamCitationStruct) MarshalJSON() ([]byte, error) {
|
||||
if c.CohereStreamCitationObject != nil {
|
||||
return providerUtils.MarshalSorted(c.CohereStreamCitationObject)
|
||||
}
|
||||
if c.CohereStreamCitationArray != nil {
|
||||
return providerUtils.MarshalSorted(c.CohereStreamCitationArray)
|
||||
}
|
||||
return providerUtils.MarshalSorted(nil)
|
||||
}
|
||||
|
||||
func (c *CohereStreamCitationStruct) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
// Try to unmarshal as array first
|
||||
var citationArray []CohereCitation
|
||||
if err := sonic.Unmarshal(data, &citationArray); err == nil {
|
||||
c.CohereStreamCitationArray = citationArray
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as single object
|
||||
var citationObject CohereCitation
|
||||
if err := sonic.Unmarshal(data, &citationObject); err == nil {
|
||||
c.CohereStreamCitationObject = &citationObject
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("citations field is neither array nor object")
|
||||
}
|
||||
|
||||
// CohereStreamMessage represents the message part of streaming deltas
|
||||
type CohereStreamMessage struct {
|
||||
Role *string `json:"role,omitempty"` // For message-start
|
||||
Content *CohereStreamContentStruct `json:"content,omitempty"` // For content events (object)
|
||||
ToolPlan *string `json:"tool_plan,omitempty"` // For tool-plan-delta
|
||||
ToolCalls *CohereStreamToolCallStruct `json:"tool_calls,omitempty"` // For tool-call events (flexible)
|
||||
Citations *CohereStreamCitationStruct `json:"citations,omitempty"` // For citation events
|
||||
}
|
||||
|
||||
// CohereStreamContent represents content in streaming events
|
||||
type CohereStreamContent struct {
|
||||
Type CohereContentBlockType `json:"type,omitempty"` // For content-start
|
||||
Text *string `json:"text,omitempty"` // For content deltas
|
||||
Thinking *string `json:"thinking,omitempty"` // For thinking deltas
|
||||
}
|
||||
|
||||
// ==================== ERROR TYPES ====================
|
||||
|
||||
// CohereError represents an error response from the Cohere API
|
||||
type CohereError struct {
|
||||
Type string `json:"type"` // Error type
|
||||
Message string `json:"message"` // Error message
|
||||
Code *string `json:"code,omitempty"` // Optional error code
|
||||
}
|
||||
|
||||
// ==================== MODEL TYPES ====================
|
||||
|
||||
type CohereModel struct {
|
||||
Name string `json:"name"`
|
||||
IsDeprecated bool `json:"is_deprecated"`
|
||||
Endpoints []string `json:"endpoints"`
|
||||
Finetuned bool `json:"finetuned"`
|
||||
ContextLength int `json:"context_length"`
|
||||
TokenizerURL string `json:"tokenizer_url"`
|
||||
DefaultEndpoints []string `json:"default_endpoints"`
|
||||
Features []string `json:"features"`
|
||||
}
|
||||
|
||||
type CohereListModelsResponse struct {
|
||||
Models []CohereModel `json:"models"`
|
||||
NextPageToken string `json:"next_page_token"`
|
||||
}
|
||||
293
core/providers/cohere/utils.go
Normal file
293
core/providers/cohere/utils.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
var (
|
||||
// Maps provider-specific finish reasons to Bifrost format
|
||||
cohereFinishReasonToBifrost = map[CohereFinishReason]string{
|
||||
FinishReasonComplete: "stop",
|
||||
FinishReasonStopSequence: "stop",
|
||||
FinishReasonMaxTokens: "length",
|
||||
FinishReasonToolCall: "tool_calls",
|
||||
}
|
||||
)
|
||||
|
||||
// ConvertCohereFinishReasonToBifrost converts provider finish reasons to Bifrost format
|
||||
func ConvertCohereFinishReasonToBifrost(providerReason CohereFinishReason) string {
|
||||
if bifrostReason, ok := cohereFinishReasonToBifrost[providerReason]; ok {
|
||||
return bifrostReason
|
||||
}
|
||||
return string(providerReason)
|
||||
}
|
||||
|
||||
// convertInterfaceToToolFunctionParameters converts an interface{} to ToolFunctionParameters
|
||||
// This handles the conversion from Cohere's flexible parameter format to Bifrost's structured format
|
||||
func convertInterfaceToToolFunctionParameters(params interface{}) *schemas.ToolFunctionParameters {
|
||||
if params == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to convert from map[string]interface{}
|
||||
paramsMap, ok := params.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := &schemas.ToolFunctionParameters{}
|
||||
|
||||
// Extract type
|
||||
if typeVal, ok := paramsMap["type"].(string); ok {
|
||||
result.Type = typeVal
|
||||
}
|
||||
|
||||
// Extract description
|
||||
if descVal, ok := paramsMap["description"].(string); ok {
|
||||
result.Description = &descVal
|
||||
}
|
||||
|
||||
// Extract required
|
||||
if requiredVal, ok := paramsMap["required"].([]interface{}); ok {
|
||||
required := make([]string, 0, len(requiredVal))
|
||||
for _, v := range requiredVal {
|
||||
if s, ok := v.(string); ok {
|
||||
required = append(required, s)
|
||||
}
|
||||
}
|
||||
result.Required = required
|
||||
}
|
||||
|
||||
// Extract properties
|
||||
if orderedProps, ok := schemas.SafeExtractOrderedMap(paramsMap["properties"]); ok {
|
||||
result.Properties = orderedProps
|
||||
}
|
||||
|
||||
// Extract enum
|
||||
if enumVal, ok := paramsMap["enum"].([]interface{}); ok {
|
||||
enum := make([]string, 0, len(enumVal))
|
||||
for _, v := range enumVal {
|
||||
if s, ok := v.(string); ok {
|
||||
enum = append(enum, s)
|
||||
}
|
||||
}
|
||||
result.Enum = enum
|
||||
}
|
||||
|
||||
// Extract additionalProperties
|
||||
if addPropsVal, ok := paramsMap["additionalProperties"].(bool); ok {
|
||||
result.AdditionalProperties = &schemas.AdditionalPropertiesStruct{
|
||||
AdditionalPropertiesBool: &addPropsVal,
|
||||
}
|
||||
}
|
||||
|
||||
if addPropsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["additionalProperties"]); ok {
|
||||
result.AdditionalProperties = &schemas.AdditionalPropertiesStruct{
|
||||
AdditionalPropertiesMap: addPropsVal,
|
||||
}
|
||||
}
|
||||
|
||||
// Extract $defs (JSON Schema draft 2019-09+)
|
||||
if defsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["$defs"]); ok {
|
||||
result.Defs = defsVal
|
||||
}
|
||||
|
||||
// Extract definitions (legacy JSON Schema draft-07)
|
||||
if defsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["definitions"]); ok {
|
||||
result.Definitions = defsVal
|
||||
}
|
||||
|
||||
// Extract $ref
|
||||
if refVal, ok := paramsMap["$ref"].(string); ok {
|
||||
result.Ref = &refVal
|
||||
}
|
||||
|
||||
// Extract items (array element schema)
|
||||
if itemsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["items"]); ok {
|
||||
result.Items = itemsVal
|
||||
}
|
||||
|
||||
// Extract minItems
|
||||
if minItemsVal, ok := extractInt64(paramsMap["minItems"]); ok {
|
||||
result.MinItems = &minItemsVal
|
||||
}
|
||||
|
||||
// Extract maxItems
|
||||
if maxItemsVal, ok := extractInt64(paramsMap["maxItems"]); ok {
|
||||
result.MaxItems = &maxItemsVal
|
||||
}
|
||||
|
||||
// Extract anyOf
|
||||
if anyOfVal, ok := paramsMap["anyOf"].([]interface{}); ok {
|
||||
anyOf := make([]schemas.OrderedMap, 0, len(anyOfVal))
|
||||
for _, v := range anyOfVal {
|
||||
if m, ok := schemas.SafeExtractOrderedMap(v); ok {
|
||||
anyOf = append(anyOf, *m)
|
||||
}
|
||||
}
|
||||
result.AnyOf = anyOf
|
||||
}
|
||||
|
||||
// Extract oneOf
|
||||
if oneOfVal, ok := paramsMap["oneOf"].([]interface{}); ok {
|
||||
oneOf := make([]schemas.OrderedMap, 0, len(oneOfVal))
|
||||
for _, v := range oneOfVal {
|
||||
if m, ok := schemas.SafeExtractOrderedMap(v); ok {
|
||||
oneOf = append(oneOf, *m)
|
||||
}
|
||||
}
|
||||
result.OneOf = oneOf
|
||||
}
|
||||
|
||||
// Extract allOf
|
||||
if allOfVal, ok := paramsMap["allOf"].([]interface{}); ok {
|
||||
allOf := make([]schemas.OrderedMap, 0, len(allOfVal))
|
||||
for _, v := range allOfVal {
|
||||
if m, ok := schemas.SafeExtractOrderedMap(v); ok {
|
||||
allOf = append(allOf, *m)
|
||||
}
|
||||
}
|
||||
result.AllOf = allOf
|
||||
}
|
||||
|
||||
// Extract format
|
||||
if formatVal, ok := paramsMap["format"].(string); ok {
|
||||
result.Format = &formatVal
|
||||
}
|
||||
|
||||
// Extract pattern
|
||||
if patternVal, ok := paramsMap["pattern"].(string); ok {
|
||||
result.Pattern = &patternVal
|
||||
}
|
||||
|
||||
// Extract minLength
|
||||
if minLengthVal, ok := extractInt64(paramsMap["minLength"]); ok {
|
||||
result.MinLength = &minLengthVal
|
||||
}
|
||||
|
||||
// Extract maxLength
|
||||
if maxLengthVal, ok := extractInt64(paramsMap["maxLength"]); ok {
|
||||
result.MaxLength = &maxLengthVal
|
||||
}
|
||||
|
||||
// Extract minimum
|
||||
if minVal, ok := extractFloat64(paramsMap["minimum"]); ok {
|
||||
result.Minimum = &minVal
|
||||
}
|
||||
|
||||
// Extract maximum
|
||||
if maxVal, ok := extractFloat64(paramsMap["maximum"]); ok {
|
||||
result.Maximum = &maxVal
|
||||
}
|
||||
|
||||
// Extract title
|
||||
if titleVal, ok := paramsMap["title"].(string); ok {
|
||||
result.Title = &titleVal
|
||||
}
|
||||
|
||||
// Extract default
|
||||
if defaultVal, exists := paramsMap["default"]; exists {
|
||||
result.Default = defaultVal
|
||||
}
|
||||
|
||||
// Extract nullable
|
||||
if nullableVal, ok := paramsMap["nullable"].(bool); ok {
|
||||
result.Nullable = &nullableVal
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// extractInt64 extracts an int64 from various numeric types
|
||||
func extractInt64(v interface{}) (int64, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return int64(val), true
|
||||
case int64:
|
||||
return val, true
|
||||
case float64:
|
||||
return int64(val), true
|
||||
case float32:
|
||||
return int64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// extractFloat64 extracts a float64 from various numeric types
|
||||
func extractFloat64(v interface{}) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertResponseFormatToCohere converts OpenAI-style response_format (interface{}) to Cohere's typed format
|
||||
// Input can be a map with structure: { type: "json_schema", json_schema: { schema: {...} } }
|
||||
// Output: CohereResponseFormat with flat structure: { type: "json_object", json_schema: {...} }
|
||||
func convertResponseFormatToCohere(responseFormat *interface{}) *CohereResponseFormat {
|
||||
if responseFormat == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to extract as map
|
||||
formatMap, ok := (*responseFormat).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
cohereFormat := &CohereResponseFormat{}
|
||||
|
||||
// Extract type
|
||||
typeStr, _ := formatMap["type"].(string)
|
||||
switch typeStr {
|
||||
case "text":
|
||||
cohereFormat.Type = ResponseFormatTypeText
|
||||
case "json_object", "json_schema":
|
||||
cohereFormat.Type = ResponseFormatTypeJSONObject
|
||||
|
||||
// Extract the nested schema
|
||||
// OpenAI format: { type: "json_schema", json_schema: { name: "X", strict: true, schema: {...} } }
|
||||
if jsonSchemaWrapper, ok := formatMap["json_schema"].(map[string]interface{}); ok {
|
||||
if schema, ok := jsonSchemaWrapper["schema"].(map[string]interface{}); ok {
|
||||
var schemaInterface interface{} = schema
|
||||
cohereFormat.JSONSchema = &schemaInterface
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return cohereFormat
|
||||
}
|
||||
|
||||
// convertCohereResponseFormatToBifrost converts Cohere's typed response_format back to interface{}
|
||||
func convertCohereResponseFormatToBifrost(cohereFormat *CohereResponseFormat) *interface{} {
|
||||
if cohereFormat == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build JSON bytes with deterministic key order using sjson
|
||||
data := []byte(`{}`)
|
||||
if cohereFormat.JSONSchema != nil {
|
||||
data, _ = sjson.SetBytes(data, "type", "json_schema")
|
||||
schemaBytes, _ := schemas.MarshalSorted(cohereFormat.JSONSchema)
|
||||
data, _ = sjson.SetRawBytes(data, "json_schema", schemaBytes)
|
||||
} else {
|
||||
data, _ = sjson.SetBytes(data, "type", string(cohereFormat.Type))
|
||||
}
|
||||
|
||||
var resultInterface interface{} = json.RawMessage(data)
|
||||
return &resultInterface
|
||||
}
|
||||
933
core/providers/elevenlabs/elevenlabs.go
Normal file
933
core/providers/elevenlabs/elevenlabs.go
Normal file
@@ -0,0 +1,933 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
type ElevenlabsProvider struct {
|
||||
logger schemas.Logger // Logger for provider operations
|
||||
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
|
||||
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
|
||||
networkConfig schemas.NetworkConfig // Network configuration including extra headers
|
||||
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
|
||||
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
|
||||
customProviderConfig *schemas.CustomProviderConfig // Custom provider config
|
||||
}
|
||||
|
||||
// NewElevenlabsProvider creates a new Elevenlabs provider instance.
|
||||
// It initializes the HTTP client with the provided configuration.
|
||||
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
|
||||
func NewElevenlabsProvider(config *schemas.ProviderConfig, logger schemas.Logger) *ElevenlabsProvider {
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
|
||||
client := &fasthttp.Client{
|
||||
ReadTimeout: requestTimeout,
|
||||
WriteTimeout: requestTimeout,
|
||||
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
|
||||
MaxIdleConnDuration: 30 * time.Second,
|
||||
MaxConnWaitTimeout: requestTimeout,
|
||||
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
|
||||
ConnPoolStrategy: fasthttp.FIFO,
|
||||
}
|
||||
|
||||
// Configure proxy and retry policy
|
||||
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
|
||||
client = providerUtils.ConfigureDialer(client)
|
||||
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
|
||||
streamingClient := providerUtils.BuildStreamingClient(client)
|
||||
// Set default BaseURL if not provided
|
||||
if config.NetworkConfig.BaseURL == "" {
|
||||
config.NetworkConfig.BaseURL = "https://api.elevenlabs.io"
|
||||
}
|
||||
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
|
||||
|
||||
return &ElevenlabsProvider{
|
||||
logger: logger,
|
||||
client: client,
|
||||
streamingClient: streamingClient,
|
||||
networkConfig: config.NetworkConfig,
|
||||
customProviderConfig: config.CustomProviderConfig,
|
||||
sendBackRawRequest: config.SendBackRawRequest,
|
||||
sendBackRawResponse: config.SendBackRawResponse,
|
||||
}
|
||||
}
|
||||
|
||||
// GetProviderKey returns the provider identifier for Elevenlabs.
|
||||
func (provider *ElevenlabsProvider) GetProviderKey() schemas.ModelProvider {
|
||||
return providerUtils.GetProviderName(schemas.Elevenlabs, provider.customProviderConfig)
|
||||
}
|
||||
|
||||
// listModelsByKey performs a list models request for a single key.
|
||||
// Returns the response and latency, or an error if the request fails.
|
||||
func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
// Create request
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Set any extra headers from network config
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
// Build URL using centralized URL construction
|
||||
req.SetRequestURI(provider.networkConfig.BaseURL + providerUtils.GetPathFromContext(ctx, "/v1/models"))
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
req.Header.SetContentType("application/json")
|
||||
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("xi-api-key", key.Value.GetValue())
|
||||
}
|
||||
|
||||
// Make request
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
// Extract and set provider response headers so they're available on error paths
|
||||
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
return nil, parseElevenlabsError(resp)
|
||||
}
|
||||
|
||||
var elevenlabsResponse ElevenlabsListModelsResponse
|
||||
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(resp.Body(), &elevenlabsResponse, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
response := elevenlabsResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered)
|
||||
|
||||
response.ExtraFields.Latency = latency.Milliseconds()
|
||||
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
||||
|
||||
// Set raw request if enabled
|
||||
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
||||
response.ExtraFields.RawRequest = rawRequest
|
||||
}
|
||||
|
||||
// Set raw response if enabled
|
||||
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
||||
response.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ListModels performs a list models request to Elevenlabs' API.
|
||||
// Requests are made concurrently for improved performance.
|
||||
func (provider *ElevenlabsProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.ListModelsRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return providerUtils.HandleMultipleListModelsRequests(
|
||||
ctx,
|
||||
keys,
|
||||
request,
|
||||
provider.listModelsByKey,
|
||||
)
|
||||
}
|
||||
|
||||
// TextCompletion is not supported by the Elevenlabs provider
|
||||
func (provider *ElevenlabsProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// TextCompletionStream is not supported by the Elevenlabs provider
|
||||
func (provider *ElevenlabsProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ChatCompletion is not supported by the Elevenlabs provider
|
||||
func (provider *ElevenlabsProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ChatCompletionRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ChatCompletionStream is not supported by the Elevenlabs provider
|
||||
func (provider *ElevenlabsProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ChatCompletionStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Responses is not supported by the Elevenlabs provider
|
||||
func (provider *ElevenlabsProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ResponsesRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ResponsesStream is not supported by the Elevenlabs provider
|
||||
func (provider *ElevenlabsProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ResponsesStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Embedding is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, input *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Speech performs a text to speech request
|
||||
func (provider *ElevenlabsProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.SpeechRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create request
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Set any extra headers from network config
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
withTimestampsRequest := request.Params != nil && request.Params.WithTimestamps != nil && *request.Params.WithTimestamps
|
||||
|
||||
var endpoint string
|
||||
if request.Params != nil && request.Params.VoiceConfig != nil && request.Params.VoiceConfig.Voice != nil {
|
||||
voice := *request.Params.VoiceConfig.Voice
|
||||
// Determine if timestamps are requested
|
||||
if withTimestampsRequest {
|
||||
endpoint = "/v1/text-to-speech/" + voice + "/with-timestamps"
|
||||
} else {
|
||||
endpoint = "/v1/text-to-speech/" + voice
|
||||
}
|
||||
} else {
|
||||
return nil, providerUtils.NewBifrostOperationError("voice parameter is required", nil)
|
||||
}
|
||||
|
||||
requestURL := provider.buildBaseSpeechRequestURL(ctx, endpoint, schemas.SpeechRequest, request)
|
||||
req.SetRequestURI(requestURL)
|
||||
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.Header.SetContentType("application/json")
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("xi-api-key", key.Value.GetValue())
|
||||
}
|
||||
|
||||
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
||||
ctx,
|
||||
request,
|
||||
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
||||
return ToElevenlabsSpeechRequest(request), nil
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
if !providerUtils.ApplyLargePayloadRequestBody(ctx, req) {
|
||||
req.SetBody(jsonData)
|
||||
}
|
||||
|
||||
// Make request
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
||||
}
|
||||
// Extract and set provider response headers so they're available on error paths
|
||||
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
||||
}
|
||||
|
||||
// Get the response body
|
||||
body, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
||||
}
|
||||
|
||||
// Create response based on whether timestamps were requested
|
||||
bifrostResponse := &schemas.BifrostSpeechResponse{
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Latency: latency.Milliseconds(),
|
||||
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
|
||||
},
|
||||
}
|
||||
|
||||
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
||||
providerUtils.ParseAndSetRawRequest(&bifrostResponse.ExtraFields, jsonData)
|
||||
}
|
||||
|
||||
if withTimestampsRequest {
|
||||
var timestampResponse ElevenlabsSpeechWithTimestampsResponse
|
||||
if err := sonic.Unmarshal(body, ×tampResponse); err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("failed to parse with-timestamps response", err)
|
||||
}
|
||||
|
||||
bifrostResponse.AudioBase64 = ×tampResponse.AudioBase64
|
||||
|
||||
if timestampResponse.Alignment != nil {
|
||||
bifrostResponse.Alignment = &schemas.SpeechAlignment{
|
||||
CharStartTimesMs: timestampResponse.Alignment.CharStartTimesMs,
|
||||
CharEndTimesMs: timestampResponse.Alignment.CharEndTimesMs,
|
||||
Characters: timestampResponse.Alignment.Characters,
|
||||
}
|
||||
}
|
||||
|
||||
if timestampResponse.NormalizedAlignment != nil {
|
||||
bifrostResponse.NormalizedAlignment = &schemas.SpeechAlignment{
|
||||
CharStartTimesMs: timestampResponse.NormalizedAlignment.CharStartTimesMs,
|
||||
CharEndTimesMs: timestampResponse.NormalizedAlignment.CharEndTimesMs,
|
||||
Characters: timestampResponse.NormalizedAlignment.Characters,
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse, nil
|
||||
}
|
||||
|
||||
bifrostResponse.Audio = body
|
||||
return bifrostResponse, nil
|
||||
}
|
||||
|
||||
// Rerank is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// OCR is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// SpeechStream performs a text to speech stream request
|
||||
func (provider *ElevenlabsProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.SpeechStreamRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
||||
ctx,
|
||||
request,
|
||||
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
||||
return ToElevenlabsSpeechRequest(request), nil
|
||||
})
|
||||
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
// Create HTTP request for streaming
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
resp.StreamBody = true
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
|
||||
// Set any extra headers from network config
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
if request.Params == nil || request.Params.VoiceConfig == nil || request.Params.VoiceConfig.Voice == nil {
|
||||
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("voice parameter is required", nil), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
||||
}
|
||||
|
||||
req.SetRequestURI(provider.buildBaseSpeechRequestURL(ctx, "/v1/text-to-speech/"+*request.Params.VoiceConfig.Voice+"/stream", schemas.SpeechStreamRequest, request))
|
||||
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.Header.SetContentType("application/json")
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("xi-api-key", key.Value.GetValue())
|
||||
}
|
||||
|
||||
if !providerUtils.ApplyLargePayloadRequestBody(ctx, req) {
|
||||
req.SetBody(jsonBody)
|
||||
}
|
||||
|
||||
// Make request
|
||||
startTime := time.Now()
|
||||
err := provider.streamingClient.Do(req, resp)
|
||||
if err != nil {
|
||||
defer providerUtils.ReleaseStreamingResponse(resp)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, providerUtils.EnrichError(ctx, &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr(schemas.RequestCancelled),
|
||||
Message: schemas.ErrRequestCancelled,
|
||||
Error: err,
|
||||
},
|
||||
}, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
||||
}
|
||||
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
||||
}
|
||||
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
||||
}
|
||||
|
||||
// Extract provider response headers before status check so error responses also forward them
|
||||
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
||||
|
||||
// Check for HTTP errors
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
defer providerUtils.ReleaseStreamingResponse(resp)
|
||||
return nil, providerUtils.EnrichError(ctx, parseElevenlabsError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
||||
}
|
||||
|
||||
// Create response channel
|
||||
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
||||
|
||||
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if ctx.Err() == context.Canceled {
|
||||
providerUtils.HandleStreamCancellation(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
||||
} else if ctx.Err() == context.DeadlineExceeded {
|
||||
providerUtils.HandleStreamTimeout(ctx, postHookRunner, responseChan, provider.logger, postHookSpanFinalizer)
|
||||
}
|
||||
close(responseChan)
|
||||
}()
|
||||
defer providerUtils.ReleaseStreamingResponse(resp)
|
||||
// Decompress gzip-encoded streams transparently (no-op for non-gzip)
|
||||
reader, releaseGzip := providerUtils.DecompressStreamBody(resp)
|
||||
defer releaseGzip()
|
||||
|
||||
// Wrap reader with idle timeout to detect stalled streams.
|
||||
reader, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(reader, resp.BodyStream(), providerUtils.GetStreamIdleTimeout(ctx))
|
||||
defer stopIdleTimeout()
|
||||
|
||||
// Setup cancellation handler to close the raw network stream on ctx cancellation,
|
||||
// which immediately unblocks any in-progress read (including reads blocked inside a gzip decompression layer).
|
||||
stopCancellation := providerUtils.SetupStreamCancellation(ctx, resp.BodyStream(), provider.logger)
|
||||
defer stopCancellation()
|
||||
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
||||
|
||||
// read binary audio chunks from the stream
|
||||
// 4KB buffer for reading chunks
|
||||
buffer := make([]byte, 4096)
|
||||
bodyStream := reader
|
||||
chunkIndex := -1
|
||||
lastChunkTime := time.Now()
|
||||
|
||||
for {
|
||||
// If context was cancelled/timed out, let defer handle it
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
n, err := bodyStream.Read(buffer)
|
||||
if err != nil {
|
||||
// If context was cancelled/timed out, let defer handle it
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
provider.logger.Warn("Error reading stream: %v", err)
|
||||
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger, postHookSpanFinalizer)
|
||||
return
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
chunkIndex++
|
||||
audioChunk := make([]byte, n)
|
||||
copy(audioChunk, buffer[:n])
|
||||
|
||||
response := &schemas.BifrostSpeechStreamResponse{
|
||||
Type: schemas.SpeechStreamResponseTypeDelta,
|
||||
Audio: audioChunk,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
ChunkIndex: chunkIndex,
|
||||
Latency: time.Since(lastChunkTime).Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
lastChunkTime = time.Now()
|
||||
|
||||
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
||||
response.ExtraFields.RawResponse = audioChunk
|
||||
}
|
||||
|
||||
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, response, nil, nil), responseChan, postHookSpanFinalizer)
|
||||
}
|
||||
}
|
||||
|
||||
// Send final response after natural loop termination (similar to Gemini pattern)
|
||||
finalResponse := &schemas.BifrostSpeechStreamResponse{
|
||||
Type: schemas.SpeechStreamResponseTypeDone,
|
||||
Audio: []byte{},
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
ChunkIndex: chunkIndex + 1,
|
||||
Latency: time.Since(startTime).Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
// Set raw request if enabled
|
||||
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
||||
providerUtils.ParseAndSetRawRequest(&finalResponse.ExtraFields, jsonBody)
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
||||
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, nil, finalResponse, nil, nil), responseChan, postHookSpanFinalizer)
|
||||
}()
|
||||
|
||||
return responseChan, nil
|
||||
}
|
||||
|
||||
// Transcription performs a transcription request
|
||||
func (provider *ElevenlabsProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
if err := providerUtils.CheckOperationAllowed(schemas.Elevenlabs, provider.customProviderConfig, schemas.TranscriptionRequest); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqBody := ToElevenlabsTranscriptionRequest(request)
|
||||
if reqBody == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("transcription request is not provided", nil)
|
||||
}
|
||||
|
||||
hasFile := len(reqBody.File) > 0
|
||||
hasURL := reqBody.CloudStorageURL != nil && strings.TrimSpace(*reqBody.CloudStorageURL) != ""
|
||||
if hasFile && hasURL {
|
||||
return nil, providerUtils.NewBifrostOperationError("provide either a file or cloud_storage_url, not both", nil)
|
||||
}
|
||||
if !hasFile && !hasURL {
|
||||
return nil, providerUtils.NewBifrostOperationError("either a transcription file or cloud_storage_url must be provided", nil)
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
|
||||
if bifrostErr := writeTranscriptionMultipart(writer, reqBody); bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
|
||||
contentType := writer.FormDataContentType()
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("failed to finalize multipart transcription request", err)
|
||||
}
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
|
||||
requestPath, isCompleteURL := providerUtils.GetRequestPath(ctx, "/v1/speech-to-text", provider.customProviderConfig, schemas.TranscriptionRequest)
|
||||
if isCompleteURL {
|
||||
req.SetRequestURI(requestPath)
|
||||
} else {
|
||||
req.SetRequestURI(provider.networkConfig.BaseURL + requestPath)
|
||||
}
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.Header.SetContentType(contentType)
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("xi-api-key", key.Value.GetValue())
|
||||
}
|
||||
req.SetBody(body.Bytes())
|
||||
|
||||
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, bifrostErr
|
||||
}
|
||||
// Extract and set provider response headers so they're available on error paths
|
||||
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
return nil, parseElevenlabsError(resp)
|
||||
}
|
||||
|
||||
responseBody, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
}
|
||||
|
||||
// Check for empty response
|
||||
trimmed := strings.TrimSpace(string(responseBody))
|
||||
if len(trimmed) == 0 {
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: schemas.ErrProviderResponseEmpty,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
chunks, err := parseTranscriptionResponse(responseBody)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(err.Error(), nil)
|
||||
}
|
||||
|
||||
if len(chunks) == 0 {
|
||||
return nil, providerUtils.NewBifrostOperationError("no chunks found in transcription response", nil)
|
||||
}
|
||||
|
||||
response := ToBifrostTranscriptionResponse(chunks)
|
||||
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
||||
Latency: latency.Milliseconds(),
|
||||
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
|
||||
}
|
||||
|
||||
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
||||
var rawResponse interface{}
|
||||
if err := sonic.Unmarshal(responseBody, &rawResponse); err != nil {
|
||||
rawResponse = string(responseBody)
|
||||
}
|
||||
response.ExtraFields.RawResponse = rawResponse
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func writeTranscriptionMultipart(writer *multipart.Writer, reqBody *ElevenlabsTranscriptionRequest) *schemas.BifrostError {
|
||||
if err := writer.WriteField("model_id", reqBody.ModelID); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write model_id field", err)
|
||||
}
|
||||
|
||||
if len(reqBody.File) > 0 {
|
||||
filename := reqBody.Filename
|
||||
if filename == "" {
|
||||
filename = providerUtils.AudioFilenameFromBytes(reqBody.File)
|
||||
}
|
||||
fileWriter, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to create file field", err)
|
||||
}
|
||||
if _, err := fileWriter.Write(reqBody.File); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write file data", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.CloudStorageURL != nil && strings.TrimSpace(*reqBody.CloudStorageURL) != "" {
|
||||
if err := writer.WriteField("cloud_storage_url", *reqBody.CloudStorageURL); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write cloud_storage_url field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.LanguageCode != nil && strings.TrimSpace(*reqBody.LanguageCode) != "" {
|
||||
if err := writer.WriteField("language_code", *reqBody.LanguageCode); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write language_code field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.TagAudioEvents != nil {
|
||||
if err := writer.WriteField("tag_audio_events", strconv.FormatBool(*reqBody.TagAudioEvents)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write tag_audio_events field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.NumSpeakers != nil && *reqBody.NumSpeakers > 0 {
|
||||
if err := writer.WriteField("num_speakers", strconv.Itoa(*reqBody.NumSpeakers)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write num_speakers field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.TimestampsGranularity != nil && *reqBody.TimestampsGranularity != "" {
|
||||
if err := writer.WriteField("timestamps_granularity", string(*reqBody.TimestampsGranularity)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write timestamps_granularity field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.Diarize != nil {
|
||||
if err := writer.WriteField("diarize", strconv.FormatBool(*reqBody.Diarize)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write diarize field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.DiarizationThreshold != nil {
|
||||
if err := writer.WriteField("diarization_threshold", strconv.FormatFloat(*reqBody.DiarizationThreshold, 'f', -1, 64)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write diarization_threshold field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(reqBody.AdditionalFormats) > 0 {
|
||||
payload, err := providerUtils.MarshalSorted(reqBody.AdditionalFormats)
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to marshal additional_formats", err)
|
||||
}
|
||||
if err := writer.WriteField("additional_formats", string(payload)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write additional_formats field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.FileFormat != nil && *reqBody.FileFormat != "" {
|
||||
if err := writer.WriteField("file_format", string(*reqBody.FileFormat)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write file_format field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.Webhook != nil {
|
||||
if err := writer.WriteField("webhook", strconv.FormatBool(*reqBody.Webhook)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write webhook field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.WebhookID != nil && strings.TrimSpace(*reqBody.WebhookID) != "" {
|
||||
if err := writer.WriteField("webhook_id", *reqBody.WebhookID); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write webhook_id field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.Temperature != nil {
|
||||
if err := writer.WriteField("temperature", strconv.FormatFloat(*reqBody.Temperature, 'f', -1, 64)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write temperature field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.Seed != nil {
|
||||
if err := writer.WriteField("seed", strconv.Itoa(*reqBody.Seed)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write seed field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.UseMultiChannel != nil {
|
||||
if err := writer.WriteField("use_multi_channel", strconv.FormatBool(*reqBody.UseMultiChannel)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write use_multi_channel field", err)
|
||||
}
|
||||
}
|
||||
|
||||
if reqBody.WebhookMetadata != nil {
|
||||
switch v := reqBody.WebhookMetadata.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) != "" {
|
||||
if err := writer.WriteField("webhook_metadata", v); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err)
|
||||
}
|
||||
}
|
||||
default:
|
||||
payload, err := providerUtils.MarshalSorted(v)
|
||||
if err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to marshal webhook_metadata", err)
|
||||
}
|
||||
if err := writer.WriteField("webhook_metadata", string(payload)); err != nil {
|
||||
return providerUtils.NewBifrostOperationError("failed to write webhook_metadata field", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TranscriptionStream is not supported by the Elevenlabs provider
|
||||
func (provider *ElevenlabsProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGeneration is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGenerationStream is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEdit is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEditStream is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageVariation is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoGeneration is not supported by the ElevenLabs provider.
|
||||
func (provider *ElevenlabsProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRetrieve is not supported by the ElevenLabs provider.
|
||||
func (provider *ElevenlabsProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDownload is not supported by the ElevenLabs provider.
|
||||
func (provider *ElevenlabsProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDelete is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoList is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRemix is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// buildSpeechRequestURL constructs the full request URL using the provider's configuration for speech.
|
||||
func (provider *ElevenlabsProvider) buildBaseSpeechRequestURL(ctx *schemas.BifrostContext, defaultPath string, requestType schemas.RequestType, request *schemas.BifrostSpeechRequest) string {
|
||||
baseURL := provider.networkConfig.BaseURL
|
||||
requestPath, isCompleteURL := providerUtils.GetRequestPath(ctx, defaultPath, provider.customProviderConfig, requestType)
|
||||
|
||||
var finalURL string
|
||||
if isCompleteURL {
|
||||
finalURL = requestPath
|
||||
} else {
|
||||
u, parseErr := url.Parse(baseURL)
|
||||
if parseErr != nil {
|
||||
finalURL = baseURL + requestPath
|
||||
} else {
|
||||
u.Path = path.Join(u.Path, requestPath)
|
||||
finalURL = u.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the final URL to add query parameters
|
||||
u, parseErr := url.Parse(finalURL)
|
||||
if parseErr != nil {
|
||||
return finalURL
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
|
||||
if request.Params != nil {
|
||||
if request.Params.EnableLogging != nil {
|
||||
q.Set("enable_logging", strconv.FormatBool(*request.Params.EnableLogging))
|
||||
}
|
||||
|
||||
convertedFormat := ConvertBifrostSpeechFormatToElevenlabs(request.Params.ResponseFormat)
|
||||
if convertedFormat != "" {
|
||||
q.Set("output_format", convertedFormat)
|
||||
}
|
||||
|
||||
if request.Params.OptimizeStreamingLatency != nil {
|
||||
q.Set("optimize_streaming_latency", strconv.FormatBool(*request.Params.OptimizeStreamingLatency))
|
||||
}
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// BatchCreate is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchList is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchRetrieve is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCancel is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchDelete is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchResults is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileUpload is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileList is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileRetrieve is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileDelete is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileContent is not supported by Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// CountTokens is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerCreate is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerList is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerRetrieve is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerDelete is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileCreate is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileList is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileRetrieve is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileContent is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileDelete is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Passthrough is not supported by the Elevenlabs provider.
|
||||
func (provider *ElevenlabsProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
func (provider *ElevenlabsProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
62
core/providers/elevenlabs/elevenlabs_test.go
Normal file
62
core/providers/elevenlabs/elevenlabs_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package elevenlabs_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestElevenlabs(t *testing.T) {
|
||||
t.Parallel()
|
||||
if strings.TrimSpace(os.Getenv("ELEVENLABS_API_KEY")) == "" {
|
||||
t.Skip("Skipping Elevenlabs tests because ELEVENLABS_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
realtimeAgentID := strings.TrimSpace(os.Getenv("ELEVENLABS_AGENT_ID"))
|
||||
hasRealtimeAgent := false
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.Elevenlabs,
|
||||
SpeechSynthesisModel: "eleven_turbo_v2_5",
|
||||
TranscriptionModel: "scribe_v1",
|
||||
RealtimeModel: realtimeAgentID,
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: false,
|
||||
TextCompletionStream: false,
|
||||
SimpleChat: false,
|
||||
CompletionStream: false,
|
||||
MultiTurnConversation: false,
|
||||
ToolCalls: false,
|
||||
MultipleToolCalls: false,
|
||||
End2EndToolCalling: false,
|
||||
AutomaticFunctionCall: false,
|
||||
ImageURL: false,
|
||||
ImageBase64: false,
|
||||
MultipleImages: false,
|
||||
CompleteEnd2End: false,
|
||||
SpeechSynthesis: true,
|
||||
SpeechSynthesisStream: true,
|
||||
Transcription: true,
|
||||
TranscriptionStream: false,
|
||||
Embedding: false,
|
||||
Reasoning: false,
|
||||
ListModels: false,
|
||||
Realtime: hasRealtimeAgent,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("ElevenlabsTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
90
core/providers/elevenlabs/errors.go
Normal file
90
core/providers/elevenlabs/errors.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func parseElevenlabsError(resp *fasthttp.Response) *schemas.BifrostError {
|
||||
var errorResp ElevenlabsError
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
|
||||
if errorResp.Detail != nil {
|
||||
var message string
|
||||
// Handle validation errors (array format)
|
||||
if len(errorResp.Detail.ValidationErrors) > 0 {
|
||||
var messages []string
|
||||
var locations []string
|
||||
var errorTypes []string
|
||||
|
||||
for _, validationErr := range errorResp.Detail.ValidationErrors {
|
||||
// Get message from either Message or Msg field
|
||||
msg := validationErr.Message
|
||||
if msg == "" {
|
||||
msg = validationErr.Msg
|
||||
}
|
||||
if msg != "" {
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
// Collect location if available
|
||||
if len(validationErr.Loc) > 0 {
|
||||
locations = append(locations, strings.Join(validationErr.Loc, "."))
|
||||
}
|
||||
|
||||
// Collect error type if available
|
||||
if validationErr.Type != "" {
|
||||
errorTypes = append(errorTypes, validationErr.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Build combined message
|
||||
if len(messages) > 0 {
|
||||
message = strings.Join(messages, "; ")
|
||||
}
|
||||
if len(locations) > 0 {
|
||||
locationStr := strings.Join(locations, ", ")
|
||||
message = message + " [" + locationStr + "]"
|
||||
}
|
||||
|
||||
errorType := ""
|
||||
if len(errorTypes) > 0 {
|
||||
errorType = strings.Join(errorTypes, ", ")
|
||||
}
|
||||
|
||||
if message != "" {
|
||||
result := &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: schemas.Ptr(resp.StatusCode()),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: schemas.Ptr(errorType),
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Handle non-validation errors (single object format)
|
||||
if errorResp.Detail.Message != nil {
|
||||
message = *errorResp.Detail.Message
|
||||
}
|
||||
|
||||
errorType := ""
|
||||
if errorResp.Detail.Status != nil {
|
||||
errorType = *errorResp.Detail.Status
|
||||
}
|
||||
|
||||
if message != "" {
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
bifrostErr.Error.Type = schemas.Ptr(errorType)
|
||||
bifrostErr.Error.Message = message
|
||||
}
|
||||
}
|
||||
return bifrostErr
|
||||
}
|
||||
51
core/providers/elevenlabs/models.go
Normal file
51
core/providers/elevenlabs/models.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(*response)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: providerKey,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range *response {
|
||||
for _, result := range pipeline.FilterModel(model.ModelID) {
|
||||
entry := schemas.Model{
|
||||
ID: string(providerKey) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(model.Name),
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
entry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, entry)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
257
core/providers/elevenlabs/realtime.go
Normal file
257
core/providers/elevenlabs/realtime.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
)
|
||||
|
||||
// SupportsRealtimeAPI returns true since ElevenLabs supports Conversational AI via WebSocket.
|
||||
func (provider *ElevenlabsProvider) SupportsRealtimeAPI() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// RealtimeWebSocketURL returns the WSS URL for the ElevenLabs Conversational AI endpoint.
|
||||
// The model parameter is used as the agent_id query parameter.
|
||||
// Format: wss://api.elevenlabs.io/v1/convai/conversation?agent_id=<model>
|
||||
func (provider *ElevenlabsProvider) RealtimeWebSocketURL(key schemas.Key, model string) string {
|
||||
base := provider.networkConfig.BaseURL
|
||||
base = strings.Replace(base, "https://", "wss://", 1)
|
||||
base = strings.Replace(base, "http://", "ws://", 1)
|
||||
return base + "/v1/convai/conversation?agent_id=" + model
|
||||
}
|
||||
|
||||
// RealtimeHeaders returns the headers required for the ElevenLabs Conversational AI WebSocket.
|
||||
func (provider *ElevenlabsProvider) RealtimeHeaders(key schemas.Key) map[string]string {
|
||||
headers := map[string]string{
|
||||
"xi-api-key": key.Value.GetValue(),
|
||||
}
|
||||
for k, v := range provider.networkConfig.ExtraHeaders {
|
||||
if strings.EqualFold(k, "xi-api-key") {
|
||||
continue
|
||||
}
|
||||
headers[k] = v
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// SupportsRealtimeWebRTC returns false — ElevenLabs WebRTC SDP exchange is not yet implemented.
|
||||
func (provider *ElevenlabsProvider) SupportsRealtimeWebRTC() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ExchangeRealtimeWebRTCSDP is not yet implemented for ElevenLabs.
|
||||
func (provider *ElevenlabsProvider) ExchangeRealtimeWebRTCSDP(_ *schemas.BifrostContext, _ schemas.Key, _ string, _ string, _ json.RawMessage) (string, *schemas.BifrostError) {
|
||||
return "", &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
StatusCode: schemas.Ptr(400),
|
||||
Error: &schemas.ErrorField{Type: schemas.Ptr("invalid_request_error"), Message: "WebRTC SDP exchange is not yet implemented for ElevenLabs"},
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *ElevenlabsProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (provider *ElevenlabsProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType {
|
||||
return schemas.RTEventResponseDone
|
||||
}
|
||||
|
||||
func (provider *ElevenlabsProvider) RealtimeWebRTCDataChannelLabel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (provider *ElevenlabsProvider) RealtimeWebSocketSubprotocol() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (provider *ElevenlabsProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (provider *ElevenlabsProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool {
|
||||
return eventType == schemas.RTEventResponseDone
|
||||
}
|
||||
|
||||
// ElevenLabs Conversational AI WebSocket event types
|
||||
const (
|
||||
elConversationInitMetadata = "conversation_initiation_metadata"
|
||||
elPing = "ping"
|
||||
elAudio = "audio"
|
||||
elUserTranscript = "user_transcript"
|
||||
elAgentResponse = "agent_response"
|
||||
elAgentResponseCorrection = "agent_response_correction"
|
||||
elInterruption = "interruption"
|
||||
elClientToolCall = "client_tool_call"
|
||||
|
||||
elUserAudioChunk = "user_audio_chunk"
|
||||
elPong = "pong"
|
||||
elClientToolResult = "client_tool_result"
|
||||
elContextualUpdate = "contextual_update"
|
||||
)
|
||||
|
||||
// elevenlabsEvent represents a raw ElevenLabs Conversational AI WebSocket event.
|
||||
type elevenlabsEvent struct {
|
||||
Type string `json:"type"`
|
||||
|
||||
// Server events
|
||||
ConversationInitMetadata json.RawMessage `json:"conversation_initiation_metadata_event,omitempty"`
|
||||
Audio json.RawMessage `json:"audio_event,omitempty"`
|
||||
UserTranscript json.RawMessage `json:"user_transcription_event,omitempty"`
|
||||
AgentResponse json.RawMessage `json:"agent_response_event,omitempty"`
|
||||
AgentResponseCorrection json.RawMessage `json:"agent_response_correction_event,omitempty"`
|
||||
ClientToolCall json.RawMessage `json:"client_tool_call,omitempty"`
|
||||
PingEvent json.RawMessage `json:"ping_event,omitempty"`
|
||||
|
||||
// Client events
|
||||
UserAudioChunk json.RawMessage `json:"user_audio_chunk,omitempty"`
|
||||
}
|
||||
|
||||
// elevenlabsAudioEvent is the audio event structure from ElevenLabs.
|
||||
type elevenlabsAudioEvent struct {
|
||||
Audio string `json:"audio_base_64,omitempty"`
|
||||
Alignment json.RawMessage `json:"alignment,omitempty"`
|
||||
}
|
||||
|
||||
// elevenlabsTranscriptEvent is the user/agent transcript event from ElevenLabs.
|
||||
type elevenlabsTranscriptEvent struct {
|
||||
UserTranscript string `json:"user_transcript,omitempty"`
|
||||
AgentResponse string `json:"agent_response,omitempty"`
|
||||
AgentResponseID string `json:"agent_response_id,omitempty"`
|
||||
}
|
||||
|
||||
// elevenlabsCorrectionEvent is the agent response correction event from ElevenLabs.
|
||||
type elevenlabsCorrectionEvent struct {
|
||||
OriginalAgentResponse string `json:"original_agent_response,omitempty"`
|
||||
CorrectedAgentResponse string `json:"corrected_agent_response,omitempty"`
|
||||
}
|
||||
|
||||
// ToBifrostRealtimeEvent converts an ElevenLabs Conversational AI event to the unified Bifrost format.
|
||||
func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.RawMessage) (*schemas.BifrostRealtimeEvent, error) {
|
||||
var raw elevenlabsEvent
|
||||
if err := json.Unmarshal(providerEvent, &raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal ElevenLabs realtime event: %w", err)
|
||||
}
|
||||
|
||||
event := &schemas.BifrostRealtimeEvent{
|
||||
RawData: providerEvent,
|
||||
}
|
||||
|
||||
switch raw.Type {
|
||||
case elConversationInitMetadata:
|
||||
event.Type = schemas.RTEventSessionCreated
|
||||
event.Session = &schemas.RealtimeSession{}
|
||||
|
||||
case elPing:
|
||||
event.Type = schemas.RealtimeEventType("ping")
|
||||
|
||||
case elAudio:
|
||||
event.Type = schemas.RTEventResponseAudioDelta
|
||||
if raw.Audio != nil {
|
||||
var audioEvt elevenlabsAudioEvent
|
||||
if err := json.Unmarshal(raw.Audio, &audioEvt); err == nil {
|
||||
event.Delta = &schemas.RealtimeDelta{
|
||||
Audio: audioEvt.Audio,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case elUserTranscript:
|
||||
event.Type = schemas.RTEventInputAudioTransCompleted
|
||||
if raw.UserTranscript != nil {
|
||||
var transcript elevenlabsTranscriptEvent
|
||||
if err := json.Unmarshal(raw.UserTranscript, &transcript); err == nil {
|
||||
event.Delta = &schemas.RealtimeDelta{
|
||||
Transcript: transcript.UserTranscript,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case elAgentResponse:
|
||||
event.Type = schemas.RTEventResponseDone
|
||||
if raw.AgentResponse != nil {
|
||||
var agentResp elevenlabsTranscriptEvent
|
||||
if err := json.Unmarshal(raw.AgentResponse, &agentResp); err == nil {
|
||||
event.Delta = &schemas.RealtimeDelta{
|
||||
Text: agentResp.AgentResponse,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case elAgentResponseCorrection:
|
||||
event.Type = schemas.RTEventResponseTextDelta
|
||||
if raw.AgentResponseCorrection != nil {
|
||||
var correction elevenlabsCorrectionEvent
|
||||
if err := json.Unmarshal(raw.AgentResponseCorrection, &correction); err == nil {
|
||||
event.Delta = &schemas.RealtimeDelta{
|
||||
Text: correction.CorrectedAgentResponse,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case elInterruption:
|
||||
event.Type = schemas.RTEventResponseCancel
|
||||
|
||||
case elClientToolCall:
|
||||
event.Type = schemas.RealtimeEventType("client_tool_call")
|
||||
if raw.ClientToolCall != nil {
|
||||
var toolCall struct {
|
||||
ToolName string `json:"tool_name"`
|
||||
Parameters json.RawMessage `json:"parameters"`
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
}
|
||||
if err := json.Unmarshal(raw.ClientToolCall, &toolCall); err == nil {
|
||||
args := string(toolCall.Parameters)
|
||||
if len(toolCall.Parameters) > 0 {
|
||||
var parsed interface{}
|
||||
if err := json.Unmarshal(toolCall.Parameters, &parsed); err == nil {
|
||||
if sorted, err := providerUtils.MarshalSorted(parsed); err == nil {
|
||||
args = string(sorted)
|
||||
}
|
||||
}
|
||||
}
|
||||
event.Item = &schemas.RealtimeItem{
|
||||
Type: "function_call",
|
||||
Name: toolCall.ToolName,
|
||||
CallID: toolCall.ToolCallID,
|
||||
Arguments: args,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
event.Type = schemas.RealtimeEventType(raw.Type)
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
|
||||
// ToProviderRealtimeEvent converts a unified Bifrost Realtime event to ElevenLabs' native JSON.
|
||||
func (provider *ElevenlabsProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) {
|
||||
switch bifrostEvent.Type {
|
||||
case schemas.RTEventInputAudioAppend:
|
||||
if bifrostEvent.Delta == nil {
|
||||
return nil, fmt.Errorf("delta must be set for input_audio_buffer.append events")
|
||||
}
|
||||
out := map[string]interface{}{
|
||||
"type": elUserAudioChunk,
|
||||
"user_audio_chunk": bifrostEvent.Delta.Audio,
|
||||
}
|
||||
return schemas.MarshalSorted(out)
|
||||
|
||||
case schemas.RealtimeEventType("pong"):
|
||||
return schemas.MarshalSorted(map[string]interface{}{
|
||||
"type": "pong",
|
||||
})
|
||||
|
||||
default:
|
||||
out := map[string]interface{}{
|
||||
"type": string(bifrostEvent.Type),
|
||||
}
|
||||
return schemas.MarshalSorted(out)
|
||||
}
|
||||
}
|
||||
102
core/providers/elevenlabs/speech.go
Normal file
102
core/providers/elevenlabs/speech.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func ToElevenlabsSpeechRequest(bifrostReq *schemas.BifrostSpeechRequest) *ElevenlabsSpeechRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
elevenlabsReq := &ElevenlabsSpeechRequest{
|
||||
ModelID: bifrostReq.Model,
|
||||
Text: bifrostReq.Input.Input,
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
elevenlabsReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
voiceSettings := ElevenlabsVoiceSettings{}
|
||||
hasVoiceSettings := false
|
||||
|
||||
if bifrostReq.Params.Speed != nil {
|
||||
voiceSettings.Speed = *bifrostReq.Params.Speed
|
||||
hasVoiceSettings = true
|
||||
}
|
||||
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
if stability, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["stability"]); ok {
|
||||
delete(elevenlabsReq.ExtraParams, "stability")
|
||||
voiceSettings.Stability = *stability
|
||||
hasVoiceSettings = true
|
||||
}
|
||||
if useSpeakerBoost, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["use_speaker_boost"]); ok {
|
||||
delete(elevenlabsReq.ExtraParams, "use_speaker_boost")
|
||||
voiceSettings.UseSpeakerBoost = *useSpeakerBoost
|
||||
hasVoiceSettings = true
|
||||
}
|
||||
if similarityBoost, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["similarity_boost"]); ok {
|
||||
delete(elevenlabsReq.ExtraParams, "similarity_boost")
|
||||
voiceSettings.SimilarityBoost = *similarityBoost
|
||||
hasVoiceSettings = true
|
||||
}
|
||||
if style, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["style"]); ok {
|
||||
delete(elevenlabsReq.ExtraParams, "style")
|
||||
voiceSettings.Style = *style
|
||||
hasVoiceSettings = true
|
||||
}
|
||||
if seed, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["seed"]); ok {
|
||||
delete(elevenlabsReq.ExtraParams, "seed")
|
||||
elevenlabsReq.Seed = seed
|
||||
}
|
||||
if previousText, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["previous_text"]); ok {
|
||||
delete(elevenlabsReq.ExtraParams, "previous_text")
|
||||
elevenlabsReq.PreviousText = previousText
|
||||
}
|
||||
if nextText, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["next_text"]); ok {
|
||||
delete(elevenlabsReq.ExtraParams, "next_text")
|
||||
elevenlabsReq.NextText = nextText
|
||||
}
|
||||
if previousRequestIDs, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["previous_request_ids"]); ok {
|
||||
delete(elevenlabsReq.ExtraParams, "previous_request_ids")
|
||||
elevenlabsReq.PreviousRequestIDs = previousRequestIDs
|
||||
}
|
||||
if nextRequestIDs, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["next_request_ids"]); ok {
|
||||
delete(elevenlabsReq.ExtraParams, "next_request_ids")
|
||||
elevenlabsReq.NextRequestIDs = nextRequestIDs
|
||||
}
|
||||
if applyTextNormalization, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["apply_text_normalization"]); ok {
|
||||
delete(elevenlabsReq.ExtraParams, "apply_text_normalization")
|
||||
elevenlabsReq.ApplyTextNormalization = applyTextNormalization
|
||||
}
|
||||
if applyLanguageTextNormalization, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["apply_language_text_normalization"]); ok {
|
||||
delete(elevenlabsReq.ExtraParams, "apply_language_text_normalization")
|
||||
elevenlabsReq.ApplyLanguageTextNormalization = applyLanguageTextNormalization
|
||||
}
|
||||
if usePVCAsIVC, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["use_pvc_as_ivc"]); ok {
|
||||
delete(elevenlabsReq.ExtraParams, "use_pvc_as_ivc")
|
||||
elevenlabsReq.UsePVCAsIVC = usePVCAsIVC
|
||||
}
|
||||
}
|
||||
|
||||
if hasVoiceSettings {
|
||||
elevenlabsReq.VoiceSettings = &voiceSettings
|
||||
}
|
||||
|
||||
if bifrostReq.Params.LanguageCode != nil {
|
||||
elevenlabsReq.LanguageCode = bifrostReq.Params.LanguageCode
|
||||
}
|
||||
|
||||
if len(bifrostReq.Params.PronunciationDictionaryLocators) > 0 {
|
||||
elevenlabsReq.PronunciationDictionaryLocators = make([]ElevenlabsPronunciationDictionaryLocator, len(bifrostReq.Params.PronunciationDictionaryLocators))
|
||||
for i, locator := range bifrostReq.Params.PronunciationDictionaryLocators {
|
||||
elevenlabsReq.PronunciationDictionaryLocators[i] = ElevenlabsPronunciationDictionaryLocator{
|
||||
PronunciationDictionaryID: locator.PronunciationDictionaryID,
|
||||
VersionID: locator.VersionID,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return elevenlabsReq
|
||||
}
|
||||
269
core/providers/elevenlabs/transcription.go
Normal file
269
core/providers/elevenlabs/transcription.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func ToElevenlabsTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionRequest) *ElevenlabsTranscriptionRequest {
|
||||
if bifrostReq == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
req := &ElevenlabsTranscriptionRequest{
|
||||
ModelID: bifrostReq.Model,
|
||||
}
|
||||
|
||||
if bifrostReq.Input != nil && len(bifrostReq.Input.File) > 0 {
|
||||
req.File = bifrostReq.Input.File
|
||||
req.Filename = bifrostReq.Input.Filename
|
||||
}
|
||||
|
||||
if bifrostReq.Params == nil {
|
||||
return req
|
||||
}
|
||||
|
||||
params := bifrostReq.Params
|
||||
|
||||
if params.Language != nil {
|
||||
req.LanguageCode = params.Language
|
||||
}
|
||||
|
||||
if params.ExtraParams != nil {
|
||||
if tagAudioEvents, ok := schemas.SafeExtractBoolPointer(params.ExtraParams["tag_audio_events"]); ok {
|
||||
delete(params.ExtraParams, "tag_audio_events")
|
||||
req.TagAudioEvents = tagAudioEvents
|
||||
}
|
||||
if numSpeakers, ok := schemas.SafeExtractIntPointer(params.ExtraParams["num_speakers"]); ok {
|
||||
delete(params.ExtraParams, "num_speakers")
|
||||
req.NumSpeakers = numSpeakers
|
||||
}
|
||||
if timestampsGranularity, ok := schemas.SafeExtractStringPointer(params.ExtraParams["timestamps_granularity"]); ok {
|
||||
granularity := ElevenlabsTimestampsGranularity(*timestampsGranularity)
|
||||
delete(params.ExtraParams, "timestamps_granularity")
|
||||
req.TimestampsGranularity = &granularity
|
||||
}
|
||||
if diarize, ok := schemas.SafeExtractBoolPointer(params.ExtraParams["diarize"]); ok {
|
||||
delete(params.ExtraParams, "diarize")
|
||||
req.Diarize = diarize
|
||||
}
|
||||
if diarizationThreshold, ok := schemas.SafeExtractFloat64Pointer(params.ExtraParams["diarization_threshold"]); ok {
|
||||
delete(params.ExtraParams, "diarization_threshold")
|
||||
req.DiarizationThreshold = diarizationThreshold
|
||||
}
|
||||
if fileFormat, ok := schemas.SafeExtractStringPointer(params.ExtraParams["file_format"]); ok {
|
||||
fileFormat := ElevenlabsFileFormat(*fileFormat)
|
||||
delete(params.ExtraParams, "file_format")
|
||||
req.FileFormat = &fileFormat
|
||||
}
|
||||
if cloudStorageURL, ok := schemas.SafeExtractStringPointer(params.ExtraParams["cloud_storage_url"]); ok {
|
||||
delete(params.ExtraParams, "cloud_storage_url")
|
||||
req.CloudStorageURL = cloudStorageURL
|
||||
}
|
||||
if webhook, ok := schemas.SafeExtractBoolPointer(params.ExtraParams["webhook"]); ok {
|
||||
delete(params.ExtraParams, "webhook")
|
||||
req.Webhook = webhook
|
||||
}
|
||||
if webhookID, ok := schemas.SafeExtractStringPointer(params.ExtraParams["webhook_id"]); ok {
|
||||
delete(params.ExtraParams, "webhook_id")
|
||||
req.WebhookID = webhookID
|
||||
}
|
||||
if temperature, ok := schemas.SafeExtractFloat64Pointer(params.ExtraParams["temperature"]); ok {
|
||||
delete(params.ExtraParams, "temperature")
|
||||
req.Temperature = temperature
|
||||
}
|
||||
if seed, ok := schemas.SafeExtractIntPointer(params.ExtraParams["seed"]); ok {
|
||||
delete(params.ExtraParams, "seed")
|
||||
req.Seed = seed
|
||||
}
|
||||
if useMultiChannel, ok := schemas.SafeExtractBoolPointer(params.ExtraParams["use_multi_channel"]); ok {
|
||||
delete(params.ExtraParams, "use_multi_channel")
|
||||
req.UseMultiChannel = useMultiChannel
|
||||
}
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
if len(params.AdditionalFormats) > 0 {
|
||||
additionalFormats := make([]ElevenlabsAdditionalFormat, 0, len(params.AdditionalFormats))
|
||||
for _, format := range params.AdditionalFormats {
|
||||
if converted, ok := convertAdditionalFormat(format); ok {
|
||||
additionalFormats = append(additionalFormats, converted)
|
||||
}
|
||||
}
|
||||
if len(additionalFormats) > 0 {
|
||||
req.AdditionalFormats = additionalFormats
|
||||
}
|
||||
}
|
||||
|
||||
if params.WebhookMetadata != nil {
|
||||
if metadataMap, ok := params.WebhookMetadata.(map[string]interface{}); ok {
|
||||
if len(metadataMap) > 0 {
|
||||
req.WebhookMetadata = metadataMap
|
||||
}
|
||||
} else {
|
||||
req.WebhookMetadata = params.WebhookMetadata
|
||||
}
|
||||
}
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func ToBifrostTranscriptionResponse(chunks []ElevenlabsSpeechToTextChunkResponse) *schemas.BifrostTranscriptionResponse {
|
||||
if len(chunks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
textParts := make([]string, 0, len(chunks))
|
||||
allWords := make([]schemas.TranscriptionWord, 0)
|
||||
allLogProbs := make([]schemas.TranscriptionLogProb, 0)
|
||||
|
||||
var language *string
|
||||
var overallDuration *float64
|
||||
|
||||
for _, chunk := range chunks {
|
||||
textParts = append(textParts, chunk.Text)
|
||||
|
||||
words, logProbs, chunkDuration := convertWords(chunk.Words)
|
||||
allWords = append(allWords, words...)
|
||||
allLogProbs = append(allLogProbs, logProbs...)
|
||||
|
||||
if language == nil && chunk.LanguageCode != "" {
|
||||
lc := chunk.LanguageCode
|
||||
language = &lc
|
||||
}
|
||||
|
||||
if chunkDuration != nil {
|
||||
if overallDuration == nil || *chunkDuration > *overallDuration {
|
||||
val := *chunkDuration
|
||||
overallDuration = &val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text := strings.Join(textParts, "\n")
|
||||
|
||||
response := &schemas.BifrostTranscriptionResponse{
|
||||
Text: text,
|
||||
Words: allWords,
|
||||
LogProbs: allLogProbs,
|
||||
}
|
||||
|
||||
if language != nil {
|
||||
response.Language = language
|
||||
}
|
||||
|
||||
if overallDuration != nil {
|
||||
response.Duration = overallDuration
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
}
|
||||
|
||||
func convertAdditionalFormat(format schemas.TranscriptionAdditionalFormat) (ElevenlabsAdditionalFormat, bool) {
|
||||
if format.Format == "" {
|
||||
return ElevenlabsAdditionalFormat{}, false
|
||||
}
|
||||
|
||||
converted := ElevenlabsAdditionalFormat{
|
||||
Format: ElevenlabsExportOptions(format.Format),
|
||||
}
|
||||
|
||||
if format.IncludeSpeakers != nil {
|
||||
converted.IncludeSpeakers = format.IncludeSpeakers
|
||||
}
|
||||
|
||||
if format.IncludeTimestamps != nil {
|
||||
converted.IncludeTimestamps = format.IncludeTimestamps
|
||||
}
|
||||
|
||||
if format.SegmentOnSilenceLongerThanS != nil {
|
||||
converted.SegmentOnSilenceLongerThanS = format.SegmentOnSilenceLongerThanS
|
||||
}
|
||||
|
||||
if format.MaxSegmentDurationS != nil {
|
||||
converted.MaxSegmentDurationS = format.MaxSegmentDurationS
|
||||
}
|
||||
|
||||
if format.MaxSegmentChars != nil {
|
||||
converted.MaxSegmentChars = format.MaxSegmentChars
|
||||
}
|
||||
|
||||
if format.MaxCharactersPerLine != nil {
|
||||
converted.MaxCharactersPerLine = format.MaxCharactersPerLine
|
||||
}
|
||||
|
||||
return converted, true
|
||||
}
|
||||
|
||||
func convertWords(words []ElevenlabsSpeechToTextWord) ([]schemas.TranscriptionWord, []schemas.TranscriptionLogProb, *float64) {
|
||||
if len(words) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
convertedWords := make([]schemas.TranscriptionWord, 0, len(words))
|
||||
logProbs := make([]schemas.TranscriptionLogProb, 0, len(words))
|
||||
|
||||
var maxEnd float64
|
||||
var hasEnd bool
|
||||
|
||||
for _, word := range words {
|
||||
trimmed := strings.TrimSpace(word.Text)
|
||||
if word.Type == "spacing" && trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
transcriptionWord := schemas.TranscriptionWord{
|
||||
Word: word.Text,
|
||||
}
|
||||
|
||||
if word.Start != nil {
|
||||
transcriptionWord.Start = *word.Start
|
||||
}
|
||||
|
||||
if word.End != nil {
|
||||
transcriptionWord.End = *word.End
|
||||
if !hasEnd || *word.End > maxEnd {
|
||||
maxEnd = *word.End
|
||||
hasEnd = true
|
||||
}
|
||||
}
|
||||
|
||||
convertedWords = append(convertedWords, transcriptionWord)
|
||||
logProbs = append(logProbs, schemas.TranscriptionLogProb{
|
||||
Token: word.Text,
|
||||
LogProb: word.LogProb,
|
||||
})
|
||||
}
|
||||
|
||||
if !hasEnd {
|
||||
return convertedWords, logProbs, nil
|
||||
}
|
||||
|
||||
duration := maxEnd
|
||||
return convertedWords, logProbs, &duration
|
||||
}
|
||||
|
||||
func parseTranscriptionResponse(body []byte) ([]ElevenlabsSpeechToTextChunkResponse, error) {
|
||||
var multichannel ElevenlabsMultichannelSpeechToTextResponse
|
||||
if err := sonic.Unmarshal(body, &multichannel); err == nil && len(multichannel.Transcripts) > 0 {
|
||||
return multichannel.Transcripts, nil
|
||||
}
|
||||
|
||||
var single ElevenlabsSpeechToTextChunkResponse
|
||||
if err := sonic.Unmarshal(body, &single); err == nil {
|
||||
if single.LanguageCode != "" || single.Text != "" || len(single.Words) > 0 {
|
||||
return []ElevenlabsSpeechToTextChunkResponse{single}, nil
|
||||
}
|
||||
}
|
||||
|
||||
var webhook ElevenlabsSpeechToTextWebhookResponse
|
||||
if err := sonic.Unmarshal(body, &webhook); err == nil && strings.TrimSpace(webhook.Message) != "" {
|
||||
return nil, errors.New(webhook.Message)
|
||||
}
|
||||
|
||||
return nil, errors.New("unexpected Elevenlabs transcription response format")
|
||||
}
|
||||
289
core/providers/elevenlabs/types.go
Normal file
289
core/providers/elevenlabs/types.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
)
|
||||
|
||||
// SPEECH TYPES
|
||||
|
||||
type ElevenlabsSpeechRequest struct {
|
||||
Text string `json:"text"`
|
||||
ModelID string `json:"model_id"` // defaults to "eleven_multilingual_v2"
|
||||
LanguageCode *string `json:"language_code,omitempty"`
|
||||
VoiceSettings *ElevenlabsVoiceSettings `json:"voice_settings,omitempty"`
|
||||
PronunciationDictionaryLocators []ElevenlabsPronunciationDictionaryLocator `json:"pronunciation_dictionary_locators"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
PreviousText *string `json:"previous_text,omitempty"`
|
||||
NextText *string `json:"next_text,omitempty"`
|
||||
PreviousRequestIDs []string `json:"previous_request_ids"`
|
||||
NextRequestIDs []string `json:"next_request_ids"`
|
||||
ApplyTextNormalization *string `json:"apply_text_normalization,omitempty"`
|
||||
ApplyLanguageTextNormalization *bool `json:"apply_language_text_normalization,omitempty"`
|
||||
UsePVCAsIVC *bool `json:"use_pvc_as_ivc,omitempty"` // deprecated
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetExtraParams implements the providerUtils.RequestBodyWithExtraParams interface.
|
||||
func (r *ElevenlabsSpeechRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
// ElevenlabsSpeechWithTimestampsResponse represents the response from the with-timestamps endpoint
|
||||
type ElevenlabsSpeechWithTimestampsResponse struct {
|
||||
AudioBase64 string `json:"audio_base64"`
|
||||
Alignment *ElevenlabsAlignment `json:"alignment,omitempty"`
|
||||
NormalizedAlignment *ElevenlabsAlignment `json:"normalized_alignment,omitempty"`
|
||||
}
|
||||
|
||||
// ElevenlabsAlignment represents character-level timing information
|
||||
type ElevenlabsAlignment struct {
|
||||
CharStartTimesMs []float64 `json:"char_start_times_ms"`
|
||||
CharEndTimesMs []float64 `json:"char_end_times_ms"`
|
||||
Characters []string `json:"characters"`
|
||||
}
|
||||
|
||||
type ElevenlabsVoiceSettings struct {
|
||||
Stability float64 `json:"stability"` // 0-1, default 0.5
|
||||
UseSpeakerBoost bool `json:"use_speaker_boost"` // default true
|
||||
SimilarityBoost float64 `json:"similarity_boost"` // 0-1, default 0.75
|
||||
Style float64 `json:"style"` // default 0
|
||||
Speed float64 `json:"speed"` // default 1
|
||||
}
|
||||
|
||||
type ElevenlabsPronunciationDictionaryLocator struct {
|
||||
PronunciationDictionaryID string `json:"pronunciation_dictionary_id"`
|
||||
VersionID *string `json:"version_id,omitempty"`
|
||||
}
|
||||
|
||||
// TRANSCRIPTION TYPES
|
||||
type ElevenlabsTranscriptionRequest struct {
|
||||
ModelID string `json:"model_id"`
|
||||
File []byte `json:"-"`
|
||||
Filename string `json:"-"` // Original filename, used to preserve file format extension
|
||||
LanguageCode *string `json:"language_code,omitempty"`
|
||||
TagAudioEvents *bool `json:"tag_audio_events,omitempty"`
|
||||
NumSpeakers *int `json:"num_speakers,omitempty"`
|
||||
TimestampsGranularity *ElevenlabsTimestampsGranularity `json:"timestamps_granularity,omitempty"`
|
||||
Diarize *bool `json:"diarize,omitempty"`
|
||||
DiarizationThreshold *float64 `json:"diarization_threshold,omitempty"`
|
||||
AdditionalFormats []ElevenlabsAdditionalFormat `json:"additional_formats,omitempty"`
|
||||
FileFormat *ElevenlabsFileFormat `json:"file_format,omitempty"`
|
||||
CloudStorageURL *string `json:"cloud_storage_url,omitempty"`
|
||||
Webhook *bool `json:"webhook,omitempty"`
|
||||
WebhookID *string `json:"webhook_id,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
Seed *int `json:"seed,omitempty"`
|
||||
UseMultiChannel *bool `json:"use_multi_channel,omitempty"`
|
||||
WebhookMetadata interface{} `json:"webhook_metadata,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetExtraParams implements the RequestBodyWithExtraParams interface
|
||||
func (req *ElevenlabsTranscriptionRequest) GetExtraParams() map[string]interface{} {
|
||||
return req.ExtraParams
|
||||
}
|
||||
|
||||
type ElevenlabsTimestampsGranularity string
|
||||
|
||||
const (
|
||||
ElevenlabsTimestampsGranularityNone ElevenlabsTimestampsGranularity = "none"
|
||||
ElevenlabsTimestampsGranularityWord ElevenlabsTimestampsGranularity = "word"
|
||||
ElevenlabsTimestampsGranularityCharacter ElevenlabsTimestampsGranularity = "character"
|
||||
)
|
||||
|
||||
type ElevenlabsFileFormat string
|
||||
|
||||
const (
|
||||
ElevenlabsFileFormatPcmS16le16 ElevenlabsFileFormat = "pcm_s16le_16"
|
||||
ElevenlabsFileFormatOther ElevenlabsFileFormat = "other"
|
||||
)
|
||||
|
||||
type ElevenlabsAdditionalFormat struct {
|
||||
Format ElevenlabsExportOptions `json:"format"`
|
||||
IncludeSpeakers *bool `json:"include_speakers,omitempty"`
|
||||
IncludeTimestamps *bool `json:"include_timestamps,omitempty"`
|
||||
SegmentOnSilenceLongerThanS *float64 `json:"segment_on_silence_longer_than_s,omitempty"`
|
||||
MaxSegmentDurationS *float64 `json:"max_segment_duration_s,omitempty"`
|
||||
MaxSegmentChars *int `json:"max_segment_chars,omitempty"`
|
||||
MaxCharactersPerLine *int `json:"max_characters_per_line,omitempty"`
|
||||
}
|
||||
|
||||
type ElevenlabsExportOptions string
|
||||
|
||||
const (
|
||||
ElevenlabsExportOptionsSegmentedJson ElevenlabsExportOptions = "segmented_json"
|
||||
ElevenlabsExportOptionsDocx ElevenlabsExportOptions = "docx"
|
||||
ElevenlabsExportOptionsPdf ElevenlabsExportOptions = "pdf"
|
||||
ElevenlabsExportOptionsTxt ElevenlabsExportOptions = "txt"
|
||||
ElevenlabsExportOptionsHtml ElevenlabsExportOptions = "html"
|
||||
ElevenlabsExportOptionsSrt ElevenlabsExportOptions = "srt"
|
||||
)
|
||||
|
||||
type ElevenlabsSpeechToTextChunkResponse struct {
|
||||
LanguageCode string `json:"language_code"`
|
||||
LanguageProbability *float64 `json:"language_probability,omitempty"`
|
||||
Text string `json:"text"`
|
||||
Words []ElevenlabsSpeechToTextWord `json:"words"`
|
||||
ChannelIndex *int `json:"channel_index,omitempty"`
|
||||
AdditionalFormats []*ElevenlabsAdditionalFormatResponse `json:"additional_formats,omitempty"`
|
||||
TranscriptionID *string `json:"transcription_id,omitempty"`
|
||||
}
|
||||
|
||||
type ElevenlabsSpeechToTextWord struct {
|
||||
Text string `json:"text"`
|
||||
Start *float64 `json:"start,omitempty"`
|
||||
End *float64 `json:"end,omitempty"`
|
||||
Type string `json:"type"`
|
||||
SpeakerID *string `json:"speaker_id,omitempty"`
|
||||
LogProb float64 `json:"logprob"`
|
||||
Characters []ElevenlabsSpeechToTextCharacter `json:"characters,omitempty"`
|
||||
}
|
||||
|
||||
type ElevenlabsSpeechToTextCharacter struct {
|
||||
Text string `json:"text"`
|
||||
Start *float64 `json:"start,omitempty"`
|
||||
End *float64 `json:"end,omitempty"`
|
||||
}
|
||||
|
||||
type ElevenlabsAdditionalFormatResponse struct {
|
||||
RequestedFormat string `json:"requested_format"`
|
||||
FileExtension string `json:"file_extension"`
|
||||
ContentType string `json:"content_type"`
|
||||
IsBase64Encoded bool `json:"is_base64_encoded"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ElevenlabsMultichannelSpeechToTextResponse struct {
|
||||
Transcripts []ElevenlabsSpeechToTextChunkResponse `json:"transcripts"`
|
||||
TranscriptionID *string `json:"transcription_id,omitempty"`
|
||||
}
|
||||
|
||||
type ElevenlabsSpeechToTextWebhookResponse struct {
|
||||
Message string `json:"message"`
|
||||
RequestID string `json:"request_id"`
|
||||
TranscriptionID *string `json:"transcription_id,omitempty"`
|
||||
}
|
||||
|
||||
// ERROR TYPES
|
||||
type ElevenlabsError struct {
|
||||
Detail *ElevenlabsErrorDetail `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
// ElevenlabsErrorDetail handles both single object (non-validation errors) and
|
||||
// array of objects (validation errors) formats from ElevenLabs API.
|
||||
type ElevenlabsErrorDetail struct {
|
||||
// Non-validation error fields (when detail is a single object)
|
||||
Status *string `json:"status,omitempty"`
|
||||
Message *string `json:"message,omitempty"`
|
||||
|
||||
// Validation error fields (when detail is an array)
|
||||
ValidationErrors []ElevenlabsValidationError `json:"-"`
|
||||
}
|
||||
|
||||
// ElevenlabsValidationError represents a single validation error entry
|
||||
type ElevenlabsValidationError struct {
|
||||
Loc []string `json:"loc"`
|
||||
Msg string `json:"msg"`
|
||||
Message string `json:"message"` // Some APIs use "message" instead of "msg"
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling to handle both
|
||||
// single object and array formats from ElevenLabs API.
|
||||
func (d *ElevenlabsErrorDetail) UnmarshalJSON(data []byte) error {
|
||||
// First, try to unmarshal as an array (validation errors)
|
||||
// Check if it's an array by looking at the first non-whitespace character
|
||||
trimmed := strings.TrimSpace(string(data))
|
||||
if len(trimmed) > 0 && trimmed[0] == '[' {
|
||||
var validationErrors []ElevenlabsValidationError
|
||||
if err := sonic.Unmarshal(data, &validationErrors); err != nil {
|
||||
return err
|
||||
}
|
||||
d.ValidationErrors = validationErrors
|
||||
// Extract message from first validation error if available
|
||||
if len(validationErrors) > 0 {
|
||||
if validationErrors[0].Message != "" {
|
||||
d.Message = &validationErrors[0].Message
|
||||
} else if validationErrors[0].Msg != "" {
|
||||
d.Message = &validationErrors[0].Msg
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If not an array, try to unmarshal as a single object (non-validation error)
|
||||
var obj struct {
|
||||
Type *string `json:"type,omitempty"`
|
||||
Loc []string `json:"loc,omitempty"`
|
||||
Message *string `json:"message,omitempty"`
|
||||
Status *string `json:"status,omitempty"`
|
||||
Msg *string `json:"msg,omitempty"` // Some APIs use "msg" instead of "message"
|
||||
}
|
||||
if err := sonic.Unmarshal(data, &obj); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Populate non-validation error fields
|
||||
d.Status = obj.Status
|
||||
if obj.Message != nil {
|
||||
d.Message = obj.Message
|
||||
} else if obj.Msg != nil {
|
||||
d.Message = obj.Msg
|
||||
}
|
||||
|
||||
// If this object has validation-like fields (Loc, Type), treat it as a single validation error
|
||||
if len(obj.Loc) > 0 || obj.Type != nil {
|
||||
validationErr := ElevenlabsValidationError{
|
||||
Loc: obj.Loc,
|
||||
Type: func() string {
|
||||
if obj.Type != nil {
|
||||
return *obj.Type
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
}
|
||||
if obj.Message != nil {
|
||||
validationErr.Message = *obj.Message
|
||||
} else if obj.Msg != nil {
|
||||
validationErr.Msg = *obj.Msg
|
||||
validationErr.Message = *obj.Msg
|
||||
}
|
||||
d.ValidationErrors = []ElevenlabsValidationError{validationErr}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MODEL TYPES
|
||||
type ElevenlabsModel struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ServesProVoices bool `json:"serves_pro_voices"`
|
||||
TokenCostFactor float64 `json:"token_cost_factor"`
|
||||
CanBeFinetuned bool `json:"can_be_finetuned"`
|
||||
CanDoTextToSpeech bool `json:"can_do_text_to_speech"`
|
||||
CanDoVoiceConversion bool `json:"can_do_voice_conversion"`
|
||||
CanUseStyle bool `json:"can_use_style"`
|
||||
CanUseSpeakerBoost bool `json:"can_use_speaker_boost"`
|
||||
Languages []ElevenlabsLanguage `json:"languages"`
|
||||
RequiresAlphaAccess bool `json:"requires_alpha_access"`
|
||||
MaxCharactersRequestFreeUser int `json:"max_characters_request_free_user"`
|
||||
MaxCharactersRequestSubscribedUser int `json:"max_characters_request_subscribed_user"`
|
||||
MaxTextLengthPerRequest int `json:"maximum_text_length_per_request"`
|
||||
ModelRates ElevenlabsModelRate `json:"model_rates"`
|
||||
ConcurrencyGroup string `json:"concurrency_group"`
|
||||
}
|
||||
|
||||
type ElevenlabsLanguage struct {
|
||||
LanguageID string `json:"language_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type ElevenlabsModelRate struct {
|
||||
CharacterCostMultiplier float64 `json:"character_cost_multiplier"`
|
||||
}
|
||||
|
||||
type ElevenlabsListModelsResponse []ElevenlabsModel
|
||||
35
core/providers/elevenlabs/utils.go
Normal file
35
core/providers/elevenlabs/utils.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package elevenlabs
|
||||
|
||||
var (
|
||||
// Maps provider-specific finish reasons to Bifrost format
|
||||
bifrostToElevenlabsSpeechFormat = map[string]string{
|
||||
"": "mp3_44100_128",
|
||||
"mp3": "mp3_44100_128",
|
||||
"opus": "opus_48000_128",
|
||||
"wav": "pcm_44100",
|
||||
"pcm": "pcm_44100",
|
||||
}
|
||||
|
||||
// Maps Bifrost finish reasons to provider-specific format
|
||||
elevenlabsSpeechFormatToBifrost = map[string]string{
|
||||
"mp3_44100_128": "mp3",
|
||||
"opus_48000_128": "opus",
|
||||
"pcm_44100": "wav",
|
||||
}
|
||||
)
|
||||
|
||||
// ConvertBifrostSpeechFormatToElevenlabs converts Bifrost speech format to Elevenlabs format
|
||||
func ConvertBifrostSpeechFormatToElevenlabs(format string) string {
|
||||
if elevenlabsFormat, ok := bifrostToElevenlabsSpeechFormat[format]; ok {
|
||||
return elevenlabsFormat
|
||||
}
|
||||
return format
|
||||
}
|
||||
|
||||
// ConvertElevenlabsSpeechFormatToBifrost converts Elevenlabs speech format to Bifrost format
|
||||
func ConvertElevenlabsSpeechFormatToBifrost(format string) string {
|
||||
if bifrostFormat, ok := elevenlabsSpeechFormatToBifrost[format]; ok {
|
||||
return bifrostFormat
|
||||
}
|
||||
return format
|
||||
}
|
||||
435
core/providers/fireworks/fireworks.go
Normal file
435
core/providers/fireworks/fireworks.go
Normal file
@@ -0,0 +1,435 @@
|
||||
// Package fireworks implements the Fireworks AI provider and its utility functions.
|
||||
package fireworks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// FireworksProvider implements the Provider interface for Fireworks AI's API.
|
||||
type FireworksProvider struct {
|
||||
logger schemas.Logger // Logger for provider operations
|
||||
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
|
||||
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
|
||||
networkConfig schemas.NetworkConfig // Network configuration including extra headers
|
||||
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
|
||||
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
|
||||
}
|
||||
|
||||
// NewFireworksProvider creates a new Fireworks AI provider instance.
|
||||
// It initializes the HTTP client with the provided configuration and sets up response pools.
|
||||
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
|
||||
func NewFireworksProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*FireworksProvider, error) {
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
|
||||
client := &fasthttp.Client{
|
||||
ReadTimeout: requestTimeout,
|
||||
WriteTimeout: requestTimeout,
|
||||
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
|
||||
MaxIdleConnDuration: 30 * time.Second,
|
||||
MaxConnWaitTimeout: requestTimeout,
|
||||
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
|
||||
ConnPoolStrategy: fasthttp.FIFO,
|
||||
}
|
||||
|
||||
// Configure proxy and retry policy
|
||||
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
|
||||
client = providerUtils.ConfigureDialer(client)
|
||||
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
|
||||
streamingClient := providerUtils.BuildStreamingClient(client)
|
||||
// Set default BaseURL if not provided
|
||||
if config.NetworkConfig.BaseURL == "" {
|
||||
config.NetworkConfig.BaseURL = "https://api.fireworks.ai/inference"
|
||||
}
|
||||
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
|
||||
|
||||
return &FireworksProvider{
|
||||
logger: logger,
|
||||
client: client,
|
||||
streamingClient: streamingClient,
|
||||
networkConfig: config.NetworkConfig,
|
||||
sendBackRawRequest: config.SendBackRawRequest,
|
||||
sendBackRawResponse: config.SendBackRawResponse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetProviderKey returns the provider identifier for Fireworks AI.
|
||||
func (provider *FireworksProvider) GetProviderKey() schemas.ModelProvider {
|
||||
return schemas.Fireworks
|
||||
}
|
||||
|
||||
// ListModels performs a list models request to Fireworks AI's API.
|
||||
func (provider *FireworksProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIListModelsRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
request,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"),
|
||||
keys,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
schemas.Fireworks,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
// TextCompletion performs a text completion request to the Fireworks AI API.
|
||||
func (provider *FireworksProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAITextCompletionRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
provider.GetProviderKey(),
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// TextCompletionStream performs a streaming text completion request to the Fireworks AI API.
|
||||
func (provider *FireworksProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
var authHeader map[string]string
|
||||
if v := key.Value.GetValue(); v != "" {
|
||||
authHeader = map[string]string{"Authorization": "Bearer " + v}
|
||||
}
|
||||
return openai.HandleOpenAITextCompletionStreaming(
|
||||
ctx,
|
||||
provider.streamingClient,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/completions"),
|
||||
request,
|
||||
authHeader,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
nil,
|
||||
postHookRunner,
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
postHookSpanFinalizer,
|
||||
)
|
||||
}
|
||||
|
||||
// ChatCompletion performs a chat completion request to the Fireworks AI API.
|
||||
func (provider *FireworksProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIChatCompletionRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// ChatCompletionStream performs a streaming chat completion request to the Fireworks AI API.
|
||||
// It supports real-time streaming of responses using Server-Sent Events (SSE).
|
||||
// Uses Fireworks AI's OpenAI-compatible streaming format.
|
||||
// Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails.
|
||||
func (provider *FireworksProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
var authHeader map[string]string
|
||||
if v := key.Value.GetValue(); v != "" {
|
||||
authHeader = map[string]string{"Authorization": "Bearer " + v}
|
||||
}
|
||||
// Use shared OpenAI-compatible streaming logic
|
||||
return openai.HandleOpenAIChatCompletionStreaming(
|
||||
ctx,
|
||||
provider.streamingClient,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
|
||||
request,
|
||||
authHeader,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
schemas.Fireworks,
|
||||
postHookRunner,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
postHookSpanFinalizer,
|
||||
)
|
||||
}
|
||||
|
||||
// Responses performs a responses request to the Fireworks AI API.
|
||||
func (provider *FireworksProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIResponsesRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/responses"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// ResponsesStream performs a streaming responses request to the Fireworks AI API.
|
||||
func (provider *FireworksProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
var authHeader map[string]string
|
||||
if v := key.Value.GetValue(); v != "" {
|
||||
authHeader = map[string]string{"Authorization": "Bearer " + v}
|
||||
}
|
||||
return openai.HandleOpenAIResponsesStreaming(
|
||||
ctx,
|
||||
provider.streamingClient,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/responses"),
|
||||
request,
|
||||
authHeader,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
postHookRunner,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
postHookSpanFinalizer,
|
||||
)
|
||||
}
|
||||
|
||||
// Embedding performs an embedding request to the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIEmbeddingRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/embeddings"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
provider.GetProviderKey(),
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// Speech is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Rerank is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// OCR is not supported by the Fireworks provider.
|
||||
func (provider *FireworksProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// SpeechStream is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Transcription is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// TranscriptionStream is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGeneration is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGenerationStream is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEdit is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEditStream is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageVariation is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoGeneration is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRetrieve is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDownload is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDelete is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoList is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRemix is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCreate is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchList is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchRetrieve is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCancel is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchDelete is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchResults is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileUpload is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileList is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileRetrieve is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileDelete is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileContent is not supported by Fireworks AI provider.
|
||||
func (provider *FireworksProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// CountTokens is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerCreate is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerList is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerRetrieve is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerDelete is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileCreate is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileList is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileRetrieve is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileContent is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileDelete is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Passthrough is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// PassthroughStream is not supported by the Fireworks AI provider.
|
||||
func (provider *FireworksProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
443
core/providers/fireworks/fireworks_test.go
Normal file
443
core/providers/fireworks/fireworks_test.go
Normal file
@@ -0,0 +1,443 @@
|
||||
package fireworks_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
fireworksprovider "github.com/maximhq/bifrost/core/providers/fireworks"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestFireworks(t *testing.T) {
|
||||
t.Parallel()
|
||||
if strings.TrimSpace(os.Getenv("FIREWORKS_API_KEY")) == "" {
|
||||
t.Skip("Skipping Fireworks tests because FIREWORKS_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
chatModel, textModel, embeddingModel := resolveFireworksModels(t, client, ctx)
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.Fireworks,
|
||||
ChatModel: chatModel,
|
||||
Fallbacks: []schemas.Fallback{},
|
||||
TextModel: textModel,
|
||||
TextCompletionFallbacks: []schemas.Fallback{},
|
||||
EmbeddingModel: embeddingModel,
|
||||
ReasoningModel: "",
|
||||
TranscriptionModel: "",
|
||||
SpeechSynthesisModel: "",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: textModel != "",
|
||||
TextCompletionStream: textModel != "",
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: false,
|
||||
End2EndToolCalling: false,
|
||||
AutomaticFunctionCall: false,
|
||||
ImageURL: false,
|
||||
ImageBase64: false,
|
||||
MultipleImages: false,
|
||||
FileBase64: false,
|
||||
FileURL: false,
|
||||
CompleteEnd2End: true,
|
||||
Embedding: embeddingModel != "",
|
||||
ListModels: true,
|
||||
Reasoning: false,
|
||||
Transcription: false,
|
||||
SpeechSynthesis: false,
|
||||
PromptCaching: false,
|
||||
},
|
||||
}
|
||||
t.Run("FireworksTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
|
||||
// resolveFireworksModels discovers live Fireworks models for chat, completions, and embeddings.
|
||||
func resolveFireworksModels(t *testing.T, client *bifrost.Bifrost, ctx context.Context) (string, string, string) {
|
||||
t.Helper()
|
||||
|
||||
requestedChatModel := normalizeFireworksModelID(os.Getenv("FIREWORKS_CHAT_MODEL"))
|
||||
requestedTextModel := normalizeFireworksModelID(os.Getenv("FIREWORKS_TEXT_MODEL"))
|
||||
requestedEmbeddingModel := normalizeFireworksModelID(os.Getenv("FIREWORKS_EMBEDDING_MODEL"))
|
||||
|
||||
chatModel := requestedChatModel
|
||||
textModel := requestedTextModel
|
||||
embeddingModel := requestedEmbeddingModel
|
||||
|
||||
if requestedChatModel != "" {
|
||||
t.Logf("Using FIREWORKS_CHAT_MODEL=%q override", requestedChatModel)
|
||||
}
|
||||
if requestedTextModel != "" {
|
||||
t.Logf("Using FIREWORKS_TEXT_MODEL=%q override", requestedTextModel)
|
||||
}
|
||||
if requestedEmbeddingModel != "" {
|
||||
t.Logf("Using FIREWORKS_EMBEDDING_MODEL=%q override", requestedEmbeddingModel)
|
||||
}
|
||||
|
||||
if chatModel == "" || textModel == "" || embeddingModel == "" {
|
||||
pageToken := ""
|
||||
for page := 0; page < 5; page++ {
|
||||
req := &schemas.BifrostListModelsRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
PageSize: 200,
|
||||
PageToken: pageToken,
|
||||
}
|
||||
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
resp, bifrostErr := client.ListModelsRequest(bfCtx, req)
|
||||
if bifrostErr != nil {
|
||||
if chatModel == "" {
|
||||
t.Fatalf("Failed to list Fireworks models for test discovery: %v", llmtests.GetErrorMessage(bifrostErr))
|
||||
}
|
||||
t.Logf("Fireworks model discovery failed: %v", llmtests.GetErrorMessage(bifrostErr))
|
||||
break
|
||||
}
|
||||
|
||||
if chatModel == "" {
|
||||
chatModel = pickFireworksChatModel(resp.Data)
|
||||
}
|
||||
if textModel == "" {
|
||||
// Fireworks text completions currently reuse the chat-capable model pool;
|
||||
// a later probe verifies that the selected model accepts /v1/completions.
|
||||
textModel = pickFireworksChatModel(resp.Data)
|
||||
}
|
||||
if embeddingModel == "" {
|
||||
embeddingModel = pickFireworksEmbeddingModel(resp.Data)
|
||||
}
|
||||
|
||||
if chatModel != "" && textModel != "" && embeddingModel != "" {
|
||||
break
|
||||
}
|
||||
if resp.NextPageToken == "" {
|
||||
break
|
||||
}
|
||||
pageToken = resp.NextPageToken
|
||||
}
|
||||
}
|
||||
|
||||
if chatModel == "" {
|
||||
t.Fatal("Unable to discover a Fireworks chat model from /v1/models; set FIREWORKS_CHAT_MODEL to override")
|
||||
}
|
||||
if textModel != "" && !fireworksModelSupportsTextCompletions(t, client, ctx, textModel) {
|
||||
t.Logf("Skipping Fireworks text completion scenarios because model %q did not accept /v1/completions", textModel)
|
||||
textModel = ""
|
||||
}
|
||||
if embeddingModel != "" && !fireworksModelSupportsEmbeddings(t, client, ctx, embeddingModel) {
|
||||
t.Logf("Skipping Fireworks embedding scenario because model %q did not accept /v1/embeddings", embeddingModel)
|
||||
embeddingModel = ""
|
||||
}
|
||||
if textModel == "" {
|
||||
t.Log("No Fireworks completions-capable model discovered from /v1/models; text completion scenarios will be skipped unless FIREWORKS_TEXT_MODEL is set")
|
||||
}
|
||||
if embeddingModel == "" {
|
||||
t.Log("No Fireworks embedding model discovered from /v1/models; embedding scenario will be skipped unless FIREWORKS_EMBEDDING_MODEL is set")
|
||||
}
|
||||
|
||||
return chatModel, textModel, embeddingModel
|
||||
}
|
||||
|
||||
// fireworksModelSupportsTextCompletions validates that the selected model actually accepts Fireworks /v1/completions.
|
||||
func fireworksModelSupportsTextCompletions(t *testing.T, client *bifrost.Bifrost, ctx context.Context, model string) bool {
|
||||
t.Helper()
|
||||
|
||||
prompt := "Say ok"
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
resp, bifrostErr := client.TextCompletionRequest(bfCtx, &schemas.BifrostTextCompletionRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: model,
|
||||
Input: &schemas.TextCompletionInput{
|
||||
PromptStr: &prompt,
|
||||
},
|
||||
Params: &schemas.TextCompletionParameters{
|
||||
MaxTokens: schemas.Ptr(8),
|
||||
},
|
||||
})
|
||||
if bifrostErr != nil {
|
||||
t.Logf("Fireworks /v1/completions probe failed for %q: %v", model, llmtests.GetErrorMessage(bifrostErr))
|
||||
return false
|
||||
}
|
||||
|
||||
return resp != nil && len(resp.Choices) > 0
|
||||
}
|
||||
|
||||
// fireworksModelSupportsEmbeddings validates that the selected model actually accepts Fireworks /v1/embeddings.
|
||||
func fireworksModelSupportsEmbeddings(t *testing.T, client *bifrost.Bifrost, ctx context.Context, model string) bool {
|
||||
t.Helper()
|
||||
|
||||
text := "embedding probe"
|
||||
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
|
||||
resp, bifrostErr := client.EmbeddingRequest(bfCtx, &schemas.BifrostEmbeddingRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: model,
|
||||
Input: &schemas.EmbeddingInput{
|
||||
Text: &text,
|
||||
},
|
||||
})
|
||||
if bifrostErr != nil {
|
||||
t.Logf("Fireworks /v1/embeddings probe failed for %q: %v", model, llmtests.GetErrorMessage(bifrostErr))
|
||||
return false
|
||||
}
|
||||
|
||||
return resp != nil && len(resp.Data) > 0
|
||||
}
|
||||
|
||||
// pickFireworksChatModel selects a text-capable Fireworks model from ListModels output.
|
||||
func pickFireworksChatModel(models []schemas.Model) string {
|
||||
for _, model := range models {
|
||||
normalized := normalizeFireworksModelID(model.ID)
|
||||
if isFireworksTextCapable(normalized) {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// pickFireworksEmbeddingModel selects an embedding-capable Fireworks model from ListModels output.
|
||||
func pickFireworksEmbeddingModel(models []schemas.Model) string {
|
||||
for _, model := range models {
|
||||
normalized := normalizeFireworksModelID(model.ID)
|
||||
lower := strings.ToLower(normalized)
|
||||
if strings.Contains(lower, "embedding") || strings.Contains(lower, "embed") {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// normalizeFireworksModelID strips any provider prefix so tests can pass raw Fireworks model IDs to Bifrost requests.
|
||||
func normalizeFireworksModelID(modelID string) string {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
return ""
|
||||
}
|
||||
_, normalized := schemas.ParseModelString(modelID, schemas.Fireworks)
|
||||
return normalized
|
||||
}
|
||||
|
||||
// isFireworksTextCapable applies a conservative name-based heuristic for text/chat-capable Fireworks models.
|
||||
func isFireworksTextCapable(modelID string) bool {
|
||||
lower := strings.ToLower(modelID)
|
||||
excludedHints := []string{
|
||||
"flux",
|
||||
"whisper",
|
||||
"audio",
|
||||
"speech",
|
||||
"transcrib",
|
||||
"embedding",
|
||||
"embed",
|
||||
"rerank",
|
||||
}
|
||||
for _, hint := range excludedHints {
|
||||
if strings.Contains(lower, hint) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
preferredHints := []string{
|
||||
"instruct",
|
||||
"chat",
|
||||
"gpt-oss",
|
||||
"deepseek",
|
||||
"qwen",
|
||||
"llama",
|
||||
"glm",
|
||||
"mixtral",
|
||||
"mistral",
|
||||
"cogito",
|
||||
"gemma",
|
||||
}
|
||||
for _, hint := range preferredHints {
|
||||
if strings.Contains(lower, hint) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// TestFireworksProviderUsesNativeEndpoints verifies that the Fireworks provider targets native completions, responses, and embeddings endpoints.
|
||||
func TestFireworksProviderUsesNativeEndpoints(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expectedPath string
|
||||
run func(t *testing.T, provider *fireworksprovider.FireworksProvider, ctx *schemas.BifrostContext, key schemas.Key)
|
||||
}{
|
||||
{
|
||||
name: "TextCompletion",
|
||||
expectedPath: "/v1/completions",
|
||||
run: func(t *testing.T, provider *fireworksprovider.FireworksProvider, ctx *schemas.BifrostContext, key schemas.Key) {
|
||||
prompt := "A is for apple and B is for"
|
||||
resp, err := provider.TextCompletion(ctx, key, &schemas.BifrostTextCompletionRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: "accounts/fireworks/models/deepseek-v3p2",
|
||||
Input: &schemas.TextCompletionInput{
|
||||
PromptStr: &prompt,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("TextCompletion returned error: %v", llmtests.GetErrorMessage(err))
|
||||
}
|
||||
if resp == nil || len(resp.Choices) == 0 || resp.Choices[0].Text == nil || *resp.Choices[0].Text == "" {
|
||||
t.Fatalf("unexpected text completion response: %#v", resp)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Responses",
|
||||
expectedPath: "/v1/responses",
|
||||
run: func(t *testing.T, provider *fireworksprovider.FireworksProvider, ctx *schemas.BifrostContext, key schemas.Key) {
|
||||
resp, err := provider.Responses(ctx, key, &schemas.BifrostResponsesRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: "accounts/fireworks/models/deepseek-v3p2",
|
||||
Input: []schemas.ResponsesMessage{
|
||||
llmtests.CreateBasicResponsesMessage("hello"),
|
||||
},
|
||||
Params: &schemas.ResponsesParameters{
|
||||
PreviousResponseID: schemas.Ptr("resp_previous"),
|
||||
MaxToolCalls: schemas.Ptr(2),
|
||||
Store: schemas.Ptr(true),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Responses returned error: %v", llmtests.GetErrorMessage(err))
|
||||
}
|
||||
if resp == nil || resp.PreviousResponseID == nil || *resp.PreviousResponseID != "resp_previous" {
|
||||
t.Fatalf("unexpected responses response: %#v", resp)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Embedding",
|
||||
expectedPath: "/v1/embeddings",
|
||||
run: func(t *testing.T, provider *fireworksprovider.FireworksProvider, ctx *schemas.BifrostContext, key schemas.Key) {
|
||||
resp, err := provider.Embedding(ctx, key, &schemas.BifrostEmbeddingRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: "accounts/fireworks/models/nomic-embed-text-v1.5",
|
||||
Input: &schemas.EmbeddingInput{
|
||||
Text: schemas.Ptr("embedding test"),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Embedding returned error: %v", llmtests.GetErrorMessage(err))
|
||||
}
|
||||
if resp == nil || len(resp.Data) != 1 || len(resp.Data[0].Embedding.EmbeddingArray) != 3 {
|
||||
t.Fatalf("unexpected embedding response: %#v", resp)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var requestedPath string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestedPath = r.URL.Path
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.URL.Path {
|
||||
case "/v1/completions":
|
||||
_, _ = fmt.Fprint(w, `{"id":"cmpl_1","object":"text_completion","created":1,"model":"accounts/fireworks/models/deepseek-v3p2","choices":[{"text":" banana","index":0,"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"completion_tokens":1,"total_tokens":5}}`)
|
||||
case "/v1/responses":
|
||||
_, _ = fmt.Fprint(w, `{"id":"resp_1","object":"response","created_at":1,"status":"completed","model":"accounts/fireworks/models/deepseek-v3p2","output":[{"id":"msg_1","type":"message","status":"completed","role":"assistant","content":[{"type":"output_text","text":"hello","annotations":[],"logprobs":[]}]}],"previous_response_id":"resp_previous","max_tool_calls":2,"store":true,"usage":{"input_tokens":1,"input_tokens_details":{"cached_tokens":0,"cached_read_tokens":0,"cached_write_tokens":0},"output_tokens":1,"total_tokens":2}}`)
|
||||
case "/v1/embeddings":
|
||||
_, _ = fmt.Fprint(w, `{"object":"list","model":"accounts/fireworks/models/nomic-embed-text-v1.5","data":[{"object":"embedding","index":0,"embedding":[0.1,0.2,0.3]}],"usage":{"prompt_tokens":2,"total_tokens":2}}`)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestFireworksProvider(t, server.URL)
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
key := schemas.Key{Value: *schemas.NewEnvVar("test-key")}
|
||||
|
||||
tt.run(t, provider, ctx, key)
|
||||
|
||||
if requestedPath != tt.expectedPath {
|
||||
t.Fatalf("expected request path %q, got %q", tt.expectedPath, requestedPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFireworksResponsesStreamUsesNativeResponsesEndpoint verifies that Fireworks responses streaming targets the native responses endpoint.
|
||||
func TestFireworksResponsesStreamUsesNativeResponsesEndpoint(t *testing.T) {
|
||||
var requestedPath string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestedPath = r.URL.Path
|
||||
if r.URL.Path != "/v1/responses" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"sequence_number\":0,\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":1,\"status\":\"completed\",\"model\":\"accounts/fireworks/models/deepseek-v3p2\",\"output\":[{\"id\":\"msg_1\",\"type\":\"message\",\"status\":\"completed\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello\",\"annotations\":[],\"logprobs\":[]}]}],\"usage\":{\"input_tokens\":1,\"input_tokens_details\":{\"cached_tokens\":0,\"cached_read_tokens\":0,\"cached_write_tokens\":0},\"output_tokens\":1,\"total_tokens\":2}}}\n\n")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := newTestFireworksProvider(t, server.URL)
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
key := schemas.Key{Value: *schemas.NewEnvVar("test-key")}
|
||||
postHookRunner := func(_ *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
||||
return result, err
|
||||
}
|
||||
|
||||
stream, err := provider.ResponsesStream(ctx, postHookRunner, nil, key, &schemas.BifrostResponsesRequest{
|
||||
Provider: schemas.Fireworks,
|
||||
Model: "accounts/fireworks/models/deepseek-v3p2",
|
||||
Input: []schemas.ResponsesMessage{
|
||||
llmtests.CreateBasicResponsesMessage("hello"),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ResponsesStream returned error: %v", llmtests.GetErrorMessage(err))
|
||||
}
|
||||
|
||||
sawCompleted := false
|
||||
for chunk := range stream {
|
||||
if chunk != nil && chunk.BifrostResponsesStreamResponse != nil &&
|
||||
chunk.BifrostResponsesStreamResponse.Type == schemas.ResponsesStreamResponseTypeCompleted {
|
||||
sawCompleted = true
|
||||
}
|
||||
}
|
||||
|
||||
if requestedPath != "/v1/responses" {
|
||||
t.Fatalf("expected responses stream to hit /v1/responses, got %q", requestedPath)
|
||||
}
|
||||
if !sawCompleted {
|
||||
t.Fatal("expected a completed responses stream chunk")
|
||||
}
|
||||
}
|
||||
|
||||
// newTestFireworksProvider creates a Fireworks provider configured to hit a local test server.
|
||||
func newTestFireworksProvider(t *testing.T, baseURL string) *fireworksprovider.FireworksProvider {
|
||||
t.Helper()
|
||||
|
||||
provider, err := fireworksprovider.NewFireworksProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{
|
||||
BaseURL: baseURL,
|
||||
DefaultRequestTimeoutInSeconds: 30,
|
||||
},
|
||||
}, bifrost.NewNoOpLogger())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Fireworks provider: %v", err)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
370
core/providers/gemini/batch.go
Normal file
370
core/providers/gemini/batch.go
Normal file
@@ -0,0 +1,370 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ToBifrostBatchStatus converts Gemini batch job state to Bifrost status.
|
||||
func ToBifrostBatchStatus(geminiState string) schemas.BatchStatus {
|
||||
switch geminiState {
|
||||
case GeminiBatchStatePending, GeminiBatchStateRunning:
|
||||
return schemas.BatchStatusInProgress
|
||||
case GeminiBatchStateSucceeded:
|
||||
return schemas.BatchStatusCompleted
|
||||
case GeminiBatchStateFailed:
|
||||
return schemas.BatchStatusFailed
|
||||
case GeminiBatchStateCancelling:
|
||||
return schemas.BatchStatusCancelling
|
||||
case GeminiBatchStateCancelled:
|
||||
return schemas.BatchStatusCancelled
|
||||
case GeminiBatchStateExpired:
|
||||
return schemas.BatchStatusExpired
|
||||
default:
|
||||
return schemas.BatchStatus(geminiState)
|
||||
}
|
||||
}
|
||||
|
||||
// ToGeminiBatchStatus converts Bifrost batch status to Gemini batch job state.
|
||||
func ToGeminiBatchStatus(status schemas.BatchStatus) string {
|
||||
switch status {
|
||||
case schemas.BatchStatusValidating, schemas.BatchStatusInProgress:
|
||||
return GeminiBatchStateRunning
|
||||
case schemas.BatchStatusFinalizing:
|
||||
return GeminiBatchStateRunning
|
||||
case schemas.BatchStatusCompleted, schemas.BatchStatusEnded:
|
||||
return GeminiBatchStateSucceeded
|
||||
case schemas.BatchStatusFailed:
|
||||
return GeminiBatchStateFailed
|
||||
case schemas.BatchStatusCancelling:
|
||||
return GeminiBatchStateCancelling
|
||||
case schemas.BatchStatusCancelled:
|
||||
return GeminiBatchStateCancelled
|
||||
case schemas.BatchStatusExpired:
|
||||
return GeminiBatchStateExpired
|
||||
default:
|
||||
return GeminiBatchStateUnspecified
|
||||
}
|
||||
}
|
||||
|
||||
// ToGeminiBatchJobResponse converts Bifrost batch create response to Gemini batch job response format.
|
||||
func ToGeminiBatchJobResponse(resp *schemas.BifrostBatchCreateResponse) *GeminiBatchJobResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
succeededCount := resp.RequestCounts.Succeeded
|
||||
if succeededCount == 0 {
|
||||
succeededCount = resp.RequestCounts.Completed
|
||||
}
|
||||
|
||||
geminiResp := &GeminiBatchJobResponse{
|
||||
Name: resp.ID,
|
||||
Metadata: &GeminiBatchMetadata{
|
||||
Name: resp.ID,
|
||||
Type: "type.googleapis.com/google.ai.generativelanguage.v1beta.BatchPredictionJob",
|
||||
CreateTime: formatGeminiTimestamp(resp.CreatedAt),
|
||||
UpdateTime: formatGeminiTimestamp(resp.CreatedAt),
|
||||
State: ToGeminiBatchStatus(resp.Status),
|
||||
BatchStats: &GeminiBatchStats{
|
||||
RequestCount: resp.RequestCounts.Total,
|
||||
PendingRequestCount: max(0, resp.RequestCounts.Total-succeededCount-resp.RequestCounts.Failed),
|
||||
SuccessfulRequestCount: succeededCount,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if resp.OperationName != nil && *resp.OperationName != "" {
|
||||
geminiResp.Metadata.Name = *resp.OperationName
|
||||
geminiResp.Name = *resp.OperationName
|
||||
}
|
||||
|
||||
if resp.InputFileID != "" {
|
||||
geminiResp.Metadata.InputConfig = &GeminiBatchMetadataInputConfig{
|
||||
FileName: resp.InputFileID,
|
||||
}
|
||||
}
|
||||
|
||||
if resp.OutputFileID != nil && *resp.OutputFileID != "" {
|
||||
geminiResp.Dest = &GeminiBatchDest{
|
||||
FileName: *resp.OutputFileID,
|
||||
}
|
||||
geminiResp.Metadata.Output = &GeminiBatchMetadataOutputConfig{
|
||||
ResponsesFile: *resp.OutputFileID,
|
||||
}
|
||||
}
|
||||
|
||||
if resp.Status == schemas.BatchStatusCompleted ||
|
||||
resp.Status == schemas.BatchStatusEnded ||
|
||||
resp.Status == schemas.BatchStatusFailed ||
|
||||
resp.Status == schemas.BatchStatusExpired ||
|
||||
resp.Status == schemas.BatchStatusCancelled {
|
||||
geminiResp.Done = true
|
||||
}
|
||||
|
||||
return geminiResp
|
||||
}
|
||||
|
||||
// ToGeminiBatchRetrieveResponse converts a Bifrost batch retrieve response to Gemini batch job response format.
|
||||
func ToGeminiBatchRetrieveResponse(resp *schemas.BifrostBatchRetrieveResponse) *GeminiBatchJobResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
succeededCount := resp.RequestCounts.Succeeded
|
||||
if succeededCount == 0 {
|
||||
succeededCount = resp.RequestCounts.Completed
|
||||
}
|
||||
|
||||
pendingCount := resp.RequestCounts.Pending
|
||||
if pendingCount == 0 && resp.RequestCounts.Total > 0 {
|
||||
processedCount := resp.RequestCounts.Completed
|
||||
if processedCount == 0 {
|
||||
processedCount = succeededCount
|
||||
}
|
||||
pendingCount = resp.RequestCounts.Total - processedCount - resp.RequestCounts.Failed
|
||||
if pendingCount < 0 {
|
||||
pendingCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
geminiResp := &GeminiBatchJobResponse{
|
||||
Name: resp.ID,
|
||||
Metadata: &GeminiBatchMetadata{
|
||||
Name: resp.ID,
|
||||
Type: "type.googleapis.com/google.ai.generativelanguage.v1beta.BatchPredictionJob",
|
||||
CreateTime: formatGeminiTimestamp(resp.CreatedAt),
|
||||
UpdateTime: formatGeminiTimestamp(resp.CreatedAt),
|
||||
State: ToGeminiBatchStatus(resp.Status),
|
||||
BatchStats: &GeminiBatchStats{
|
||||
RequestCount: resp.RequestCounts.Total,
|
||||
PendingRequestCount: pendingCount,
|
||||
SuccessfulRequestCount: succeededCount,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if resp.OperationName != nil && *resp.OperationName != "" {
|
||||
geminiResp.Metadata.Name = *resp.OperationName
|
||||
geminiResp.Name = *resp.OperationName
|
||||
}
|
||||
|
||||
if resp.Done != nil {
|
||||
geminiResp.Done = *resp.Done
|
||||
} else {
|
||||
geminiResp.Done = resp.Status == schemas.BatchStatusCompleted ||
|
||||
resp.Status == schemas.BatchStatusEnded ||
|
||||
resp.Status == schemas.BatchStatusFailed ||
|
||||
resp.Status == schemas.BatchStatusExpired ||
|
||||
resp.Status == schemas.BatchStatusCancelled
|
||||
}
|
||||
|
||||
if resp.InputFileID != "" {
|
||||
geminiResp.Metadata.InputConfig = &GeminiBatchMetadataInputConfig{
|
||||
FileName: resp.InputFileID,
|
||||
}
|
||||
}
|
||||
|
||||
if resp.OutputFileID != nil && *resp.OutputFileID != "" {
|
||||
geminiResp.Dest = &GeminiBatchDest{
|
||||
FileName: *resp.OutputFileID,
|
||||
}
|
||||
geminiResp.Metadata.Output = &GeminiBatchMetadataOutputConfig{
|
||||
ResponsesFile: *resp.OutputFileID,
|
||||
}
|
||||
}
|
||||
|
||||
// Set end time from the most relevant terminal timestamp
|
||||
var endTime int64
|
||||
if resp.CompletedAt != nil {
|
||||
endTime = *resp.CompletedAt
|
||||
} else if resp.FailedAt != nil {
|
||||
endTime = *resp.FailedAt
|
||||
} else if resp.ExpiredAt != nil {
|
||||
endTime = *resp.ExpiredAt
|
||||
} else if resp.CancelledAt != nil {
|
||||
endTime = *resp.CancelledAt
|
||||
}
|
||||
if endTime > 0 {
|
||||
geminiResp.Metadata.EndTime = formatGeminiTimestamp(endTime)
|
||||
}
|
||||
|
||||
return geminiResp
|
||||
}
|
||||
|
||||
// ToGeminiBatchListResponse converts a Bifrost batch list response to Gemini format.
|
||||
func ToGeminiBatchListResponse(resp *schemas.BifrostBatchListResponse) *GeminiBatchListResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
operations := make([]GeminiBatchJobResponse, 0, len(resp.Data))
|
||||
for i := range resp.Data {
|
||||
if geminiResp := ToGeminiBatchRetrieveResponse(&resp.Data[i]); geminiResp != nil {
|
||||
operations = append(operations, *geminiResp)
|
||||
}
|
||||
}
|
||||
|
||||
geminiListResp := &GeminiBatchListResponse{
|
||||
Operations: operations,
|
||||
}
|
||||
|
||||
if resp.NextCursor != nil {
|
||||
geminiListResp.NextPageToken = *resp.NextCursor
|
||||
}
|
||||
|
||||
return geminiListResp
|
||||
}
|
||||
|
||||
// parseGeminiTimestamp converts Gemini RFC3339 timestamp to Unix timestamp.
|
||||
func parseGeminiTimestamp(timestamp string) int64 {
|
||||
if timestamp == "" {
|
||||
return 0
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, timestamp)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
// extractBatchIDFromName extracts the batch ID from the full resource name.
|
||||
// e.g., "batches/abc123" -> "abc123"
|
||||
func extractBatchIDFromName(name string) string {
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(name, "/")
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
|
||||
// downloadBatchResultsFile downloads and parses a batch results file from Gemini.
|
||||
// Returns the parsed result items from the JSONL file and any parse errors encountered.
|
||||
func (provider *GeminiProvider) downloadBatchResultsFile(ctx context.Context, key schemas.Key, fileName string) ([]schemas.BatchResultItem, []schemas.BatchError, *schemas.BifrostError) {
|
||||
// Create request to download the file
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Build download URL - use the download endpoint with alt=media
|
||||
// The base URL is like https://generativelanguage.googleapis.com/v1beta
|
||||
// We need to change it to https://generativelanguage.googleapis.com/download/v1beta
|
||||
baseURL := strings.Replace(provider.networkConfig.BaseURL, "/v1beta", "/download/v1beta", 1)
|
||||
|
||||
// Ensure fileName has proper format
|
||||
fileID := fileName
|
||||
if !strings.HasPrefix(fileID, "files/") {
|
||||
fileID = "files/" + fileID
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s:download?alt=media", baseURL, fileID)
|
||||
|
||||
provider.logger.Debug("gemini batch results file download url: " + url)
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
req.SetRequestURI(url)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("x-goog-api-key", key.Value.GetValue())
|
||||
}
|
||||
|
||||
// Make request
|
||||
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, nil, bifrostErr
|
||||
}
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
return nil, nil, parseGeminiError(resp)
|
||||
}
|
||||
|
||||
body, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
}
|
||||
|
||||
// Parse JSONL content - each line is a separate JSON object
|
||||
// Use streaming parser to avoid string conversion and collect parse errors
|
||||
results := make([]schemas.BatchResultItem, 0)
|
||||
|
||||
parseResult := providerUtils.ParseJSONL(body, func(line []byte) error {
|
||||
var resultLine GeminiBatchFileResultLine
|
||||
if err := sonic.Unmarshal(line, &resultLine); err != nil {
|
||||
provider.logger.Warn("gemini batch results file parse error: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
customID := resultLine.Key
|
||||
if customID == "" {
|
||||
customID = fmt.Sprintf("request-%d", len(results))
|
||||
}
|
||||
|
||||
resultItem := schemas.BatchResultItem{
|
||||
CustomID: customID,
|
||||
}
|
||||
|
||||
if resultLine.Error != nil {
|
||||
resultItem.Error = &schemas.BatchResultError{
|
||||
Code: fmt.Sprintf("%d", resultLine.Error.Code),
|
||||
Message: resultLine.Error.Message,
|
||||
}
|
||||
} else if resultLine.Response != nil {
|
||||
// Convert the response to a map for the Body field
|
||||
respBody := make(map[string]interface{})
|
||||
if len(resultLine.Response.Candidates) > 0 {
|
||||
candidate := resultLine.Response.Candidates[0]
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
var textParts []string
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
textParts = append(textParts, part.Text)
|
||||
}
|
||||
}
|
||||
if len(textParts) > 0 {
|
||||
respBody["text"] = strings.Join(textParts, "")
|
||||
}
|
||||
}
|
||||
respBody["finish_reason"] = string(candidate.FinishReason)
|
||||
}
|
||||
if resultLine.Response.UsageMetadata != nil {
|
||||
respBody["usage"] = map[string]interface{}{
|
||||
"prompt_tokens": resultLine.Response.UsageMetadata.PromptTokenCount,
|
||||
"completion_tokens": resultLine.Response.UsageMetadata.CandidatesTokenCount,
|
||||
"total_tokens": resultLine.Response.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
}
|
||||
|
||||
resultItem.Response = &schemas.BatchResultResponse{
|
||||
StatusCode: 200,
|
||||
Body: respBody,
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, resultItem)
|
||||
return nil
|
||||
})
|
||||
|
||||
return results, parseResult.Errors, nil
|
||||
}
|
||||
|
||||
// extractGeminiUsageMetadata extracts usage metadata (as ints) from Gemini response
|
||||
func extractGeminiUsageMetadata(geminiResponse *GenerateContentResponse) (int, int, int) {
|
||||
var inputTokens, outputTokens, totalTokens int
|
||||
if geminiResponse.UsageMetadata != nil {
|
||||
usageMetadata := geminiResponse.UsageMetadata
|
||||
inputTokens = int(usageMetadata.PromptTokenCount)
|
||||
outputTokens = int(usageMetadata.CandidatesTokenCount)
|
||||
totalTokens = int(usageMetadata.TotalTokenCount)
|
||||
}
|
||||
return inputTokens, outputTokens, totalTokens
|
||||
}
|
||||
558
core/providers/gemini/chat.go
Normal file
558
core/providers/gemini/chat.go
Normal file
@@ -0,0 +1,558 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToGeminiChatCompletionRequest converts a BifrostChatRequest to Gemini's generation request format for chat completion
|
||||
func ToGeminiChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*GeminiGenerationRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Create the base Gemini generation request
|
||||
geminiReq := &GeminiGenerationRequest{
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
|
||||
// Convert parameters to generation config
|
||||
if bifrostReq.Params != nil {
|
||||
geminiReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
var err error
|
||||
geminiReq.GenerationConfig, err = convertParamsToGenerationConfig(bifrostReq.Params, []string{}, bifrostReq.Model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Handle tool-related parameters
|
||||
if len(bifrostReq.Params.Tools) > 0 {
|
||||
geminiReq.Tools = convertBifrostToolsToGemini(bifrostReq.Params.Tools)
|
||||
|
||||
// Convert tool choice to tool config
|
||||
if bifrostReq.Params.ToolChoice != nil {
|
||||
geminiReq.ToolConfig = convertToolChoiceToToolConfig(bifrostReq.Params.ToolChoice)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle extra parameters
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
// Safety settings
|
||||
if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok {
|
||||
delete(geminiReq.ExtraParams, "safety_settings")
|
||||
if settings, ok := SafeExtractSafetySettings(safetySettings); ok {
|
||||
geminiReq.SafetySettings = settings
|
||||
}
|
||||
}
|
||||
|
||||
// Cached content
|
||||
if cachedContent, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["cached_content"]); ok {
|
||||
delete(geminiReq.ExtraParams, "cached_content")
|
||||
geminiReq.CachedContent = cachedContent
|
||||
}
|
||||
|
||||
// Labels
|
||||
if labels, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "labels"); ok {
|
||||
delete(geminiReq.ExtraParams, "labels")
|
||||
if labelMap, ok := schemas.SafeExtractStringMap(labels); ok {
|
||||
geminiReq.Labels = labelMap
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Convert chat completion messages to Gemini format
|
||||
contents, systemInstruction := convertBifrostMessagesToGemini(bifrostReq.Input)
|
||||
if systemInstruction != nil {
|
||||
geminiReq.SystemInstruction = systemInstruction
|
||||
}
|
||||
geminiReq.Contents = contents
|
||||
return geminiReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostChatResponse converts a GenerateContentResponse to a BifrostChatResponse
|
||||
func (response *GenerateContentResponse) ToBifrostChatResponse() *schemas.BifrostChatResponse {
|
||||
bifrostResp := &schemas.BifrostChatResponse{
|
||||
ID: response.ResponseID,
|
||||
Model: response.ModelVersion,
|
||||
Object: "chat.completion",
|
||||
}
|
||||
|
||||
// Set creation timestamp if available
|
||||
if !response.CreateTime.IsZero() {
|
||||
bifrostResp.Created = int(response.CreateTime.Unix())
|
||||
}
|
||||
|
||||
// Handle empty candidates (filtered/malformed responses)
|
||||
if len(response.Candidates) == 0 {
|
||||
finishReason := ConvertGeminiFinishReasonToBifrost(FinishReasonMalformedFunctionCall)
|
||||
return createErrorResponse(response, finishReason, false)
|
||||
}
|
||||
|
||||
candidate := response.Candidates[0]
|
||||
|
||||
// Check for filtered finish reasons that indicate errors
|
||||
if isErrorFinishReason(candidate.FinishReason) {
|
||||
finishReason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason)
|
||||
return createErrorResponse(response, finishReason, false)
|
||||
}
|
||||
|
||||
// Collect all content and tool calls into a single message
|
||||
var toolCalls []schemas.ChatAssistantMessageToolCall
|
||||
var contentBlocks []schemas.ChatContentBlock
|
||||
var reasoningDetails []schemas.ChatReasoningDetails
|
||||
var contentStr *string
|
||||
|
||||
// Process candidate content to extract text, tool calls, and reasoning
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
// Handle thought/reasoning text separately - add to reasoning details
|
||||
if part.Text != "" && part.Thought {
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: &part.Text,
|
||||
})
|
||||
continue
|
||||
}
|
||||
// Handle regular text
|
||||
if part.Text != "" {
|
||||
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: &part.Text,
|
||||
})
|
||||
// Add thought signature to reasoning details if present with text
|
||||
if len(part.ThoughtSignature) > 0 {
|
||||
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
|
||||
Signature: &thoughtSig,
|
||||
})
|
||||
}
|
||||
}
|
||||
if part.FunctionCall != nil {
|
||||
function := schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: &part.FunctionCall.Name,
|
||||
}
|
||||
|
||||
if len(part.FunctionCall.Args) > 0 {
|
||||
function.Arguments = string(part.FunctionCall.Args)
|
||||
}
|
||||
|
||||
callID := part.FunctionCall.Name
|
||||
if part.FunctionCall.ID != "" {
|
||||
callID = part.FunctionCall.ID
|
||||
}
|
||||
|
||||
// Embed thought signature into CallID if present (matches responses.go pattern)
|
||||
if len(part.ThoughtSignature) > 0 && !strings.Contains(callID, thoughtSignatureSeparator) {
|
||||
encoded := base64.RawURLEncoding.EncodeToString(part.ThoughtSignature)
|
||||
callID = fmt.Sprintf("%s%s%s", callID, thoughtSignatureSeparator, encoded)
|
||||
}
|
||||
|
||||
toolCall := schemas.ChatAssistantMessageToolCall{
|
||||
Index: uint16(len(toolCalls)),
|
||||
Type: schemas.Ptr(string(schemas.ChatToolChoiceTypeFunction)),
|
||||
ID: &callID,
|
||||
Function: function,
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
|
||||
// Also add to reasoning details for backward compatibility
|
||||
if len(part.ThoughtSignature) > 0 {
|
||||
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
|
||||
// Extract base ID without signature for reasoning detail lookup
|
||||
baseCallID := callID
|
||||
if strings.Contains(callID, thoughtSignatureSeparator) {
|
||||
parts := strings.SplitN(callID, thoughtSignatureSeparator, 2)
|
||||
if len(parts) == 2 {
|
||||
baseCallID = parts[0]
|
||||
}
|
||||
}
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
|
||||
Signature: &thoughtSig,
|
||||
ID: schemas.Ptr(fmt.Sprintf("tool_call_%s", baseCallID)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if part.FunctionResponse != nil {
|
||||
// Extract the output from the response
|
||||
output := extractFunctionResponseOutput(part.FunctionResponse)
|
||||
|
||||
// Add as text content block
|
||||
if output != "" {
|
||||
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: &output,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle code execution results
|
||||
if part.CodeExecutionResult != nil {
|
||||
output := part.CodeExecutionResult.Output
|
||||
if part.CodeExecutionResult.Outcome != OutcomeOK {
|
||||
output = "Error: " + output
|
||||
}
|
||||
if output != "" {
|
||||
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: &output,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle executable code
|
||||
if part.ExecutableCode != nil {
|
||||
codeContent := "```" + part.ExecutableCode.Language + "\n" + part.ExecutableCode.Code + "\n```"
|
||||
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: &codeContent,
|
||||
})
|
||||
}
|
||||
|
||||
// Handle standalone thought signature (not associated with function call or text)
|
||||
if len(part.ThoughtSignature) > 0 && part.FunctionCall == nil && part.Text == "" {
|
||||
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
|
||||
Signature: &thoughtSig,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Build the choice with message
|
||||
message := &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
}
|
||||
|
||||
if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText {
|
||||
contentStr = contentBlocks[0].Text
|
||||
contentBlocks = nil
|
||||
}
|
||||
|
||||
message.Content = &schemas.ChatMessageContent{
|
||||
ContentStr: contentStr,
|
||||
ContentBlocks: contentBlocks,
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 || len(reasoningDetails) > 0 {
|
||||
message.ChatAssistantMessage = &schemas.ChatAssistantMessage{
|
||||
ToolCalls: toolCalls,
|
||||
ReasoningDetails: reasoningDetails,
|
||||
}
|
||||
}
|
||||
|
||||
// Convert finish reason to Bifrost format.
|
||||
// Gemini uses "STOP" for both normal text completions and tool call responses —
|
||||
// it has no dedicated finish reason for tool calls. Override to "tool_calls" when
|
||||
// tool calls are present so downstream consumers see a uniform signal.
|
||||
finishReason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason)
|
||||
if len(toolCalls) > 0 && finishReason == "stop" {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
bifrostResp.Choices = append(bifrostResp.Choices, schemas.BifrostResponseChoice{
|
||||
Index: 0,
|
||||
FinishReason: &finishReason,
|
||||
LogProbs: ConvertGeminiLogprobsResultToBifrost(candidate.LogprobsResult),
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Set usage information
|
||||
bifrostResp.Usage = ConvertGeminiUsageMetadataToChatUsage(response.UsageMetadata)
|
||||
|
||||
return bifrostResp
|
||||
}
|
||||
|
||||
// GeminiStreamState tracks tool-call index across streaming chunks.
|
||||
type GeminiStreamState struct {
|
||||
nextToolCallIndex int
|
||||
hadToolCalls bool // true if any tool calls were seen in this stream
|
||||
}
|
||||
|
||||
// NewGeminiStreamState returns initialised stream state for one streaming response.
|
||||
func NewGeminiStreamState() *GeminiStreamState {
|
||||
return &GeminiStreamState{}
|
||||
}
|
||||
|
||||
// ToBifrostChatCompletionStream converts a Gemini streaming response to a Bifrost Chat Completion Stream response
|
||||
// Returns the response, error (if any), and a boolean indicating if this is the last chunk
|
||||
func (response *GenerateContentResponse) ToBifrostChatCompletionStream(state *GeminiStreamState) (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) {
|
||||
if response == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
if state == nil {
|
||||
state = NewGeminiStreamState()
|
||||
}
|
||||
|
||||
// Handle empty candidates (filtered/malformed responses)
|
||||
if len(response.Candidates) == 0 {
|
||||
finishReason := ConvertGeminiFinishReasonToBifrost(FinishReasonMalformedFunctionCall)
|
||||
return createErrorResponse(response, finishReason, true), nil, true
|
||||
}
|
||||
|
||||
candidate := response.Candidates[0]
|
||||
|
||||
// Check for filtered finish reasons that indicate errors
|
||||
if isErrorFinishReason(candidate.FinishReason) {
|
||||
finishReason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason)
|
||||
return createErrorResponse(response, finishReason, true), nil, true
|
||||
}
|
||||
|
||||
// Determine if this is the last chunk based on finish reason and usage metadata
|
||||
isLastChunk := candidate.FinishReason != "" && response.UsageMetadata != nil
|
||||
|
||||
// Create the streaming response
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
ID: response.ResponseID,
|
||||
Model: response.ModelVersion,
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
|
||||
// Set creation timestamp if available
|
||||
if !response.CreateTime.IsZero() {
|
||||
streamResponse.Created = int(response.CreateTime.Unix())
|
||||
}
|
||||
|
||||
// Build delta content
|
||||
delta := &schemas.ChatStreamResponseChoiceDelta{}
|
||||
|
||||
// Process content parts
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
// Set role from the first chunk (Gemini uses "model" for assistant)
|
||||
if candidate.Content.Role != "" {
|
||||
role := candidate.Content.Role
|
||||
if role == string(RoleModel) {
|
||||
role = string(schemas.ChatMessageRoleAssistant)
|
||||
}
|
||||
delta.Role = &role
|
||||
}
|
||||
|
||||
var textContent string
|
||||
var toolCalls []schemas.ChatAssistantMessageToolCall
|
||||
var reasoningDetails []schemas.ChatReasoningDetails
|
||||
|
||||
for _, part := range candidate.Content.Parts {
|
||||
switch {
|
||||
case part.Text != "" && part.Thought:
|
||||
// Thought/reasoning content - add to reasoning details
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: &part.Text,
|
||||
})
|
||||
|
||||
case part.Text != "":
|
||||
// Regular text content
|
||||
textContent += part.Text
|
||||
|
||||
case part.FunctionCall != nil:
|
||||
// Function call
|
||||
jsonArgs := ""
|
||||
if len(part.FunctionCall.Args) > 0 {
|
||||
jsonArgs = string(part.FunctionCall.Args)
|
||||
}
|
||||
|
||||
// Use ID if available, otherwise use function name
|
||||
callID := part.FunctionCall.Name
|
||||
if part.FunctionCall.ID != "" {
|
||||
callID = part.FunctionCall.ID
|
||||
}
|
||||
|
||||
// Embed thought signature into CallID if present
|
||||
if len(part.ThoughtSignature) > 0 && !strings.Contains(callID, thoughtSignatureSeparator) {
|
||||
encoded := base64.RawURLEncoding.EncodeToString(part.ThoughtSignature)
|
||||
callID = fmt.Sprintf("%s%s%s", callID, thoughtSignatureSeparator, encoded)
|
||||
}
|
||||
|
||||
toolCallIdx := state.nextToolCallIndex
|
||||
state.nextToolCallIndex++
|
||||
|
||||
toolCall := schemas.ChatAssistantMessageToolCall{
|
||||
Index: uint16(toolCallIdx),
|
||||
Type: schemas.Ptr(string(schemas.ChatToolTypeFunction)),
|
||||
ID: &callID,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: &part.FunctionCall.Name,
|
||||
Arguments: jsonArgs,
|
||||
},
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
|
||||
// Also add thought signature to reasoning details if present
|
||||
if len(part.ThoughtSignature) > 0 {
|
||||
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
|
||||
// Extract base ID without signature for reasoning detail lookup
|
||||
baseCallID := callID
|
||||
if strings.Contains(callID, thoughtSignatureSeparator) {
|
||||
parts := strings.SplitN(callID, thoughtSignatureSeparator, 2)
|
||||
if len(parts) == 2 {
|
||||
baseCallID = parts[0]
|
||||
}
|
||||
}
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
|
||||
Signature: &thoughtSig,
|
||||
ID: schemas.Ptr(fmt.Sprintf("tool_call_%s", baseCallID)),
|
||||
})
|
||||
}
|
||||
|
||||
case part.FunctionResponse != nil:
|
||||
// Extract the output from the response and add to text content
|
||||
output := extractFunctionResponseOutput(part.FunctionResponse)
|
||||
if output != "" {
|
||||
textContent += output
|
||||
}
|
||||
case part.CodeExecutionResult != nil:
|
||||
output := part.CodeExecutionResult.Output
|
||||
if part.CodeExecutionResult.Outcome != OutcomeOK {
|
||||
output = "Error: " + output
|
||||
}
|
||||
if output != "" {
|
||||
textContent += output
|
||||
}
|
||||
case part.ExecutableCode != nil:
|
||||
codeContent := "```" + part.ExecutableCode.Language + "\n" + part.ExecutableCode.Code + "\n```"
|
||||
textContent += codeContent
|
||||
}
|
||||
|
||||
// Handle thought signature separately (not part of the switch since it can co-exist with other types)
|
||||
if len(part.ThoughtSignature) > 0 && part.FunctionCall == nil {
|
||||
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
|
||||
Signature: &thoughtSig,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Set text content if present
|
||||
if textContent != "" {
|
||||
delta.Content = &textContent
|
||||
}
|
||||
|
||||
// Set reasoning details if present
|
||||
if len(reasoningDetails) > 0 {
|
||||
delta.ReasoningDetails = reasoningDetails
|
||||
}
|
||||
|
||||
// Set tool calls if present
|
||||
if len(toolCalls) > 0 {
|
||||
delta.ToolCalls = toolCalls
|
||||
state.hadToolCalls = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if delta has any content - if not and it's not the last chunk, skip it
|
||||
hasDeltaContent := delta.Role != nil || delta.Content != nil || len(delta.ToolCalls) > 0 || len(delta.ReasoningDetails) > 0
|
||||
if !hasDeltaContent && !isLastChunk {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
// Build the choice
|
||||
var finishReason *string
|
||||
if isLastChunk && candidate.FinishReason != "" {
|
||||
reason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason)
|
||||
// Gemini uses "STOP" for both text completions and tool call responses.
|
||||
// Override to "tool_calls" when tool calls were seen in this stream for uniformity.
|
||||
if (len(delta.ToolCalls) > 0 || state.hadToolCalls) && reason == "stop" {
|
||||
reason = "tool_calls"
|
||||
}
|
||||
finishReason = &reason
|
||||
}
|
||||
|
||||
choice := schemas.BifrostResponseChoice{
|
||||
Index: int(candidate.Index),
|
||||
FinishReason: finishReason,
|
||||
LogProbs: ConvertGeminiLogprobsResultToBifrost(candidate.LogprobsResult),
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: delta,
|
||||
},
|
||||
}
|
||||
|
||||
streamResponse.Choices = []schemas.BifrostResponseChoice{choice}
|
||||
|
||||
// Add usage information if this is the last chunk
|
||||
if isLastChunk && response.UsageMetadata != nil {
|
||||
streamResponse.Usage = ConvertGeminiUsageMetadataToChatUsage(response.UsageMetadata)
|
||||
}
|
||||
|
||||
return streamResponse, nil, isLastChunk
|
||||
}
|
||||
|
||||
// isErrorFinishReason checks if a finish reason indicates a filtered or error response
|
||||
func isErrorFinishReason(reason FinishReason) bool {
|
||||
return reason == FinishReasonSafety ||
|
||||
reason == FinishReasonRecitation ||
|
||||
reason == FinishReasonMalformedFunctionCall ||
|
||||
reason == FinishReasonBlocklist ||
|
||||
reason == FinishReasonProhibitedContent ||
|
||||
reason == FinishReasonSPII ||
|
||||
reason == FinishReasonImageSafety ||
|
||||
reason == FinishReasonUnexpectedToolCall ||
|
||||
reason == FinishReasonMissingThoughtSignature ||
|
||||
reason == FinishReasonMalformedResponse ||
|
||||
reason == FinishReasonImageProhibitedContent ||
|
||||
reason == FinishReasonImageRecitation ||
|
||||
reason == FinishReasonTooManyToolCalls ||
|
||||
reason == FinishReasonNoImage
|
||||
}
|
||||
|
||||
// createErrorResponse creates a complete BifrostChatResponse for error cases
|
||||
func createErrorResponse(response *GenerateContentResponse, finishReason string, isStream bool) *schemas.BifrostChatResponse {
|
||||
var choice schemas.BifrostResponseChoice
|
||||
if isStream {
|
||||
choice = schemas.BifrostResponseChoice{
|
||||
Index: 0,
|
||||
FinishReason: &finishReason,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
choice = schemas.BifrostResponseChoice{
|
||||
Index: 0,
|
||||
FinishReason: &finishReason,
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
objectType := "chat.completion"
|
||||
if isStream {
|
||||
objectType = "chat.completion.chunk"
|
||||
}
|
||||
|
||||
errorResp := &schemas.BifrostChatResponse{
|
||||
ID: response.ResponseID,
|
||||
Model: response.ModelVersion,
|
||||
Object: objectType,
|
||||
Choices: []schemas.BifrostResponseChoice{choice},
|
||||
Usage: ConvertGeminiUsageMetadataToChatUsage(response.UsageMetadata),
|
||||
}
|
||||
|
||||
if !response.CreateTime.IsZero() {
|
||||
errorResp.Created = int(response.CreateTime.Unix())
|
||||
}
|
||||
|
||||
return errorResp
|
||||
}
|
||||
80
core/providers/gemini/count_tokens.go
Normal file
80
core/providers/gemini/count_tokens.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostCountTokensResponse converts a Gemini count tokens response to Bifrost format.
|
||||
func (resp *GeminiCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sum prompt tokens and map modality-specific counts
|
||||
inputTokens := 0
|
||||
inputDetails := &schemas.ResponsesResponseInputTokens{}
|
||||
|
||||
for _, m := range resp.PromptTokensDetails {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
inputTokens += int(m.TokenCount)
|
||||
mod := strings.ToLower(string(m.Modality))
|
||||
// handle audio modality
|
||||
if strings.Contains(mod, "audio") {
|
||||
inputDetails.AudioTokens += int(m.TokenCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Set cached tokens from top-level field if present
|
||||
if resp.CachedContentTokenCount != 0 {
|
||||
inputDetails.CachedReadTokens = int(resp.CachedContentTokenCount)
|
||||
} else if resp.CacheTokensDetails != nil {
|
||||
// If cache tokens details present, sum them
|
||||
cachedSum := 0
|
||||
for _, m := range resp.CacheTokensDetails {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
cachedSum += int(m.TokenCount)
|
||||
if strings.Contains(strings.ToLower(string(m.Modality)), "audio") {
|
||||
// also populate audio tokens from cache into AudioTokens (additive)
|
||||
inputDetails.AudioTokens += int(m.TokenCount)
|
||||
}
|
||||
}
|
||||
inputDetails.CachedReadTokens = cachedSum
|
||||
}
|
||||
|
||||
total := int(resp.TotalTokens)
|
||||
|
||||
return &schemas.BifrostCountTokensResponse{
|
||||
Model: model,
|
||||
Object: "response.input_tokens",
|
||||
InputTokens: inputTokens,
|
||||
InputTokensDetails: inputDetails,
|
||||
TotalTokens: &total,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{},
|
||||
}
|
||||
}
|
||||
|
||||
// ToGeminiCountTokensResponse converts a Bifrost count tokens response to Gemini format.
|
||||
func ToGeminiCountTokensResponse(bifrostResp *schemas.BifrostCountTokensResponse) *GeminiCountTokensResponse {
|
||||
if bifrostResp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
response := &GeminiCountTokensResponse{
|
||||
TotalTokens: int32(bifrostResp.InputTokens),
|
||||
}
|
||||
|
||||
// Map cached content token count if available
|
||||
if bifrostResp.InputTokensDetails != nil && bifrostResp.InputTokensDetails.CachedReadTokens > 0 {
|
||||
response.CachedContentTokenCount = int32(bifrostResp.InputTokensDetails.CachedReadTokens)
|
||||
} else {
|
||||
response.CachedContentTokenCount = 0
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
247
core/providers/gemini/embedding.go
Normal file
247
core/providers/gemini/embedding.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToGeminiEmbeddingRequest converts a BifrostRequest with embedding input to Gemini's batch embedding request format
|
||||
// GeminiGenerationRequest contains requests array for batch embed content endpoint
|
||||
func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *GeminiBatchEmbeddingRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
|
||||
return nil
|
||||
}
|
||||
|
||||
embeddingInput := bifrostReq.Input
|
||||
|
||||
// Collect all texts to embed
|
||||
var texts []string
|
||||
if embeddingInput.Text != nil {
|
||||
texts = append(texts, *embeddingInput.Text)
|
||||
}
|
||||
if len(embeddingInput.Texts) > 0 {
|
||||
texts = append(texts, embeddingInput.Texts...)
|
||||
}
|
||||
|
||||
if len(texts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create batch embedding request with one request per text
|
||||
batchRequest := &GeminiBatchEmbeddingRequest{
|
||||
Requests: make([]GeminiEmbeddingRequest, len(texts)),
|
||||
}
|
||||
if bifrostReq.Params != nil {
|
||||
batchRequest.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
// Create individual embedding requests for each text
|
||||
for i, text := range texts {
|
||||
embeddingReq := GeminiEmbeddingRequest{
|
||||
Model: "models/" + bifrostReq.Model,
|
||||
Content: &Content{
|
||||
Parts: []*Part{
|
||||
{
|
||||
Text: text,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add parameters if available
|
||||
if bifrostReq.Params != nil {
|
||||
if bifrostReq.Params.Dimensions != nil {
|
||||
embeddingReq.OutputDimensionality = bifrostReq.Params.Dimensions
|
||||
}
|
||||
|
||||
// Handle extra parameters
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
if taskType, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["taskType"]); ok {
|
||||
delete(batchRequest.ExtraParams, "taskType")
|
||||
embeddingReq.TaskType = taskType
|
||||
}
|
||||
if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok {
|
||||
delete(batchRequest.ExtraParams, "title")
|
||||
embeddingReq.Title = title
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
batchRequest.Requests[i] = embeddingReq
|
||||
}
|
||||
|
||||
return batchRequest
|
||||
}
|
||||
|
||||
// ToGeminiEmbeddingResponse converts a BifrostResponse with embedding data to Gemini's embedding response format
|
||||
func ToGeminiEmbeddingResponse(bifrostResp *schemas.BifrostEmbeddingResponse) *GeminiEmbeddingResponse {
|
||||
if bifrostResp == nil || len(bifrostResp.Data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResp := &GeminiEmbeddingResponse{
|
||||
Embeddings: make([]GeminiEmbedding, len(bifrostResp.Data)),
|
||||
}
|
||||
|
||||
// Convert each embedding from Bifrost format to Gemini format
|
||||
for i, embedding := range bifrostResp.Data {
|
||||
var values []float64
|
||||
|
||||
// Extract embedding values from BifrostEmbeddingResponse
|
||||
if embedding.Embedding.EmbeddingArray != nil {
|
||||
values = append([]float64(nil), embedding.Embedding.EmbeddingArray...)
|
||||
} else if len(embedding.Embedding.Embedding2DArray) > 0 {
|
||||
// If it's a 2D array, take the first array
|
||||
values = append([]float64(nil), embedding.Embedding.Embedding2DArray[0]...)
|
||||
}
|
||||
|
||||
geminiEmbedding := GeminiEmbedding{
|
||||
Values: values,
|
||||
}
|
||||
|
||||
// Add statistics if available (token count from usage metadata)
|
||||
if bifrostResp.Usage != nil {
|
||||
geminiEmbedding.Statistics = &ContentEmbeddingStatistics{
|
||||
TokenCount: int32(bifrostResp.Usage.PromptTokens),
|
||||
}
|
||||
}
|
||||
|
||||
geminiResp.Embeddings[i] = geminiEmbedding
|
||||
}
|
||||
|
||||
// Set metadata if available (for Vertex API compatibility)
|
||||
if bifrostResp.Usage != nil {
|
||||
geminiResp.Metadata = &EmbedContentMetadata{
|
||||
BillableCharacterCount: int32(bifrostResp.Usage.PromptTokens),
|
||||
}
|
||||
}
|
||||
|
||||
return geminiResp
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingResponse converts a Gemini embedding response to BifrostEmbeddingResponse format
|
||||
func ToBifrostEmbeddingResponse(geminiResp *GeminiEmbeddingResponse, model string) *schemas.BifrostEmbeddingResponse {
|
||||
if geminiResp == nil || len(geminiResp.Embeddings) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResp := &schemas.BifrostEmbeddingResponse{
|
||||
Data: make([]schemas.EmbeddingData, len(geminiResp.Embeddings)),
|
||||
Model: model,
|
||||
Object: "list",
|
||||
}
|
||||
|
||||
// Convert each embedding from Gemini format to Bifrost format
|
||||
for i, geminiEmbedding := range geminiResp.Embeddings {
|
||||
embeddingData := schemas.EmbeddingData{
|
||||
Index: i,
|
||||
Object: "embedding",
|
||||
Embedding: schemas.EmbeddingStruct{
|
||||
EmbeddingArray: geminiEmbedding.Values,
|
||||
},
|
||||
}
|
||||
|
||||
bifrostResp.Data[i] = embeddingData
|
||||
}
|
||||
|
||||
// Convert usage metadata if available
|
||||
if geminiResp.Metadata != nil || (len(geminiResp.Embeddings) > 0 && geminiResp.Embeddings[0].Statistics != nil) {
|
||||
bifrostResp.Usage = &schemas.BifrostLLMUsage{}
|
||||
|
||||
// Use statistics from the first embedding if available
|
||||
if geminiResp.Embeddings[0].Statistics != nil {
|
||||
bifrostResp.Usage.PromptTokens = int(geminiResp.Embeddings[0].Statistics.TokenCount)
|
||||
} else if geminiResp.Metadata != nil {
|
||||
// Fall back to metadata if statistics are not available
|
||||
bifrostResp.Usage.PromptTokens = int(geminiResp.Metadata.BillableCharacterCount)
|
||||
}
|
||||
|
||||
// Set total tokens same as prompt tokens for embeddings
|
||||
bifrostResp.Usage.TotalTokens = bifrostResp.Usage.PromptTokens
|
||||
}
|
||||
|
||||
return bifrostResp
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingRequest converts a GeminiGenerationRequest to BifrostEmbeddingRequest format
|
||||
func (request *GeminiGenerationRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest {
|
||||
if request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
|
||||
|
||||
// Create the embedding request
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
|
||||
}
|
||||
|
||||
// SDK batch embedding request contains multiple embedding requests with same parameters but different text fields.
|
||||
if len(request.Requests) > 0 {
|
||||
var texts []string
|
||||
for _, req := range request.Requests {
|
||||
if req.Content != nil && len(req.Content.Parts) > 0 {
|
||||
for _, part := range req.Content.Parts {
|
||||
if part != nil && part.Text != "" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(texts) > 0 {
|
||||
bifrostReq.Input = &schemas.EmbeddingInput{}
|
||||
if len(texts) == 1 {
|
||||
bifrostReq.Input.Text = &texts[0]
|
||||
} else {
|
||||
bifrostReq.Input.Texts = texts
|
||||
}
|
||||
}
|
||||
|
||||
embeddingRequest := request.Requests[0]
|
||||
|
||||
// Convert parameters
|
||||
if embeddingRequest.OutputDimensionality != nil || embeddingRequest.TaskType != nil || embeddingRequest.Title != nil {
|
||||
bifrostReq.Params = &schemas.EmbeddingParameters{}
|
||||
|
||||
if embeddingRequest.OutputDimensionality != nil {
|
||||
bifrostReq.Params.Dimensions = embeddingRequest.OutputDimensionality
|
||||
}
|
||||
|
||||
// Handle extra parameters
|
||||
if embeddingRequest.TaskType != nil || embeddingRequest.Title != nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
if embeddingRequest.TaskType != nil {
|
||||
bifrostReq.Params.ExtraParams["taskType"] = embeddingRequest.TaskType
|
||||
}
|
||||
if embeddingRequest.Title != nil {
|
||||
bifrostReq.Params.ExtraParams["title"] = embeddingRequest.Title
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generation-style requests (e.g., non-Imagen :predict) carry text in contents[].parts[].
|
||||
// If no SDK requests[] were provided, derive embedding input from contents.
|
||||
if bifrostReq.Input == nil {
|
||||
var texts []string
|
||||
for _, content := range request.Contents {
|
||||
for _, part := range content.Parts {
|
||||
if part != nil && part.Text != "" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(texts) > 0 {
|
||||
bifrostReq.Input = &schemas.EmbeddingInput{}
|
||||
if len(texts) == 1 {
|
||||
bifrostReq.Input.Text = &texts[0]
|
||||
} else {
|
||||
bifrostReq.Input.Texts = texts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
79
core/providers/gemini/errors.go
Normal file
79
core/providers/gemini/errors.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ToGeminiError derives a GeminiGenerationError from a BifrostError
|
||||
func ToGeminiError(bifrostErr *schemas.BifrostError) *GeminiGenerationError {
|
||||
if bifrostErr == nil {
|
||||
return nil
|
||||
}
|
||||
code := 500
|
||||
status := ""
|
||||
if bifrostErr.Error != nil && bifrostErr.Error.Type != nil {
|
||||
status = *bifrostErr.Error.Type
|
||||
}
|
||||
message := ""
|
||||
if bifrostErr.Error != nil && bifrostErr.Error.Message != "" {
|
||||
message = bifrostErr.Error.Message
|
||||
}
|
||||
if bifrostErr.StatusCode != nil {
|
||||
code = *bifrostErr.StatusCode
|
||||
}
|
||||
return &GeminiGenerationError{
|
||||
Error: &GeminiGenerationErrorStruct{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Status: status,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// parseGeminiError parses Gemini error responses
|
||||
func parseGeminiError(resp *fasthttp.Response) *schemas.BifrostError {
|
||||
// Try to parse as []GeminiGenerationError
|
||||
var errorResps []GeminiGenerationError
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResps)
|
||||
if len(errorResps) > 0 {
|
||||
var message string
|
||||
var firstError *GeminiGenerationErrorStruct
|
||||
for _, errorResp := range errorResps {
|
||||
if errorResp.Error != nil {
|
||||
if firstError == nil {
|
||||
firstError = errorResp.Error
|
||||
}
|
||||
message = message + errorResp.Error.Message + "\n"
|
||||
}
|
||||
}
|
||||
// Trim trailing newline
|
||||
message = strings.TrimSuffix(message, "\n")
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
// Set Code from first error if available
|
||||
if firstError != nil {
|
||||
bifrostErr.Error.Code = schemas.Ptr(strconv.Itoa(firstError.Code))
|
||||
}
|
||||
// Set Message to trimmed concatenated message
|
||||
bifrostErr.Error.Message = message
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
// Try to parse as GeminiGenerationError
|
||||
var errorResp GeminiGenerationError
|
||||
bifrostErr = providerUtils.HandleProviderAPIError(resp, &errorResp)
|
||||
if errorResp.Error != nil {
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
bifrostErr.Error.Code = schemas.Ptr(strconv.Itoa(errorResp.Error.Code))
|
||||
bifrostErr.Error.Message = errorResp.Error.Message
|
||||
}
|
||||
return bifrostErr
|
||||
}
|
||||
145
core/providers/gemini/files.go
Normal file
145
core/providers/gemini/files.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// Gemini Files API types
|
||||
// The Gemini Files API allows uploading files for use with multimodal models.
|
||||
|
||||
// GeminiFileResponse represents a file object from Gemini's API.
|
||||
type GeminiFileResponse struct {
|
||||
Name string `json:"name"` // Resource name (e.g., "files/abc123")
|
||||
DisplayName string `json:"displayName"` // User-provided display name
|
||||
MimeType string `json:"mimeType"` // MIME type of the file
|
||||
SizeBytes string `json:"sizeBytes"` // Size in bytes (as string)
|
||||
CreateTime string `json:"createTime"` // RFC3339 timestamp
|
||||
UpdateTime string `json:"updateTime"` // RFC3339 timestamp
|
||||
ExpirationTime string `json:"expirationTime,omitempty"` // RFC3339 timestamp when file will be deleted
|
||||
SHA256Hash string `json:"sha256Hash"` // Base64 encoded SHA256 hash
|
||||
URI string `json:"uri"` // URI for accessing the file
|
||||
State string `json:"state"` // "PROCESSING", "ACTIVE", "FAILED"
|
||||
VideoMetadata *GeminiFileVideoMetadata `json:"videoMetadata,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiFileVideoMetadata contains video-specific metadata.
|
||||
type GeminiFileVideoMetadata struct {
|
||||
VideoDuration string `json:"videoDuration"` // Duration in seconds
|
||||
}
|
||||
|
||||
// GeminiFileListResponse represents the response from listing files.
|
||||
type GeminiFileListResponse struct {
|
||||
Files []GeminiFileResponse `json:"files"`
|
||||
NextPageToken string `json:"nextPageToken,omitempty"`
|
||||
}
|
||||
|
||||
// ToBifrostFileStatus converts Gemini file state to Bifrost status.
|
||||
func ToBifrostFileStatus(state string) schemas.FileStatus {
|
||||
switch state {
|
||||
case "PROCESSING":
|
||||
return schemas.FileStatusProcessing
|
||||
case "ACTIVE":
|
||||
return schemas.FileStatusProcessed
|
||||
case "FAILED":
|
||||
return schemas.FileStatusError
|
||||
default:
|
||||
return schemas.FileStatus(strings.ToLower(state))
|
||||
}
|
||||
}
|
||||
|
||||
// ToGeminiFileListResponse converts a Bifrost file list response to Gemini format.
|
||||
func ToGeminiFileListResponse(resp *schemas.BifrostFileListResponse) *GeminiFileListResponse {
|
||||
files := make([]GeminiFileResponse, len(resp.Data))
|
||||
for i, f := range resp.Data {
|
||||
updateAt := f.UpdatedAt
|
||||
if updateAt == 0 {
|
||||
updateAt = f.CreatedAt
|
||||
}
|
||||
files[i] = GeminiFileResponse{
|
||||
Name: f.ID,
|
||||
DisplayName: f.Filename,
|
||||
SizeBytes: fmt.Sprintf("%d", f.Bytes),
|
||||
CreateTime: formatGeminiTimestamp(f.CreatedAt),
|
||||
UpdateTime: formatGeminiTimestamp(updateAt),
|
||||
State: toGeminiFileState(f.Status),
|
||||
ExpirationTime: formatGeminiTimestamp(safeDerefInt64(f.ExpiresAt)),
|
||||
}
|
||||
}
|
||||
result := &GeminiFileListResponse{Files: files}
|
||||
if resp.After != nil && *resp.After != "" {
|
||||
result.NextPageToken = *resp.After
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ToGeminiFileRetrieveResponse converts a Bifrost file retrieve response to Gemini format.
|
||||
func ToGeminiFileRetrieveResponse(resp *schemas.BifrostFileRetrieveResponse) *GeminiFileResponse {
|
||||
updateAt := resp.UpdatedAt
|
||||
if updateAt == 0 {
|
||||
updateAt = resp.CreatedAt
|
||||
}
|
||||
return &GeminiFileResponse{
|
||||
Name: resp.ID,
|
||||
DisplayName: resp.Filename,
|
||||
SizeBytes: fmt.Sprintf("%d", resp.Bytes),
|
||||
CreateTime: formatGeminiTimestamp(resp.CreatedAt),
|
||||
UpdateTime: formatGeminiTimestamp(updateAt),
|
||||
State: toGeminiFileState(resp.Status),
|
||||
URI: resp.StorageURI,
|
||||
ExpirationTime: formatGeminiTimestamp(safeDerefInt64(resp.ExpiresAt)),
|
||||
}
|
||||
}
|
||||
|
||||
// toGeminiFileState converts Bifrost file status to Gemini state.
|
||||
func toGeminiFileState(status schemas.FileStatus) string {
|
||||
switch status {
|
||||
case schemas.FileStatusProcessing:
|
||||
return "PROCESSING"
|
||||
case schemas.FileStatusProcessed:
|
||||
return "ACTIVE"
|
||||
case schemas.FileStatusError:
|
||||
return "FAILED"
|
||||
default:
|
||||
return strings.ToUpper(string(status))
|
||||
}
|
||||
}
|
||||
|
||||
// formatGeminiTimestamp converts Unix timestamp to Gemini RFC3339 format.
|
||||
func formatGeminiTimestamp(unixTime int64) string {
|
||||
if unixTime == 0 {
|
||||
return ""
|
||||
}
|
||||
return time.Unix(unixTime, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// safeDerefInt64 safely dereferences an int64 pointer.
|
||||
func safeDerefInt64(ptr *int64) int64 {
|
||||
if ptr == nil {
|
||||
return 0
|
||||
}
|
||||
return *ptr
|
||||
}
|
||||
|
||||
// ToGeminiFileUploadResponse converts a Bifrost file upload response to Gemini format.
|
||||
func ToGeminiFileUploadResponse(resp *schemas.BifrostFileUploadResponse) map[string]interface{} {
|
||||
file := map[string]interface{}{
|
||||
"name": resp.ID,
|
||||
"displayName": resp.Filename,
|
||||
"mimeType": "application/octet-stream",
|
||||
"sizeBytes": fmt.Sprintf("%d", resp.Bytes),
|
||||
"createTime": formatGeminiTimestamp(resp.CreatedAt),
|
||||
"updateTime": formatGeminiTimestamp(resp.CreatedAt),
|
||||
"state": toGeminiFileState(resp.Status),
|
||||
"uri": resp.StorageURI,
|
||||
}
|
||||
if exp := formatGeminiTimestamp(safeDerefInt64(resp.ExpiresAt)); exp != "" {
|
||||
file["expirationTime"] = exp
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"file": file,
|
||||
}
|
||||
}
|
||||
4306
core/providers/gemini/gemini.go
Normal file
4306
core/providers/gemini/gemini.go
Normal file
File diff suppressed because it is too large
Load Diff
62
core/providers/gemini/gemini_stream_reader_test.go
Normal file
62
core/providers/gemini/gemini_stream_reader_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadNextSSEDataLine_SkipInlineDataOnGzipReader(t *testing.T) {
|
||||
var compressed bytes.Buffer
|
||||
gz := gzip.NewWriter(&compressed)
|
||||
payload := "data: {\"candidates\":[{\"content\":{\"parts\":[{\"inlineData\":{\"data\":\"abc\"}}]}}]}\n" +
|
||||
"data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]}}]}\n"
|
||||
if _, err := gz.Write([]byte(payload)); err != nil {
|
||||
t.Fatalf("failed to write gzip payload: %v", err)
|
||||
}
|
||||
if err := gz.Close(); err != nil {
|
||||
t.Fatalf("failed to close gzip writer: %v", err)
|
||||
}
|
||||
|
||||
reader, err := gzip.NewReader(bytes.NewReader(compressed.Bytes()))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create gzip reader: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
line, err := readNextSSEDataLine(bufio.NewReaderSize(reader, 64*1024), true)
|
||||
if err != nil {
|
||||
t.Fatalf("expected next non-inline line, got error: %v", err)
|
||||
}
|
||||
|
||||
want := []byte(`{"candidates":[{"content":{"parts":[{"text":"ok"}]}}]}`)
|
||||
if !bytes.Equal(line, want) {
|
||||
t.Fatalf("expected %q, got %q", string(want), string(line))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadNextSSEDataLine_SkipInlineDataContinuedLine(t *testing.T) {
|
||||
longInline := bytes.Repeat([]byte("x"), 70*1024)
|
||||
var stream bytes.Buffer
|
||||
stream.WriteString("data: {\"candidates\":[{\"content\":{\"parts\":[{\"inlineData\":{\"data\":\"")
|
||||
stream.Write(longInline)
|
||||
stream.WriteString("\"}}]}}]}\n")
|
||||
stream.WriteString("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]}}]}\n")
|
||||
|
||||
line, err := readNextSSEDataLine(bufio.NewReaderSize(bytes.NewReader(stream.Bytes()), 64*1024), true)
|
||||
if err != nil {
|
||||
t.Fatalf("expected next non-inline line, got error: %v", err)
|
||||
}
|
||||
|
||||
want := []byte(`{"candidates":[{"content":{"parts":[{"text":"ok"}]}}]}`)
|
||||
if !bytes.Equal(line, want) {
|
||||
t.Fatalf("expected %q, got %q", string(want), string(line))
|
||||
}
|
||||
|
||||
_, err = readNextSSEDataLine(bufio.NewReaderSize(bytes.NewReader(nil), 64*1024), true)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF on empty reader, got %v", err)
|
||||
}
|
||||
}
|
||||
2988
core/providers/gemini/gemini_test.go
Normal file
2988
core/providers/gemini/gemini_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1095
core/providers/gemini/images.go
Normal file
1095
core/providers/gemini/images.go
Normal file
File diff suppressed because it is too large
Load Diff
61
core/providers/gemini/list_models_single_payload_test.go
Normal file
61
core/providers/gemini/list_models_single_payload_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testNoopLogger struct{}
|
||||
|
||||
func (testNoopLogger) Debug(string, ...any) {}
|
||||
func (testNoopLogger) Info(string, ...any) {}
|
||||
func (testNoopLogger) Warn(string, ...any) {}
|
||||
func (testNoopLogger) Error(string, ...any) {}
|
||||
func (testNoopLogger) Fatal(string, ...any) {}
|
||||
func (testNoopLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (testNoopLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (testNoopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
func TestListModelsByKey_ParsesSingleModelPayload(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.URL.Path != "/models/gemini-2.5-pro" {
|
||||
http.Error(w, "unexpected path", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"name":"models/gemini-2.5-pro","displayName":"Gemini 2.5 Pro","description":"test","inputTokenLimit":1048576,"outputTokenLimit":8192,"supportedGenerationMethods":["generateContent"]}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := NewGeminiProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{BaseURL: ts.URL},
|
||||
}, testNoopLogger{})
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
ctx.SetValue(schemas.BifrostContextKeyURLPath, "/models/gemini-2.5-pro")
|
||||
|
||||
key := schemas.Key{Value: *schemas.NewEnvVar("dummy-key")}
|
||||
// Unfiltered=true bypasses the allowed/alias/blacklist filter pipeline so
|
||||
// this test can focus on the single-model-payload parsing code path in
|
||||
// listModelsByKey (gemini.go:215-220).
|
||||
resp, err := provider.listModelsByKey(ctx, key, &schemas.BifrostListModelsRequest{Provider: schemas.Gemini, Unfiltered: true})
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Data, 1)
|
||||
assert.Equal(t, "gemini/gemini-2.5-pro", resp.Data[0].ID)
|
||||
require.NotNil(t, resp.Data[0].Name)
|
||||
assert.Equal(t, "Gemini 2.5 Pro", *resp.Data[0].Name)
|
||||
}
|
||||
105
core/providers/gemini/models.go
Normal file
105
core/providers/gemini/models.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func toGeminiModelResourceName(modelID string) string {
|
||||
if strings.HasPrefix(modelID, "models/") {
|
||||
return modelID
|
||||
}
|
||||
if idx := strings.Index(modelID, "/"); idx >= 0 && idx+1 < len(modelID) {
|
||||
return "models/" + modelID[idx+1:]
|
||||
}
|
||||
return "models/" + modelID
|
||||
}
|
||||
|
||||
func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.Models)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: providerKey,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.Models {
|
||||
contextLength := model.InputTokenLimit + model.OutputTokenLimit
|
||||
// Gemini returns model names with a "models/" prefix — strip it before filtering
|
||||
// so that allowedModels entries like "gemini-1.5-pro" match correctly.
|
||||
modelName := strings.TrimPrefix(model.Name, "models/")
|
||||
|
||||
for _, result := range pipeline.FilterModel(modelName) {
|
||||
entry := schemas.Model{
|
||||
ID: string(providerKey) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(model.DisplayName),
|
||||
Description: schemas.Ptr(model.Description),
|
||||
ContextLength: schemas.Ptr(int(contextLength)),
|
||||
MaxInputTokens: schemas.Ptr(model.InputTokenLimit),
|
||||
MaxOutputTokens: schemas.Ptr(model.OutputTokenLimit),
|
||||
SupportedMethods: model.SupportedGenerationMethods,
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
entry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, entry)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
func ToGeminiListModelsResponse(resp *schemas.BifrostListModelsResponse) *GeminiListModelsResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResponse := &GeminiListModelsResponse{
|
||||
Models: make([]GeminiModel, 0, len(resp.Data)),
|
||||
NextPageToken: resp.NextPageToken,
|
||||
}
|
||||
|
||||
for _, model := range resp.Data {
|
||||
geminiModel := GeminiModel{
|
||||
Name: toGeminiModelResourceName(model.ID),
|
||||
SupportedGenerationMethods: model.SupportedMethods,
|
||||
}
|
||||
if model.Name != nil {
|
||||
geminiModel.DisplayName = *model.Name
|
||||
}
|
||||
if model.Description != nil {
|
||||
geminiModel.Description = *model.Description
|
||||
}
|
||||
if model.MaxInputTokens != nil {
|
||||
geminiModel.InputTokenLimit = *model.MaxInputTokens
|
||||
}
|
||||
if model.MaxOutputTokens != nil {
|
||||
geminiModel.OutputTokenLimit = *model.MaxOutputTokens
|
||||
}
|
||||
|
||||
geminiResponse.Models = append(geminiResponse.Models, geminiModel)
|
||||
}
|
||||
|
||||
return geminiResponse
|
||||
}
|
||||
41
core/providers/gemini/models_test.go
Normal file
41
core/providers/gemini/models_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestToGeminiModelResourceName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{name: "already native", input: "models/gemini-2.5-pro", want: "models/gemini-2.5-pro"},
|
||||
{name: "provider prefixed", input: "gemini/gemini-2.5-pro", want: "models/gemini-2.5-pro"},
|
||||
{name: "bare model", input: "gemini-2.5-pro", want: "models/gemini-2.5-pro"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, toGeminiModelResourceName(tc.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToGeminiListModelsResponse_UsesNativeModelResourceName(t *testing.T) {
|
||||
resp := &schemas.BifrostListModelsResponse{
|
||||
Data: []schemas.Model{
|
||||
{ID: "gemini/gemini-2.5-pro"},
|
||||
{ID: "models/gemini-2.5-flash"},
|
||||
},
|
||||
}
|
||||
|
||||
converted := ToGeminiListModelsResponse(resp)
|
||||
if assert.Len(t, converted.Models, 2) {
|
||||
assert.Equal(t, "models/gemini-2.5-pro", converted.Models[0].Name)
|
||||
assert.Equal(t, "models/gemini-2.5-flash", converted.Models[1].Name)
|
||||
}
|
||||
}
|
||||
56
core/providers/gemini/payload_ordering_test.go
Normal file
56
core/providers/gemini/payload_ordering_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPayloadOrdering_GeminiGenerationRequest(t *testing.T) {
|
||||
req := &GeminiGenerationRequest{
|
||||
Model: "gemini-2.0-flash",
|
||||
Contents: []Content{
|
||||
{
|
||||
Parts: []*Part{{Text: "hello"}},
|
||||
Role: "user",
|
||||
},
|
||||
},
|
||||
GenerationConfig: GenerationConfig{
|
||||
Temperature: schemas.Ptr(float64(0.7)),
|
||||
},
|
||||
Tools: []Tool{
|
||||
{
|
||||
FunctionDeclarations: []*FunctionDeclaration{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: &Schema{
|
||||
Type: "OBJECT",
|
||||
Properties: map[string]*Schema{
|
||||
"location": {Type: "STRING"},
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
golden := `{"model":"gemini-2.0-flash","contents":[{"parts":[{"text":"hello"}],"role":"user"}],"generationConfig":{"temperature":0.7},"tools":[{"functionDeclarations":[{"description":"Get weather","name":"get_weather","parameters":{"properties":{"location":{"type":"STRING"}},"required":["location"],"type":"OBJECT"}}]}]}`
|
||||
|
||||
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
|
||||
|
||||
// Determinism: 100 iterations must produce identical bytes
|
||||
for i := 0; i < 100; i++ {
|
||||
iter, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
|
||||
}
|
||||
}
|
||||
3652
core/providers/gemini/responses.go
Normal file
3652
core/providers/gemini/responses.go
Normal file
File diff suppressed because it is too large
Load Diff
200
core/providers/gemini/speech.go
Normal file
200
core/providers/gemini/speech.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostSpeechRequest converts a GeminiGenerationRequest to a BifrostSpeechRequest
|
||||
func (request *GeminiGenerationRequest) ToBifrostSpeechRequest(ctx *schemas.BifrostContext) *schemas.BifrostSpeechRequest {
|
||||
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
|
||||
|
||||
bifrostReq := &schemas.BifrostSpeechRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}
|
||||
|
||||
// Extract text input from contents
|
||||
var textInput string
|
||||
for _, content := range request.Contents {
|
||||
for _, part := range content.Parts {
|
||||
if part.Text != "" {
|
||||
textInput += part.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bifrostReq.Input = &schemas.SpeechInput{
|
||||
Input: textInput,
|
||||
}
|
||||
|
||||
// Convert generation config to parameters
|
||||
if request.GenerationConfig.SpeechConfig != nil || len(request.GenerationConfig.ResponseModalities) > 0 {
|
||||
bifrostReq.Params = &schemas.SpeechParameters{}
|
||||
|
||||
// Extract voice config from speech config
|
||||
if request.GenerationConfig.SpeechConfig != nil {
|
||||
// Handle single-speaker voice config
|
||||
if request.GenerationConfig.SpeechConfig.VoiceConfig != nil {
|
||||
bifrostReq.Params.VoiceConfig = &schemas.SpeechVoiceInput{}
|
||||
|
||||
if request.GenerationConfig.SpeechConfig.VoiceConfig.PrebuiltVoiceConfig != nil {
|
||||
voiceName := request.GenerationConfig.SpeechConfig.VoiceConfig.PrebuiltVoiceConfig.VoiceName
|
||||
bifrostReq.Params.VoiceConfig.Voice = &voiceName
|
||||
}
|
||||
} else if request.GenerationConfig.SpeechConfig.MultiSpeakerVoiceConfig != nil {
|
||||
// Handle multi-speaker voice config
|
||||
// Convert to Bifrost's MultiVoiceConfig format
|
||||
if len(request.GenerationConfig.SpeechConfig.MultiSpeakerVoiceConfig.SpeakerVoiceConfigs) > 0 {
|
||||
bifrostReq.Params.VoiceConfig = &schemas.SpeechVoiceInput{}
|
||||
multiVoiceConfig := make([]schemas.VoiceConfig, 0, len(request.GenerationConfig.SpeechConfig.MultiSpeakerVoiceConfig.SpeakerVoiceConfigs))
|
||||
|
||||
for _, speakerConfig := range request.GenerationConfig.SpeechConfig.MultiSpeakerVoiceConfig.SpeakerVoiceConfigs {
|
||||
if speakerConfig.VoiceConfig != nil && speakerConfig.VoiceConfig.PrebuiltVoiceConfig != nil {
|
||||
multiVoiceConfig = append(multiVoiceConfig, schemas.VoiceConfig{
|
||||
Speaker: speakerConfig.Speaker,
|
||||
Voice: speakerConfig.VoiceConfig.PrebuiltVoiceConfig.VoiceName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
bifrostReq.Params.VoiceConfig.MultiVoiceConfig = multiVoiceConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store response modalities in extra params if needed
|
||||
if len(request.GenerationConfig.ResponseModalities) > 0 {
|
||||
if bifrostReq.Params.ExtraParams == nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
}
|
||||
modalities := make([]string, len(request.GenerationConfig.ResponseModalities))
|
||||
for i, mod := range request.GenerationConfig.ResponseModalities {
|
||||
modalities[i] = string(mod)
|
||||
}
|
||||
bifrostReq.Params.ExtraParams["response_modalities"] = modalities
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToGeminiSpeechRequest converts a BifrostSpeechRequest to a GeminiGenerationRequest
|
||||
func ToGeminiSpeechRequest(bifrostReq *schemas.BifrostSpeechRequest) (*GeminiGenerationRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrostReq is nil")
|
||||
}
|
||||
// Here we confirm if the response_format is wav or empty string
|
||||
// If its anything else, we will return an error
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.ResponseFormat != "" && bifrostReq.Params.ResponseFormat != "wav" {
|
||||
return nil, fmt.Errorf("gemini does not support response_format: %s. Only wav or empty string is supported which defaults to wav", bifrostReq.Params.ResponseFormat)
|
||||
}
|
||||
// Create the base Gemini generation request
|
||||
geminiReq := &GeminiGenerationRequest{
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
// Convert parameters to generation config
|
||||
geminiReq.GenerationConfig.ResponseModalities = []Modality{ModalityAudio}
|
||||
// Convert speech input to Gemini format
|
||||
if bifrostReq.Input != nil && bifrostReq.Input.Input != "" {
|
||||
geminiReq.Contents = []Content{
|
||||
{
|
||||
Parts: []*Part{
|
||||
{
|
||||
Text: bifrostReq.Input.Input,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// Add speech config to generation config if voice config is provided
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.VoiceConfig != nil {
|
||||
// Handle both single voice and multi-voice configurations
|
||||
if bifrostReq.Params.VoiceConfig.Voice != nil || len(bifrostReq.Params.VoiceConfig.MultiVoiceConfig) > 0 {
|
||||
addSpeechConfigToGenerationConfig(&geminiReq.GenerationConfig, bifrostReq.Params.VoiceConfig)
|
||||
}
|
||||
geminiReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
}
|
||||
return geminiReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostSpeechResponse converts a GenerateContentResponse to a BifrostSpeechResponse
|
||||
func (response *GenerateContentResponse) ToBifrostSpeechResponse(ctx context.Context) (*schemas.BifrostSpeechResponse, error) {
|
||||
bifrostResp := &schemas.BifrostSpeechResponse{}
|
||||
|
||||
// Process candidates to extract audio content
|
||||
if len(response.Candidates) > 0 {
|
||||
candidate := response.Candidates[0]
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
var audioData []byte
|
||||
// Extract audio data from all parts
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && len(part.InlineData.Data) > 0 {
|
||||
// Check if this is audio data
|
||||
if strings.HasPrefix(part.InlineData.MIMEType, "audio/") {
|
||||
decodedData, err := decodeBase64StringToBytes(part.InlineData.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode base64 audio data: %v", err)
|
||||
}
|
||||
audioData = append(audioData, decodedData...)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(audioData) > 0 {
|
||||
responseFormat := ctx.Value(BifrostContextKeyResponseFormat).(string)
|
||||
// Gemini returns PCM audio (s16le, 24000 Hz, mono)
|
||||
// Convert to WAV for standard playable output format
|
||||
if responseFormat == "wav" {
|
||||
wavData, err := utils.ConvertPCMToWAV(audioData, utils.DefaultGeminiPCMConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert PCM to WAV: %v", err)
|
||||
}
|
||||
bifrostResp.Audio = wavData
|
||||
} else {
|
||||
bifrostResp.Audio = audioData
|
||||
}
|
||||
}
|
||||
|
||||
// Set usage information
|
||||
if response.UsageMetadata != nil {
|
||||
bifrostResp.Usage = convertGeminiUsageMetadataToSpeechUsage(response.UsageMetadata)
|
||||
}
|
||||
}
|
||||
}
|
||||
return bifrostResp, nil
|
||||
}
|
||||
|
||||
// ToGeminiSpeechResponse converts a BifrostSpeechResponse to Gemini's GenerateContentResponse
|
||||
func ToGeminiSpeechResponse(bifrostResp *schemas.BifrostSpeechResponse) *GenerateContentResponse {
|
||||
if bifrostResp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
genaiResp := &GenerateContentResponse{}
|
||||
|
||||
candidate := &Candidate{
|
||||
Content: &Content{
|
||||
Parts: []*Part{
|
||||
{
|
||||
InlineData: &Blob{
|
||||
Data: encodeBytesToBase64String(bifrostResp.Audio),
|
||||
MIMEType: utils.DetectAudioMimeType(bifrostResp.Audio),
|
||||
},
|
||||
},
|
||||
},
|
||||
Role: string(RoleModel),
|
||||
},
|
||||
}
|
||||
|
||||
// Set usage metadata if present
|
||||
if bifrostResp.Usage != nil {
|
||||
genaiResp.UsageMetadata = convertBifrostSpeechUsageToGeminiUsageMetadata(bifrostResp.Usage)
|
||||
}
|
||||
|
||||
genaiResp.Candidates = []*Candidate{candidate}
|
||||
return genaiResp
|
||||
}
|
||||
233
core/providers/gemini/transcription.go
Normal file
233
core/providers/gemini/transcription.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostTranscriptionRequest converts a GeminiGenerationRequest to a BifrostTranscriptionRequest
|
||||
func (request *GeminiGenerationRequest) ToBifrostTranscriptionRequest(ctx *schemas.BifrostContext) (*schemas.BifrostTranscriptionRequest, error) {
|
||||
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
|
||||
|
||||
bifrostReq := &schemas.BifrostTranscriptionRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}
|
||||
|
||||
// Extract audio data and prompt from contents
|
||||
var promptText string
|
||||
var audioData []byte
|
||||
var audioMimeType string
|
||||
|
||||
for _, content := range request.Contents {
|
||||
for _, part := range content.Parts {
|
||||
// Extract text prompt
|
||||
if part.Text != "" {
|
||||
if promptText != "" {
|
||||
promptText += " "
|
||||
}
|
||||
promptText += part.Text
|
||||
}
|
||||
|
||||
// Extract audio data from inline data
|
||||
if part.InlineData != nil && strings.HasPrefix(strings.ToLower(part.InlineData.MIMEType), "audio/") {
|
||||
decodedData, err := decodeBase64StringToBytes(part.InlineData.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode base64 audio data: %v", err)
|
||||
}
|
||||
audioData = append(audioData, decodedData...)
|
||||
if audioMimeType == "" {
|
||||
audioMimeType = part.InlineData.MIMEType
|
||||
}
|
||||
}
|
||||
|
||||
// Extract audio data from file data (would need to be fetched separately in real scenario)
|
||||
// For now, we just note the file URI in extra params
|
||||
if part.FileData != nil && strings.HasPrefix(strings.ToLower(part.FileData.MIMEType), "audio/") {
|
||||
if bifrostReq.Params == nil {
|
||||
bifrostReq.Params = &schemas.TranscriptionParameters{}
|
||||
}
|
||||
if bifrostReq.Params.ExtraParams == nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
}
|
||||
bifrostReq.Params.ExtraParams["file_uri"] = part.FileData.FileURI
|
||||
if audioMimeType == "" {
|
||||
audioMimeType = part.FileData.MIMEType
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set the audio input
|
||||
bifrostReq.Input = &schemas.TranscriptionInput{
|
||||
File: audioData,
|
||||
}
|
||||
|
||||
// Set parameters
|
||||
if bifrostReq.Params == nil {
|
||||
bifrostReq.Params = &schemas.TranscriptionParameters{}
|
||||
}
|
||||
|
||||
// Set prompt if provided
|
||||
if promptText != "" {
|
||||
bifrostReq.Params.Prompt = &promptText
|
||||
}
|
||||
|
||||
// Handle safety settings from request
|
||||
if len(request.SafetySettings) > 0 {
|
||||
if bifrostReq.Params.ExtraParams == nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
}
|
||||
bifrostReq.Params.ExtraParams["safety_settings"] = request.SafetySettings
|
||||
}
|
||||
|
||||
// Handle cached content
|
||||
if request.CachedContent != "" {
|
||||
if bifrostReq.Params.ExtraParams == nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
}
|
||||
bifrostReq.Params.ExtraParams["cached_content"] = request.CachedContent
|
||||
}
|
||||
|
||||
// Handle labels
|
||||
if len(request.Labels) > 0 {
|
||||
if bifrostReq.Params.ExtraParams == nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
}
|
||||
bifrostReq.Params.ExtraParams["labels"] = request.Labels
|
||||
}
|
||||
|
||||
return bifrostReq, nil
|
||||
}
|
||||
|
||||
func ToGeminiTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionRequest) *GeminiGenerationRequest {
|
||||
if bifrostReq == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create the base Gemini generation request
|
||||
geminiReq := &GeminiGenerationRequest{
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
|
||||
// Convert parameters to generation config
|
||||
if bifrostReq.Params != nil {
|
||||
geminiReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
// Handle extra parameters
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
// Safety settings
|
||||
if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok {
|
||||
delete(geminiReq.ExtraParams, "safety_settings")
|
||||
if settings, ok := SafeExtractSafetySettings(safetySettings); ok {
|
||||
geminiReq.SafetySettings = settings
|
||||
}
|
||||
}
|
||||
|
||||
// Cached content
|
||||
if cachedContent, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["cached_content"]); ok {
|
||||
delete(geminiReq.ExtraParams, "cached_content")
|
||||
geminiReq.CachedContent = cachedContent
|
||||
}
|
||||
|
||||
// Labels
|
||||
if labels, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "labels"); ok {
|
||||
if labelMap, ok := schemas.SafeExtractStringMap(labels); ok {
|
||||
delete(geminiReq.ExtraParams, "labels")
|
||||
geminiReq.Labels = labelMap
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the prompt text
|
||||
var prompt string
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.Prompt != nil {
|
||||
prompt = *bifrostReq.Params.Prompt
|
||||
} else {
|
||||
prompt = "Generate a transcript of the speech."
|
||||
}
|
||||
|
||||
// Create parts for the transcription request
|
||||
parts := []*Part{
|
||||
{
|
||||
Text: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
// Add audio file if present
|
||||
if len(bifrostReq.Input.File) > 0 {
|
||||
parts = append(parts, &Part{
|
||||
InlineData: &Blob{
|
||||
MIMEType: utils.DetectAudioMimeType(bifrostReq.Input.File),
|
||||
Data: encodeBytesToBase64String(bifrostReq.Input.File),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
geminiReq.Contents = []Content{
|
||||
{
|
||||
Parts: parts,
|
||||
},
|
||||
}
|
||||
|
||||
return geminiReq
|
||||
}
|
||||
|
||||
// ToBifrostTranscriptionResponse converts a GenerateContentResponse to a BifrostTranscriptionResponse
|
||||
func (response *GenerateContentResponse) ToBifrostTranscriptionResponse() *schemas.BifrostTranscriptionResponse {
|
||||
bifrostResp := &schemas.BifrostTranscriptionResponse{}
|
||||
|
||||
// Process candidates to extract text content
|
||||
if len(response.Candidates) > 0 {
|
||||
candidate := response.Candidates[0]
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
var textContent string
|
||||
|
||||
// Extract text content from all parts
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
textContent += part.Text
|
||||
}
|
||||
}
|
||||
|
||||
if textContent != "" {
|
||||
bifrostResp.Text = textContent
|
||||
bifrostResp.Task = schemas.Ptr("transcribe")
|
||||
|
||||
// Set usage information with modality details
|
||||
bifrostResp.Usage = convertGeminiUsageMetadataToTranscriptionUsage(response.UsageMetadata)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResp
|
||||
}
|
||||
|
||||
// ToGeminiTranscriptionResponse converts a BifrostTranscriptionResponse to Gemini's GenerateContentResponse
|
||||
func ToGeminiTranscriptionResponse(bifrostResp *schemas.BifrostTranscriptionResponse) *GenerateContentResponse {
|
||||
if bifrostResp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
genaiResp := &GenerateContentResponse{}
|
||||
|
||||
candidate := &Candidate{
|
||||
Content: &Content{
|
||||
Parts: []*Part{
|
||||
{
|
||||
Text: bifrostResp.Text,
|
||||
},
|
||||
},
|
||||
Role: string(RoleModel),
|
||||
},
|
||||
}
|
||||
|
||||
// Set usage metadata from transcription usage with modality details
|
||||
genaiResp.UsageMetadata = convertBifrostTranscriptionUsageToGeminiUsageMetadata(bifrostResp.Usage)
|
||||
|
||||
genaiResp.Candidates = []*Candidate{candidate}
|
||||
return genaiResp
|
||||
}
|
||||
2745
core/providers/gemini/types.go
Normal file
2745
core/providers/gemini/types.go
Normal file
File diff suppressed because it is too large
Load Diff
2523
core/providers/gemini/utils.go
Normal file
2523
core/providers/gemini/utils.go
Normal file
File diff suppressed because it is too large
Load Diff
595
core/providers/gemini/videos.go
Normal file
595
core/providers/gemini/videos.go
Normal file
@@ -0,0 +1,595 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const defaultVideoContentType = "video/mp4"
|
||||
|
||||
// sizeToAspectRatio converts OpenAI-style size strings to Gemini aspect ratios.
|
||||
// Gemini supports 16:9 and 9:16. Returns default value if no mapping exists.
|
||||
func sizeToAspectRatio(size string) string {
|
||||
switch size {
|
||||
case "1280x720", "1792x1024":
|
||||
return "16:9"
|
||||
case "720x1280", "1024x1792":
|
||||
return "9:16"
|
||||
default:
|
||||
return "16:9"
|
||||
}
|
||||
}
|
||||
|
||||
func addVideoURLOutput(uri, contentType string) *schemas.VideoOutput {
|
||||
if uri == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(contentType) == "" {
|
||||
contentType = defaultVideoContentType
|
||||
}
|
||||
return &schemas.VideoOutput{
|
||||
Type: schemas.VideoOutputTypeURL,
|
||||
URL: schemas.Ptr(uri),
|
||||
ContentType: contentType,
|
||||
}
|
||||
}
|
||||
|
||||
func addVideoBase64Output(base64Value, contentType string) *schemas.VideoOutput {
|
||||
if base64Value == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(contentType) == "" {
|
||||
contentType = defaultVideoContentType
|
||||
}
|
||||
return &schemas.VideoOutput{
|
||||
Type: schemas.VideoOutputTypeBase64,
|
||||
Base64Data: schemas.Ptr(base64Value),
|
||||
ContentType: contentType,
|
||||
}
|
||||
}
|
||||
|
||||
func parseVideoDataURL(data string) (mimeType string, base64Payload string, ok bool) {
|
||||
if !strings.HasPrefix(data, "data:") {
|
||||
return "", "", false
|
||||
}
|
||||
parts := strings.SplitN(data, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", false
|
||||
}
|
||||
header := parts[0]
|
||||
payload := parts[1]
|
||||
if payload == "" {
|
||||
return "", "", false
|
||||
}
|
||||
header = strings.TrimPrefix(header, "data:")
|
||||
if before, _, found := strings.Cut(header, ";"); found {
|
||||
return before, payload, true
|
||||
}
|
||||
return header, payload, true
|
||||
}
|
||||
|
||||
// ToGeminiVideoGenerationRequest converts a Bifrost video generation request to Gemini REST API format
|
||||
// This creates the request body for POST /models/{model}:predictLongRunning
|
||||
func ToGeminiVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRequest) (*GeminiVideoGenerationRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("bifrost request or input is nil")
|
||||
}
|
||||
|
||||
// Create the instance with prompt
|
||||
instance := &GeminiVideoGenerationInstance{
|
||||
Prompt: bifrostReq.Input.Prompt,
|
||||
}
|
||||
|
||||
// Handle input reference (image for image-to-video)
|
||||
if bifrostReq.Input.InputReference != nil && *bifrostReq.Input.InputReference != "" {
|
||||
// extract mime type and base64 string from input reference
|
||||
sanitizedURL, err := schemas.SanitizeImageURL(*bifrostReq.Input.InputReference)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid input reference: %w", err)
|
||||
}
|
||||
urlInfo := schemas.ExtractURLTypeInfo(sanitizedURL)
|
||||
|
||||
image := &VideoImageData{}
|
||||
|
||||
if urlInfo.DataURLWithoutPrefix != nil {
|
||||
image.BytesBase64Encoded = urlInfo.DataURLWithoutPrefix
|
||||
}
|
||||
image.MimeType = schemas.Ptr("image/png")
|
||||
if urlInfo.MediaType != nil {
|
||||
image.MimeType = urlInfo.MediaType
|
||||
}
|
||||
|
||||
instance.Image = image
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.VideoURI != nil {
|
||||
instance.Video = &VideoGenerationVideoInput{
|
||||
URI: bifrostReq.Params.VideoURI,
|
||||
}
|
||||
}
|
||||
|
||||
req := &GeminiVideoGenerationRequest{
|
||||
Instances: []GeminiVideoGenerationInstance{*instance},
|
||||
}
|
||||
|
||||
// Map parameters if provided
|
||||
if bifrostReq.Params != nil {
|
||||
params := &VideoGenerationParameters{}
|
||||
|
||||
// Extract all video generation parameters from ExtraParams
|
||||
if bifrostReq.Params.NegativePrompt != nil {
|
||||
params.NegativePrompt = bifrostReq.Params.NegativePrompt
|
||||
}
|
||||
|
||||
if bifrostReq.Params.Seconds != nil {
|
||||
seconds, err := strconv.Atoi(*bifrostReq.Params.Seconds)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid seconds value: %w", err)
|
||||
}
|
||||
params.DurationSeconds = &seconds
|
||||
}
|
||||
|
||||
if bifrostReq.Params.Seed != nil {
|
||||
params.Seed = bifrostReq.Params.Seed
|
||||
}
|
||||
|
||||
if bifrostReq.Params.Audio != nil {
|
||||
params.GenerateAudio = bifrostReq.Params.Audio
|
||||
}
|
||||
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if aspectRatio, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["aspectRatio"]); ok {
|
||||
params.AspectRatio = aspectRatio
|
||||
}
|
||||
if resolution, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["resolution"]); ok {
|
||||
params.Resolution = resolution
|
||||
}
|
||||
|
||||
if sampleCount, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["sampleCount"]); ok {
|
||||
params.SampleCount = sampleCount
|
||||
}
|
||||
|
||||
if personGeneration, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["personGeneration"]); ok {
|
||||
params.PersonGeneration = personGeneration
|
||||
}
|
||||
|
||||
if numberOfVideos, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["numberOfVideos"]); ok {
|
||||
params.NumberOfVideos = numberOfVideos
|
||||
}
|
||||
if storageURI, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["storageURI"]); ok {
|
||||
params.StorageURI = storageURI
|
||||
}
|
||||
if compressionQuality, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["compressionQuality"]); ok {
|
||||
params.CompressionQuality = compressionQuality
|
||||
}
|
||||
if enhancePrompt, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enhancePrompt"]); ok {
|
||||
params.EnhancePrompt = enhancePrompt
|
||||
}
|
||||
if resizeMode, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["resizeMode"]); ok {
|
||||
params.ResizeMode = resizeMode
|
||||
}
|
||||
if referenceImages, ok := bifrostReq.Params.ExtraParams["referenceImages"]; ok {
|
||||
if referenceImages, ok := referenceImages.([]VideoReferenceImage); ok && referenceImages != nil {
|
||||
params.ReferenceImages = referenceImages
|
||||
} else if data, err := providerUtils.MarshalSorted(referenceImages); err == nil {
|
||||
var referenceImages []VideoReferenceImage
|
||||
if sonic.Unmarshal(data, &referenceImages) == nil {
|
||||
params.ReferenceImages = referenceImages
|
||||
}
|
||||
}
|
||||
}
|
||||
if lastFrame, ok := bifrostReq.Params.ExtraParams["lastFrame"]; ok {
|
||||
if lastFrame, ok := lastFrame.(*VideoImageData); ok {
|
||||
params.LastFrame = lastFrame
|
||||
} else if data, err := providerUtils.MarshalSorted(lastFrame); err == nil {
|
||||
var lastFrame VideoImageData
|
||||
if sonic.Unmarshal(data, &lastFrame) == nil {
|
||||
params.LastFrame = &lastFrame
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert size to aspect ratio if size is provided and aspect ratio is not already set
|
||||
if params.AspectRatio == nil && bifrostReq.Params.Size != "" {
|
||||
aspectRatio := sizeToAspectRatio(bifrostReq.Params.Size)
|
||||
if aspectRatio != "" {
|
||||
params.AspectRatio = &aspectRatio
|
||||
}
|
||||
}
|
||||
|
||||
req.Parameters = params
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// ToBifrostVideoGenerationResponse converts Gemini operation response to Bifrost format
|
||||
func ToBifrostVideoGenerationResponse(operation *GenerateVideosOperation, model string) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
if operation == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("operation is nil", nil)
|
||||
}
|
||||
|
||||
response := &schemas.BifrostVideoGenerationResponse{
|
||||
ID: operation.Name,
|
||||
Object: "video",
|
||||
CreatedAt: time.Now().Unix(),
|
||||
}
|
||||
if model != "" {
|
||||
response.Model = model
|
||||
}
|
||||
|
||||
// Set status based on operation state
|
||||
if !operation.Done {
|
||||
response.Status = schemas.VideoStatusInProgress
|
||||
if operation.Metadata != nil {
|
||||
if p := providerUtils.GetJSONField([]byte(operation.Metadata), "progress"); p.Exists() {
|
||||
progress := p.Float()
|
||||
response.Progress = &progress
|
||||
}
|
||||
}
|
||||
} else if operation.Error != nil {
|
||||
response.Status = schemas.VideoStatusFailed
|
||||
code := providerUtils.GetJSONField(operation.Error, "code").String()
|
||||
message := providerUtils.GetJSONField(operation.Error, "message").String()
|
||||
if code == "" {
|
||||
code = "video_generation_failed"
|
||||
}
|
||||
if message == "" {
|
||||
message = string(operation.Error)
|
||||
}
|
||||
response.Error = &schemas.VideoCreateError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
} else if operation.Response != nil {
|
||||
// Check new response format with content filtering support
|
||||
if genVideoResp := operation.Response.GenerateVideoResponse; genVideoResp != nil {
|
||||
// Check for content filtering
|
||||
if genVideoResp.RAIMediaFilteredCount > 0 {
|
||||
response.Status = schemas.VideoStatusFailed
|
||||
response.ContentFilter = &schemas.ContentFilterInfo{
|
||||
FilteredCount: int(genVideoResp.RAIMediaFilteredCount),
|
||||
Reasons: genVideoResp.RAIMediaFilteredReasons,
|
||||
}
|
||||
errorMsg := "Content filtered by safety policies"
|
||||
if len(genVideoResp.RAIMediaFilteredReasons) > 0 {
|
||||
errorMsg = genVideoResp.RAIMediaFilteredReasons[0]
|
||||
}
|
||||
response.Error = &schemas.VideoCreateError{
|
||||
Code: "content_filtered",
|
||||
Message: errorMsg,
|
||||
}
|
||||
} else {
|
||||
response.Status = schemas.VideoStatusCompleted
|
||||
|
||||
// Collect all generated videos from multiple possible locations.
|
||||
var videos []schemas.VideoOutput
|
||||
|
||||
// Priority 1: GeneratedSamples
|
||||
if len(genVideoResp.GeneratedSamples) > 0 {
|
||||
for _, sample := range genVideoResp.GeneratedSamples {
|
||||
if sample == nil || sample.Video == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if sample.Video.URI != "" {
|
||||
videoOutput := addVideoURLOutput(sample.Video.URI, sample.Video.MIMEType)
|
||||
if videoOutput != nil {
|
||||
videos = append(videos, *videoOutput)
|
||||
}
|
||||
}
|
||||
if len(sample.Video.VideoBytes) > 0 {
|
||||
videoOutput := addVideoBase64Output(
|
||||
base64.StdEncoding.EncodeToString(sample.Video.VideoBytes),
|
||||
sample.Video.MIMEType,
|
||||
)
|
||||
if videoOutput != nil {
|
||||
videos = append(videos, *videoOutput)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(videos) > 0 {
|
||||
response.Videos = videos
|
||||
}
|
||||
}
|
||||
} else if len(operation.Response.GeneratedVideos) > 0 {
|
||||
// Backward compatibility for older response shapes
|
||||
response.Status = schemas.VideoStatusCompleted
|
||||
var videos []schemas.VideoOutput
|
||||
for _, genVideo := range operation.Response.GeneratedVideos {
|
||||
if genVideo == nil || genVideo.Video == nil {
|
||||
continue
|
||||
}
|
||||
if genVideo.Video.URI != "" {
|
||||
videoOutput := addVideoURLOutput(genVideo.Video.URI, genVideo.Video.MIMEType)
|
||||
if videoOutput != nil {
|
||||
videos = append(videos, *videoOutput)
|
||||
}
|
||||
}
|
||||
if len(genVideo.Video.VideoBytes) > 0 {
|
||||
videoOutput := addVideoBase64Output(
|
||||
base64.StdEncoding.EncodeToString(genVideo.Video.VideoBytes),
|
||||
genVideo.Video.MIMEType,
|
||||
)
|
||||
if videoOutput != nil {
|
||||
videos = append(videos, *videoOutput)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(videos) > 0 {
|
||||
response.Videos = videos
|
||||
}
|
||||
} else if len(operation.Response.Videos) > 0 {
|
||||
response.Status = schemas.VideoStatusCompleted
|
||||
var videos []schemas.VideoOutput
|
||||
for _, video := range operation.Response.Videos {
|
||||
if video.GCSURI != nil && *video.GCSURI != "" {
|
||||
mimeType := defaultVideoContentType
|
||||
if video.MIMEType != nil && *video.MIMEType != "" {
|
||||
mimeType = *video.MIMEType
|
||||
}
|
||||
videoOutput := addVideoURLOutput(*video.GCSURI, mimeType)
|
||||
if videoOutput != nil {
|
||||
videos = append(videos, *videoOutput)
|
||||
}
|
||||
} else if video.BytesBase64Encoded != nil && *video.BytesBase64Encoded != "" {
|
||||
mimeType := defaultVideoContentType
|
||||
if video.MIMEType != nil && *video.MIMEType != "" {
|
||||
mimeType = *video.MIMEType
|
||||
}
|
||||
videoOutput := addVideoBase64Output(*video.BytesBase64Encoded, mimeType)
|
||||
if videoOutput != nil {
|
||||
videos = append(videos, *videoOutput)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(videos) > 0 {
|
||||
response.Videos = videos
|
||||
}
|
||||
} else {
|
||||
response.Status = schemas.VideoStatusCompleted
|
||||
}
|
||||
} else {
|
||||
response.Status = schemas.VideoStatusCompleted
|
||||
}
|
||||
|
||||
// Try to extract timestamps from metadata
|
||||
if operation.Metadata != nil {
|
||||
if ct := providerUtils.GetJSONField([]byte(operation.Metadata), "createTime"); ct.Exists() {
|
||||
if t, err := time.Parse(time.RFC3339, ct.String()); err == nil {
|
||||
response.CreatedAt = t.Unix()
|
||||
}
|
||||
}
|
||||
if ut := providerUtils.GetJSONField([]byte(operation.Metadata), "updateTime"); ut.Exists() {
|
||||
if t, err := time.Parse(time.RFC3339, ut.String()); err == nil && operation.Done {
|
||||
response.CompletedAt = schemas.Ptr(t.Unix())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (request *GeminiVideoGenerationRequest) ToBifrostVideoGenerationRequest(ctx *schemas.BifrostContext) (*schemas.BifrostVideoGenerationRequest, error) {
|
||||
if request == nil || len(request.Instances) == 0 {
|
||||
return nil, fmt.Errorf("request is nil or has no instances")
|
||||
}
|
||||
|
||||
// Use the first instance for the main input
|
||||
instance := request.Instances[0]
|
||||
|
||||
provider, model := schemas.ParseModelString(request.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
|
||||
|
||||
bifrostReq := &schemas.BifrostVideoGenerationRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: &schemas.VideoGenerationInput{
|
||||
Prompt: instance.Prompt,
|
||||
},
|
||||
}
|
||||
|
||||
// Handle image input for image-to-video
|
||||
if instance.Image != nil && instance.Image.BytesBase64Encoded != nil && *instance.Image.BytesBase64Encoded != "" {
|
||||
// attach mime type and base64 string to input reference
|
||||
mimeType := "image/png"
|
||||
if instance.Image.MimeType != nil && *instance.Image.MimeType != "" {
|
||||
mimeType = *instance.Image.MimeType
|
||||
}
|
||||
bifrostReq.Input.InputReference = schemas.Ptr(fmt.Sprintf("data:%s;base64,%s", mimeType, *instance.Image.BytesBase64Encoded))
|
||||
}
|
||||
|
||||
// Helper to ensure params are initialized
|
||||
ensureParams := func() {
|
||||
if bifrostReq.Params == nil {
|
||||
bifrostReq.Params = &schemas.VideoGenerationParameters{
|
||||
ExtraParams: make(map[string]any),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle reference images
|
||||
if len(instance.ReferenceImages) > 0 {
|
||||
ensureParams()
|
||||
bifrostReq.Params.ExtraParams["referenceImages"] = instance.ReferenceImages
|
||||
}
|
||||
|
||||
// Handle video URI
|
||||
if instance.Video != nil && instance.Video.URI != nil {
|
||||
ensureParams()
|
||||
bifrostReq.Params.VideoURI = instance.Video.URI
|
||||
}
|
||||
|
||||
// Handle last frame
|
||||
if instance.LastFrame != nil {
|
||||
ensureParams()
|
||||
bifrostReq.Params.ExtraParams["lastFrame"] = instance.LastFrame
|
||||
}
|
||||
|
||||
// Map parameters if provided
|
||||
if request.Parameters != nil {
|
||||
ensureParams()
|
||||
params := bifrostReq.Params
|
||||
|
||||
if request.Parameters.NegativePrompt != nil {
|
||||
params.NegativePrompt = request.Parameters.NegativePrompt
|
||||
}
|
||||
if request.Parameters.DurationSeconds != nil {
|
||||
seconds := strconv.Itoa(*request.Parameters.DurationSeconds)
|
||||
params.Seconds = &seconds
|
||||
}
|
||||
if request.Parameters.Seed != nil {
|
||||
params.Seed = request.Parameters.Seed
|
||||
}
|
||||
if request.Parameters.GenerateAudio != nil {
|
||||
params.Audio = request.Parameters.GenerateAudio
|
||||
}
|
||||
if request.Parameters.AspectRatio != nil {
|
||||
params.ExtraParams["aspectRatio"] = *request.Parameters.AspectRatio
|
||||
}
|
||||
if request.Parameters.Resolution != nil {
|
||||
params.ExtraParams["resolution"] = *request.Parameters.Resolution
|
||||
}
|
||||
if request.Parameters.SampleCount != nil {
|
||||
params.ExtraParams["sampleCount"] = *request.Parameters.SampleCount
|
||||
}
|
||||
if request.Parameters.PersonGeneration != nil {
|
||||
params.ExtraParams["personGeneration"] = *request.Parameters.PersonGeneration
|
||||
}
|
||||
if request.Parameters.NumberOfVideos != nil {
|
||||
params.ExtraParams["numberOfVideos"] = *request.Parameters.NumberOfVideos
|
||||
}
|
||||
if request.Parameters.StorageURI != nil {
|
||||
params.ExtraParams["storageURI"] = *request.Parameters.StorageURI
|
||||
}
|
||||
if request.Parameters.CompressionQuality != nil {
|
||||
params.ExtraParams["compressionQuality"] = *request.Parameters.CompressionQuality
|
||||
}
|
||||
if request.Parameters.EnhancePrompt != nil {
|
||||
params.ExtraParams["enhancePrompt"] = *request.Parameters.EnhancePrompt
|
||||
}
|
||||
if request.Parameters.ResizeMode != nil {
|
||||
params.ExtraParams["resizeMode"] = *request.Parameters.ResizeMode
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostReq, nil
|
||||
}
|
||||
|
||||
func ToGeminiVideoGenerationResponse(response *schemas.BifrostVideoGenerationResponse) *GenerateVideosOperation {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
decodedID := response.ID
|
||||
if decoded, err := url.PathUnescape(decodedID); err == nil {
|
||||
decodedID = decoded
|
||||
}
|
||||
|
||||
// if id is in gemini or vertex format, set name in format models/model/operations/operation_id:provider
|
||||
// else make the id in gemini format
|
||||
if !(strings.HasPrefix(decodedID, "models/") && strings.Contains(decodedID, response.Model) && strings.Contains(decodedID, "operations/")) {
|
||||
// url encode model
|
||||
encodedModel := url.PathEscape(response.Model)
|
||||
decodedID = "models/" + encodedModel + "/operations/" + decodedID
|
||||
}
|
||||
operation := &GenerateVideosOperation{
|
||||
Name: decodedID,
|
||||
}
|
||||
|
||||
switch response.Status {
|
||||
case schemas.VideoStatusCompleted:
|
||||
operation.Done = true
|
||||
if len(response.Videos) > 0 {
|
||||
generatedSamples := make([]*GeneratedVideo, 0, len(response.Videos))
|
||||
for _, output := range response.Videos {
|
||||
var video *Video
|
||||
|
||||
switch output.Type {
|
||||
case schemas.VideoOutputTypeURL:
|
||||
if output.URL == nil || *output.URL == "" {
|
||||
continue
|
||||
}
|
||||
video = &Video{
|
||||
URI: *output.URL,
|
||||
}
|
||||
if output.ContentType != "" {
|
||||
video.MIMEType = output.ContentType
|
||||
}
|
||||
case schemas.VideoOutputTypeBase64:
|
||||
if output.Base64Data == nil || *output.Base64Data == "" {
|
||||
continue
|
||||
}
|
||||
base64Payload := *output.Base64Data
|
||||
mimeType := output.ContentType
|
||||
if parsedMimeType, payload, ok := parseVideoDataURL(*output.Base64Data); ok {
|
||||
base64Payload = payload
|
||||
if mimeType == "" {
|
||||
mimeType = parsedMimeType
|
||||
}
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(base64Payload)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if mimeType == "" {
|
||||
mimeType = defaultVideoContentType
|
||||
}
|
||||
video = &Video{
|
||||
VideoBytes: decoded,
|
||||
MIMEType: mimeType,
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if video == nil {
|
||||
continue
|
||||
}
|
||||
generatedSamples = append(generatedSamples, &GeneratedVideo{
|
||||
Video: video,
|
||||
})
|
||||
}
|
||||
if len(generatedSamples) > 0 {
|
||||
operation.Response = &GenerateVideosOperationResponse{
|
||||
GenerateVideoResponse: &GenerateVideoResponse{
|
||||
GeneratedSamples: generatedSamples,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
case schemas.VideoStatusFailed:
|
||||
operation.Done = true
|
||||
// Check if this is a content filtering case
|
||||
if response.ContentFilter != nil && response.ContentFilter.FilteredCount > 0 {
|
||||
operation.Response = &GenerateVideosOperationResponse{
|
||||
GenerateVideoResponse: &GenerateVideoResponse{
|
||||
RAIMediaFilteredCount: int32(response.ContentFilter.FilteredCount),
|
||||
RAIMediaFilteredReasons: response.ContentFilter.Reasons,
|
||||
},
|
||||
}
|
||||
} else if response.Error != nil {
|
||||
errBytes, _ := providerUtils.MarshalSorted(map[string]any{
|
||||
"message": response.Error.Message,
|
||||
"code": response.Error.Code,
|
||||
})
|
||||
operation.Error = json.RawMessage(errBytes)
|
||||
}
|
||||
default:
|
||||
operation.Done = false
|
||||
}
|
||||
|
||||
return operation
|
||||
}
|
||||
402
core/providers/groq/groq.go
Normal file
402
core/providers/groq/groq.go
Normal file
@@ -0,0 +1,402 @@
|
||||
// Package groq implements the Groq provider and its utility functions.
|
||||
package groq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// GroqProvider implements the Provider interface for Groq's API.
|
||||
type GroqProvider struct {
|
||||
logger schemas.Logger // Logger for provider operations
|
||||
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
|
||||
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
|
||||
networkConfig schemas.NetworkConfig // Network configuration including extra headers
|
||||
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
|
||||
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
|
||||
}
|
||||
|
||||
// NewGroqProvider creates a new Groq provider instance.
|
||||
// It initializes the HTTP client with the provided configuration and sets up response pools.
|
||||
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
|
||||
func NewGroqProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*GroqProvider, error) {
|
||||
config.CheckAndSetDefaults()
|
||||
|
||||
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
|
||||
client := &fasthttp.Client{
|
||||
ReadTimeout: requestTimeout,
|
||||
WriteTimeout: requestTimeout,
|
||||
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
|
||||
MaxIdleConnDuration: 30 * time.Second,
|
||||
MaxConnWaitTimeout: requestTimeout,
|
||||
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
|
||||
ConnPoolStrategy: fasthttp.FIFO,
|
||||
}
|
||||
|
||||
// // Pre-warm response pools
|
||||
// for range config.ConcurrencyAndBufferSize.Concurrency {
|
||||
// groqResponsePool.Put(&schemas.BifrostResponse{})
|
||||
// }
|
||||
|
||||
// Configure proxy and retry policy
|
||||
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
|
||||
client = providerUtils.ConfigureDialer(client)
|
||||
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
|
||||
streamingClient := providerUtils.BuildStreamingClient(client)
|
||||
// Set default BaseURL if not provided
|
||||
if config.NetworkConfig.BaseURL == "" {
|
||||
config.NetworkConfig.BaseURL = "https://api.groq.com/openai"
|
||||
}
|
||||
config.NetworkConfig.BaseURL = strings.TrimRight(config.NetworkConfig.BaseURL, "/")
|
||||
|
||||
return &GroqProvider{
|
||||
logger: logger,
|
||||
client: client,
|
||||
streamingClient: streamingClient,
|
||||
networkConfig: config.NetworkConfig,
|
||||
sendBackRawRequest: config.SendBackRawRequest,
|
||||
sendBackRawResponse: config.SendBackRawResponse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetProviderKey returns the provider identifier for Groq.
|
||||
func (provider *GroqProvider) GetProviderKey() schemas.ModelProvider {
|
||||
return schemas.Groq
|
||||
}
|
||||
|
||||
// ListModels performs a list models request to Groq's API.
|
||||
func (provider *GroqProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIListModelsRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
request,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/models"),
|
||||
keys,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
schemas.Groq,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
)
|
||||
}
|
||||
|
||||
// TextCompletion is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError("text completion", "groq")
|
||||
}
|
||||
|
||||
// TextCompletionStream performs a streaming text completion request to Groq's API.
|
||||
// It formats the request, sends it to Groq, and processes the response.
|
||||
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
|
||||
func (provider *GroqProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError("text completion", "groq")
|
||||
}
|
||||
|
||||
// ChatCompletion performs a chat completion request to the Groq API.
|
||||
func (provider *GroqProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAIChatCompletionRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/chat/completions"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
provider.GetProviderKey(),
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// ChatCompletionStream performs a streaming chat completion request to the Groq API.
|
||||
// It supports real-time streaming of responses using Server-Sent Events (SSE).
|
||||
// Uses Groq's OpenAI-compatible streaming format.
|
||||
// Returns a channel containing BifrostStreamChunk objects representing the stream or an error if the request fails.
|
||||
func (provider *GroqProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
var authHeader map[string]string
|
||||
if v := key.Value.GetValue(); v != "" {
|
||||
authHeader = map[string]string{"Authorization": "Bearer " + v}
|
||||
}
|
||||
// Use shared OpenAI-compatible streaming logic
|
||||
return openai.HandleOpenAIChatCompletionStreaming(
|
||||
ctx,
|
||||
provider.streamingClient,
|
||||
provider.networkConfig.BaseURL+"/v1/chat/completions",
|
||||
request,
|
||||
authHeader,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
schemas.Groq,
|
||||
postHookRunner,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
provider.logger,
|
||||
postHookSpanFinalizer,
|
||||
)
|
||||
}
|
||||
|
||||
// Responses performs a responses request to the Groq API.
|
||||
func (provider *GroqProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
||||
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := chatResponse.ToBifrostResponsesResponse()
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ResponsesStream performs a streaming responses request to the Groq API.
|
||||
func (provider *GroqProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
|
||||
return provider.ChatCompletionStream(
|
||||
ctx,
|
||||
postHookRunner,
|
||||
postHookSpanFinalizer,
|
||||
key,
|
||||
request.ToChatRequest(),
|
||||
)
|
||||
}
|
||||
|
||||
// Embedding is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.EmbeddingRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Speech handles non-streaming speech synthesis requests.
|
||||
// It formats the request body, makes the API call, and returns the response.
|
||||
// Returns the response and any error that occurred.
|
||||
func (provider *GroqProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAISpeechRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/audio/speech"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
schemas.Groq,
|
||||
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// Rerank is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.RerankRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// OCR is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// SpeechStream is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Transcription handles non-streaming transcription requests.
|
||||
// It creates a multipart form, adds fields, makes the API call, and returns the response.
|
||||
// Returns the response and any error that occurred.
|
||||
func (provider *GroqProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
||||
return openai.HandleOpenAITranscriptionRequest(
|
||||
ctx,
|
||||
provider.client,
|
||||
provider.networkConfig.BaseURL+providerUtils.GetPathFromContext(ctx, "/v1/audio/transcriptions"),
|
||||
request,
|
||||
key,
|
||||
provider.networkConfig.ExtraHeaders,
|
||||
schemas.Groq,
|
||||
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
||||
nil,
|
||||
provider.logger,
|
||||
)
|
||||
}
|
||||
|
||||
// TranscriptionStream is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGeneration is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageGenerationStream is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEdit is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageEditStream is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ImageVariation is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoGeneration is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) VideoGeneration(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoGenerationRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRetrieve is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) VideoRetrieve(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDownload is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) VideoDownload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDownloadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoDelete is not supported by Groq provider.
|
||||
func (provider *GroqProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoList is not supported by Groq provider.
|
||||
func (provider *GroqProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// VideoRemix is not supported by Groq provider.
|
||||
func (provider *GroqProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCreate is not supported by Groq provider.
|
||||
func (provider *GroqProvider) BatchCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchList is not supported by Groq provider.
|
||||
func (provider *GroqProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchRetrieve is not supported by Groq provider.
|
||||
func (provider *GroqProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchCancel is not supported by Groq provider.
|
||||
func (provider *GroqProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchDelete is not supported by Groq provider.
|
||||
func (provider *GroqProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// BatchResults is not supported by Groq provider.
|
||||
func (provider *GroqProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileUpload is not supported by Groq provider.
|
||||
func (provider *GroqProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileList is not supported by Groq provider.
|
||||
func (provider *GroqProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileRetrieve is not supported by Groq provider.
|
||||
func (provider *GroqProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileDelete is not supported by Groq provider.
|
||||
func (provider *GroqProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// FileContent is not supported by Groq provider.
|
||||
func (provider *GroqProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// CountTokens is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) CountTokens(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.CountTokensRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerCreate is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerList is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerRetrieve is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerDelete is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileCreate is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileList is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileRetrieve is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileContent is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// ContainerFileDelete is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
// Passthrough is not supported by the Groq provider.
|
||||
func (provider *GroqProvider) Passthrough(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughRequest, provider.GetProviderKey())
|
||||
}
|
||||
|
||||
func (provider *GroqProvider) PassthroughStream(_ *schemas.BifrostContext, _ schemas.PostHookRunner, _ func(context.Context), _ schemas.Key, _ *schemas.BifrostPassthroughRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return nil, providerUtils.NewUnsupportedOperationError(schemas.PassthroughStreamRequest, provider.GetProviderKey())
|
||||
}
|
||||
68
core/providers/groq/groq_test.go
Normal file
68
core/providers/groq/groq_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package groq_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestGroq(t *testing.T) {
|
||||
t.Parallel()
|
||||
if strings.TrimSpace(os.Getenv("GROQ_API_KEY")) == "" {
|
||||
t.Skip("Skipping Groq tests because GROQ_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.Groq,
|
||||
ChatModel: "llama-3.3-70b-versatile",
|
||||
Fallbacks: []schemas.Fallback{
|
||||
{Provider: schemas.Groq, Model: "openai/gpt-oss-120b"},
|
||||
},
|
||||
TextModel: "llama-3.3-70b-versatile",
|
||||
TextCompletionFallbacks: []schemas.Fallback{
|
||||
{Provider: schemas.Groq, Model: "openai/gpt-oss-20b"},
|
||||
},
|
||||
EmbeddingModel: "", // Groq doesn't support embedding
|
||||
ReasoningModel: "openai/gpt-oss-120b",
|
||||
TranscriptionModel: "whisper-large-v3",
|
||||
SpeechSynthesisModel: "canopylabs/orpheus-v1-english",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: false,
|
||||
TextCompletionStream: false,
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: true,
|
||||
MultipleToolCallsStreaming: true,
|
||||
End2EndToolCalling: true,
|
||||
AutomaticFunctionCall: true,
|
||||
ImageURL: false,
|
||||
ImageBase64: false,
|
||||
MultipleImages: false,
|
||||
FileBase64: false, // Not supported
|
||||
FileURL: false, // Not supported
|
||||
CompleteEnd2End: true,
|
||||
Embedding: false,
|
||||
ListModels: true,
|
||||
Reasoning: true,
|
||||
Transcription: true,
|
||||
SpeechSynthesis: true,
|
||||
},
|
||||
}
|
||||
t.Run("GroqTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
144
core/providers/huggingface/chat.go
Normal file
144
core/providers/huggingface/chat.go
Normal file
@@ -0,0 +1,144 @@
|
||||
// Package huggingface provides a HuggingFace chat provider.
|
||||
package huggingface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// sanitizeMessagesForHuggingFace removes unsupported ChatAssistantMessage fields
|
||||
// from chat messages. HuggingFace's OpenAI-compatible API doesn't support fields
|
||||
// like reasoning_details, reasoning, annotations, audio, and refusal.
|
||||
// Only ToolCalls is preserved from ChatAssistantMessage.
|
||||
func sanitizeMessagesForHuggingFace(messages []schemas.ChatMessage) []schemas.ChatMessage {
|
||||
sanitized := make([]schemas.ChatMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
sanitized[i] = schemas.ChatMessage{
|
||||
Name: msg.Name,
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ChatToolMessage: msg.ChatToolMessage,
|
||||
}
|
||||
// Only preserve ToolCalls from ChatAssistantMessage
|
||||
if msg.ChatAssistantMessage != nil && len(msg.ChatAssistantMessage.ToolCalls) > 0 {
|
||||
sanitized[i].ChatAssistantMessage = &schemas.ChatAssistantMessage{
|
||||
ToolCalls: msg.ChatAssistantMessage.ToolCalls,
|
||||
}
|
||||
}
|
||||
}
|
||||
return sanitized
|
||||
}
|
||||
|
||||
func ToHuggingFaceChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*HuggingFaceChatRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Create the HuggingFace request
|
||||
// Sanitize messages to remove unsupported fields like reasoning_details
|
||||
hfReq := &HuggingFaceChatRequest{
|
||||
Messages: sanitizeMessagesForHuggingFace(bifrostReq.Input),
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
|
||||
// Map parameters if present
|
||||
if bifrostReq.Params != nil {
|
||||
params := bifrostReq.Params
|
||||
|
||||
if params.FrequencyPenalty != nil {
|
||||
hfReq.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
if params.LogProbs != nil {
|
||||
hfReq.Logprobs = params.LogProbs
|
||||
}
|
||||
if params.MaxCompletionTokens != nil {
|
||||
hfReq.MaxTokens = params.MaxCompletionTokens
|
||||
}
|
||||
if params.PresencePenalty != nil {
|
||||
hfReq.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
if params.Seed != nil {
|
||||
hfReq.Seed = params.Seed
|
||||
}
|
||||
if len(params.Stop) > 0 {
|
||||
hfReq.Stop = params.Stop
|
||||
}
|
||||
if params.Temperature != nil {
|
||||
hfReq.Temperature = params.Temperature
|
||||
}
|
||||
if params.TopLogProbs != nil {
|
||||
hfReq.TopLogprobs = params.TopLogProbs
|
||||
}
|
||||
if params.TopP != nil {
|
||||
hfReq.TopP = params.TopP
|
||||
}
|
||||
|
||||
// Handle response format (direct type assertion to avoid marshal→unmarshal round-trip)
|
||||
if params.ResponseFormat != nil {
|
||||
var hfRF *HuggingFaceResponseFormat
|
||||
if rfMap, ok := (*params.ResponseFormat).(map[string]interface{}); ok {
|
||||
hfRF = &HuggingFaceResponseFormat{}
|
||||
if t, ok := rfMap["type"].(string); ok {
|
||||
hfRF.Type = t
|
||||
}
|
||||
if jsVal, ok := rfMap["json_schema"]; ok {
|
||||
jsBytes, err := providerUtils.MarshalSorted(jsVal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal json_schema: %w", err)
|
||||
}
|
||||
var hfSchema HuggingFaceJSONSchema
|
||||
if err := sonic.Unmarshal(jsBytes, &hfSchema); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal json_schema: %w", err)
|
||||
}
|
||||
hfRF.JSONSchema = &hfSchema
|
||||
}
|
||||
} else if converted, err := schemas.ConvertViaJSON[HuggingFaceResponseFormat](*params.ResponseFormat); err == nil {
|
||||
hfRF = &converted
|
||||
}
|
||||
hfReq.ResponseFormat = hfRF
|
||||
}
|
||||
|
||||
// Handle stream options
|
||||
if params.StreamOptions != nil {
|
||||
hfReq.StreamOptions = &schemas.ChatStreamOptions{
|
||||
IncludeUsage: params.StreamOptions.IncludeUsage,
|
||||
}
|
||||
}
|
||||
|
||||
hfReq.Tools = params.Tools
|
||||
|
||||
// Handle tool choice
|
||||
if params.ToolChoice != nil {
|
||||
hfToolChoice := &HuggingFaceToolChoice{}
|
||||
if params.ToolChoice.ChatToolChoiceStr != nil {
|
||||
switch *params.ToolChoice.ChatToolChoiceStr {
|
||||
case "auto":
|
||||
auto := EnumStringTypeAuto
|
||||
hfToolChoice.EnumValue = &auto
|
||||
case "none":
|
||||
none := EnumStringTypeNone
|
||||
hfToolChoice.EnumValue = &none
|
||||
case "required":
|
||||
required := EnumStringTypeRequired
|
||||
hfToolChoice.EnumValue = &required
|
||||
}
|
||||
} else if params.ToolChoice.ChatToolChoiceStruct != nil {
|
||||
if params.ToolChoice.ChatToolChoiceStruct.Type == schemas.ChatToolChoiceTypeFunction && params.ToolChoice.ChatToolChoiceStruct.Function != nil {
|
||||
hfToolChoice.Function = &schemas.ChatToolChoiceFunction{
|
||||
Name: params.ToolChoice.ChatToolChoiceStruct.Function.Name,
|
||||
}
|
||||
}
|
||||
}
|
||||
if hfToolChoice.EnumValue != nil || hfToolChoice.Function != nil {
|
||||
hfReq.ToolChoice = hfToolChoice
|
||||
}
|
||||
}
|
||||
hfReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
return hfReq, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user