first commit
This commit is contained in:
722
core/providers/cohere/chat.go
Normal file
722
core/providers/cohere/chat.go
Normal file
@@ -0,0 +1,722 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/anthropic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToCohereChatCompletionRequest converts a Bifrost request to Cohere v2 format
|
||||
func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*CohereChatRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("bifrost request is nil")
|
||||
}
|
||||
|
||||
messages := bifrostReq.Input
|
||||
cohereReq := &CohereChatRequest{
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
|
||||
// Convert messages to Cohere v2 format
|
||||
var cohereMessages []CohereMessage
|
||||
for _, msg := range messages {
|
||||
cohereMsg := CohereMessage{
|
||||
Role: string(msg.Role),
|
||||
}
|
||||
|
||||
// Convert content
|
||||
if msg.Content != nil && msg.Content.ContentStr != nil {
|
||||
cohereMsg.Content = NewStringContent(*msg.Content.ContentStr)
|
||||
} else if msg.Content != nil && msg.Content.ContentBlocks != nil {
|
||||
var contentBlocks []CohereContentBlock
|
||||
for _, block := range msg.Content.ContentBlocks {
|
||||
if block.Text != nil {
|
||||
contentBlocks = append(contentBlocks, CohereContentBlock{
|
||||
Type: CohereContentBlockTypeText,
|
||||
Text: block.Text,
|
||||
})
|
||||
} else if block.ImageURLStruct != nil {
|
||||
contentBlocks = append(contentBlocks, CohereContentBlock{
|
||||
Type: CohereContentBlockTypeImage,
|
||||
ImageURL: &CohereImageURL{
|
||||
URL: block.ImageURLStruct.URL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(contentBlocks) > 0 {
|
||||
cohereMsg.Content = NewBlocksContent(contentBlocks)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tool calls for assistant messages
|
||||
if msg.ChatAssistantMessage != nil && msg.ChatAssistantMessage.ToolCalls != nil {
|
||||
var toolCalls []CohereToolCall
|
||||
for _, toolCall := range msg.ChatAssistantMessage.ToolCalls {
|
||||
// Safely extract function name and arguments
|
||||
var functionName *string
|
||||
var functionArguments string
|
||||
|
||||
if toolCall.Function.Name != nil {
|
||||
functionName = toolCall.Function.Name
|
||||
} else {
|
||||
// Use empty string if Name is nil
|
||||
functionName = schemas.Ptr("")
|
||||
}
|
||||
|
||||
// Arguments is a string, not a pointer, so it's safe to access directly
|
||||
// Default to "{}" if empty to ensure the field is always present.
|
||||
if toolCall.Function.Arguments == "" {
|
||||
functionArguments = "{}"
|
||||
} else {
|
||||
functionArguments = toolCall.Function.Arguments
|
||||
}
|
||||
|
||||
cohereToolCall := CohereToolCall{
|
||||
ID: toolCall.ID,
|
||||
Type: "function",
|
||||
Function: &CohereFunction{
|
||||
Name: functionName,
|
||||
Arguments: functionArguments,
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, cohereToolCall)
|
||||
}
|
||||
cohereMsg.ToolCalls = toolCalls
|
||||
}
|
||||
|
||||
// Convert tool messages
|
||||
if msg.ChatToolMessage != nil && msg.ChatToolMessage.ToolCallID != nil {
|
||||
cohereMsg.ToolCallID = msg.ChatToolMessage.ToolCallID
|
||||
}
|
||||
|
||||
cohereMessages = append(cohereMessages, cohereMsg)
|
||||
}
|
||||
|
||||
cohereReq.Messages = cohereMessages
|
||||
|
||||
// Convert parameters
|
||||
if bifrostReq.Params != nil {
|
||||
cohereReq.MaxTokens = bifrostReq.Params.MaxCompletionTokens
|
||||
cohereReq.Temperature = bifrostReq.Params.Temperature
|
||||
cohereReq.P = bifrostReq.Params.TopP
|
||||
cohereReq.StopSequences = bifrostReq.Params.Stop
|
||||
cohereReq.FrequencyPenalty = bifrostReq.Params.FrequencyPenalty
|
||||
cohereReq.PresencePenalty = bifrostReq.Params.PresencePenalty
|
||||
|
||||
// Convert reasoning
|
||||
if bifrostReq.Params.Reasoning != nil {
|
||||
if bifrostReq.Params.Reasoning.MaxTokens != nil {
|
||||
thinking := &CohereThinking{
|
||||
Type: ThinkingTypeEnabled,
|
||||
}
|
||||
if *bifrostReq.Params.Reasoning.MaxTokens == -1 {
|
||||
// cohere does not support dynamic reasoning budget like gemini
|
||||
// setting it to minimum reasoning budget
|
||||
thinking.TokenBudget = schemas.Ptr(anthropic.MinimumReasoningMaxTokens)
|
||||
} else {
|
||||
thinking.TokenBudget = bifrostReq.Params.Reasoning.MaxTokens
|
||||
}
|
||||
cohereReq.Thinking = thinking
|
||||
} else if bifrostReq.Params.Reasoning.Effort != nil {
|
||||
if *bifrostReq.Params.Reasoning.Effort != "none" {
|
||||
maxCompletionTokens := providerUtils.GetMaxOutputTokensOrDefault(bifrostReq.Model, DefaultCompletionMaxTokens)
|
||||
if bifrostReq.Params.MaxCompletionTokens != nil {
|
||||
maxCompletionTokens = *bifrostReq.Params.MaxCompletionTokens
|
||||
}
|
||||
budgetTokens, err := providerUtils.GetBudgetTokensFromReasoningEffort(*bifrostReq.Params.Reasoning.Effort, MinimumReasoningMaxTokens, maxCompletionTokens)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cohereReq.Thinking = &CohereThinking{
|
||||
Type: ThinkingTypeEnabled,
|
||||
TokenBudget: schemas.Ptr(budgetTokens), // Max tokens for reasoning
|
||||
}
|
||||
} else {
|
||||
cohereReq.Thinking = &CohereThinking{
|
||||
Type: ThinkingTypeDisabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert response format
|
||||
if bifrostReq.Params.ResponseFormat != nil {
|
||||
cohereReq.ResponseFormat = convertResponseFormatToCohere(bifrostReq.Params.ResponseFormat)
|
||||
}
|
||||
|
||||
// Convert extra params
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
// Handle thinking parameter
|
||||
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if thinkingParam, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "thinking"); ok {
|
||||
if thinkingMap, ok := thinkingParam.(map[string]interface{}); ok {
|
||||
thinking := &CohereThinking{}
|
||||
if typeStr, ok := schemas.SafeExtractString(thinkingMap["type"]); ok {
|
||||
delete(thinkingMap, "type")
|
||||
thinking.Type = CohereThinkingType(typeStr)
|
||||
}
|
||||
if tokenBudget, ok := schemas.SafeExtractIntPointer(thinkingMap["token_budget"]); ok {
|
||||
delete(thinkingMap, "token_budget")
|
||||
thinking.TokenBudget = tokenBudget
|
||||
}
|
||||
cohereReq.Thinking = thinking
|
||||
cohereReq.ExtraParams["thinking"] = thinkingMap
|
||||
}
|
||||
}
|
||||
|
||||
// Handle other Cohere-specific extra params
|
||||
if safetyMode, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["safety_mode"]); ok {
|
||||
delete(cohereReq.ExtraParams, "safety_mode")
|
||||
cohereReq.SafetyMode = safetyMode
|
||||
}
|
||||
|
||||
if logProbs, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["log_probs"]); ok {
|
||||
delete(cohereReq.ExtraParams, "log_probs")
|
||||
cohereReq.LogProbs = logProbs
|
||||
}
|
||||
|
||||
if strictToolChoice, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["strict_tool_choice"]); ok {
|
||||
delete(cohereReq.ExtraParams, "strict_tool_choice")
|
||||
cohereReq.StrictToolChoice = strictToolChoice
|
||||
}
|
||||
}
|
||||
|
||||
// Convert tools to Cohere-specific format (without "strict" field)
|
||||
if bifrostReq.Params.Tools != nil {
|
||||
cohereTools := make([]CohereChatRequestTool, len(bifrostReq.Params.Tools))
|
||||
for i, tool := range bifrostReq.Params.Tools {
|
||||
cohereTools[i] = CohereChatRequestTool{
|
||||
Type: string(tool.Type),
|
||||
}
|
||||
if tool.Function != nil {
|
||||
cohereTools[i].Function = CohereChatRequestFunction{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
Parameters: tool.Function.Parameters, // Convert to map
|
||||
// Note: No "strict" field - Cohere doesn't support it
|
||||
}
|
||||
}
|
||||
}
|
||||
cohereReq.Tools = cohereTools
|
||||
}
|
||||
|
||||
// Convert tool choice
|
||||
if bifrostReq.Params.ToolChoice != nil {
|
||||
toolChoice := bifrostReq.Params.ToolChoice
|
||||
|
||||
if toolChoice.ChatToolChoiceStr != nil {
|
||||
switch schemas.ChatToolChoiceType(*toolChoice.ChatToolChoiceStr) {
|
||||
case schemas.ChatToolChoiceTypeNone:
|
||||
toolChoice := ToolChoiceNone
|
||||
cohereReq.ToolChoice = &toolChoice
|
||||
default:
|
||||
toolChoice := ToolChoiceRequired
|
||||
cohereReq.ToolChoice = &toolChoice
|
||||
}
|
||||
} else if toolChoice.ChatToolChoiceStruct != nil {
|
||||
switch toolChoice.ChatToolChoiceStruct.Type {
|
||||
case schemas.ChatToolChoiceTypeFunction:
|
||||
toolChoice := ToolChoiceRequired
|
||||
cohereReq.ToolChoice = &toolChoice
|
||||
default:
|
||||
toolChoice := ToolChoiceAuto
|
||||
cohereReq.ToolChoice = &toolChoice
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cohereReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostChatRequest converts a Cohere v2 chat request to Bifrost format
|
||||
func (req *CohereChatRequest) ToBifrostChatRequest(ctx *schemas.BifrostContext) *schemas.BifrostChatRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
|
||||
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Params: &schemas.ChatParameters{},
|
||||
}
|
||||
// Convert messages
|
||||
if req.Messages != nil {
|
||||
bifrostMessages := make([]schemas.ChatMessage, len(req.Messages))
|
||||
for i, message := range req.Messages {
|
||||
bifrostMessages[i] = *message.ToBifrostChatMessage()
|
||||
}
|
||||
bifrostReq.Input = bifrostMessages
|
||||
}
|
||||
// Convert parameters
|
||||
if req.MaxTokens != nil {
|
||||
bifrostReq.Params.MaxCompletionTokens = req.MaxTokens
|
||||
}
|
||||
if req.Temperature != nil {
|
||||
bifrostReq.Params.Temperature = req.Temperature
|
||||
}
|
||||
if req.P != nil {
|
||||
bifrostReq.Params.TopP = req.P
|
||||
}
|
||||
if req.StopSequences != nil {
|
||||
bifrostReq.Params.Stop = req.StopSequences
|
||||
}
|
||||
if req.FrequencyPenalty != nil {
|
||||
bifrostReq.Params.FrequencyPenalty = req.FrequencyPenalty
|
||||
}
|
||||
if req.PresencePenalty != nil {
|
||||
bifrostReq.Params.PresencePenalty = req.PresencePenalty
|
||||
}
|
||||
|
||||
// Convert reasoning
|
||||
if req.Thinking != nil {
|
||||
if req.Thinking.Type == ThinkingTypeDisabled {
|
||||
bifrostReq.Params.Reasoning = &schemas.ChatReasoning{
|
||||
Effort: schemas.Ptr("none"),
|
||||
}
|
||||
} else {
|
||||
bifrostReq.Params.Reasoning = &schemas.ChatReasoning{
|
||||
Effort: schemas.Ptr("auto"),
|
||||
}
|
||||
if req.Thinking.TokenBudget != nil {
|
||||
bifrostReq.Params.Reasoning.MaxTokens = req.Thinking.TokenBudget
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.ResponseFormat != nil {
|
||||
bifrostReq.Params.ResponseFormat = convertCohereResponseFormatToBifrost(req.ResponseFormat)
|
||||
}
|
||||
|
||||
// Convert tools
|
||||
if req.Tools != nil {
|
||||
bifrostTools := make([]schemas.ChatTool, len(req.Tools))
|
||||
for i, tool := range req.Tools {
|
||||
bifrostTools[i] = schemas.ChatTool{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
Parameters: convertInterfaceToToolFunctionParameters(tool.Function.Parameters),
|
||||
},
|
||||
}
|
||||
}
|
||||
bifrostReq.Params.Tools = bifrostTools
|
||||
}
|
||||
|
||||
// Convert tool choice
|
||||
if req.ToolChoice != nil {
|
||||
switch *req.ToolChoice {
|
||||
case ToolChoiceNone:
|
||||
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeNone)),
|
||||
}
|
||||
case ToolChoiceRequired:
|
||||
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeRequired)),
|
||||
}
|
||||
case ToolChoiceAuto:
|
||||
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
|
||||
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeAny)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert extra params
|
||||
extraParams := make(map[string]interface{})
|
||||
if req.SafetyMode != nil {
|
||||
extraParams["safety_mode"] = *req.SafetyMode
|
||||
}
|
||||
if req.LogProbs != nil {
|
||||
extraParams["log_probs"] = *req.LogProbs
|
||||
}
|
||||
if req.StrictToolChoice != nil {
|
||||
extraParams["strict_tool_choice"] = *req.StrictToolChoice
|
||||
}
|
||||
if req.Thinking != nil {
|
||||
thinkingMap := map[string]interface{}{
|
||||
"type": string(req.Thinking.Type),
|
||||
}
|
||||
if req.Thinking.TokenBudget != nil {
|
||||
thinkingMap["token_budget"] = *req.Thinking.TokenBudget
|
||||
}
|
||||
extraParams["thinking"] = thinkingMap
|
||||
}
|
||||
if len(extraParams) > 0 {
|
||||
bifrostReq.Params.ExtraParams = extraParams
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToBifrostChatResponse converts a Cohere v2 response to Bifrost format
|
||||
func (response *CohereChatResponse) ToBifrostChatResponse(model string) *schemas.BifrostChatResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostChatResponse{
|
||||
ID: response.ID,
|
||||
Model: model,
|
||||
Object: "chat.completion",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{},
|
||||
},
|
||||
},
|
||||
Created: int(time.Now().Unix()),
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
},
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
if response.Message != nil {
|
||||
bifrostMessage := response.Message.ToBifrostChatMessage()
|
||||
bifrostResponse.Choices[0].ChatNonStreamResponseChoice.Message = bifrostMessage
|
||||
}
|
||||
|
||||
// Convert finish reason
|
||||
if response.FinishReason != nil {
|
||||
finishReason := ConvertCohereFinishReasonToBifrost(*response.FinishReason)
|
||||
bifrostResponse.Choices[0].FinishReason = schemas.Ptr(finishReason)
|
||||
}
|
||||
|
||||
// Convert usage information
|
||||
if response.Usage != nil {
|
||||
usage := &schemas.BifrostLLMUsage{}
|
||||
|
||||
if response.Usage.Tokens != nil {
|
||||
if response.Usage.Tokens.InputTokens != nil {
|
||||
usage.PromptTokens = *response.Usage.Tokens.InputTokens
|
||||
}
|
||||
if response.Usage.Tokens.OutputTokens != nil {
|
||||
usage.CompletionTokens = *response.Usage.Tokens.OutputTokens
|
||||
}
|
||||
if response.Usage.CachedTokens != nil {
|
||||
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
|
||||
CachedReadTokens: *response.Usage.CachedTokens,
|
||||
}
|
||||
}
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
|
||||
bifrostResponse.Usage = usage
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
func (chunk *CohereStreamEvent) ToBifrostChatCompletionStream() (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) {
|
||||
switch chunk.Type {
|
||||
case StreamEventMessageStart:
|
||||
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.Role != nil {
|
||||
// Create streaming response for this delta
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Role: chunk.Delta.Message.Role,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
}
|
||||
|
||||
case StreamEventContentDelta:
|
||||
if chunk.Delta != nil &&
|
||||
chunk.Delta.Message != nil &&
|
||||
chunk.Delta.Message.Content != nil &&
|
||||
chunk.Delta.Message.Content.CohereStreamContentObject != nil {
|
||||
if chunk.Delta.Message.Content.CohereStreamContentObject.Text != nil {
|
||||
// Try to cast content to CohereStreamContent
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Content: chunk.Delta.Message.Content.CohereStreamContentObject.Text,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
} else if chunk.Delta.Message.Content.CohereStreamContentObject.Thinking != nil {
|
||||
thinkingText := *chunk.Delta.Message.Content.CohereStreamContentObject.Thinking
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Reasoning: schemas.Ptr(thinkingText),
|
||||
ReasoningDetails: []schemas.ChatReasoningDetails{
|
||||
{
|
||||
Index: 0,
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: schemas.Ptr(thinkingText),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
case StreamEventToolPlanDelta:
|
||||
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolPlan != nil {
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
Reasoning: chunk.Delta.Message.ToolPlan,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
}
|
||||
|
||||
case StreamEventContentStart:
|
||||
// Content start event - just continue, actual content comes in content-delta
|
||||
return nil, nil, false
|
||||
|
||||
case StreamEventToolCallStart, StreamEventToolCallDelta:
|
||||
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolCalls != nil && chunk.Delta.Message.ToolCalls.CohereToolCallObject != nil {
|
||||
// Handle single tool call object (tool-call-start/delta events)
|
||||
cohereToolCall := chunk.Delta.Message.ToolCalls.CohereToolCallObject
|
||||
toolCall := schemas.ChatAssistantMessageToolCall{}
|
||||
|
||||
if chunk.Index != nil {
|
||||
toolCall.Index = uint16(*chunk.Index)
|
||||
}
|
||||
|
||||
if cohereToolCall.ID != nil {
|
||||
toolCall.ID = cohereToolCall.ID
|
||||
}
|
||||
|
||||
if cohereToolCall.Function != nil {
|
||||
if cohereToolCall.Function.Name != nil {
|
||||
toolCall.Function.Name = cohereToolCall.Function.Name
|
||||
}
|
||||
toolCall.Function.Arguments = cohereToolCall.Function.Arguments
|
||||
}
|
||||
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return streamResponse, nil, false
|
||||
}
|
||||
|
||||
case StreamEventToolCallEnd:
|
||||
return nil, nil, false
|
||||
|
||||
case StreamEventContentEnd:
|
||||
return nil, nil, false
|
||||
|
||||
case StreamEventMessageEnd:
|
||||
if chunk.Delta != nil {
|
||||
var finishReason string
|
||||
usage := &schemas.BifrostLLMUsage{}
|
||||
// Set finish reason
|
||||
if chunk.Delta.FinishReason != nil {
|
||||
finishReason = ConvertCohereFinishReasonToBifrost(*chunk.Delta.FinishReason)
|
||||
}
|
||||
|
||||
// Set usage information
|
||||
if chunk.Delta.Usage != nil {
|
||||
if chunk.Delta.Usage.Tokens != nil {
|
||||
if chunk.Delta.Usage.Tokens.InputTokens != nil {
|
||||
usage.PromptTokens = *chunk.Delta.Usage.Tokens.InputTokens
|
||||
}
|
||||
if chunk.Delta.Usage.Tokens.OutputTokens != nil {
|
||||
usage.CompletionTokens = *chunk.Delta.Usage.Tokens.OutputTokens
|
||||
}
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Choices: []schemas.BifrostResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: &finishReason,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{},
|
||||
},
|
||||
},
|
||||
},
|
||||
Usage: usage,
|
||||
}
|
||||
|
||||
return streamResponse, nil, true
|
||||
}
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
func (cm *CohereMessage) ToBifrostChatMessage() *schemas.ChatMessage {
|
||||
if cm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var content *string
|
||||
var contentBlocks []schemas.ChatContentBlock
|
||||
var toolCalls []schemas.ChatAssistantMessageToolCall
|
||||
var reasoningDetails []schemas.ChatReasoningDetails
|
||||
var reasoningText string
|
||||
|
||||
// Convert message content
|
||||
if cm.Content != nil {
|
||||
if cm.Content.IsString() ||
|
||||
(cm.Content.IsBlocks() &&
|
||||
len(cm.Content.GetBlocks()) == 1 &&
|
||||
cm.Content.GetBlocks()[0].Type == CohereContentBlockTypeText) {
|
||||
if cm.Content.IsString() {
|
||||
content = cm.Content.GetString()
|
||||
} else {
|
||||
content = cm.Content.GetBlocks()[0].Text
|
||||
}
|
||||
} else if cm.Content.IsBlocks() {
|
||||
for _, block := range cm.Content.GetBlocks() {
|
||||
if block.Type == CohereContentBlockTypeText && block.Text != nil {
|
||||
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: block.Text,
|
||||
})
|
||||
} else if block.Type == CohereContentBlockTypeImage && block.ImageURL != nil {
|
||||
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeImage,
|
||||
ImageURLStruct: &schemas.ChatInputImage{
|
||||
URL: block.ImageURL.URL,
|
||||
},
|
||||
})
|
||||
} else if block.Type == CohereContentBlockTypeThinking && block.Thinking != nil {
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: block.Thinking,
|
||||
})
|
||||
if len(reasoningText) > 0 {
|
||||
reasoningText += "\n"
|
||||
}
|
||||
reasoningText += *block.Thinking
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText {
|
||||
content = contentBlocks[0].Text
|
||||
contentBlocks = nil
|
||||
}
|
||||
|
||||
// Create the message content
|
||||
messageContent := &schemas.ChatMessageContent{
|
||||
ContentStr: content,
|
||||
ContentBlocks: contentBlocks,
|
||||
}
|
||||
|
||||
// Convert tool calls
|
||||
if cm.ToolCalls != nil {
|
||||
for _, toolCall := range cm.ToolCalls {
|
||||
// Check if Function is nil to avoid nil pointer dereference
|
||||
if toolCall.Function == nil {
|
||||
// Skip this tool call if Function is nil
|
||||
continue
|
||||
}
|
||||
|
||||
// Safely extract function name and arguments
|
||||
var functionName *string
|
||||
var functionArguments string
|
||||
|
||||
if toolCall.Function.Name != nil {
|
||||
functionName = toolCall.Function.Name
|
||||
} else {
|
||||
// Use empty string if Name is nil
|
||||
functionName = schemas.Ptr("")
|
||||
}
|
||||
|
||||
// Arguments is a string, not a pointer, so it's safe to access directly
|
||||
functionArguments = toolCall.Function.Arguments
|
||||
|
||||
bifrostToolCall := schemas.ChatAssistantMessageToolCall{
|
||||
Index: uint16(len(toolCalls)),
|
||||
ID: toolCall.ID,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: functionName,
|
||||
Arguments: functionArguments,
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, bifrostToolCall)
|
||||
}
|
||||
}
|
||||
|
||||
// Create assistant message if we have tool calls
|
||||
var assistantMessage *schemas.ChatAssistantMessage
|
||||
if len(toolCalls) > 0 {
|
||||
assistantMessage = &schemas.ChatAssistantMessage{
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
if len(reasoningDetails) > 0 {
|
||||
if assistantMessage == nil {
|
||||
assistantMessage = &schemas.ChatAssistantMessage{}
|
||||
}
|
||||
assistantMessage.ReasoningDetails = reasoningDetails
|
||||
assistantMessage.Reasoning = schemas.Ptr(reasoningText)
|
||||
}
|
||||
|
||||
bifrostMessage := &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRole(cm.Role),
|
||||
Content: messageContent,
|
||||
ChatAssistantMessage: assistantMessage,
|
||||
}
|
||||
|
||||
if cm.Role == "tool" {
|
||||
bifrostMessage.ChatToolMessage = &schemas.ChatToolMessage{
|
||||
ToolCallID: cm.ToolCallID,
|
||||
}
|
||||
}
|
||||
return bifrostMessage
|
||||
}
|
||||
1266
core/providers/cohere/cohere.go
Normal file
1266
core/providers/cohere/cohere.go
Normal file
File diff suppressed because it is too large
Load Diff
62
core/providers/cohere/cohere_test.go
Normal file
62
core/providers/cohere/cohere_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package cohere_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestCohere(t *testing.T) {
|
||||
t.Parallel()
|
||||
if strings.TrimSpace(os.Getenv("COHERE_API_KEY")) == "" {
|
||||
t.Skip("Skipping Cohere tests because COHERE_API_KEY is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.Cohere,
|
||||
ChatModel: "command-a-03-2025",
|
||||
VisionModel: "command-a-vision-07-2025", // Cohere's latest vision model
|
||||
TextModel: "", // Cohere focuses on chat
|
||||
EmbeddingModel: "embed-v4.0",
|
||||
RerankModel: "rerank-v3.5",
|
||||
ReasoningModel: "command-a-reasoning-08-2025",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: false, // Not typical for Cohere
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: true,
|
||||
MultipleToolCallsStreaming: true,
|
||||
End2EndToolCalling: true,
|
||||
AutomaticFunctionCall: true, // May not support automatic
|
||||
ImageURL: false, // Supported by c4ai-aya-vision-8b model
|
||||
ImageBase64: true, // Supported by c4ai-aya-vision-8b model
|
||||
MultipleImages: false, // Supported by c4ai-aya-vision-8b model
|
||||
FileBase64: false, // Not supported
|
||||
FileURL: false, // Not supported
|
||||
CompleteEnd2End: false,
|
||||
Embedding: true,
|
||||
Rerank: true,
|
||||
Reasoning: true,
|
||||
ListModels: true,
|
||||
CountTokens: true,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("CohereTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
125
core/providers/cohere/count_tokens.go
Normal file
125
core/providers/cohere/count_tokens.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostResponsesRequest converts a Cohere count tokens request to Bifrost format.
|
||||
func (req *CohereCountTokensRequest) ToBifrostResponsesRequest(ctx *schemas.BifrostContext) *schemas.BifrostResponsesRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
|
||||
|
||||
userRole := schemas.ResponsesInputMessageRoleUser
|
||||
return &schemas.BifrostResponsesRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: []schemas.ResponsesMessage{
|
||||
{
|
||||
Role: &userRole,
|
||||
Content: &schemas.ResponsesMessageContent{
|
||||
ContentStr: &req.Text,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ToCohereCountTokensRequest converts a Bifrost count tokens request to Cohere's tokenize payload.
|
||||
func ToCohereCountTokensRequest(bifrostReq *schemas.BifrostResponsesRequest) (*CohereCountTokensRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("count tokens input is not provided")
|
||||
}
|
||||
|
||||
text := buildCohereCountTokensText(bifrostReq.Input)
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return nil, fmt.Errorf("count tokens text is empty after conversion")
|
||||
}
|
||||
runeCount := utf8.RuneCountInString(trimmed)
|
||||
if runeCount < cohereTokenizeMinTextLength || runeCount > cohereTokenizeMaxTextLength {
|
||||
return nil, fmt.Errorf("count tokens text length must be between %d and %d characters", cohereTokenizeMinTextLength, cohereTokenizeMaxTextLength)
|
||||
}
|
||||
|
||||
cohereReq := &CohereCountTokensRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Text: trimmed,
|
||||
}
|
||||
if bifrostReq.Params != nil {
|
||||
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
return cohereReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostCountTokensResponse converts a Cohere tokenize response to Bifrost format.
|
||||
func (resp *CohereCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
inputTokens := len(resp.Tokens)
|
||||
if inputTokens == 0 && len(resp.TokenStrings) > 0 {
|
||||
inputTokens = len(resp.TokenStrings)
|
||||
}
|
||||
totalTokens := inputTokens
|
||||
|
||||
return &schemas.BifrostCountTokensResponse{
|
||||
Model: model,
|
||||
InputTokens: inputTokens,
|
||||
TotalTokens: &totalTokens,
|
||||
TokenStrings: resp.TokenStrings,
|
||||
Tokens: resp.Tokens,
|
||||
Object: "response.input_tokens",
|
||||
}
|
||||
}
|
||||
|
||||
// buildCohereCountTokensText flattens Responses messages into a plain text payload for tokenization.
|
||||
func buildCohereCountTokensText(messages []schemas.ResponsesMessage) string {
|
||||
var parts []string
|
||||
|
||||
for _, msg := range messages {
|
||||
var contentParts []string
|
||||
|
||||
if msg.Content != nil {
|
||||
if msg.Content.ContentStr != nil {
|
||||
contentParts = append(contentParts, *msg.Content.ContentStr)
|
||||
}
|
||||
for _, block := range msg.Content.ContentBlocks {
|
||||
if block.Text != nil {
|
||||
contentParts = append(contentParts, *block.Text)
|
||||
}
|
||||
if block.ResponsesOutputMessageContentRefusal != nil && block.ResponsesOutputMessageContentRefusal.Refusal != "" {
|
||||
contentParts = append(contentParts, block.ResponsesOutputMessageContentRefusal.Refusal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if msg.ResponsesReasoning != nil {
|
||||
for _, summary := range msg.ResponsesReasoning.Summary {
|
||||
if summary.Text != "" {
|
||||
contentParts = append(contentParts, summary.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(contentParts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
parts = append(parts, strings.Join(contentParts, "\n"))
|
||||
}
|
||||
|
||||
return strings.TrimSpace(strings.Join(parts, "\n"))
|
||||
}
|
||||
190
core/providers/cohere/embedding.go
Normal file
190
core/providers/cohere/embedding.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToCohereEmbeddingRequest converts a Bifrost embedding request to Cohere format
|
||||
func ToCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *CohereEmbeddingRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
|
||||
return nil
|
||||
}
|
||||
|
||||
embeddingInput := bifrostReq.Input
|
||||
cohereReq := &CohereEmbeddingRequest{
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
|
||||
texts := []string{}
|
||||
if embeddingInput.Text != nil {
|
||||
texts = append(texts, *embeddingInput.Text)
|
||||
} else {
|
||||
texts = embeddingInput.Texts
|
||||
}
|
||||
|
||||
// Convert texts from Bifrost format
|
||||
if len(texts) > 0 {
|
||||
cohereReq.Texts = texts
|
||||
}
|
||||
|
||||
// Set default input type if not specified in extra params
|
||||
cohereReq.InputType = "search_document" // Default value
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
cohereReq.OutputDimension = bifrostReq.Params.Dimensions
|
||||
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
if maxTokens, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["max_tokens"]); ok {
|
||||
delete(cohereReq.ExtraParams, "max_tokens")
|
||||
cohereReq.MaxTokens = maxTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle extra params
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil {
|
||||
// Input type
|
||||
if inputType, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["input_type"]); ok {
|
||||
delete(cohereReq.ExtraParams, "input_type")
|
||||
cohereReq.InputType = inputType
|
||||
}
|
||||
|
||||
// Embedding types
|
||||
if embeddingTypes, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["embedding_types"]); ok {
|
||||
if len(embeddingTypes) > 0 {
|
||||
delete(cohereReq.ExtraParams, "embedding_types")
|
||||
cohereReq.EmbeddingTypes = embeddingTypes
|
||||
}
|
||||
}
|
||||
|
||||
// Truncate
|
||||
if truncate, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["truncate"]); ok {
|
||||
delete(cohereReq.ExtraParams, "truncate")
|
||||
cohereReq.Truncate = truncate
|
||||
}
|
||||
}
|
||||
|
||||
return cohereReq
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingRequest converts a Cohere embedding request to Bifrost format
|
||||
func (req *CohereEmbeddingRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
|
||||
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: &schemas.EmbeddingInput{},
|
||||
Params: &schemas.EmbeddingParameters{},
|
||||
}
|
||||
|
||||
// Convert texts
|
||||
if len(req.Texts) > 0 {
|
||||
if len(req.Texts) == 1 {
|
||||
bifrostReq.Input.Text = &req.Texts[0]
|
||||
} else {
|
||||
bifrostReq.Input.Texts = req.Texts
|
||||
}
|
||||
}
|
||||
|
||||
// Convert parameters
|
||||
if req.OutputDimension != nil {
|
||||
bifrostReq.Params.Dimensions = req.OutputDimension
|
||||
}
|
||||
|
||||
// Convert extra params
|
||||
extraParams := make(map[string]interface{})
|
||||
if req.InputType != "" {
|
||||
extraParams["input_type"] = req.InputType
|
||||
}
|
||||
if req.EmbeddingTypes != nil {
|
||||
extraParams["embedding_types"] = req.EmbeddingTypes
|
||||
}
|
||||
if req.Truncate != nil {
|
||||
extraParams["truncate"] = *req.Truncate
|
||||
}
|
||||
if req.MaxTokens != nil {
|
||||
extraParams["max_tokens"] = *req.MaxTokens
|
||||
}
|
||||
if len(extraParams) > 0 {
|
||||
bifrostReq.Params.ExtraParams = extraParams
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingResponse converts a Cohere embedding response to Bifrost format
|
||||
func (response *CohereEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostEmbeddingResponse{
|
||||
Object: "list",
|
||||
}
|
||||
|
||||
// Convert embeddings data
|
||||
if response.Embeddings != nil {
|
||||
var bifrostEmbeddings []schemas.EmbeddingData
|
||||
|
||||
// Handle different embedding types - prioritize float embeddings
|
||||
if response.Embeddings.Float != nil {
|
||||
for i, embedding := range response.Embeddings.Float {
|
||||
bifrostEmbedding := schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{
|
||||
EmbeddingArray: embedding,
|
||||
},
|
||||
}
|
||||
bifrostEmbeddings = append(bifrostEmbeddings, bifrostEmbedding)
|
||||
}
|
||||
} else if response.Embeddings.Base64 != nil {
|
||||
// Handle base64 embeddings as strings
|
||||
for i, embedding := range response.Embeddings.Base64 {
|
||||
bifrostEmbedding := schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Index: i,
|
||||
Embedding: schemas.EmbeddingStruct{
|
||||
EmbeddingStr: &embedding,
|
||||
},
|
||||
}
|
||||
bifrostEmbeddings = append(bifrostEmbeddings, bifrostEmbedding)
|
||||
}
|
||||
}
|
||||
// Note: Int8, Uint8, Binary, Ubinary types would need special handling
|
||||
// depending on how Bifrost wants to represent them
|
||||
|
||||
bifrostResponse.Data = bifrostEmbeddings
|
||||
}
|
||||
|
||||
// Convert usage information
|
||||
if response.Meta != nil {
|
||||
if response.Meta.Tokens != nil {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{}
|
||||
if response.Meta.Tokens.InputTokens != nil {
|
||||
bifrostResponse.Usage.PromptTokens = int(*response.Meta.Tokens.InputTokens)
|
||||
}
|
||||
if response.Meta.Tokens.OutputTokens != nil {
|
||||
bifrostResponse.Usage.CompletionTokens = int(*response.Meta.Tokens.OutputTokens)
|
||||
}
|
||||
bifrostResponse.Usage.TotalTokens = bifrostResponse.Usage.PromptTokens + bifrostResponse.Usage.CompletionTokens
|
||||
} else if response.Meta.BilledUnits != nil {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{}
|
||||
if response.Meta.BilledUnits.InputTokens != nil {
|
||||
bifrostResponse.Usage.PromptTokens = int(*response.Meta.BilledUnits.InputTokens)
|
||||
}
|
||||
if response.Meta.BilledUnits.OutputTokens != nil {
|
||||
bifrostResponse.Usage.CompletionTokens = int(*response.Meta.BilledUnits.OutputTokens)
|
||||
}
|
||||
bifrostResponse.Usage.TotalTokens = bifrostResponse.Usage.PromptTokens + bifrostResponse.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
90
core/providers/cohere/embedding_test.go
Normal file
90
core/providers/cohere/embedding_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToCohereEmbeddingRequest(t *testing.T) {
|
||||
t.Run("returns nil for missing input", func(t *testing.T) {
|
||||
assert.Nil(t, ToCohereEmbeddingRequest(nil))
|
||||
assert.Nil(t, ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{}))
|
||||
assert.Nil(t, ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{
|
||||
Input: &schemas.EmbeddingInput{},
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("single text keeps model in direct cohere body", func(t *testing.T) {
|
||||
text := "hello"
|
||||
truncate := "END"
|
||||
dimensions := 1024
|
||||
maxTokens := 256
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Model: "embed-v4.0",
|
||||
Input: &schemas.EmbeddingInput{Text: &text},
|
||||
Params: &schemas.EmbeddingParameters{
|
||||
Dimensions: &dimensions,
|
||||
ExtraParams: map[string]interface{}{
|
||||
"input_type": "classification",
|
||||
"embedding_types": []string{"float", "int8"},
|
||||
"truncate": truncate,
|
||||
"max_tokens": maxTokens,
|
||||
"priority": "high",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := ToCohereEmbeddingRequest(bifrostReq)
|
||||
require.NotNil(t, req)
|
||||
assert.Equal(t, "embed-v4.0", req.Model)
|
||||
assert.Equal(t, "classification", req.InputType)
|
||||
assert.Equal(t, []string{"hello"}, req.Texts)
|
||||
assert.Equal(t, []string{"float", "int8"}, req.EmbeddingTypes)
|
||||
assert.Equal(t, &dimensions, req.OutputDimension)
|
||||
assert.Equal(t, &maxTokens, req.MaxTokens)
|
||||
require.NotNil(t, req.Truncate)
|
||||
assert.Equal(t, truncate, *req.Truncate)
|
||||
assert.Equal(t, map[string]interface{}{"priority": "high"}, req.ExtraParams)
|
||||
})
|
||||
|
||||
t.Run("multiple texts use default input type", func(t *testing.T) {
|
||||
req := ToCohereEmbeddingRequest(&schemas.BifrostEmbeddingRequest{
|
||||
Model: "embed-english-v3.0",
|
||||
Input: &schemas.EmbeddingInput{Texts: []string{"hello", "world"}},
|
||||
})
|
||||
|
||||
require.NotNil(t, req)
|
||||
assert.Equal(t, "embed-english-v3.0", req.Model)
|
||||
assert.Equal(t, "search_document", req.InputType)
|
||||
assert.Equal(t, []string{"hello", "world"}, req.Texts)
|
||||
assert.Nil(t, req.ExtraParams)
|
||||
})
|
||||
}
|
||||
|
||||
func TestToCohereEmbeddingRequestBodyIncludesModelForDirectCohere(t *testing.T) {
|
||||
text := "hello"
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Model: "embed-v4.0",
|
||||
Input: &schemas.EmbeddingInput{Text: &text},
|
||||
}
|
||||
|
||||
wireBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
||||
context.Background(),
|
||||
bifrostReq,
|
||||
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
||||
return ToCohereEmbeddingRequest(bifrostReq), nil
|
||||
},
|
||||
schemas.Cohere,
|
||||
)
|
||||
require.Nil(t, bifrostErr)
|
||||
assert.JSONEq(t, `{
|
||||
"model": "embed-v4.0",
|
||||
"input_type": "search_document",
|
||||
"texts": ["hello"]
|
||||
}`, string(wireBody))
|
||||
}
|
||||
21
core/providers/cohere/errors.go
Normal file
21
core/providers/cohere/errors.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func parseCohereError(resp *fasthttp.Response) *schemas.BifrostError {
|
||||
var errorResp CohereError
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResp)
|
||||
bifrostErr.Type = &errorResp.Type
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
bifrostErr.Error.Message = errorResp.Message
|
||||
if errorResp.Code != nil {
|
||||
bifrostErr.Error.Code = errorResp.Code
|
||||
}
|
||||
return bifrostErr
|
||||
}
|
||||
92
core/providers/cohere/models.go
Normal file
92
core/providers/cohere/models.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// CohereRerankRequest represents a Cohere rerank API request.
|
||||
type CohereRerankRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
MaxTokensPerDoc *int `json:"max_tokens_per_doc,omitempty"`
|
||||
Priority *int `json:"priority,omitempty"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
// GetExtraParams returns extra parameters for the rerank request.
|
||||
func (r *CohereRerankRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
// CohereRerankResult represents a single result from Cohere rerank.
|
||||
type CohereRerankResult struct {
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
Document json.RawMessage `json:"document,omitempty"`
|
||||
}
|
||||
|
||||
// CohereRerankResponse represents a Cohere rerank API response.
|
||||
type CohereRerankResponse struct {
|
||||
ID string `json:"id"`
|
||||
Results []CohereRerankResult `json:"results"`
|
||||
Meta *CohereRerankMeta `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// CohereRerankMeta represents metadata in Cohere rerank response.
|
||||
type CohereRerankMeta struct {
|
||||
APIVersion *CohereEmbeddingAPIVersion `json:"api_version,omitempty"`
|
||||
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"`
|
||||
Tokens *CohereTokenUsage `json:"tokens,omitempty"`
|
||||
}
|
||||
|
||||
func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.Models)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: providerKey,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.Models {
|
||||
// Cohere uses model.Name as the model identifier
|
||||
for _, result := range pipeline.FilterModel(model.Name) {
|
||||
entry := schemas.Model{
|
||||
ID: string(providerKey) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(model.Name),
|
||||
ContextLength: schemas.Ptr(int(model.ContextLength)),
|
||||
SupportedMethods: model.Endpoints,
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
entry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, entry)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
209
core/providers/cohere/rerank.go
Normal file
209
core/providers/cohere/rerank.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ToCohereRerankRequest converts a Bifrost rerank request to Cohere format
|
||||
func ToCohereRerankRequest(bifrostReq *schemas.BifrostRerankRequest) *CohereRerankRequest {
|
||||
if bifrostReq == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cohereReq := &CohereRerankRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Query: bifrostReq.Query,
|
||||
}
|
||||
|
||||
// Cohere v2 expects documents as a list of strings.
|
||||
documents := make([]string, len(bifrostReq.Documents))
|
||||
for i, doc := range bifrostReq.Documents {
|
||||
documents[i] = formatCohereRerankDocument(doc)
|
||||
}
|
||||
cohereReq.Documents = documents
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
cohereReq.TopN = bifrostReq.Params.TopN
|
||||
cohereReq.MaxTokensPerDoc = bifrostReq.Params.MaxTokensPerDoc
|
||||
cohereReq.Priority = bifrostReq.Params.Priority
|
||||
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
return cohereReq
|
||||
}
|
||||
|
||||
// ToBifrostRerankRequest converts a Cohere rerank request to Bifrost format
|
||||
func (req *CohereRerankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
|
||||
|
||||
bifrostReq := &schemas.BifrostRerankRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Query: req.Query,
|
||||
Params: &schemas.RerankParameters{},
|
||||
}
|
||||
|
||||
// Convert documents
|
||||
for _, doc := range req.Documents {
|
||||
bifrostReq.Documents = append(bifrostReq.Documents, schemas.RerankDocument{
|
||||
Text: doc,
|
||||
})
|
||||
}
|
||||
|
||||
if req.TopN != nil {
|
||||
bifrostReq.Params.TopN = req.TopN
|
||||
}
|
||||
if req.MaxTokensPerDoc != nil {
|
||||
bifrostReq.Params.MaxTokensPerDoc = req.MaxTokensPerDoc
|
||||
}
|
||||
if req.Priority != nil {
|
||||
bifrostReq.Params.Priority = req.Priority
|
||||
}
|
||||
if req.ExtraParams != nil {
|
||||
bifrostReq.Params.ExtraParams = req.ExtraParams
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToBifrostRerankResponse converts a Cohere rerank response to Bifrost format.
|
||||
func (response *CohereRerankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) *schemas.BifrostRerankResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostRerankResponse{
|
||||
ID: response.ID,
|
||||
}
|
||||
|
||||
// Convert results
|
||||
for _, result := range response.Results {
|
||||
rerankResult := schemas.RerankResult{
|
||||
Index: result.Index,
|
||||
RelevanceScore: result.RelevanceScore,
|
||||
}
|
||||
|
||||
// Convert document if present
|
||||
if len(result.Document) > 0 {
|
||||
var docMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(result.Document, &docMap); err == nil {
|
||||
doc := &schemas.RerankDocument{}
|
||||
populated := false
|
||||
if text, ok := docMap["text"].(string); ok {
|
||||
doc.Text = text
|
||||
populated = true
|
||||
}
|
||||
if id, ok := docMap["id"].(string); ok {
|
||||
doc.ID = &id
|
||||
populated = true
|
||||
}
|
||||
// Collect metadata: unwrap "metadata"/"meta" keys to avoid nesting
|
||||
meta := make(map[string]interface{})
|
||||
if rawMeta, ok := docMap["metadata"].(map[string]interface{}); ok {
|
||||
for k, v := range rawMeta {
|
||||
meta[k] = v
|
||||
}
|
||||
} else if rawMeta, ok := docMap["meta"].(map[string]interface{}); ok {
|
||||
for k, v := range rawMeta {
|
||||
meta[k] = v
|
||||
}
|
||||
}
|
||||
for k, v := range docMap {
|
||||
if k != "text" && k != "id" && k != "metadata" && k != "meta" {
|
||||
meta[k] = v
|
||||
}
|
||||
}
|
||||
if len(meta) > 0 {
|
||||
doc.Meta = meta
|
||||
populated = true
|
||||
}
|
||||
if populated {
|
||||
rerankResult.Document = doc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Results = append(bifrostResponse.Results, rerankResult)
|
||||
}
|
||||
sort.SliceStable(bifrostResponse.Results, func(i, j int) bool {
|
||||
if bifrostResponse.Results[i].RelevanceScore == bifrostResponse.Results[j].RelevanceScore {
|
||||
return bifrostResponse.Results[i].Index < bifrostResponse.Results[j].Index
|
||||
}
|
||||
return bifrostResponse.Results[i].RelevanceScore > bifrostResponse.Results[j].RelevanceScore
|
||||
})
|
||||
if returnDocuments {
|
||||
for i := range bifrostResponse.Results {
|
||||
resultIndex := bifrostResponse.Results[i].Index
|
||||
if resultIndex >= 0 && resultIndex < len(documents) {
|
||||
bifrostResponse.Results[i].Document = schemas.Ptr(documents[resultIndex])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert usage information
|
||||
if response.Meta != nil {
|
||||
promptTokens := 0
|
||||
completionTokens := 0
|
||||
hasTokenUsage := false
|
||||
if response.Meta.Tokens != nil {
|
||||
if response.Meta.Tokens.InputTokens != nil {
|
||||
promptTokens = int(*response.Meta.Tokens.InputTokens)
|
||||
hasTokenUsage = true
|
||||
}
|
||||
if response.Meta.Tokens.OutputTokens != nil {
|
||||
completionTokens = int(*response.Meta.Tokens.OutputTokens)
|
||||
hasTokenUsage = true
|
||||
}
|
||||
} else if response.Meta.BilledUnits != nil {
|
||||
if response.Meta.BilledUnits.InputTokens != nil {
|
||||
promptTokens = int(*response.Meta.BilledUnits.InputTokens)
|
||||
hasTokenUsage = true
|
||||
}
|
||||
if response.Meta.BilledUnits.OutputTokens != nil {
|
||||
completionTokens = int(*response.Meta.BilledUnits.OutputTokens)
|
||||
hasTokenUsage = true
|
||||
}
|
||||
}
|
||||
if hasTokenUsage {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
func formatCohereRerankDocument(doc schemas.RerankDocument) string {
|
||||
if doc.ID == nil && len(doc.Meta) == 0 {
|
||||
return doc.Text
|
||||
}
|
||||
|
||||
// Keep metadata/id available by encoding a structured string document.
|
||||
documentPayload := map[string]interface{}{
|
||||
"text": doc.Text,
|
||||
}
|
||||
if doc.ID != nil {
|
||||
documentPayload["id"] = *doc.ID
|
||||
}
|
||||
if len(doc.Meta) > 0 {
|
||||
documentPayload["metadata"] = doc.Meta
|
||||
}
|
||||
|
||||
encoded, err := yaml.Marshal(documentPayload)
|
||||
if err != nil {
|
||||
return doc.Text
|
||||
}
|
||||
return string(encoded)
|
||||
}
|
||||
72
core/providers/cohere/rerank_test.go
Normal file
72
core/providers/cohere/rerank_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCohereRerankResponseToBifrostRerankResponse(t *testing.T) {
|
||||
response := (&CohereRerankResponse{
|
||||
ID: "rerank-response-id",
|
||||
Results: []CohereRerankResult{
|
||||
{
|
||||
Index: 1,
|
||||
RelevanceScore: 0.62,
|
||||
Document: json.RawMessage(`{"text":"provider-doc-1","id":"doc-1","topic":"geography"}`),
|
||||
},
|
||||
{
|
||||
Index: 0,
|
||||
RelevanceScore: 0.91,
|
||||
Document: json.RawMessage(`{"text":"provider-doc-0"}`),
|
||||
},
|
||||
},
|
||||
}).ToBifrostRerankResponse(nil, false)
|
||||
|
||||
require.NotNil(t, response)
|
||||
assert.Equal(t, "rerank-response-id", response.ID)
|
||||
require.Len(t, response.Results, 2)
|
||||
assert.Equal(t, 0, response.Results[0].Index)
|
||||
assert.Equal(t, 1, response.Results[1].Index)
|
||||
require.NotNil(t, response.Results[0].Document)
|
||||
require.NotNil(t, response.Results[1].Document)
|
||||
assert.Equal(t, "provider-doc-0", response.Results[0].Document.Text)
|
||||
assert.Equal(t, "provider-doc-1", response.Results[1].Document.Text)
|
||||
require.NotNil(t, response.Results[1].Document.ID)
|
||||
assert.Equal(t, "doc-1", *response.Results[1].Document.ID)
|
||||
assert.Equal(t, "geography", response.Results[1].Document.Meta["topic"])
|
||||
}
|
||||
|
||||
func TestCohereRerankResponseToBifrostRerankResponseReturnDocuments(t *testing.T) {
|
||||
requestDocs := []schemas.RerankDocument{
|
||||
{Text: "request-doc-0"},
|
||||
{Text: "request-doc-1"},
|
||||
}
|
||||
|
||||
response := (&CohereRerankResponse{
|
||||
Results: []CohereRerankResult{
|
||||
{
|
||||
Index: 1,
|
||||
RelevanceScore: 0.62,
|
||||
Document: json.RawMessage(`{"text":"provider-doc-1"}`),
|
||||
},
|
||||
{
|
||||
Index: 0,
|
||||
RelevanceScore: 0.91,
|
||||
Document: json.RawMessage(`{"text":"provider-doc-0"}`),
|
||||
},
|
||||
},
|
||||
}).ToBifrostRerankResponse(requestDocs, true)
|
||||
|
||||
require.NotNil(t, response)
|
||||
require.Len(t, response.Results, 2)
|
||||
require.NotNil(t, response.Results[0].Document)
|
||||
require.NotNil(t, response.Results[1].Document)
|
||||
assert.Equal(t, 0, response.Results[0].Index)
|
||||
assert.Equal(t, 1, response.Results[1].Index)
|
||||
assert.Equal(t, "request-doc-0", response.Results[0].Document.Text)
|
||||
assert.Equal(t, "request-doc-1", response.Results[1].Document.Text)
|
||||
}
|
||||
1726
core/providers/cohere/responses.go
Normal file
1726
core/providers/cohere/responses.go
Normal file
File diff suppressed because it is too large
Load Diff
616
core/providers/cohere/types.go
Normal file
616
core/providers/cohere/types.go
Normal file
@@ -0,0 +1,616 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const (
|
||||
MinimumReasoningMaxTokens = 1
|
||||
DefaultCompletionMaxTokens = 4096 // Only used for relative reasoning max token calculation - not passed in body by default
|
||||
)
|
||||
|
||||
// Limits for tokenize input api call https://docs.cohere.com/reference/tokenize#request
|
||||
const (
|
||||
cohereTokenizeMinTextLength = 1
|
||||
cohereTokenizeMaxTextLength = 65536
|
||||
)
|
||||
|
||||
// ==================== REQUEST TYPES ====================
|
||||
|
||||
// CohereChatRequest represents a Cohere chat completion request
|
||||
type CohereChatRequest struct {
|
||||
Model string `json:"model"` // Required: Model to use for chat completion
|
||||
Messages []CohereMessage `json:"messages"` // Required: Array of message objects
|
||||
Tools []CohereChatRequestTool `json:"tools,omitempty"` // Optional: Tools available for the model
|
||||
ToolChoice *CohereToolChoice `json:"tool_choice,omitempty"` // Optional: Tool choice configuration
|
||||
Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature
|
||||
P *float64 `json:"p,omitempty"` // Optional: Top-p sampling
|
||||
K *int `json:"k,omitempty"` // Optional: Top-k sampling
|
||||
MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate
|
||||
StopSequences []string `json:"stop_sequences,omitempty"` // Optional: Stop sequences
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty
|
||||
Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming
|
||||
SafetyMode *string `json:"safety_mode,omitempty"` // Optional: Safety mode
|
||||
LogProbs *bool `json:"log_probs,omitempty"` // Optional: Log probabilities
|
||||
StrictToolChoice *bool `json:"strict_tool_choice,omitempty"` // Optional: Strict tool choice
|
||||
Thinking *CohereThinking `json:"thinking,omitempty"` // Optional: Reasoning configuration
|
||||
ResponseFormat *CohereResponseFormat `json:"response_format,omitempty"` // Optional: Format for the response
|
||||
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
|
||||
}
|
||||
|
||||
// IsStreamingRequested implements the StreamingRequest interface
|
||||
func (r *CohereChatRequest) IsStreamingRequested() bool {
|
||||
return r.Stream != nil && *r.Stream
|
||||
}
|
||||
|
||||
func (r *CohereChatRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
type CohereChatRequestTool struct {
|
||||
Type string `json:"type"` // always "function"
|
||||
Function CohereChatRequestFunction `json:"function"`
|
||||
}
|
||||
|
||||
type CohereChatRequestFunction struct {
|
||||
Name string `json:"name"` // Function name
|
||||
Parameters interface{} `json:"parameters,omitempty"` // Function parameters (JSON string)
|
||||
Description *string `json:"description,omitempty"` // Optional: Function description
|
||||
}
|
||||
|
||||
// CohereMessage represents a message in Cohere format
|
||||
type CohereMessage struct {
|
||||
Role string `json:"role"` // Required: Message role (system, user, assistant, tool)
|
||||
Content *CohereMessageContent `json:"content,omitempty"` // Optional: Message content (string or array of content blocks)
|
||||
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"` // Optional: Tool calls (for assistant messages)
|
||||
ToolCallID *string `json:"tool_call_id,omitempty"` // Optional: Tool call ID (for tool messages)
|
||||
ToolPlan *string `json:"tool_plan,omitempty"` // Optional: Chain-of-thought style reflection (assistant only)
|
||||
}
|
||||
|
||||
// CohereMessageContent represents flexible content that can be string or content blocks
|
||||
type CohereMessageContent struct {
|
||||
// Use custom marshaling to handle string or []CohereContentBlock
|
||||
StringContent *string `json:"-"`
|
||||
BlocksContent []CohereContentBlock `json:"-"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for CohereMessageContent
|
||||
func (c *CohereMessageContent) MarshalJSON() ([]byte, error) {
|
||||
if c.StringContent != nil {
|
||||
return providerUtils.MarshalSorted(*c.StringContent)
|
||||
}
|
||||
if c.BlocksContent != nil {
|
||||
return providerUtils.MarshalSorted(c.BlocksContent)
|
||||
}
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for CohereMessageContent
|
||||
func (c *CohereMessageContent) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as string first
|
||||
var str string
|
||||
if err := sonic.Unmarshal(data, &str); err == nil {
|
||||
c.StringContent = &str
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as content blocks array
|
||||
var blocks []CohereContentBlock
|
||||
if err := sonic.Unmarshal(data, &blocks); err == nil {
|
||||
c.BlocksContent = blocks
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("content must be either string or array of content blocks")
|
||||
}
|
||||
|
||||
// Helper methods for CohereMessageContent
|
||||
|
||||
// NewStringContent creates a CohereMessageContent with string content
|
||||
func NewStringContent(content string) *CohereMessageContent {
|
||||
return &CohereMessageContent{
|
||||
StringContent: &content,
|
||||
}
|
||||
}
|
||||
|
||||
// NewBlocksContent creates a CohereMessageContent with content blocks
|
||||
func NewBlocksContent(blocks []CohereContentBlock) *CohereMessageContent {
|
||||
return &CohereMessageContent{
|
||||
BlocksContent: blocks,
|
||||
}
|
||||
}
|
||||
|
||||
// IsString returns true if content is a string
|
||||
func (c *CohereMessageContent) IsString() bool {
|
||||
return c.StringContent != nil
|
||||
}
|
||||
|
||||
// IsBlocks returns true if content is content blocks
|
||||
func (c *CohereMessageContent) IsBlocks() bool {
|
||||
return c.BlocksContent != nil
|
||||
}
|
||||
|
||||
// GetString returns the string content (nil if not string)
|
||||
func (c *CohereMessageContent) GetString() *string {
|
||||
return c.StringContent
|
||||
}
|
||||
|
||||
// GetBlocks returns the content blocks (nil if not blocks)
|
||||
func (c *CohereMessageContent) GetBlocks() []CohereContentBlock {
|
||||
return c.BlocksContent
|
||||
}
|
||||
|
||||
type CohereContentBlockType string
|
||||
|
||||
const (
|
||||
CohereContentBlockTypeText CohereContentBlockType = "text"
|
||||
CohereContentBlockTypeImage CohereContentBlockType = "image_url"
|
||||
CohereContentBlockTypeThinking CohereContentBlockType = "thinking"
|
||||
CohereContentBlockTypeDocument CohereContentBlockType = "document"
|
||||
)
|
||||
|
||||
// CohereContentBlock represents a content block in Cohere format
|
||||
// This is a union type that can be text, image_url, thinking, or document
|
||||
type CohereContentBlock struct {
|
||||
Type CohereContentBlockType `json:"type"` // Required: Content block type
|
||||
|
||||
// Text content block
|
||||
Text *string `json:"text,omitempty"`
|
||||
|
||||
// Image URL content block
|
||||
ImageURL *CohereImageURL `json:"image_url,omitempty"`
|
||||
|
||||
// Thinking content block (assistant only)
|
||||
Thinking *string `json:"thinking,omitempty"`
|
||||
|
||||
// Document content block (tool messages)
|
||||
Document *CohereDocument `json:"document,omitempty"`
|
||||
}
|
||||
|
||||
// CohereImageURL represents an image URL content block
|
||||
type CohereImageURL struct {
|
||||
URL string `json:"url"` // Required: Image URL
|
||||
}
|
||||
|
||||
// CohereDocument represents a document content block
|
||||
type CohereDocument struct {
|
||||
Data schemas.OrderedMap `json:"data"` // Required: Document data as key-value pairs
|
||||
ID *string `json:"id,omitempty"` // Optional: Document ID for citations
|
||||
}
|
||||
|
||||
// CohereThinking represents reasoning configuration
|
||||
type CohereThinking struct {
|
||||
Type CohereThinkingType `json:"type"` // Required: Reasoning type (enabled, disabled)
|
||||
TokenBudget *int `json:"token_budget,omitempty"` // Optional: Maximum thinking tokens (>=1)
|
||||
}
|
||||
|
||||
// CohereThinkingType represents the type of reasoning
|
||||
type CohereThinkingType string
|
||||
|
||||
const (
|
||||
ThinkingTypeEnabled CohereThinkingType = "enabled"
|
||||
ThinkingTypeDisabled CohereThinkingType = "disabled"
|
||||
)
|
||||
|
||||
// CohereResponseFormat represents the response format configuration for Cohere chat requests
|
||||
type CohereResponseFormat struct {
|
||||
Type CohereResponseFormatType `json:"type"` // Required: Response format type
|
||||
JSONSchema *interface{} `json:"schema,omitempty"` // Optional: JSON schema for structured output (not used when type is "text")
|
||||
}
|
||||
|
||||
// CohereResponseFormatType represents the type of response format
|
||||
type CohereResponseFormatType string
|
||||
|
||||
const (
|
||||
ResponseFormatTypeText CohereResponseFormatType = "text"
|
||||
ResponseFormatTypeJSONObject CohereResponseFormatType = "json_object"
|
||||
)
|
||||
|
||||
// CohereToolChoice represents tool choice configuration
|
||||
type CohereToolChoice string
|
||||
|
||||
const (
|
||||
ToolChoiceRequired CohereToolChoice = "REQUIRED"
|
||||
ToolChoiceNone CohereToolChoice = "NONE"
|
||||
ToolChoiceAuto CohereToolChoice = "AUTO"
|
||||
)
|
||||
|
||||
// CohereToolCall represents a tool call in Cohere format
|
||||
type CohereToolCall struct {
|
||||
ID *string `json:"id,omitempty"` // Optional: Tool call ID
|
||||
Type string `json:"type"` // Required: Tool call type (must be "function")
|
||||
Function *CohereFunction `json:"function"` // Required: Function call details
|
||||
}
|
||||
|
||||
// CohereFunction represents a function call
|
||||
type CohereFunction struct {
|
||||
Name *string `json:"name,omitempty"` // Optional: Function name
|
||||
Arguments string `json:"arguments,omitempty"` // Optional: Function arguments (JSON string)
|
||||
}
|
||||
|
||||
// CohereParameterDefinition represents a parameter definition for a Cohere tool.
|
||||
// It defines the type, description, and whether the parameter is required.
|
||||
type CohereParameterDefinition struct {
|
||||
Type string `json:"type"` // Type of the parameter
|
||||
Description *string `json:"description,omitempty"` // Optional description of the parameter
|
||||
Required bool `json:"required"` // Whether the parameter is required
|
||||
}
|
||||
|
||||
// CohereTool represents a tool definition for the Cohere API.
|
||||
// It includes the tool's name, description, and parameter definitions.
|
||||
type CohereTool struct {
|
||||
Name string `json:"name"` // Name of the tool
|
||||
Description string `json:"description"` // Description of the tool
|
||||
ParameterDefinitions map[string]CohereParameterDefinition `json:"parameter_definitions"` // Definitions of the tool's parameters
|
||||
}
|
||||
|
||||
// CohereCountTokensRequest represents a Cohere tokenize request
|
||||
type CohereCountTokensRequest struct {
|
||||
Model string `json:"model"` // Required: Model whose tokenizer should be used
|
||||
Text string `json:"text"` // Required: Text to tokenize (1-65536 chars)
|
||||
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
|
||||
}
|
||||
|
||||
func (r *CohereCountTokensRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
// CohereEmbeddingRequest represents a Cohere embedding request
|
||||
type CohereEmbeddingRequest struct {
|
||||
Model string `json:"model"` // Required: ID of embedding model
|
||||
InputType string `json:"input_type"` // Required: Type of input for v3+ models
|
||||
Texts []string `json:"texts,omitempty"` // Optional: Array of strings to embed (max 96)
|
||||
Images []string `json:"images,omitempty"` // Optional: Array of image data URIs (max 1)
|
||||
Inputs []CohereEmbeddingInput `json:"inputs,omitempty"` // Optional: Array of mixed text/image inputs (max 96)
|
||||
MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Max tokens to embed per input
|
||||
OutputDimension *int `json:"output_dimension,omitempty"` // Optional: Embedding dimensions (256, 512, 1024, 1536)
|
||||
EmbeddingTypes []string `json:"embedding_types,omitempty"` // Optional: Types of embeddings to return
|
||||
Truncate *string `json:"truncate,omitempty"` // Optional: How to handle long inputs
|
||||
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
|
||||
}
|
||||
|
||||
func (r *CohereEmbeddingRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
// CohereEmbeddingInput represents a mixed text/image input
|
||||
type CohereEmbeddingInput struct {
|
||||
Content []CohereContentBlock `json:"content"` // Required: Array of content blocks (reuses chat content blocks)
|
||||
}
|
||||
|
||||
// CohereEmbeddingResponse represents a Cohere embedding response
|
||||
type CohereEmbeddingResponse struct {
|
||||
ID string `json:"id"` // Response ID
|
||||
Embeddings *CohereEmbeddingData `json:"embeddings,omitempty"` // Embedding data object
|
||||
ResponseType *string `json:"response_type,omitempty"` // Response type (embeddings_floats, embeddings_by_type)
|
||||
Texts []string `json:"texts,omitempty"` // Original text entries
|
||||
Images []CohereEmbeddingImageInfo `json:"images,omitempty"` // Original image entries
|
||||
Meta *CohereEmbeddingMeta `json:"meta,omitempty"` // Response metadata
|
||||
}
|
||||
|
||||
// CohereEmbeddingData represents the embeddings object with different types
|
||||
type CohereEmbeddingData struct {
|
||||
Float [][]float64 `json:"float,omitempty"` // Float embeddings
|
||||
Int8 [][]int8 `json:"int8,omitempty"` // Int8 embeddings
|
||||
Uint8 [][]uint8 `json:"uint8,omitempty"` // Uint8 embeddings
|
||||
Binary [][]int8 `json:"binary,omitempty"` // Binary embeddings
|
||||
Ubinary [][]uint8 `json:"ubinary,omitempty"` // Unsigned binary embeddings
|
||||
Base64 []string `json:"base64,omitempty"` // Base64 embeddings
|
||||
}
|
||||
|
||||
// CohereEmbeddingImageInfo represents image information in the response
|
||||
type CohereEmbeddingImageInfo struct {
|
||||
Width int64 `json:"width"` // Width in pixels
|
||||
Height int64 `json:"height"` // Height in pixels
|
||||
Format string `json:"format"` // Image format
|
||||
BitDepth int64 `json:"bit_depth"` // Bit depth
|
||||
}
|
||||
|
||||
// CohereEmbeddingMeta represents metadata in embedding response
|
||||
type CohereEmbeddingMeta struct {
|
||||
APIVersion *CohereEmbeddingAPIVersion `json:"api_version,omitempty"` // API version info
|
||||
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"` // Billing information
|
||||
Tokens *CohereTokenUsage `json:"tokens,omitempty"` // Token usage
|
||||
Warnings []string `json:"warnings,omitempty"` // Any warnings
|
||||
}
|
||||
|
||||
// CohereEmbeddingAPIVersion represents API version information
|
||||
type CohereEmbeddingAPIVersion struct {
|
||||
Version *string `json:"version,omitempty"` // API version
|
||||
IsDeprecated *bool `json:"is_deprecated,omitempty"` // Deprecation status
|
||||
IsExperimental *bool `json:"is_experimental,omitempty"` // Experimental status
|
||||
}
|
||||
|
||||
// ==================== RESPONSE TYPES ====================
|
||||
|
||||
// CohereCountTokensResponse represents the response from the tokenize endpoint
|
||||
type CohereCountTokensResponse struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
TokenStrings []string `json:"token_strings,omitempty"`
|
||||
Meta *CohereTokenizeMeta `json:"meta,omitempty"`
|
||||
}
|
||||
|
||||
// CohereTokenizeMeta captures metadata returned by the tokenize endpoint
|
||||
type CohereTokenizeMeta struct {
|
||||
APIVersion *CohereTokenizeAPIVersion `json:"api_version,omitempty"`
|
||||
}
|
||||
|
||||
// CohereTokenizeAPIVersion describes API version metadata
|
||||
type CohereTokenizeAPIVersion struct {
|
||||
Version *string `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
// CohereChatResponse represents a Cohere chat completion response
|
||||
type CohereChatResponse struct {
|
||||
ID string `json:"id"` // Unique identifier for the generated reply
|
||||
FinishReason *CohereFinishReason `json:"finish_reason,omitempty"` // Reason for completion
|
||||
Message *CohereMessage `json:"message,omitempty"` // Generated message from assistant
|
||||
Usage *CohereUsage `json:"usage,omitempty"` // Token usage information
|
||||
LogProbs []CohereLogProb `json:"logprobs,omitempty"` // Log probabilities (if requested)
|
||||
}
|
||||
|
||||
// CohereFinishReason represents the reason a chat request has finished
|
||||
type CohereFinishReason string
|
||||
|
||||
const (
|
||||
FinishReasonComplete CohereFinishReason = "COMPLETE" // Model finished sending complete message
|
||||
FinishReasonStopSequence CohereFinishReason = "STOP_SEQUENCE" // Stop sequence was reached
|
||||
FinishReasonMaxTokens CohereFinishReason = "MAX_TOKENS" // Max tokens exceeded
|
||||
FinishReasonToolCall CohereFinishReason = "TOOL_CALL" // Model generated tool call
|
||||
FinishReasonError CohereFinishReason = "ERROR" // Generation failed due to internal error
|
||||
FinishReasonTimeout CohereFinishReason = "TIMEOUT" // Timeout
|
||||
)
|
||||
|
||||
// CohereUsage represents token usage information
|
||||
type CohereUsage struct {
|
||||
BilledUnits *CohereBilledUnits `json:"billed_units,omitempty"` // Billed usage information
|
||||
Tokens *CohereTokenUsage `json:"tokens,omitempty"` // Token usage details
|
||||
CachedTokens *int `json:"cached_tokens,omitempty"` // Cached tokens
|
||||
}
|
||||
|
||||
// CohereBilledUnits represents billed usage information
|
||||
type CohereBilledUnits struct {
|
||||
InputTokens *int `json:"input_tokens,omitempty"` // Number of billed input tokens
|
||||
OutputTokens *int `json:"output_tokens,omitempty"` // Number of billed output tokens
|
||||
SearchUnits *int `json:"search_units,omitempty"` // Number of billed search units
|
||||
Classifications *int `json:"classifications,omitempty"` // Number of billed classification units
|
||||
}
|
||||
|
||||
// CohereTokenUsage represents detailed token usage
|
||||
type CohereTokenUsage struct {
|
||||
InputTokens *int `json:"input_tokens"` // Number of input tokens used
|
||||
OutputTokens *int `json:"output_tokens"` // Number of output tokens produced
|
||||
}
|
||||
|
||||
// CohereLogProb represents log probability information
|
||||
type CohereLogProb struct {
|
||||
TokenIDs []int `json:"token_ids"` // Token IDs of each token in text chunk
|
||||
Text *string `json:"text,omitempty"` // Text chunk for log probabilities
|
||||
LogProbs []float64 `json:"logprobs,omitempty"` // Log probability of each token
|
||||
}
|
||||
|
||||
type CohereCitationType string
|
||||
|
||||
const (
|
||||
CitationTypeTextContent CohereCitationType = "TEXT_CONTENT"
|
||||
CitationTypeThinkingContent CohereCitationType = "THINKING_CONTENT"
|
||||
CitationTypePlan CohereCitationType = "PLAN"
|
||||
)
|
||||
|
||||
type CohereSourceType string
|
||||
|
||||
const (
|
||||
SourceTypeTool CohereSourceType = "tool"
|
||||
SourceTypeDocument CohereSourceType = "document"
|
||||
)
|
||||
|
||||
// CohereCitation represents a citation in the response
|
||||
type CohereCitation struct {
|
||||
Start int `json:"start"` // Start position of cited text
|
||||
End int `json:"end"` // End position of cited text
|
||||
Text string `json:"text"` // Cited text
|
||||
Sources []CohereSource `json:"sources,omitempty"` // Citation sources
|
||||
ContentIndex int `json:"content_index"` // Content index of the citation
|
||||
Type CohereCitationType `json:"type"` // Type of citation
|
||||
}
|
||||
|
||||
// CohereSource represents a citation source
|
||||
type CohereSource struct {
|
||||
Type CohereSourceType `json:"type"` // Source type ("tool" or "document")
|
||||
ID *string `json:"id,omitempty"` // Source ID (nullable)
|
||||
ToolOutput *json.RawMessage `json:"tool_output,omitempty"` // Tool output (for tool sources, json.RawMessage preserves key ordering)
|
||||
Document *json.RawMessage `json:"document,omitempty"` // Document data (for document sources, json.RawMessage preserves key ordering)
|
||||
}
|
||||
|
||||
// ==================== STREAMING TYPES ====================
|
||||
|
||||
// CohereStreamEventType represents the type of streaming event
|
||||
type CohereStreamEventType string
|
||||
|
||||
const (
|
||||
StreamEventMessageStart CohereStreamEventType = "message-start"
|
||||
StreamEventContentStart CohereStreamEventType = "content-start"
|
||||
StreamEventContentDelta CohereStreamEventType = "content-delta"
|
||||
StreamEventContentEnd CohereStreamEventType = "content-end"
|
||||
StreamEventToolPlanDelta CohereStreamEventType = "tool-plan-delta"
|
||||
StreamEventToolCallStart CohereStreamEventType = "tool-call-start"
|
||||
StreamEventToolCallDelta CohereStreamEventType = "tool-call-delta"
|
||||
StreamEventToolCallEnd CohereStreamEventType = "tool-call-end"
|
||||
StreamEventCitationStart CohereStreamEventType = "citation-start"
|
||||
StreamEventCitationEnd CohereStreamEventType = "citation-end"
|
||||
StreamEventMessageEnd CohereStreamEventType = "message-end"
|
||||
StreamEventDebug CohereStreamEventType = "debug"
|
||||
)
|
||||
|
||||
// CohereStreamEvent represents a unified streaming event from Cohere API
|
||||
type CohereStreamEvent struct {
|
||||
Type CohereStreamEventType `json:"type"`
|
||||
ID *string `json:"id,omitempty"` // For message-start
|
||||
Index *int `json:"index,omitempty"` // For indexed events
|
||||
Delta *CohereStreamDelta `json:"delta,omitempty"`
|
||||
}
|
||||
|
||||
// CohereStreamDelta represents the delta content in streaming events
|
||||
type CohereStreamDelta struct {
|
||||
Message *CohereStreamMessage `json:"message,omitempty"`
|
||||
FinishReason *CohereFinishReason `json:"finish_reason,omitempty"`
|
||||
Usage *CohereUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type CohereStreamToolCallStruct struct {
|
||||
CohereToolCallObject *CohereToolCall
|
||||
CohereToolCallArray []CohereToolCall
|
||||
}
|
||||
|
||||
// JSON marshaling for CohereStreamToolCall
|
||||
func (c *CohereStreamToolCallStruct) MarshalJSON() ([]byte, error) {
|
||||
if c.CohereToolCallObject != nil {
|
||||
return providerUtils.MarshalSorted(c.CohereToolCallObject)
|
||||
}
|
||||
if c.CohereToolCallArray != nil {
|
||||
return providerUtils.MarshalSorted(c.CohereToolCallArray)
|
||||
}
|
||||
return providerUtils.MarshalSorted(nil)
|
||||
}
|
||||
|
||||
func (c *CohereStreamToolCallStruct) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
// Try to unmarshal as array first
|
||||
var toolCallArray []CohereToolCall
|
||||
if err := sonic.Unmarshal(data, &toolCallArray); err == nil {
|
||||
c.CohereToolCallArray = toolCallArray
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as single object
|
||||
var toolCallObject CohereToolCall
|
||||
if err := sonic.Unmarshal(data, &toolCallObject); err == nil {
|
||||
c.CohereToolCallObject = &toolCallObject
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("tool_calls field is neither array nor object")
|
||||
}
|
||||
|
||||
type CohereStreamContentStruct struct {
|
||||
CohereStreamContentObject *CohereStreamContent
|
||||
CohereStreamContentArray []CohereStreamContent
|
||||
}
|
||||
|
||||
func (c *CohereStreamContentStruct) MarshalJSON() ([]byte, error) {
|
||||
if c.CohereStreamContentObject != nil {
|
||||
return providerUtils.MarshalSorted(c.CohereStreamContentObject)
|
||||
}
|
||||
if c.CohereStreamContentArray != nil {
|
||||
return providerUtils.MarshalSorted(c.CohereStreamContentArray)
|
||||
}
|
||||
return providerUtils.MarshalSorted(nil)
|
||||
}
|
||||
|
||||
func (c *CohereStreamContentStruct) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
// Try to unmarshal as array first
|
||||
var contentArray []CohereStreamContent
|
||||
if err := sonic.Unmarshal(data, &contentArray); err == nil {
|
||||
c.CohereStreamContentArray = contentArray
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as single object
|
||||
var contentObject CohereStreamContent
|
||||
if err := sonic.Unmarshal(data, &contentObject); err == nil {
|
||||
c.CohereStreamContentObject = &contentObject
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("content field is neither array nor object")
|
||||
}
|
||||
|
||||
type CohereStreamCitationStruct struct {
|
||||
CohereStreamCitationObject *CohereCitation
|
||||
CohereStreamCitationArray []CohereCitation
|
||||
}
|
||||
|
||||
func (c *CohereStreamCitationStruct) MarshalJSON() ([]byte, error) {
|
||||
if c.CohereStreamCitationObject != nil {
|
||||
return providerUtils.MarshalSorted(c.CohereStreamCitationObject)
|
||||
}
|
||||
if c.CohereStreamCitationArray != nil {
|
||||
return providerUtils.MarshalSorted(c.CohereStreamCitationArray)
|
||||
}
|
||||
return providerUtils.MarshalSorted(nil)
|
||||
}
|
||||
|
||||
func (c *CohereStreamCitationStruct) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
// Try to unmarshal as array first
|
||||
var citationArray []CohereCitation
|
||||
if err := sonic.Unmarshal(data, &citationArray); err == nil {
|
||||
c.CohereStreamCitationArray = citationArray
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as single object
|
||||
var citationObject CohereCitation
|
||||
if err := sonic.Unmarshal(data, &citationObject); err == nil {
|
||||
c.CohereStreamCitationObject = &citationObject
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("citations field is neither array nor object")
|
||||
}
|
||||
|
||||
// CohereStreamMessage represents the message part of streaming deltas
|
||||
type CohereStreamMessage struct {
|
||||
Role *string `json:"role,omitempty"` // For message-start
|
||||
Content *CohereStreamContentStruct `json:"content,omitempty"` // For content events (object)
|
||||
ToolPlan *string `json:"tool_plan,omitempty"` // For tool-plan-delta
|
||||
ToolCalls *CohereStreamToolCallStruct `json:"tool_calls,omitempty"` // For tool-call events (flexible)
|
||||
Citations *CohereStreamCitationStruct `json:"citations,omitempty"` // For citation events
|
||||
}
|
||||
|
||||
// CohereStreamContent represents content in streaming events
|
||||
type CohereStreamContent struct {
|
||||
Type CohereContentBlockType `json:"type,omitempty"` // For content-start
|
||||
Text *string `json:"text,omitempty"` // For content deltas
|
||||
Thinking *string `json:"thinking,omitempty"` // For thinking deltas
|
||||
}
|
||||
|
||||
// ==================== ERROR TYPES ====================
|
||||
|
||||
// CohereError represents an error response from the Cohere API
|
||||
type CohereError struct {
|
||||
Type string `json:"type"` // Error type
|
||||
Message string `json:"message"` // Error message
|
||||
Code *string `json:"code,omitempty"` // Optional error code
|
||||
}
|
||||
|
||||
// ==================== MODEL TYPES ====================
|
||||
|
||||
type CohereModel struct {
|
||||
Name string `json:"name"`
|
||||
IsDeprecated bool `json:"is_deprecated"`
|
||||
Endpoints []string `json:"endpoints"`
|
||||
Finetuned bool `json:"finetuned"`
|
||||
ContextLength int `json:"context_length"`
|
||||
TokenizerURL string `json:"tokenizer_url"`
|
||||
DefaultEndpoints []string `json:"default_endpoints"`
|
||||
Features []string `json:"features"`
|
||||
}
|
||||
|
||||
type CohereListModelsResponse struct {
|
||||
Models []CohereModel `json:"models"`
|
||||
NextPageToken string `json:"next_page_token"`
|
||||
}
|
||||
293
core/providers/cohere/utils.go
Normal file
293
core/providers/cohere/utils.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
var (
|
||||
// Maps provider-specific finish reasons to Bifrost format
|
||||
cohereFinishReasonToBifrost = map[CohereFinishReason]string{
|
||||
FinishReasonComplete: "stop",
|
||||
FinishReasonStopSequence: "stop",
|
||||
FinishReasonMaxTokens: "length",
|
||||
FinishReasonToolCall: "tool_calls",
|
||||
}
|
||||
)
|
||||
|
||||
// ConvertCohereFinishReasonToBifrost converts provider finish reasons to Bifrost format
|
||||
func ConvertCohereFinishReasonToBifrost(providerReason CohereFinishReason) string {
|
||||
if bifrostReason, ok := cohereFinishReasonToBifrost[providerReason]; ok {
|
||||
return bifrostReason
|
||||
}
|
||||
return string(providerReason)
|
||||
}
|
||||
|
||||
// convertInterfaceToToolFunctionParameters converts an interface{} to ToolFunctionParameters
|
||||
// This handles the conversion from Cohere's flexible parameter format to Bifrost's structured format
|
||||
func convertInterfaceToToolFunctionParameters(params interface{}) *schemas.ToolFunctionParameters {
|
||||
if params == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to convert from map[string]interface{}
|
||||
paramsMap, ok := params.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := &schemas.ToolFunctionParameters{}
|
||||
|
||||
// Extract type
|
||||
if typeVal, ok := paramsMap["type"].(string); ok {
|
||||
result.Type = typeVal
|
||||
}
|
||||
|
||||
// Extract description
|
||||
if descVal, ok := paramsMap["description"].(string); ok {
|
||||
result.Description = &descVal
|
||||
}
|
||||
|
||||
// Extract required
|
||||
if requiredVal, ok := paramsMap["required"].([]interface{}); ok {
|
||||
required := make([]string, 0, len(requiredVal))
|
||||
for _, v := range requiredVal {
|
||||
if s, ok := v.(string); ok {
|
||||
required = append(required, s)
|
||||
}
|
||||
}
|
||||
result.Required = required
|
||||
}
|
||||
|
||||
// Extract properties
|
||||
if orderedProps, ok := schemas.SafeExtractOrderedMap(paramsMap["properties"]); ok {
|
||||
result.Properties = orderedProps
|
||||
}
|
||||
|
||||
// Extract enum
|
||||
if enumVal, ok := paramsMap["enum"].([]interface{}); ok {
|
||||
enum := make([]string, 0, len(enumVal))
|
||||
for _, v := range enumVal {
|
||||
if s, ok := v.(string); ok {
|
||||
enum = append(enum, s)
|
||||
}
|
||||
}
|
||||
result.Enum = enum
|
||||
}
|
||||
|
||||
// Extract additionalProperties
|
||||
if addPropsVal, ok := paramsMap["additionalProperties"].(bool); ok {
|
||||
result.AdditionalProperties = &schemas.AdditionalPropertiesStruct{
|
||||
AdditionalPropertiesBool: &addPropsVal,
|
||||
}
|
||||
}
|
||||
|
||||
if addPropsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["additionalProperties"]); ok {
|
||||
result.AdditionalProperties = &schemas.AdditionalPropertiesStruct{
|
||||
AdditionalPropertiesMap: addPropsVal,
|
||||
}
|
||||
}
|
||||
|
||||
// Extract $defs (JSON Schema draft 2019-09+)
|
||||
if defsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["$defs"]); ok {
|
||||
result.Defs = defsVal
|
||||
}
|
||||
|
||||
// Extract definitions (legacy JSON Schema draft-07)
|
||||
if defsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["definitions"]); ok {
|
||||
result.Definitions = defsVal
|
||||
}
|
||||
|
||||
// Extract $ref
|
||||
if refVal, ok := paramsMap["$ref"].(string); ok {
|
||||
result.Ref = &refVal
|
||||
}
|
||||
|
||||
// Extract items (array element schema)
|
||||
if itemsVal, ok := schemas.SafeExtractOrderedMap(paramsMap["items"]); ok {
|
||||
result.Items = itemsVal
|
||||
}
|
||||
|
||||
// Extract minItems
|
||||
if minItemsVal, ok := extractInt64(paramsMap["minItems"]); ok {
|
||||
result.MinItems = &minItemsVal
|
||||
}
|
||||
|
||||
// Extract maxItems
|
||||
if maxItemsVal, ok := extractInt64(paramsMap["maxItems"]); ok {
|
||||
result.MaxItems = &maxItemsVal
|
||||
}
|
||||
|
||||
// Extract anyOf
|
||||
if anyOfVal, ok := paramsMap["anyOf"].([]interface{}); ok {
|
||||
anyOf := make([]schemas.OrderedMap, 0, len(anyOfVal))
|
||||
for _, v := range anyOfVal {
|
||||
if m, ok := schemas.SafeExtractOrderedMap(v); ok {
|
||||
anyOf = append(anyOf, *m)
|
||||
}
|
||||
}
|
||||
result.AnyOf = anyOf
|
||||
}
|
||||
|
||||
// Extract oneOf
|
||||
if oneOfVal, ok := paramsMap["oneOf"].([]interface{}); ok {
|
||||
oneOf := make([]schemas.OrderedMap, 0, len(oneOfVal))
|
||||
for _, v := range oneOfVal {
|
||||
if m, ok := schemas.SafeExtractOrderedMap(v); ok {
|
||||
oneOf = append(oneOf, *m)
|
||||
}
|
||||
}
|
||||
result.OneOf = oneOf
|
||||
}
|
||||
|
||||
// Extract allOf
|
||||
if allOfVal, ok := paramsMap["allOf"].([]interface{}); ok {
|
||||
allOf := make([]schemas.OrderedMap, 0, len(allOfVal))
|
||||
for _, v := range allOfVal {
|
||||
if m, ok := schemas.SafeExtractOrderedMap(v); ok {
|
||||
allOf = append(allOf, *m)
|
||||
}
|
||||
}
|
||||
result.AllOf = allOf
|
||||
}
|
||||
|
||||
// Extract format
|
||||
if formatVal, ok := paramsMap["format"].(string); ok {
|
||||
result.Format = &formatVal
|
||||
}
|
||||
|
||||
// Extract pattern
|
||||
if patternVal, ok := paramsMap["pattern"].(string); ok {
|
||||
result.Pattern = &patternVal
|
||||
}
|
||||
|
||||
// Extract minLength
|
||||
if minLengthVal, ok := extractInt64(paramsMap["minLength"]); ok {
|
||||
result.MinLength = &minLengthVal
|
||||
}
|
||||
|
||||
// Extract maxLength
|
||||
if maxLengthVal, ok := extractInt64(paramsMap["maxLength"]); ok {
|
||||
result.MaxLength = &maxLengthVal
|
||||
}
|
||||
|
||||
// Extract minimum
|
||||
if minVal, ok := extractFloat64(paramsMap["minimum"]); ok {
|
||||
result.Minimum = &minVal
|
||||
}
|
||||
|
||||
// Extract maximum
|
||||
if maxVal, ok := extractFloat64(paramsMap["maximum"]); ok {
|
||||
result.Maximum = &maxVal
|
||||
}
|
||||
|
||||
// Extract title
|
||||
if titleVal, ok := paramsMap["title"].(string); ok {
|
||||
result.Title = &titleVal
|
||||
}
|
||||
|
||||
// Extract default
|
||||
if defaultVal, exists := paramsMap["default"]; exists {
|
||||
result.Default = defaultVal
|
||||
}
|
||||
|
||||
// Extract nullable
|
||||
if nullableVal, ok := paramsMap["nullable"].(bool); ok {
|
||||
result.Nullable = &nullableVal
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// extractInt64 extracts an int64 from various numeric types
|
||||
func extractInt64(v interface{}) (int64, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return int64(val), true
|
||||
case int64:
|
||||
return val, true
|
||||
case float64:
|
||||
return int64(val), true
|
||||
case float32:
|
||||
return int64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// extractFloat64 extracts a float64 from various numeric types
|
||||
func extractFloat64(v interface{}) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertResponseFormatToCohere converts OpenAI-style response_format (interface{}) to Cohere's typed format
|
||||
// Input can be a map with structure: { type: "json_schema", json_schema: { schema: {...} } }
|
||||
// Output: CohereResponseFormat with flat structure: { type: "json_object", json_schema: {...} }
|
||||
func convertResponseFormatToCohere(responseFormat *interface{}) *CohereResponseFormat {
|
||||
if responseFormat == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to extract as map
|
||||
formatMap, ok := (*responseFormat).(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
cohereFormat := &CohereResponseFormat{}
|
||||
|
||||
// Extract type
|
||||
typeStr, _ := formatMap["type"].(string)
|
||||
switch typeStr {
|
||||
case "text":
|
||||
cohereFormat.Type = ResponseFormatTypeText
|
||||
case "json_object", "json_schema":
|
||||
cohereFormat.Type = ResponseFormatTypeJSONObject
|
||||
|
||||
// Extract the nested schema
|
||||
// OpenAI format: { type: "json_schema", json_schema: { name: "X", strict: true, schema: {...} } }
|
||||
if jsonSchemaWrapper, ok := formatMap["json_schema"].(map[string]interface{}); ok {
|
||||
if schema, ok := jsonSchemaWrapper["schema"].(map[string]interface{}); ok {
|
||||
var schemaInterface interface{} = schema
|
||||
cohereFormat.JSONSchema = &schemaInterface
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return cohereFormat
|
||||
}
|
||||
|
||||
// convertCohereResponseFormatToBifrost converts Cohere's typed response_format back to interface{}
|
||||
func convertCohereResponseFormatToBifrost(cohereFormat *CohereResponseFormat) *interface{} {
|
||||
if cohereFormat == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build JSON bytes with deterministic key order using sjson
|
||||
data := []byte(`{}`)
|
||||
if cohereFormat.JSONSchema != nil {
|
||||
data, _ = sjson.SetBytes(data, "type", "json_schema")
|
||||
schemaBytes, _ := schemas.MarshalSorted(cohereFormat.JSONSchema)
|
||||
data, _ = sjson.SetRawBytes(data, "json_schema", schemaBytes)
|
||||
} else {
|
||||
data, _ = sjson.SetBytes(data, "type", string(cohereFormat.Type))
|
||||
}
|
||||
|
||||
var resultInterface interface{} = json.RawMessage(data)
|
||||
return &resultInterface
|
||||
}
|
||||
Reference in New Issue
Block a user