first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,722 @@
package cohere
import (
"fmt"
"time"
"github.com/maximhq/bifrost/core/providers/anthropic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToCohereChatCompletionRequest converts a Bifrost request to Cohere v2 format
func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*CohereChatRequest, error) {
if bifrostReq == nil || bifrostReq.Input == nil {
return nil, fmt.Errorf("bifrost request is nil")
}
messages := bifrostReq.Input
cohereReq := &CohereChatRequest{
Model: bifrostReq.Model,
}
// Convert messages to Cohere v2 format
var cohereMessages []CohereMessage
for _, msg := range messages {
cohereMsg := CohereMessage{
Role: string(msg.Role),
}
// Convert content
if msg.Content != nil && msg.Content.ContentStr != nil {
cohereMsg.Content = NewStringContent(*msg.Content.ContentStr)
} else if msg.Content != nil && msg.Content.ContentBlocks != nil {
var contentBlocks []CohereContentBlock
for _, block := range msg.Content.ContentBlocks {
if block.Text != nil {
contentBlocks = append(contentBlocks, CohereContentBlock{
Type: CohereContentBlockTypeText,
Text: block.Text,
})
} else if block.ImageURLStruct != nil {
contentBlocks = append(contentBlocks, CohereContentBlock{
Type: CohereContentBlockTypeImage,
ImageURL: &CohereImageURL{
URL: block.ImageURLStruct.URL,
},
})
}
}
if len(contentBlocks) > 0 {
cohereMsg.Content = NewBlocksContent(contentBlocks)
}
}
// Convert tool calls for assistant messages
if msg.ChatAssistantMessage != nil && msg.ChatAssistantMessage.ToolCalls != nil {
var toolCalls []CohereToolCall
for _, toolCall := range msg.ChatAssistantMessage.ToolCalls {
// Safely extract function name and arguments
var functionName *string
var functionArguments string
if toolCall.Function.Name != nil {
functionName = toolCall.Function.Name
} else {
// Use empty string if Name is nil
functionName = schemas.Ptr("")
}
// Arguments is a string, not a pointer, so it's safe to access directly
// Default to "{}" if empty to ensure the field is always present.
if toolCall.Function.Arguments == "" {
functionArguments = "{}"
} else {
functionArguments = toolCall.Function.Arguments
}
cohereToolCall := CohereToolCall{
ID: toolCall.ID,
Type: "function",
Function: &CohereFunction{
Name: functionName,
Arguments: functionArguments,
},
}
toolCalls = append(toolCalls, cohereToolCall)
}
cohereMsg.ToolCalls = toolCalls
}
// Convert tool messages
if msg.ChatToolMessage != nil && msg.ChatToolMessage.ToolCallID != nil {
cohereMsg.ToolCallID = msg.ChatToolMessage.ToolCallID
}
cohereMessages = append(cohereMessages, cohereMsg)
}
cohereReq.Messages = cohereMessages
// Convert parameters
if bifrostReq.Params != nil {
cohereReq.MaxTokens = bifrostReq.Params.MaxCompletionTokens
cohereReq.Temperature = bifrostReq.Params.Temperature
cohereReq.P = bifrostReq.Params.TopP
cohereReq.StopSequences = bifrostReq.Params.Stop
cohereReq.FrequencyPenalty = bifrostReq.Params.FrequencyPenalty
cohereReq.PresencePenalty = bifrostReq.Params.PresencePenalty
// Convert reasoning
if bifrostReq.Params.Reasoning != nil {
if bifrostReq.Params.Reasoning.MaxTokens != nil {
thinking := &CohereThinking{
Type: ThinkingTypeEnabled,
}
if *bifrostReq.Params.Reasoning.MaxTokens == -1 {
// cohere does not support dynamic reasoning budget like gemini
// setting it to minimum reasoning budget
thinking.TokenBudget = schemas.Ptr(anthropic.MinimumReasoningMaxTokens)
} else {
thinking.TokenBudget = bifrostReq.Params.Reasoning.MaxTokens
}
cohereReq.Thinking = thinking
} else if bifrostReq.Params.Reasoning.Effort != nil {
if *bifrostReq.Params.Reasoning.Effort != "none" {
maxCompletionTokens := providerUtils.GetMaxOutputTokensOrDefault(bifrostReq.Model, DefaultCompletionMaxTokens)
if bifrostReq.Params.MaxCompletionTokens != nil {
maxCompletionTokens = *bifrostReq.Params.MaxCompletionTokens
}
budgetTokens, err := providerUtils.GetBudgetTokensFromReasoningEffort(*bifrostReq.Params.Reasoning.Effort, MinimumReasoningMaxTokens, maxCompletionTokens)
if err != nil {
return nil, err
}
cohereReq.Thinking = &CohereThinking{
Type: ThinkingTypeEnabled,
TokenBudget: schemas.Ptr(budgetTokens), // Max tokens for reasoning
}
} else {
cohereReq.Thinking = &CohereThinking{
Type: ThinkingTypeDisabled,
}
}
}
}
// Convert response format
if bifrostReq.Params.ResponseFormat != nil {
cohereReq.ResponseFormat = convertResponseFormatToCohere(bifrostReq.Params.ResponseFormat)
}
// Convert extra params
if bifrostReq.Params.ExtraParams != nil {
// Handle thinking parameter
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
if thinkingParam, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "thinking"); ok {
if thinkingMap, ok := thinkingParam.(map[string]interface{}); ok {
thinking := &CohereThinking{}
if typeStr, ok := schemas.SafeExtractString(thinkingMap["type"]); ok {
delete(thinkingMap, "type")
thinking.Type = CohereThinkingType(typeStr)
}
if tokenBudget, ok := schemas.SafeExtractIntPointer(thinkingMap["token_budget"]); ok {
delete(thinkingMap, "token_budget")
thinking.TokenBudget = tokenBudget
}
cohereReq.Thinking = thinking
cohereReq.ExtraParams["thinking"] = thinkingMap
}
}
// Handle other Cohere-specific extra params
if safetyMode, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["safety_mode"]); ok {
delete(cohereReq.ExtraParams, "safety_mode")
cohereReq.SafetyMode = safetyMode
}
if logProbs, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["log_probs"]); ok {
delete(cohereReq.ExtraParams, "log_probs")
cohereReq.LogProbs = logProbs
}
if strictToolChoice, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["strict_tool_choice"]); ok {
delete(cohereReq.ExtraParams, "strict_tool_choice")
cohereReq.StrictToolChoice = strictToolChoice
}
}
// Convert tools to Cohere-specific format (without "strict" field)
if bifrostReq.Params.Tools != nil {
cohereTools := make([]CohereChatRequestTool, len(bifrostReq.Params.Tools))
for i, tool := range bifrostReq.Params.Tools {
cohereTools[i] = CohereChatRequestTool{
Type: string(tool.Type),
}
if tool.Function != nil {
cohereTools[i].Function = CohereChatRequestFunction{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: tool.Function.Parameters, // Convert to map
// Note: No "strict" field - Cohere doesn't support it
}
}
}
cohereReq.Tools = cohereTools
}
// Convert tool choice
if bifrostReq.Params.ToolChoice != nil {
toolChoice := bifrostReq.Params.ToolChoice
if toolChoice.ChatToolChoiceStr != nil {
switch schemas.ChatToolChoiceType(*toolChoice.ChatToolChoiceStr) {
case schemas.ChatToolChoiceTypeNone:
toolChoice := ToolChoiceNone
cohereReq.ToolChoice = &toolChoice
default:
toolChoice := ToolChoiceRequired
cohereReq.ToolChoice = &toolChoice
}
} else if toolChoice.ChatToolChoiceStruct != nil {
switch toolChoice.ChatToolChoiceStruct.Type {
case schemas.ChatToolChoiceTypeFunction:
toolChoice := ToolChoiceRequired
cohereReq.ToolChoice = &toolChoice
default:
toolChoice := ToolChoiceAuto
cohereReq.ToolChoice = &toolChoice
}
}
}
}
return cohereReq, nil
}
// ToBifrostChatRequest converts a Cohere v2 chat request to Bifrost format
func (req *CohereChatRequest) ToBifrostChatRequest(ctx *schemas.BifrostContext) *schemas.BifrostChatRequest {
if req == nil {
return nil
}
provider, model := schemas.ParseModelString(req.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
bifrostReq := &schemas.BifrostChatRequest{
Provider: provider,
Model: model,
Params: &schemas.ChatParameters{},
}
// Convert messages
if req.Messages != nil {
bifrostMessages := make([]schemas.ChatMessage, len(req.Messages))
for i, message := range req.Messages {
bifrostMessages[i] = *message.ToBifrostChatMessage()
}
bifrostReq.Input = bifrostMessages
}
// Convert parameters
if req.MaxTokens != nil {
bifrostReq.Params.MaxCompletionTokens = req.MaxTokens
}
if req.Temperature != nil {
bifrostReq.Params.Temperature = req.Temperature
}
if req.P != nil {
bifrostReq.Params.TopP = req.P
}
if req.StopSequences != nil {
bifrostReq.Params.Stop = req.StopSequences
}
if req.FrequencyPenalty != nil {
bifrostReq.Params.FrequencyPenalty = req.FrequencyPenalty
}
if req.PresencePenalty != nil {
bifrostReq.Params.PresencePenalty = req.PresencePenalty
}
// Convert reasoning
if req.Thinking != nil {
if req.Thinking.Type == ThinkingTypeDisabled {
bifrostReq.Params.Reasoning = &schemas.ChatReasoning{
Effort: schemas.Ptr("none"),
}
} else {
bifrostReq.Params.Reasoning = &schemas.ChatReasoning{
Effort: schemas.Ptr("auto"),
}
if req.Thinking.TokenBudget != nil {
bifrostReq.Params.Reasoning.MaxTokens = req.Thinking.TokenBudget
}
}
}
if req.ResponseFormat != nil {
bifrostReq.Params.ResponseFormat = convertCohereResponseFormatToBifrost(req.ResponseFormat)
}
// Convert tools
if req.Tools != nil {
bifrostTools := make([]schemas.ChatTool, len(req.Tools))
for i, tool := range req.Tools {
bifrostTools[i] = schemas.ChatTool{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: convertInterfaceToToolFunctionParameters(tool.Function.Parameters),
},
}
}
bifrostReq.Params.Tools = bifrostTools
}
// Convert tool choice
if req.ToolChoice != nil {
switch *req.ToolChoice {
case ToolChoiceNone:
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeNone)),
}
case ToolChoiceRequired:
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeRequired)),
}
case ToolChoiceAuto:
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeAny)),
}
}
}
// Convert extra params
extraParams := make(map[string]interface{})
if req.SafetyMode != nil {
extraParams["safety_mode"] = *req.SafetyMode
}
if req.LogProbs != nil {
extraParams["log_probs"] = *req.LogProbs
}
if req.StrictToolChoice != nil {
extraParams["strict_tool_choice"] = *req.StrictToolChoice
}
if req.Thinking != nil {
thinkingMap := map[string]interface{}{
"type": string(req.Thinking.Type),
}
if req.Thinking.TokenBudget != nil {
thinkingMap["token_budget"] = *req.Thinking.TokenBudget
}
extraParams["thinking"] = thinkingMap
}
if len(extraParams) > 0 {
bifrostReq.Params.ExtraParams = extraParams
}
return bifrostReq
}
// ToBifrostChatResponse converts a Cohere v2 response to Bifrost format
func (response *CohereChatResponse) ToBifrostChatResponse(model string) *schemas.BifrostChatResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostChatResponse{
ID: response.ID,
Model: model,
Object: "chat.completion",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{},
},
},
Created: int(time.Now().Unix()),
ExtraFields: schemas.BifrostResponseExtraFields{
},
}
// Convert messages
if response.Message != nil {
bifrostMessage := response.Message.ToBifrostChatMessage()
bifrostResponse.Choices[0].ChatNonStreamResponseChoice.Message = bifrostMessage
}
// Convert finish reason
if response.FinishReason != nil {
finishReason := ConvertCohereFinishReasonToBifrost(*response.FinishReason)
bifrostResponse.Choices[0].FinishReason = schemas.Ptr(finishReason)
}
// Convert usage information
if response.Usage != nil {
usage := &schemas.BifrostLLMUsage{}
if response.Usage.Tokens != nil {
if response.Usage.Tokens.InputTokens != nil {
usage.PromptTokens = *response.Usage.Tokens.InputTokens
}
if response.Usage.Tokens.OutputTokens != nil {
usage.CompletionTokens = *response.Usage.Tokens.OutputTokens
}
if response.Usage.CachedTokens != nil {
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
CachedReadTokens: *response.Usage.CachedTokens,
}
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
bifrostResponse.Usage = usage
}
return bifrostResponse
}
func (chunk *CohereStreamEvent) ToBifrostChatCompletionStream() (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) {
switch chunk.Type {
case StreamEventMessageStart:
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.Role != nil {
// Create streaming response for this delta
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Role: chunk.Delta.Message.Role,
},
},
},
},
}
return streamResponse, nil, false
}
case StreamEventContentDelta:
if chunk.Delta != nil &&
chunk.Delta.Message != nil &&
chunk.Delta.Message.Content != nil &&
chunk.Delta.Message.Content.CohereStreamContentObject != nil {
if chunk.Delta.Message.Content.CohereStreamContentObject.Text != nil {
// Try to cast content to CohereStreamContent
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Content: chunk.Delta.Message.Content.CohereStreamContentObject.Text,
},
},
},
},
}
return streamResponse, nil, false
} else if chunk.Delta.Message.Content.CohereStreamContentObject.Thinking != nil {
thinkingText := *chunk.Delta.Message.Content.CohereStreamContentObject.Thinking
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Reasoning: schemas.Ptr(thinkingText),
ReasoningDetails: []schemas.ChatReasoningDetails{
{
Index: 0,
Type: schemas.BifrostReasoningDetailsTypeText,
Text: schemas.Ptr(thinkingText),
},
},
},
},
},
},
}
return streamResponse, nil, false
}
}
case StreamEventToolPlanDelta:
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolPlan != nil {
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Reasoning: chunk.Delta.Message.ToolPlan,
},
},
},
},
}
return streamResponse, nil, false
}
case StreamEventContentStart:
// Content start event - just continue, actual content comes in content-delta
return nil, nil, false
case StreamEventToolCallStart, StreamEventToolCallDelta:
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolCalls != nil && chunk.Delta.Message.ToolCalls.CohereToolCallObject != nil {
// Handle single tool call object (tool-call-start/delta events)
cohereToolCall := chunk.Delta.Message.ToolCalls.CohereToolCallObject
toolCall := schemas.ChatAssistantMessageToolCall{}
if chunk.Index != nil {
toolCall.Index = uint16(*chunk.Index)
}
if cohereToolCall.ID != nil {
toolCall.ID = cohereToolCall.ID
}
if cohereToolCall.Function != nil {
if cohereToolCall.Function.Name != nil {
toolCall.Function.Name = cohereToolCall.Function.Name
}
toolCall.Function.Arguments = cohereToolCall.Function.Arguments
}
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
},
},
},
},
}
return streamResponse, nil, false
}
case StreamEventToolCallEnd:
return nil, nil, false
case StreamEventContentEnd:
return nil, nil, false
case StreamEventMessageEnd:
if chunk.Delta != nil {
var finishReason string
usage := &schemas.BifrostLLMUsage{}
// Set finish reason
if chunk.Delta.FinishReason != nil {
finishReason = ConvertCohereFinishReasonToBifrost(*chunk.Delta.FinishReason)
}
// Set usage information
if chunk.Delta.Usage != nil {
if chunk.Delta.Usage.Tokens != nil {
if chunk.Delta.Usage.Tokens.InputTokens != nil {
usage.PromptTokens = *chunk.Delta.Usage.Tokens.InputTokens
}
if chunk.Delta.Usage.Tokens.OutputTokens != nil {
usage.CompletionTokens = *chunk.Delta.Usage.Tokens.OutputTokens
}
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
}
streamResponse := &schemas.BifrostChatResponse{
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
FinishReason: &finishReason,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{},
},
},
},
Usage: usage,
}
return streamResponse, nil, true
}
return nil, nil, false
}
return nil, nil, false
}
func (cm *CohereMessage) ToBifrostChatMessage() *schemas.ChatMessage {
if cm == nil {
return nil
}
var content *string
var contentBlocks []schemas.ChatContentBlock
var toolCalls []schemas.ChatAssistantMessageToolCall
var reasoningDetails []schemas.ChatReasoningDetails
var reasoningText string
// Convert message content
if cm.Content != nil {
if cm.Content.IsString() ||
(cm.Content.IsBlocks() &&
len(cm.Content.GetBlocks()) == 1 &&
cm.Content.GetBlocks()[0].Type == CohereContentBlockTypeText) {
if cm.Content.IsString() {
content = cm.Content.GetString()
} else {
content = cm.Content.GetBlocks()[0].Text
}
} else if cm.Content.IsBlocks() {
for _, block := range cm.Content.GetBlocks() {
if block.Type == CohereContentBlockTypeText && block.Text != nil {
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
Type: schemas.ChatContentBlockTypeText,
Text: block.Text,
})
} else if block.Type == CohereContentBlockTypeImage && block.ImageURL != nil {
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
Type: schemas.ChatContentBlockTypeImage,
ImageURLStruct: &schemas.ChatInputImage{
URL: block.ImageURL.URL,
},
})
} else if block.Type == CohereContentBlockTypeThinking && block.Thinking != nil {
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
Index: len(reasoningDetails),
Type: schemas.BifrostReasoningDetailsTypeText,
Text: block.Thinking,
})
if len(reasoningText) > 0 {
reasoningText += "\n"
}
reasoningText += *block.Thinking
}
}
}
}
if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText {
content = contentBlocks[0].Text
contentBlocks = nil
}
// Create the message content
messageContent := &schemas.ChatMessageContent{
ContentStr: content,
ContentBlocks: contentBlocks,
}
// Convert tool calls
if cm.ToolCalls != nil {
for _, toolCall := range cm.ToolCalls {
// Check if Function is nil to avoid nil pointer dereference
if toolCall.Function == nil {
// Skip this tool call if Function is nil
continue
}
// Safely extract function name and arguments
var functionName *string
var functionArguments string
if toolCall.Function.Name != nil {
functionName = toolCall.Function.Name
} else {
// Use empty string if Name is nil
functionName = schemas.Ptr("")
}
// Arguments is a string, not a pointer, so it's safe to access directly
functionArguments = toolCall.Function.Arguments
bifrostToolCall := schemas.ChatAssistantMessageToolCall{
Index: uint16(len(toolCalls)),
ID: toolCall.ID,
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: functionName,
Arguments: functionArguments,
},
}
toolCalls = append(toolCalls, bifrostToolCall)
}
}
// Create assistant message if we have tool calls
var assistantMessage *schemas.ChatAssistantMessage
if len(toolCalls) > 0 {
assistantMessage = &schemas.ChatAssistantMessage{
ToolCalls: toolCalls,
}
}
if len(reasoningDetails) > 0 {
if assistantMessage == nil {
assistantMessage = &schemas.ChatAssistantMessage{}
}
assistantMessage.ReasoningDetails = reasoningDetails
assistantMessage.Reasoning = schemas.Ptr(reasoningText)
}
bifrostMessage := &schemas.ChatMessage{
Role: schemas.ChatMessageRole(cm.Role),
Content: messageContent,
ChatAssistantMessage: assistantMessage,
}
if cm.Role == "tool" {
bifrostMessage.ChatToolMessage = &schemas.ChatToolMessage{
ToolCallID: cm.ToolCallID,
}
}
return bifrostMessage
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,62 @@
package cohere_test
import (
"os"
"strings"
"testing"
"github.com/maximhq/bifrost/core/internal/llmtests"
"github.com/maximhq/bifrost/core/schemas"
)
func TestCohere(t *testing.T) {
t.Parallel()
if strings.TrimSpace(os.Getenv("COHERE_API_KEY")) == "" {
t.Skip("Skipping Cohere tests because COHERE_API_KEY is not set")
}
client, ctx, cancel, err := llmtests.SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
testConfig := llmtests.ComprehensiveTestConfig{
Provider: schemas.Cohere,
ChatModel: "command-a-03-2025",
VisionModel: "command-a-vision-07-2025", // Cohere's latest vision model
TextModel: "", // Cohere focuses on chat
EmbeddingModel: "embed-v4.0",
RerankModel: "rerank-v3.5",
ReasoningModel: "command-a-reasoning-08-2025",
Scenarios: llmtests.TestScenarios{
TextCompletion: false, // Not typical for Cohere
SimpleChat: true,
CompletionStream: true,
MultiTurnConversation: true,
ToolCalls: true,
ToolCallsStreaming: true,
MultipleToolCalls: true,
MultipleToolCallsStreaming: true,
End2EndToolCalling: true,
AutomaticFunctionCall: true, // May not support automatic
ImageURL: false, // Supported by c4ai-aya-vision-8b model
ImageBase64: true, // Supported by c4ai-aya-vision-8b model
MultipleImages: false, // Supported by c4ai-aya-vision-8b model
FileBase64: false, // Not supported
FileURL: false, // Not supported
CompleteEnd2End: false,
Embedding: true,
Rerank: true,
Reasoning: true,
ListModels: true,
CountTokens: true,
},
}
t.Run("CohereTests", func(t *testing.T) {
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
})
}

View File

@@ -0,0 +1,125 @@
package cohere
import (
"fmt"
"strings"
"unicode/utf8"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostResponsesRequest converts a Cohere count tokens request to Bifrost format.
func (req *CohereCountTokensRequest) ToBifrostResponsesRequest(ctx *schemas.BifrostContext) *schemas.BifrostResponsesRequest {
if req == nil {
return nil
}
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
userRole := schemas.ResponsesInputMessageRoleUser
return &schemas.BifrostResponsesRequest{
Provider: provider,
Model: model,
Input: []schemas.ResponsesMessage{
{
Role: &userRole,
Content: &schemas.ResponsesMessageContent{
ContentStr: &req.Text,
},
},
},
}
}
// ToCohereCountTokensRequest converts a Bifrost count tokens request to Cohere's tokenize payload.
func ToCohereCountTokensRequest(bifrostReq *schemas.BifrostResponsesRequest) (*CohereCountTokensRequest, error) {
if bifrostReq == nil {
return nil, nil
}
if bifrostReq.Input == nil {
return nil, fmt.Errorf("count tokens input is not provided")
}
text := buildCohereCountTokensText(bifrostReq.Input)
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return nil, fmt.Errorf("count tokens text is empty after conversion")
}
runeCount := utf8.RuneCountInString(trimmed)
if runeCount < cohereTokenizeMinTextLength || runeCount > cohereTokenizeMaxTextLength {
return nil, fmt.Errorf("count tokens text length must be between %d and %d characters", cohereTokenizeMinTextLength, cohereTokenizeMaxTextLength)
}
cohereReq := &CohereCountTokensRequest{
Model: bifrostReq.Model,
Text: trimmed,
}
if bifrostReq.Params != nil {
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
}
return cohereReq, nil
}
// ToBifrostCountTokensResponse converts a Cohere tokenize response to Bifrost format.
func (resp *CohereCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
if resp == nil {
return nil
}
inputTokens := len(resp.Tokens)
if inputTokens == 0 && len(resp.TokenStrings) > 0 {
inputTokens = len(resp.TokenStrings)
}
totalTokens := inputTokens
return &schemas.BifrostCountTokensResponse{
Model: model,
InputTokens: inputTokens,
TotalTokens: &totalTokens,
TokenStrings: resp.TokenStrings,
Tokens: resp.Tokens,
Object: "response.input_tokens",
}
}
// buildCohereCountTokensText flattens Responses messages into a plain text payload for tokenization.
func buildCohereCountTokensText(messages []schemas.ResponsesMessage) string {
var parts []string
for _, msg := range messages {
var contentParts []string
if msg.Content != nil {
if msg.Content.ContentStr != nil {
contentParts = append(contentParts, *msg.Content.ContentStr)
}
for _, block := range msg.Content.ContentBlocks {
if block.Text != nil {
contentParts = append(contentParts, *block.Text)
}
if block.ResponsesOutputMessageContentRefusal != nil && block.ResponsesOutputMessageContentRefusal.Refusal != "" {
contentParts = append(contentParts, block.ResponsesOutputMessageContentRefusal.Refusal)
}
}
}
if msg.ResponsesReasoning != nil {
for _, summary := range msg.ResponsesReasoning.Summary {
if summary.Text != "" {
contentParts = append(contentParts, summary.Text)
}
}
}
if len(contentParts) == 0 {
continue
}
parts = append(parts, strings.Join(contentParts, "\n"))
}
return strings.TrimSpace(strings.Join(parts, "\n"))
}

View File

@@ -0,0 +1,190 @@
package cohere
import (
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToCohereEmbeddingRequest converts a Bifrost embedding request to Cohere format
func ToCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *CohereEmbeddingRequest {
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
return nil
}
embeddingInput := bifrostReq.Input
cohereReq := &CohereEmbeddingRequest{
Model: bifrostReq.Model,
}
texts := []string{}
if embeddingInput.Text != nil {
texts = append(texts, *embeddingInput.Text)
} else {
texts = embeddingInput.Texts
}
// Convert texts from Bifrost format
if len(texts) > 0 {
cohereReq.Texts = texts
}
// Set default input type if not specified in extra params
cohereReq.InputType = "search_document" // Default value
if bifrostReq.Params != nil {
cohereReq.OutputDimension = bifrostReq.Params.Dimensions
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
if bifrostReq.Params.ExtraParams != nil {
if maxTokens, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["max_tokens"]); ok {
delete(cohereReq.ExtraParams, "max_tokens")
cohereReq.MaxTokens = maxTokens
}
}
}
// Handle extra params
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
// Input type
if inputType, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["input_type"]); ok {
delete(cohereReq.ExtraParams, "input_type")
cohereReq.InputType = inputType
}
// Embedding types
if embeddingTypes, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["embedding_types"]); ok {
if len(embeddingTypes) > 0 {
delete(cohereReq.ExtraParams, "embedding_types")
cohereReq.EmbeddingTypes = embeddingTypes
}
}
// Truncate
if truncate, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["truncate"]); ok {
delete(cohereReq.ExtraParams, "truncate")
cohereReq.Truncate = truncate
}
}
return cohereReq
}
// ToBifrostEmbeddingRequest converts a Cohere embedding request to Bifrost format
func (req *CohereEmbeddingRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest {
if req == nil {
return nil
}
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
bifrostReq := &schemas.BifrostEmbeddingRequest{
Provider: provider,
Model: model,
Input: &schemas.EmbeddingInput{},
Params: &schemas.EmbeddingParameters{},
}
// Convert texts
if len(req.Texts) > 0 {
if len(req.Texts) == 1 {
bifrostReq.Input.Text = &req.Texts[0]
} else {
bifrostReq.Input.Texts = req.Texts
}
}
// Convert parameters
if req.OutputDimension != nil {
bifrostReq.Params.Dimensions = req.OutputDimension
}
// Convert extra params
extraParams := make(map[string]interface{})
if req.InputType != "" {
extraParams["input_type"] = req.InputType
}
if req.EmbeddingTypes != nil {
extraParams["embedding_types"] = req.EmbeddingTypes
}
if req.Truncate != nil {
extraParams["truncate"] = *req.Truncate
}
if req.MaxTokens != nil {
extraParams["max_tokens"] = *req.MaxTokens
}
if len(extraParams) > 0 {
bifrostReq.Params.ExtraParams = extraParams
}
return bifrostReq
}
// ToBifrostEmbeddingResponse converts a Cohere embedding response to Bifrost format
func (response *CohereEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostEmbeddingResponse{
Object: "list",
}
// Convert embeddings data
if response.Embeddings != nil {
var bifrostEmbeddings []schemas.EmbeddingData
// Handle different embedding types - prioritize float embeddings
if response.Embeddings.Float != nil {
for i, embedding := range response.Embeddings.Float {
bifrostEmbedding := schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{
EmbeddingArray: embedding,
},
}
bifrostEmbeddings = append(bifrostEmbeddings, bifrostEmbedding)
}
} else if response.Embeddings.Base64 != nil {
// Handle base64 embeddings as strings
for i, embedding := range response.Embeddings.Base64 {
bifrostEmbedding := schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{
EmbeddingStr: &embedding,
},
}
bifrostEmbeddings = append(bifrostEmbeddings, bifrostEmbedding)
}
}
// Note: Int8, Uint8, Binary, Ubinary types would need special handling
// depending on how Bifrost wants to represent them
bifrostResponse.Data = bifrostEmbeddings
}
// Convert usage information
if response.Meta != nil {
if response.Meta.Tokens != nil {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{}
if response.Meta.Tokens.InputTokens != nil {
bifrostResponse.Usage.PromptTokens = int(*response.Meta.Tokens.InputTokens)
}
if response.Meta.Tokens.OutputTokens != nil {
bifrostResponse.Usage.CompletionTokens = int(*response.Meta.Tokens.OutputTokens)
}
bifrostResponse.Usage.TotalTokens = bifrostResponse.Usage.PromptTokens + bifrostResponse.Usage.CompletionTokens
} else if response.Meta.BilledUnits != nil {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{}
if response.Meta.BilledUnits.InputTokens != nil {
bifrostResponse.Usage.PromptTokens = int(*response.Meta.BilledUnits.InputTokens)
}
if response.Meta.BilledUnits.OutputTokens != nil {
bifrostResponse.Usage.CompletionTokens = int(*response.Meta.BilledUnits.OutputTokens)
}
bifrostResponse.Usage.TotalTokens = bifrostResponse.Usage.PromptTokens + bifrostResponse.Usage.CompletionTokens
}
}
return bifrostResponse
}

View File

@@ -0,0 +1,90 @@
package cohere
import (
"context"
"testing"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestToCohereEmbeddingRequest(t *testing.T) {
t.Run("returns nil for missing input", func(t *testing.T) {
assert.Nil(t, ToCohereEmbeddingRequest(nil))
assert.Nil(t, ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{}))
assert.Nil(t, ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{
Input: &schemas.EmbeddingInput{},
}))
})
t.Run("single text keeps model in direct cohere body", func(t *testing.T) {
text := "hello"
truncate := "END"
dimensions := 1024
maxTokens := 256
bifrostReq := &schemas.BifrostEmbeddingRequest{
Model: "embed-v4.0",
Input: &schemas.EmbeddingInput{Text: &text},
Params: &schemas.EmbeddingParameters{
Dimensions: &dimensions,
ExtraParams: map[string]interface{}{
"input_type": "classification",
"embedding_types": []string{"float", "int8"},
"truncate": truncate,
"max_tokens": maxTokens,
"priority": "high",
},
},
}
req := ToCohereEmbeddingRequest(bifrostReq)
require.NotNil(t, req)
assert.Equal(t, "embed-v4.0", req.Model)
assert.Equal(t, "classification", req.InputType)
assert.Equal(t, []string{"hello"}, req.Texts)
assert.Equal(t, []string{"float", "int8"}, req.EmbeddingTypes)
assert.Equal(t, &dimensions, req.OutputDimension)
assert.Equal(t, &maxTokens, req.MaxTokens)
require.NotNil(t, req.Truncate)
assert.Equal(t, truncate, *req.Truncate)
assert.Equal(t, map[string]interface{}{"priority": "high"}, req.ExtraParams)
})
t.Run("multiple texts use default input type", func(t *testing.T) {
req := ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{
Model: "embed-english-v3.0",
Input: &schemas.EmbeddingInput{Texts: []string{"hello", "world"}},
})
require.NotNil(t, req)
assert.Equal(t, "embed-english-v3.0", req.Model)
assert.Equal(t, "search_document", req.InputType)
assert.Equal(t, []string{"hello", "world"}, req.Texts)
assert.Nil(t, req.ExtraParams)
})
}
func TestToCohereEmbeddingRequestBodyIncludesModelForDirectCohere(t *testing.T) {
text := "hello"
bifrostReq := &schemas.BifrostEmbeddingRequest{
Model: "embed-v4.0",
Input: &schemas.EmbeddingInput{Text: &text},
}
wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
context.Background(),
bifrostReq,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToCohereEmbeddingRequest(bifrostReq), nil
},
schemas.Cohere,
)
require.Nil(t, bifrostErr)
assert.JSONEq(t, `{
"model": "embed-v4.0",
"input_type": "search_document",
"texts": ["hello"]
}`, string(wireBody))
}

View File

@@ -0,0 +1,21 @@
package cohere
import (
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
func parseCohereError(resp *fasthttp.Response) *schemas.BifrostError {
var errorResp CohereError
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
bifrostErr.Type = &errorResp.Type
if bifrostErr.Error == nil {
bifrostErr.Error = &schemas.ErrorField{}
}
bifrostErr.Error.Message = errorResp.Message
if errorResp.Code != nil {
bifrostErr.Error.Code = errorResp.Code
}
return bifrostErr
}

View File

@@ -0,0 +1,92 @@
package cohere
import (
"encoding/json"
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// CohereRerankRequest represents a Cohere rerank API request.
type CohereRerankRequest struct {
Model string `json:"model"`
Query string `json:"query"`
Documents []string `json:"documents"`
TopN *int `json:"top_n,omitempty"`
MaxTokensPerDoc *int `json:"max_tokens_per_doc,omitempty"`
Priority *int `json:"priority,omitempty"`
ExtraParams map[string]interface{} `json:"-"`
}
// GetExtraParams returns extra parameters for the rerank request.
func (r *CohereRerankRequest) GetExtraParams() map[string]interface{} {
return r.ExtraParams
}
// CohereRerankResult represents a single result from Cohere rerank.
type CohereRerankResult struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
Document json.RawMessage `json:"document,omitempty"`
}
// CohereRerankResponse represents a Cohere rerank API response.
type CohereRerankResponse struct {
ID string `json:"id"`
Results []CohereRerankResult `json:"results"`
Meta *CohereRerankMeta `json:"meta,omitempty"`
}
// CohereRerankMeta represents metadata in Cohere rerank response.
type CohereRerankMeta struct {
APIVersion *CohereEmbeddingAPIVersion `json:"api_version,omitempty"`
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"`
Tokens *CohereTokenUsage `json:"tokens,omitempty"`
}
func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(response.Models)),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: providerKey,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
for _, model := range response.Models {
// Cohere uses model.Name as the model identifier
for _, result := range pipeline.FilterModel(model.Name) {
entry := schemas.Model{
ID: string(providerKey) + "/" + result.ResolvedID,
Name: schemas.Ptr(model.Name),
ContextLength: schemas.Ptr(int(model.ContextLength)),
SupportedMethods: model.Endpoints,
}
if result.AliasValue != "" {
entry.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, entry)
included[strings.ToLower(result.ResolvedID)] = true
}
}
bifrostResponse.Data = append(bifrostResponse.Data,
pipeline.BackfillModels(included)...)
return bifrostResponse
}

View File

@@ -0,0 +1,209 @@
package cohere
import (
"sort"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"gopkg.in/yaml.v3"
)
// ToCohereRerankRequest converts a Bifrost rerank request to Cohere format
func ToCohereRerankRequest(bifrostReq *schemas.BifrostRerankRequest) *CohereRerankRequest {
if bifrostReq == nil {
return nil
}
cohereReq := &CohereRerankRequest{
Model: bifrostReq.Model,
Query: bifrostReq.Query,
}
// Cohere v2 expects documents as a list of strings.
documents := make([]string, len(bifrostReq.Documents))
for i, doc := range bifrostReq.Documents {
documents[i] = formatCohereRerankDocument(doc)
}
cohereReq.Documents = documents
if bifrostReq.Params != nil {
cohereReq.TopN = bifrostReq.Params.TopN
cohereReq.MaxTokensPerDoc = bifrostReq.Params.MaxTokensPerDoc
cohereReq.Priority = bifrostReq.Params.Priority
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
}
return cohereReq
}
// ToBifrostRerankRequest converts a Cohere rerank request to Bifrost format
func (req *CohereRerankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
if req == nil {
return nil
}
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
bifrostReq := &schemas.BifrostRerankRequest{
Provider: provider,
Model: model,
Query: req.Query,
Params: &schemas.RerankParameters{},
}
// Convert documents
for _, doc := range req.Documents {
bifrostReq.Documents = append(bifrostReq.Documents, schemas.RerankDocument{
Text: doc,
})
}
if req.TopN != nil {
bifrostReq.Params.TopN = req.TopN
}
if req.MaxTokensPerDoc != nil {
bifrostReq.Params.MaxTokensPerDoc = req.MaxTokensPerDoc
}
if req.Priority != nil {
bifrostReq.Params.Priority = req.Priority
}
if req.ExtraParams != nil {
bifrostReq.Params.ExtraParams = req.ExtraParams
}
return bifrostReq
}
// ToBifrostRerankResponse converts a Cohere rerank response to Bifrost format.
func (response *CohereRerankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) *schemas.BifrostRerankResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostRerankResponse{
ID: response.ID,
}
// Convert results
for _, result := range response.Results {
rerankResult := schemas.RerankResult{
Index: result.Index,
RelevanceScore: result.RelevanceScore,
}
// Convert document if present
if len(result.Document) > 0 {
var docMap map[string]interface{}
if err := sonic.Unmarshal(result.Document, &docMap); err == nil {
doc := &schemas.RerankDocument{}
populated := false
if text, ok := docMap["text"].(string); ok {
doc.Text = text
populated = true
}
if id, ok := docMap["id"].(string); ok {
doc.ID = &id
populated = true
}
// Collect metadata: unwrap "metadata"/"meta" keys to avoid nesting
meta := make(map[string]interface{})
if rawMeta, ok := docMap["metadata"].(map[string]interface{}); ok {
for k, v := range rawMeta {
meta[k] = v
}
} else if rawMeta, ok := docMap["meta"].(map[string]interface{}); ok {
for k, v := range rawMeta {
meta[k] = v
}
}
for k, v := range docMap {
if k != "text" && k != "id" && k != "metadata" && k != "meta" {
meta[k] = v
}
}
if len(meta) > 0 {
doc.Meta = meta
populated = true
}
if populated {
rerankResult.Document = doc
}
}
}
bifrostResponse.Results = append(bifrostResponse.Results, rerankResult)
}
sort.SliceStable(bifrostResponse.Results, func(i, j int) bool {
if bifrostResponse.Results[i].RelevanceScore == bifrostResponse.Results[j].RelevanceScore {
return bifrostResponse.Results[i].Index < bifrostResponse.Results[j].Index
}
return bifrostResponse.Results[i].RelevanceScore > bifrostResponse.Results[j].RelevanceScore
})
if returnDocuments {
for i := range bifrostResponse.Results {
resultIndex := bifrostResponse.Results[i].Index
if resultIndex >= 0 && resultIndex < len(documents) {
bifrostResponse.Results[i].Document = schemas.Ptr(documents[resultIndex])
}
}
}
// Convert usage information
if response.Meta != nil {
promptTokens := 0
completionTokens := 0
hasTokenUsage := false
if response.Meta.Tokens != nil {
if response.Meta.Tokens.InputTokens != nil {
promptTokens = int(*response.Meta.Tokens.InputTokens)
hasTokenUsage = true
}
if response.Meta.Tokens.OutputTokens != nil {
completionTokens = int(*response.Meta.Tokens.OutputTokens)
hasTokenUsage = true
}
} else if response.Meta.BilledUnits != nil {
if response.Meta.BilledUnits.InputTokens != nil {
promptTokens = int(*response.Meta.BilledUnits.InputTokens)
hasTokenUsage = true
}
if response.Meta.BilledUnits.OutputTokens != nil {
completionTokens = int(*response.Meta.BilledUnits.OutputTokens)
hasTokenUsage = true
}
}
if hasTokenUsage {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
}
return bifrostResponse
}
func formatCohereRerankDocument(doc schemas.RerankDocument) string {
if doc.ID == nil && len(doc.Meta) == 0 {
return doc.Text
}
// Keep metadata/id available by encoding a structured string document.
documentPayload := map[string]interface{}{
"text": doc.Text,
}
if doc.ID != nil {
documentPayload["id"] = *doc.ID
}
if len(doc.Meta) > 0 {
documentPayload["metadata"] = doc.Meta
}
encoded, err := yaml.Marshal(documentPayload)
if err != nil {
return doc.Text
}
return string(encoded)
}

View File

@@ -0,0 +1,72 @@
package cohere
import (
"encoding/json"
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCohereRerankResponseToBifrostRerankResponse(t *testing.T) {
response := (&CohereRerankResponse{
ID: "rerank-response-id",
Results: []CohereRerankResult{
{
Index: 1,
RelevanceScore: 0.62,
Document: json.RawMessage(`{"text":"provider-doc-1","id":"doc-1","topic":"geography"}`),
},
{
Index: 0,
RelevanceScore: 0.91,
Document: json.RawMessage(`{"text":"provider-doc-0"}`),
},
},
}).ToBifrostRerankResponse(nil, false)
require.NotNil(t, response)
assert.Equal(t, "rerank-response-id", response.ID)
require.Len(t, response.Results, 2)
assert.Equal(t, 0, response.Results[0].Index)
assert.Equal(t, 1, response.Results[1].Index)
require.NotNil(t, response.Results[0].Document)
require.NotNil(t, response.Results[1].Document)
assert.Equal(t, "provider-doc-0", response.Results[0].Document.Text)
assert.Equal(t, "provider-doc-1", response.Results[1].Document.Text)
require.NotNil(t, response.Results[1].Document.ID)
assert.Equal(t, "doc-1", *response.Results[1].Document.ID)
assert.Equal(t, "geography", response.Results[1].Document.Meta["topic"])
}
func TestCohereRerankResponseToBifrostRerankResponseReturnDocuments(t *testing.T) {
requestDocs := []schemas.RerankDocument{
{Text: "request-doc-0"},
{Text: "request-doc-1"},
}
response := (&CohereRerankResponse{
Results: []CohereRerankResult{
{
Index: 1,
RelevanceScore: 0.62,
Document: json.RawMessage(`{"text":"provider-doc-1"}`),
},
{
Index: 0,
RelevanceScore: 0.91,
Document: json.RawMessage(`{"text":"provider-doc-0"}`),
},
},
}).ToBifrostRerankResponse(requestDocs, true)
require.NotNil(t, response)
require.Len(t, response.Results, 2)
require.NotNil(t, response.Results[0].Document)
require.NotNil(t, response.Results[1].Document)
assert.Equal(t, 0, response.Results[0].Index)
assert.Equal(t, 1, response.Results[1].Index)
assert.Equal(t, "request-doc-0", response.Results[0].Document.Text)
assert.Equal(t, "request-doc-1", response.Results[1].Document.Text)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,616 @@
package cohere
import (
"encoding/json"
"fmt"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
const (
MinimumReasoningMaxTokens = 1
DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default
)
// Limits for tokenize input api call https://docs.cohere.com/reference/tokenize#request
const (
cohereTokenizeMinTextLength = 1
cohereTokenizeMaxTextLength = 65536
)
// ==================== REQUEST TYPES ====================
// CohereChatRequest represents a Cohere chat completion request
type CohereChatRequest struct {
Model string `json:"model"` // Required: Model to use for chat completion
Messages []CohereMessage `json:"messages"` // Required: Array of message objects
Tools []CohereChatRequestTool `json:"tools,omitempty"` // Optional: Tools available for the model
ToolChoice *CohereToolChoice `json:"tool_choice,omitempty"` // Optional: Tool choice configuration
Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature
P *float64 `json:"p,omitempty"` // Optional: Top-p sampling
K *int `json:"k,omitempty"` // Optional: Top-k sampling
MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate
StopSequences []string `json:"stop_sequences,omitempty"` // Optional: Stop sequences
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty
Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming
SafetyMode *string `json:"safety_mode,omitempty"` // Optional: Safety mode
LogProbs *bool `json:"log_probs,omitempty"` // Optional: Log probabilities
StrictToolChoice *bool `json:"strict_tool_choice,omitempty"` // Optional: Strict tool choice
Thinking *CohereThinking `json:"thinking,omitempty"` // Optional: Reasoning configuration
ResponseFormat *CohereResponseFormat `json:"response_format,omitempty"` // Optional: Format for the response
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
}
// IsStreamingRequested implements the StreamingRequest interface
func (r *CohereChatRequest) IsStreamingRequested() bool {
return r.Stream != nil && *r.Stream
}
func (r *CohereChatRequest) GetExtraParams() map[string]interface{} {
return r.ExtraParams
}
type CohereChatRequestTool struct {
Type string `json:"type"` // always "function"
Function CohereChatRequestFunction `json:"function"`
}
type CohereChatRequestFunction struct {
Name string `json:"name"` // Function name
Parameters interface{} `json:"parameters,omitempty"` // Function parameters (JSON string)
Description *string `json:"description,omitempty"` // Optional: Function description
}
// CohereMessage represents a message in Cohere format
type CohereMessage struct {
Role string `json:"role"` // Required: Message role (system, user, assistant, tool)
Content *CohereMessageContent `json:"content,omitempty"` // Optional: Message content (string or array of content blocks)
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"` // Optional: Tool calls (for assistant messages)
ToolCallID *string `json:"tool_call_id,omitempty"` // Optional: Tool call ID (for tool messages)
ToolPlan *string `json:"tool_plan,omitempty"` // Optional: Chain-of-thought style reflection (assistant only)
}
// CohereMessageContent represents flexible content that can be string or content blocks
type CohereMessageContent struct {
// Use custom marshaling to handle string or []CohereContentBlock
StringContent *string `json:"-"`
BlocksContent []CohereContentBlock `json:"-"`
}
// MarshalJSON implements custom JSON marshaling for CohereMessageContent
func (c *CohereMessageContent) MarshalJSON() ([]byte, error) {
if c.StringContent != nil {
return providerUtils.MarshalSorted(*c.StringContent)
}
if c.BlocksContent != nil {
return providerUtils.MarshalSorted(c.BlocksContent)
}
return []byte("null"), nil
}
// UnmarshalJSON implements custom JSON unmarshaling for CohereMessageContent
func (c *CohereMessageContent) UnmarshalJSON(data []byte) error {
// Try to unmarshal as string first
var str string
if err := sonic.Unmarshal(data, &str); err == nil {
c.StringContent = &str
return nil
}
// Try to unmarshal as content blocks array
var blocks []CohereContentBlock
if err := sonic.Unmarshal(data, &blocks); err == nil {
c.BlocksContent = blocks
return nil
}
return fmt.Errorf("content must be either string or array of content blocks")
}
// Helper methods for CohereMessageContent
// NewStringContent creates a CohereMessageContent with string content
func NewStringContent(content string) *CohereMessageContent {
return &CohereMessageContent{
StringContent: &content,
}
}
// NewBlocksContent creates a CohereMessageContent with content blocks
func NewBlocksContent(blocks []CohereContentBlock) *CohereMessageContent {
return &CohereMessageContent{
BlocksContent: blocks,
}
}
// IsString returns true if content is a string
func (c *CohereMessageContent) IsString() bool {
return c.StringContent != nil
}
// IsBlocks returns true if content is content blocks
func (c *CohereMessageContent) IsBlocks() bool {
return c.BlocksContent != nil
}
// GetString returns the string content (nil if not string)
func (c *CohereMessageContent) GetString() *string {
return c.StringContent
}
// GetBlocks returns the content blocks (nil if not blocks)
func (c *CohereMessageContent) GetBlocks() []CohereContentBlock {
return c.BlocksContent
}
type CohereContentBlockType string
const (
CohereContentBlockTypeText CohereContentBlockType = "text"
CohereContentBlockTypeImage CohereContentBlockType = "image_url"
CohereContentBlockTypeThinking CohereContentBlockType = "thinking"
CohereContentBlockTypeDocument CohereContentBlockType = "document"
)
// CohereContentBlock represents a content block in Cohere format
// This is a union type that can be text, image_url, thinking, or document
type CohereContentBlock struct {
Type CohereContentBlockType `json:"type"` // Required: Content block type
// Text content block
Text *string `json:"text,omitempty"`
// Image URL content block
ImageURL *CohereImageURL `json:"image_url,omitempty"`
// Thinking content block (assistant only)
Thinking *string `json:"thinking,omitempty"`
// Document content block (tool messages)
Document *CohereDocument `json:"document,omitempty"`
}
// CohereImageURL represents an image URL content block
type CohereImageURL struct {
URL string `json:"url"` // Required: Image URL
}
// CohereDocument represents a document content block
type CohereDocument struct {
Data schemas.OrderedMap `json:"data"` // Required: Document data as key-value pairs
ID *string `json:"id,omitempty"` // Optional: Document ID for citations
}
// CohereThinking represents reasoning configuration
type CohereThinking struct {
Type CohereThinkingType `json:"type"` // Required: Reasoning type (enabled, disabled)
TokenBudget *int `json:"token_budget,omitempty"` // Optional: Maximum thinking tokens (>=1)
}
// CohereThinkingType represents the type of reasoning
type CohereThinkingType string
const (
ThinkingTypeEnabled CohereThinkingType = "enabled"
ThinkingTypeDisabled CohereThinkingType = "disabled"
)
// CohereResponseFormat represents the response format configuration for Cohere chat requests
type CohereResponseFormat struct {
Type CohereResponseFormatType `json:"type"` // Required: Response format type
JSONSchema *interface{} `json:"schema,omitempty"` // Optional: JSON schema for structured output (not used when type is "text")
}
// CohereResponseFormatType represents the type of response format
type CohereResponseFormatType string
const (
ResponseFormatTypeText CohereResponseFormatType = "text"
ResponseFormatTypeJSONObject CohereResponseFormatType = "json_object"
)
// CohereToolChoice represents tool choice configuration
type CohereToolChoice string
const (
ToolChoiceRequired CohereToolChoice = "REQUIRED"
ToolChoiceNone CohereToolChoice = "NONE"
ToolChoiceAuto CohereToolChoice = "AUTO"
)
// CohereToolCall represents a tool call in Cohere format
type CohereToolCall struct {
ID *string `json:"id,omitempty"` // Optional: Tool call ID
Type string `json:"type"` // Required: Tool call type (must be "function")
Function *CohereFunction `json:"function"` // Required: Function call details
}
// CohereFunction represents a function call
type CohereFunction struct {
Name *string `json:"name,omitempty"` // Optional: Function name
Arguments string `json:"arguments,omitempty"` // Optional: Function arguments (JSON string)
}
// CohereParameterDefinition represents a parameter definition for a Cohere tool.
// It defines the type, description, and whether the parameter is required.
type CohereParameterDefinition struct {
Type string `json:"type"` // Type of the parameter
Description *string `json:"description,omitempty"` // Optional description of the parameter
Required bool `json:"required"` // Whether the parameter is required
}
// CohereTool represents a tool definition for the Cohere API.
// It includes the tool's name, description, and parameter definitions.
type CohereTool struct {
Name string `json:"name"` // Name of the tool
Description string `json:"description"` // Description of the tool
ParameterDefinitions map[string]CohereParameterDefinition `json:"parameter_definitions"` // Definitions of the tool's parameters
}
// CohereCountTokensRequest represents a Cohere tokenize request
type CohereCountTokensRequest struct {
Model string `json:"model"` // Required: Model whose tokenizer should be used
Text string `json:"text"` // Required: Text to tokenize (1-65536 chars)
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
}
func (r *CohereCountTokensRequest) GetExtraParams() map[string]interface{} {
return r.ExtraParams
}
// CohereEmbeddingRequest represents a Cohere embedding request
type CohereEmbeddingRequest struct {
Model string `json:"model"` // Required: ID of embedding model
InputType string `json:"input_type"` // Required: Type of input for v3+ models
Texts []string `json:"texts,omitempty"` // Optional: Array of strings to embed (max 96)
Images []string `json:"images,omitempty"` // Optional: Array of image data URIs (max 1)
Inputs []CohereEmbeddingInput `json:"inputs,omitempty"` // Optional: Array of mixed text/image inputs (max 96)
MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Max tokens to embed per input
OutputDimension *int `json:"output_dimension,omitempty"` // Optional: Embedding dimensions (256, 512, 1024, 1536)
EmbeddingTypes []string `json:"embedding_types,omitempty"` // Optional: Types of embeddings to return
Truncate *string `json:"truncate,omitempty"` // Optional: How to handle long inputs
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
}
func (r *CohereEmbeddingRequest) GetExtraParams() map[string]interface{} {
return r.ExtraParams
}
// CohereEmbeddingInput represents a mixed text/image input
type CohereEmbeddingInput struct {
Content []CohereContentBlock `json:"content"` // Required: Array of content blocks (reuses chat content blocks)
}
// CohereEmbeddingResponse represents a Cohere embedding response
type CohereEmbeddingResponse struct {
ID string `json:"id"` // Response ID
Embeddings *CohereEmbeddingData `json:"embeddings,omitempty"` // Embedding data object
ResponseType *string `json:"response_type,omitempty"` // Response type (embeddings_floats, embeddings_by_type)
Texts []string `json:"texts,omitempty"` // Original text entries
Images []CohereEmbeddingImageInfo `json:"images,omitempty"` // Original image entries
Meta *CohereEmbeddingMeta `json:"meta,omitempty"` // Response metadata
}
// CohereEmbeddingData represents the embeddings object with different types
type CohereEmbeddingData struct {
Float [][]float64 `json:"float,omitempty"` // Float embeddings
Int8 [][]int8 `json:"int8,omitempty"` // Int8 embeddings
Uint8 [][]uint8 `json:"uint8,omitempty"` // Uint8 embeddings
Binary [][]int8 `json:"binary,omitempty"` // Binary embeddings
Ubinary [][]uint8 `json:"ubinary,omitempty"` // Unsigned binary embeddings
Base64 []string `json:"base64,omitempty"` // Base64 embeddings
}
// CohereEmbeddingImageInfo represents image information in the response
type CohereEmbeddingImageInfo struct {
Width int64 `json:"width"` // Width in pixels
Height int64 `json:"height"` // Height in pixels
Format string `json:"format"` // Image format
BitDepth int64 `json:"bit_depth"` // Bit depth
}
// CohereEmbeddingMeta represents metadata in embedding response
type CohereEmbeddingMeta struct {
APIVersion *CohereEmbeddingAPIVersion `json:"api_version,omitempty"` // API version info
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"` // Billing information
Tokens *CohereTokenUsage `json:"tokens,omitempty"` // Token usage
Warnings []string `json:"warnings,omitempty"` // Any warnings
}
// CohereEmbeddingAPIVersion represents API version information
type CohereEmbeddingAPIVersion struct {
Version *string `json:"version,omitempty"` // API version
IsDeprecated *bool `json:"is_deprecated,omitempty"` // Deprecation status
IsExperimental *bool `json:"is_experimental,omitempty"` // Experimental status
}
// ==================== RESPONSE TYPES ====================
// CohereCountTokensResponse represents the response from the tokenize endpoint
type CohereCountTokensResponse struct {
Tokens []int `json:"tokens"`
TokenStrings []string `json:"token_strings,omitempty"`
Meta *CohereTokenizeMeta `json:"meta,omitempty"`
}
// CohereTokenizeMeta captures metadata returned by the tokenize endpoint
type CohereTokenizeMeta struct {
APIVersion *CohereTokenizeAPIVersion `json:"api_version,omitempty"`
}
// CohereTokenizeAPIVersion describes API version metadata
type CohereTokenizeAPIVersion struct {
Version *string `json:"version,omitempty"`
}
// CohereChatResponse represents a Cohere chat completion response
type CohereChatResponse struct {
ID string `json:"id"` // Unique identifier for the generated reply
FinishReason *CohereFinishReason `json:"finish_reason,omitempty"` // Reason for completion
Message *CohereMessage `json:"message,omitempty"` // Generated message from assistant
Usage *CohereUsage `json:"usage,omitempty"` // Token usage information
LogProbs []CohereLogProb `json:"logprobs,omitempty"` // Log probabilities (if requested)
}
// CohereFinishReason represents the reason a chat request has finished
type CohereFinishReason string
const (
FinishReasonComplete CohereFinishReason = "COMPLETE" // Model finished sending complete message
FinishReasonStopSequence CohereFinishReason = "STOP_SEQUENCE" // Stop sequence was reached
FinishReasonMaxTokens CohereFinishReason = "MAX_TOKENS" // Max tokens exceeded
FinishReasonToolCall CohereFinishReason = "TOOL_CALL" // Model generated tool call
FinishReasonError CohereFinishReason = "ERROR" // Generation failed due to internal error
FinishReasonTimeout CohereFinishReason = "TIMEOUT" // Timeout
)
// CohereUsage represents token usage information
type CohereUsage struct {
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"` // Billed usage information
Tokens *CohereTokenUsage `json:"tokens,omitempty"` // Token usage details
CachedTokens *int `json:"cached_tokens,omitempty"` // Cached tokens
}
// CohereBilledUnits represents billed usage information
type CohereBilledUnits struct {
InputTokens *int `json:"input_tokens,omitempty"` // Number of billed input tokens
OutputTokens *int `json:"output_tokens,omitempty"` // Number of billed output tokens
SearchUnits *int `json:"search_units,omitempty"` // Number of billed search units
Classifications *int `json:"classifications,omitempty"` // Number of billed classification units
}
// CohereTokenUsage represents detailed token usage
type CohereTokenUsage struct {
InputTokens *int `json:"input_tokens"` // Number of input tokens used
OutputTokens *int `json:"output_tokens"` // Number of output tokens produced
}
// CohereLogProb represents log probability information
type CohereLogProb struct {
TokenIDs []int `json:"token_ids"` // Token IDs of each token in text chunk
Text *string `json:"text,omitempty"` // Text chunk for log probabilities
LogProbs []float64 `json:"logprobs,omitempty"` // Log probability of each token
}
type CohereCitationType string
const (
CitationTypeTextContent CohereCitationType = "TEXT_CONTENT"
CitationTypeThinkingContent CohereCitationType = "THINKING_CONTENT"
CitationTypePlan CohereCitationType = "PLAN"
)
type CohereSourceType string
const (
SourceTypeTool CohereSourceType = "tool"
SourceTypeDocument CohereSourceType = "document"
)
// CohereCitation represents a citation in the response
type CohereCitation struct {
Start int `json:"start"` // Start position of cited text
End int `json:"end"` // End position of cited text
Text string `json:"text"` // Cited text
Sources []CohereSource `json:"sources,omitempty"` // Citation sources
ContentIndex int `json:"content_index"` // Content index of the citation
Type CohereCitationType `json:"type"` // Type of citation
}
// CohereSource represents a citation source
type CohereSource struct {
Type CohereSourceType `json:"type"` // Source type ("tool" or "document")
ID *string `json:"id,omitempty"` // Source ID (nullable)
ToolOutput *json.RawMessage `json:"tool_output,omitempty"` // Tool output (for tool sources, json.RawMessage preserves key ordering)
Document *json.RawMessage `json:"document,omitempty"` // Document data (for document sources, json.RawMessage preserves key ordering)
}
// ==================== STREAMING TYPES ====================
// CohereStreamEventType represents the type of streaming event
type CohereStreamEventType string
const (
StreamEventMessageStart CohereStreamEventType = "message-start"
StreamEventContentStart CohereStreamEventType = "content-start"
StreamEventContentDelta CohereStreamEventType = "content-delta"
StreamEventContentEnd CohereStreamEventType = "content-end"
StreamEventToolPlanDelta CohereStreamEventType = "tool-plan-delta"
StreamEventToolCallStart CohereStreamEventType = "tool-call-start"
StreamEventToolCallDelta CohereStreamEventType = "tool-call-delta"
StreamEventToolCallEnd CohereStreamEventType = "tool-call-end"
StreamEventCitationStart CohereStreamEventType = "citation-start"
StreamEventCitationEnd CohereStreamEventType = "citation-end"
StreamEventMessageEnd CohereStreamEventType = "message-end"
StreamEventDebug CohereStreamEventType = "debug"
)
// CohereStreamEvent represents a unified streaming event from Cohere API
type CohereStreamEvent struct {
Type CohereStreamEventType `json:"type"`
ID *string `json:"id,omitempty"` // For message-start
Index *int `json:"index,omitempty"` // For indexed events
Delta *CohereStreamDelta `json:"delta,omitempty"`
}
// CohereStreamDelta represents the delta content in streaming events
type CohereStreamDelta struct {
Message *CohereStreamMessage `json:"message,omitempty"`
FinishReason *CohereFinishReason `json:"finish_reason,omitempty"`
Usage *CohereUsage `json:"usage,omitempty"`
}
type CohereStreamToolCallStruct struct {
CohereToolCallObject *CohereToolCall
CohereToolCallArray []CohereToolCall
}
// JSON marshaling for CohereStreamToolCall
func (c *CohereStreamToolCallStruct) MarshalJSON() ([]byte, error) {
if c.CohereToolCallObject != nil {
return providerUtils.MarshalSorted(c.CohereToolCallObject)
}
if c.CohereToolCallArray != nil {
return providerUtils.MarshalSorted(c.CohereToolCallArray)
}
return providerUtils.MarshalSorted(nil)
}
func (c *CohereStreamToolCallStruct) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
return nil
}
// Try to unmarshal as array first
var toolCallArray []CohereToolCall
if err := sonic.Unmarshal(data, &toolCallArray); err == nil {
c.CohereToolCallArray = toolCallArray
return nil
}
// Try to unmarshal as single object
var toolCallObject CohereToolCall
if err := sonic.Unmarshal(data, &toolCallObject); err == nil {
c.CohereToolCallObject = &toolCallObject
return nil
}
return fmt.Errorf("tool_calls field is neither array nor object")
}
type CohereStreamContentStruct struct {
CohereStreamContentObject *CohereStreamContent
CohereStreamContentArray []CohereStreamContent
}
func (c *CohereStreamContentStruct) MarshalJSON() ([]byte, error) {
if c.CohereStreamContentObject != nil {
return providerUtils.MarshalSorted(c.CohereStreamContentObject)
}
if c.CohereStreamContentArray != nil {
return providerUtils.MarshalSorted(c.CohereStreamContentArray)
}
return providerUtils.MarshalSorted(nil)
}
func (c *CohereStreamContentStruct) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
return nil
}
// Try to unmarshal as array first
var contentArray []CohereStreamContent
if err := sonic.Unmarshal(data, &contentArray); err == nil {
c.CohereStreamContentArray = contentArray
return nil
}
// Try to unmarshal as single object
var contentObject CohereStreamContent
if err := sonic.Unmarshal(data, &contentObject); err == nil {
c.CohereStreamContentObject = &contentObject
return nil
}
return fmt.Errorf("content field is neither array nor object")
}
type CohereStreamCitationStruct struct {
CohereStreamCitationObject *CohereCitation
CohereStreamCitationArray []CohereCitation
}
func (c *CohereStreamCitationStruct) MarshalJSON() ([]byte, error) {
if c.CohereStreamCitationObject != nil {
return providerUtils.MarshalSorted(c.CohereStreamCitationObject)
}
if c.CohereStreamCitationArray != nil {
return providerUtils.MarshalSorted(c.CohereStreamCitationArray)
}
return providerUtils.MarshalSorted(nil)
}
func (c *CohereStreamCitationStruct) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
return nil
}
// Try to unmarshal as array first
var citationArray []CohereCitation
if err := sonic.Unmarshal(data, &citationArray); err == nil {
c.CohereStreamCitationArray = citationArray
return nil
}
// Try to unmarshal as single object
var citationObject CohereCitation
if err := sonic.Unmarshal(data, &citationObject); err == nil {
c.CohereStreamCitationObject = &citationObject
return nil
}
return fmt.Errorf("citations field is neither array nor object")
}
// CohereStreamMessage represents the message part of streaming deltas
type CohereStreamMessage struct {
Role *string `json:"role,omitempty"` // For message-start
Content *CohereStreamContentStruct `json:"content,omitempty"` // For content events (object)
ToolPlan *string `json:"tool_plan,omitempty"` // For tool-plan-delta
ToolCalls *CohereStreamToolCallStruct `json:"tool_calls,omitempty"` // For tool-call events (flexible)
Citations *CohereStreamCitationStruct `json:"citations,omitempty"` // For citation events
}
// CohereStreamContent represents content in streaming events
type CohereStreamContent struct {
Type CohereContentBlockType `json:"type,omitempty"` // For content-start
Text *string `json:"text,omitempty"` // For content deltas
Thinking *string `json:"thinking,omitempty"` // For thinking deltas
}
// ==================== ERROR TYPES ====================
// CohereError represents an error response from the Cohere API
type CohereError struct {
Type string `json:"type"` // Error type
Message string `json:"message"` // Error message
Code *string `json:"code,omitempty"` // Optional error code
}
// ==================== MODEL TYPES ====================
type CohereModel struct {
Name string `json:"name"`
IsDeprecated bool `json:"is_deprecated"`
Endpoints []string `json:"endpoints"`
Finetuned bool `json:"finetuned"`
ContextLength int `json:"context_length"`
TokenizerURL string `json:"tokenizer_url"`
DefaultEndpoints []string `json:"default_endpoints"`
Features []string `json:"features"`
}
type CohereListModelsResponse struct {
Models []CohereModel `json:"models"`
NextPageToken string `json:"next_page_token"`
}

View File

@@ -0,0 +1,293 @@
package cohere
import (
"encoding/json"
"github.com/maximhq/bifrost/core/schemas"
"github.com/tidwall/sjson"
)
var (
// Maps provider-specific finish reasons to Bifrost format
cohereFinishReasonToBifrost = map[CohereFinishReason]string{
FinishReasonComplete: "stop",
FinishReasonStopSequence: "stop",
FinishReasonMaxTokens: "length",
FinishReasonToolCall: "tool_calls",
}
)
// ConvertCohereFinishReasonToBifrost converts provider finish reasons to Bifrost format
func ConvertCohereFinishReasonToBifrost(providerReason CohereFinishReason) string {
if bifrostReason, ok := cohereFinishReasonToBifrost[providerReason]; ok {
return bifrostReason
}
return string(providerReason)
}
// convertInterfaceToToolFunctionParameters converts an interface{} to ToolFunctionParameters
// This handles the conversion from Cohere's flexible parameter format to Bifrost's structured format
func convertInterfaceToToolFunctionParameters(params interface{}) *schemas.ToolFunctionParameters {
if params == nil {
return nil
}
// Try to convert from map[string]interface{}
paramsMap, ok := params.(map[string]interface{})
if !ok {
return nil
}
result := &schemas.ToolFunctionParameters{}
// Extract type
if typeVal, ok := paramsMap["type"].(string); ok {
result.Type = typeVal
}
// Extract description
if descVal, ok := paramsMap["description"].(string); ok {
result.Description = &descVal
}
// Extract required
if requiredVal, ok := paramsMap["required"].([]interface{}); ok {
required := make([]string, 0, len(requiredVal))
for _, v := range requiredVal {
if s, ok := v.(string); ok {
required = append(required, s)
}
}
result.Required = required
}
// Extract properties
if orderedProps, ok := schemas.SafeExtractOrderedMap(paramsMap["properties"]); ok {
result.Properties = orderedProps
}
// Extract enum
if enumVal, ok := paramsMap["enum"].([]interface{}); ok {
enum := make([]string, 0, len(enumVal))
for _, v := range enumVal {
if s, ok := v.(string); ok {
enum = append(enum, s)
}
}
result.Enum = enum
}
// Extract additionalProperties
if addPropsVal, ok := paramsMap["additionalProperties"].(bool); ok {
result.AdditionalProperties = &schemas.AdditionalPropertiesStruct{
AdditionalPropertiesBool: &addPropsVal,
}
}
if addPropsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["additionalProperties"]); ok {
result.AdditionalProperties = &schemas.AdditionalPropertiesStruct{
AdditionalPropertiesMap: addPropsVal,
}
}
// Extract $defs (JSON Schema draft 2019-09+)
if defsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["$defs"]); ok {
result.Defs = defsVal
}
// Extract definitions (legacy JSON Schema draft-07)
if defsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["definitions"]); ok {
result.Definitions = defsVal
}
// Extract $ref
if refVal, ok := paramsMap["$ref"].(string); ok {
result.Ref = &refVal
}
// Extract items (array element schema)
if itemsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["items"]); ok {
result.Items = itemsVal
}
// Extract minItems
if minItemsVal, ok := extractInt64(paramsMap["minItems"]); ok {
result.MinItems = &minItemsVal
}
// Extract maxItems
if maxItemsVal, ok := extractInt64(paramsMap["maxItems"]); ok {
result.MaxItems = &maxItemsVal
}
// Extract anyOf
if anyOfVal, ok := paramsMap["anyOf"].([]interface{}); ok {
anyOf := make([]schemas.OrderedMap, 0, len(anyOfVal))
for _, v := range anyOfVal {
if m, ok := schemas.SafeExtractOrderedMap(v); ok {
anyOf = append(anyOf, *m)
}
}
result.AnyOf = anyOf
}
// Extract oneOf
if oneOfVal, ok := paramsMap["oneOf"].([]interface{}); ok {
oneOf := make([]schemas.OrderedMap, 0, len(oneOfVal))
for _, v := range oneOfVal {
if m, ok := schemas.SafeExtractOrderedMap(v); ok {
oneOf = append(oneOf, *m)
}
}
result.OneOf = oneOf
}
// Extract allOf
if allOfVal, ok := paramsMap["allOf"].([]interface{}); ok {
allOf := make([]schemas.OrderedMap, 0, len(allOfVal))
for _, v := range allOfVal {
if m, ok := schemas.SafeExtractOrderedMap(v); ok {
allOf = append(allOf, *m)
}
}
result.AllOf = allOf
}
// Extract format
if formatVal, ok := paramsMap["format"].(string); ok {
result.Format = &formatVal
}
// Extract pattern
if patternVal, ok := paramsMap["pattern"].(string); ok {
result.Pattern = &patternVal
}
// Extract minLength
if minLengthVal, ok := extractInt64(paramsMap["minLength"]); ok {
result.MinLength = &minLengthVal
}
// Extract maxLength
if maxLengthVal, ok := extractInt64(paramsMap["maxLength"]); ok {
result.MaxLength = &maxLengthVal
}
// Extract minimum
if minVal, ok := extractFloat64(paramsMap["minimum"]); ok {
result.Minimum = &minVal
}
// Extract maximum
if maxVal, ok := extractFloat64(paramsMap["maximum"]); ok {
result.Maximum = &maxVal
}
// Extract title
if titleVal, ok := paramsMap["title"].(string); ok {
result.Title = &titleVal
}
// Extract default
if defaultVal, exists := paramsMap["default"]; exists {
result.Default = defaultVal
}
// Extract nullable
if nullableVal, ok := paramsMap["nullable"].(bool); ok {
result.Nullable = &nullableVal
}
return result
}
// extractInt64 extracts an int64 from various numeric types
func extractInt64(v interface{}) (int64, bool) {
switch val := v.(type) {
case int:
return int64(val), true
case int64:
return val, true
case float64:
return int64(val), true
case float32:
return int64(val), true
default:
return 0, false
}
}
// extractFloat64 extracts a float64 from various numeric types
func extractFloat64(v interface{}) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
// ConvertResponseFormatToCohere converts OpenAI-style response_format (interface{}) to Cohere's typed format
// Input can be a map with structure: { type: "json_schema", json_schema: { schema: {...} } }
// Output: CohereResponseFormat with flat structure: { type: "json_object", json_schema: {...} }
func convertResponseFormatToCohere(responseFormat *interface{}) *CohereResponseFormat {
if responseFormat == nil {
return nil
}
// Try to extract as map
formatMap, ok := (*responseFormat).(map[string]interface{})
if !ok {
return nil
}
cohereFormat := &CohereResponseFormat{}
// Extract type
typeStr, _ := formatMap["type"].(string)
switch typeStr {
case "text":
cohereFormat.Type = ResponseFormatTypeText
case "json_object", "json_schema":
cohereFormat.Type = ResponseFormatTypeJSONObject
// Extract the nested schema
// OpenAI format: { type: "json_schema", json_schema: { name: "X", strict: true, schema: {...} } }
if jsonSchemaWrapper, ok := formatMap["json_schema"].(map[string]interface{}); ok {
if schema, ok := jsonSchemaWrapper["schema"].(map[string]interface{}); ok {
var schemaInterface interface{} = schema
cohereFormat.JSONSchema = &schemaInterface
}
}
default:
return nil
}
return cohereFormat
}
// convertCohereResponseFormatToBifrost converts Cohere's typed response_format back to interface{}
func convertCohereResponseFormatToBifrost(cohereFormat *CohereResponseFormat) *interface{} {
if cohereFormat == nil {
return nil
}
// Build JSON bytes with deterministic key order using sjson
data := []byte(`{}`)
if cohereFormat.JSONSchema != nil {
data, _ = sjson.SetBytes(data, "type", "json_schema")
schemaBytes, _ := schemas.MarshalSorted(cohereFormat.JSONSchema)
data, _ = sjson.SetRawBytes(data, "json_schema", schemaBytes)
} else {
data, _ = sjson.SetBytes(data, "type", string(cohereFormat.Type))
}
var resultInterface interface{} = json.RawMessage(data)
return &resultInterface
}