723 lines
22 KiB
Go
723 lines
22 KiB
Go
package cohere
|
|
|
|
import (
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/maximhq/bifrost/core/providers/anthropic"
|
|
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// ToCohereChatCompletionRequest converts a Bifrost request to Cohere v2 format
|
|
func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*CohereChatRequest, error) {
|
|
if bifrostReq == nil || bifrostReq.Input == nil {
|
|
return nil, fmt.Errorf("bifrost request is nil")
|
|
}
|
|
|
|
messages := bifrostReq.Input
|
|
cohereReq := &CohereChatRequest{
|
|
Model: bifrostReq.Model,
|
|
}
|
|
|
|
// Convert messages to Cohere v2 format
|
|
var cohereMessages []CohereMessage
|
|
for _, msg := range messages {
|
|
cohereMsg := CohereMessage{
|
|
Role: string(msg.Role),
|
|
}
|
|
|
|
// Convert content
|
|
if msg.Content != nil && msg.Content.ContentStr != nil {
|
|
cohereMsg.Content = NewStringContent(*msg.Content.ContentStr)
|
|
} else if msg.Content != nil && msg.Content.ContentBlocks != nil {
|
|
var contentBlocks []CohereContentBlock
|
|
for _, block := range msg.Content.ContentBlocks {
|
|
if block.Text != nil {
|
|
contentBlocks = append(contentBlocks, CohereContentBlock{
|
|
Type: CohereContentBlockTypeText,
|
|
Text: block.Text,
|
|
})
|
|
} else if block.ImageURLStruct != nil {
|
|
contentBlocks = append(contentBlocks, CohereContentBlock{
|
|
Type: CohereContentBlockTypeImage,
|
|
ImageURL: &CohereImageURL{
|
|
URL: block.ImageURLStruct.URL,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
if len(contentBlocks) > 0 {
|
|
cohereMsg.Content = NewBlocksContent(contentBlocks)
|
|
}
|
|
}
|
|
|
|
// Convert tool calls for assistant messages
|
|
if msg.ChatAssistantMessage != nil && msg.ChatAssistantMessage.ToolCalls != nil {
|
|
var toolCalls []CohereToolCall
|
|
for _, toolCall := range msg.ChatAssistantMessage.ToolCalls {
|
|
// Safely extract function name and arguments
|
|
var functionName *string
|
|
var functionArguments string
|
|
|
|
if toolCall.Function.Name != nil {
|
|
functionName = toolCall.Function.Name
|
|
} else {
|
|
// Use empty string if Name is nil
|
|
functionName = schemas.Ptr("")
|
|
}
|
|
|
|
// Arguments is a string, not a pointer, so it's safe to access directly
|
|
// Default to "{}" if empty to ensure the field is always present.
|
|
if toolCall.Function.Arguments == "" {
|
|
functionArguments = "{}"
|
|
} else {
|
|
functionArguments = toolCall.Function.Arguments
|
|
}
|
|
|
|
cohereToolCall := CohereToolCall{
|
|
ID: toolCall.ID,
|
|
Type: "function",
|
|
Function: &CohereFunction{
|
|
Name: functionName,
|
|
Arguments: functionArguments,
|
|
},
|
|
}
|
|
toolCalls = append(toolCalls, cohereToolCall)
|
|
}
|
|
cohereMsg.ToolCalls = toolCalls
|
|
}
|
|
|
|
// Convert tool messages
|
|
if msg.ChatToolMessage != nil && msg.ChatToolMessage.ToolCallID != nil {
|
|
cohereMsg.ToolCallID = msg.ChatToolMessage.ToolCallID
|
|
}
|
|
|
|
cohereMessages = append(cohereMessages, cohereMsg)
|
|
}
|
|
|
|
cohereReq.Messages = cohereMessages
|
|
|
|
// Convert parameters
|
|
if bifrostReq.Params != nil {
|
|
cohereReq.MaxTokens = bifrostReq.Params.MaxCompletionTokens
|
|
cohereReq.Temperature = bifrostReq.Params.Temperature
|
|
cohereReq.P = bifrostReq.Params.TopP
|
|
cohereReq.StopSequences = bifrostReq.Params.Stop
|
|
cohereReq.FrequencyPenalty = bifrostReq.Params.FrequencyPenalty
|
|
cohereReq.PresencePenalty = bifrostReq.Params.PresencePenalty
|
|
|
|
// Convert reasoning
|
|
if bifrostReq.Params.Reasoning != nil {
|
|
if bifrostReq.Params.Reasoning.MaxTokens != nil {
|
|
thinking := &CohereThinking{
|
|
Type: ThinkingTypeEnabled,
|
|
}
|
|
if *bifrostReq.Params.Reasoning.MaxTokens == -1 {
|
|
// cohere does not support dynamic reasoning budget like gemini
|
|
// setting it to minimum reasoning budget
|
|
thinking.TokenBudget = schemas.Ptr(anthropic.MinimumReasoningMaxTokens)
|
|
} else {
|
|
thinking.TokenBudget = bifrostReq.Params.Reasoning.MaxTokens
|
|
}
|
|
cohereReq.Thinking = thinking
|
|
} else if bifrostReq.Params.Reasoning.Effort != nil {
|
|
if *bifrostReq.Params.Reasoning.Effort != "none" {
|
|
maxCompletionTokens := providerUtils.GetMaxOutputTokensOrDefault(bifrostReq.Model, DefaultCompletionMaxTokens)
|
|
if bifrostReq.Params.MaxCompletionTokens != nil {
|
|
maxCompletionTokens = *bifrostReq.Params.MaxCompletionTokens
|
|
}
|
|
budgetTokens, err := providerUtils.GetBudgetTokensFromReasoningEffort(*bifrostReq.Params.Reasoning.Effort, MinimumReasoningMaxTokens, maxCompletionTokens)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cohereReq.Thinking = &CohereThinking{
|
|
Type: ThinkingTypeEnabled,
|
|
TokenBudget: schemas.Ptr(budgetTokens), // Max tokens for reasoning
|
|
}
|
|
} else {
|
|
cohereReq.Thinking = &CohereThinking{
|
|
Type: ThinkingTypeDisabled,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert response format
|
|
if bifrostReq.Params.ResponseFormat != nil {
|
|
cohereReq.ResponseFormat = convertResponseFormatToCohere(bifrostReq.Params.ResponseFormat)
|
|
}
|
|
|
|
// Convert extra params
|
|
if bifrostReq.Params.ExtraParams != nil {
|
|
// Handle thinking parameter
|
|
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
|
|
if thinkingParam, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "thinking"); ok {
|
|
if thinkingMap, ok := thinkingParam.(map[string]interface{}); ok {
|
|
thinking := &CohereThinking{}
|
|
if typeStr, ok := schemas.SafeExtractString(thinkingMap["type"]); ok {
|
|
delete(thinkingMap, "type")
|
|
thinking.Type = CohereThinkingType(typeStr)
|
|
}
|
|
if tokenBudget, ok := schemas.SafeExtractIntPointer(thinkingMap["token_budget"]); ok {
|
|
delete(thinkingMap, "token_budget")
|
|
thinking.TokenBudget = tokenBudget
|
|
}
|
|
cohereReq.Thinking = thinking
|
|
cohereReq.ExtraParams["thinking"] = thinkingMap
|
|
}
|
|
}
|
|
|
|
// Handle other Cohere-specific extra params
|
|
if safetyMode, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["safety_mode"]); ok {
|
|
delete(cohereReq.ExtraParams, "safety_mode")
|
|
cohereReq.SafetyMode = safetyMode
|
|
}
|
|
|
|
if logProbs, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["log_probs"]); ok {
|
|
delete(cohereReq.ExtraParams, "log_probs")
|
|
cohereReq.LogProbs = logProbs
|
|
}
|
|
|
|
if strictToolChoice, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["strict_tool_choice"]); ok {
|
|
delete(cohereReq.ExtraParams, "strict_tool_choice")
|
|
cohereReq.StrictToolChoice = strictToolChoice
|
|
}
|
|
}
|
|
|
|
// Convert tools to Cohere-specific format (without "strict" field)
|
|
if bifrostReq.Params.Tools != nil {
|
|
cohereTools := make([]CohereChatRequestTool, len(bifrostReq.Params.Tools))
|
|
for i, tool := range bifrostReq.Params.Tools {
|
|
cohereTools[i] = CohereChatRequestTool{
|
|
Type: string(tool.Type),
|
|
}
|
|
if tool.Function != nil {
|
|
cohereTools[i].Function = CohereChatRequestFunction{
|
|
Name: tool.Function.Name,
|
|
Description: tool.Function.Description,
|
|
Parameters: tool.Function.Parameters, // Convert to map
|
|
// Note: No "strict" field - Cohere doesn't support it
|
|
}
|
|
}
|
|
}
|
|
cohereReq.Tools = cohereTools
|
|
}
|
|
|
|
// Convert tool choice
|
|
if bifrostReq.Params.ToolChoice != nil {
|
|
toolChoice := bifrostReq.Params.ToolChoice
|
|
|
|
if toolChoice.ChatToolChoiceStr != nil {
|
|
switch schemas.ChatToolChoiceType(*toolChoice.ChatToolChoiceStr) {
|
|
case schemas.ChatToolChoiceTypeNone:
|
|
toolChoice := ToolChoiceNone
|
|
cohereReq.ToolChoice = &toolChoice
|
|
default:
|
|
toolChoice := ToolChoiceRequired
|
|
cohereReq.ToolChoice = &toolChoice
|
|
}
|
|
} else if toolChoice.ChatToolChoiceStruct != nil {
|
|
switch toolChoice.ChatToolChoiceStruct.Type {
|
|
case schemas.ChatToolChoiceTypeFunction:
|
|
toolChoice := ToolChoiceRequired
|
|
cohereReq.ToolChoice = &toolChoice
|
|
default:
|
|
toolChoice := ToolChoiceAuto
|
|
cohereReq.ToolChoice = &toolChoice
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return cohereReq, nil
|
|
}
|
|
|
|
// ToBifrostChatRequest converts a Cohere v2 chat request to Bifrost format
|
|
func (req *CohereChatRequest) ToBifrostChatRequest(ctx *schemas.BifrostContext) *schemas.BifrostChatRequest {
|
|
if req == nil {
|
|
return nil
|
|
}
|
|
|
|
provider, model := schemas.ParseModelString(req.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
|
|
|
|
bifrostReq := &schemas.BifrostChatRequest{
|
|
Provider: provider,
|
|
Model: model,
|
|
Params: &schemas.ChatParameters{},
|
|
}
|
|
// Convert messages
|
|
if req.Messages != nil {
|
|
bifrostMessages := make([]schemas.ChatMessage, len(req.Messages))
|
|
for i, message := range req.Messages {
|
|
bifrostMessages[i] = *message.ToBifrostChatMessage()
|
|
}
|
|
bifrostReq.Input = bifrostMessages
|
|
}
|
|
// Convert parameters
|
|
if req.MaxTokens != nil {
|
|
bifrostReq.Params.MaxCompletionTokens = req.MaxTokens
|
|
}
|
|
if req.Temperature != nil {
|
|
bifrostReq.Params.Temperature = req.Temperature
|
|
}
|
|
if req.P != nil {
|
|
bifrostReq.Params.TopP = req.P
|
|
}
|
|
if req.StopSequences != nil {
|
|
bifrostReq.Params.Stop = req.StopSequences
|
|
}
|
|
if req.FrequencyPenalty != nil {
|
|
bifrostReq.Params.FrequencyPenalty = req.FrequencyPenalty
|
|
}
|
|
if req.PresencePenalty != nil {
|
|
bifrostReq.Params.PresencePenalty = req.PresencePenalty
|
|
}
|
|
|
|
// Convert reasoning
|
|
if req.Thinking != nil {
|
|
if req.Thinking.Type == ThinkingTypeDisabled {
|
|
bifrostReq.Params.Reasoning = &schemas.ChatReasoning{
|
|
Effort: schemas.Ptr("none"),
|
|
}
|
|
} else {
|
|
bifrostReq.Params.Reasoning = &schemas.ChatReasoning{
|
|
Effort: schemas.Ptr("auto"),
|
|
}
|
|
if req.Thinking.TokenBudget != nil {
|
|
bifrostReq.Params.Reasoning.MaxTokens = req.Thinking.TokenBudget
|
|
}
|
|
}
|
|
}
|
|
if req.ResponseFormat != nil {
|
|
bifrostReq.Params.ResponseFormat = convertCohereResponseFormatToBifrost(req.ResponseFormat)
|
|
}
|
|
|
|
// Convert tools
|
|
if req.Tools != nil {
|
|
bifrostTools := make([]schemas.ChatTool, len(req.Tools))
|
|
for i, tool := range req.Tools {
|
|
bifrostTools[i] = schemas.ChatTool{
|
|
Type: schemas.ChatToolTypeFunction,
|
|
Function: &schemas.ChatToolFunction{
|
|
Name: tool.Function.Name,
|
|
Description: tool.Function.Description,
|
|
Parameters: convertInterfaceToToolFunctionParameters(tool.Function.Parameters),
|
|
},
|
|
}
|
|
}
|
|
bifrostReq.Params.Tools = bifrostTools
|
|
}
|
|
|
|
// Convert tool choice
|
|
if req.ToolChoice != nil {
|
|
switch *req.ToolChoice {
|
|
case ToolChoiceNone:
|
|
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
|
|
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeNone)),
|
|
}
|
|
case ToolChoiceRequired:
|
|
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
|
|
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeRequired)),
|
|
}
|
|
case ToolChoiceAuto:
|
|
bifrostReq.Params.ToolChoice = &schemas.ChatToolChoice{
|
|
ChatToolChoiceStr: schemas.Ptr(string(schemas.ChatToolChoiceTypeAny)),
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert extra params
|
|
extraParams := make(map[string]interface{})
|
|
if req.SafetyMode != nil {
|
|
extraParams["safety_mode"] = *req.SafetyMode
|
|
}
|
|
if req.LogProbs != nil {
|
|
extraParams["log_probs"] = *req.LogProbs
|
|
}
|
|
if req.StrictToolChoice != nil {
|
|
extraParams["strict_tool_choice"] = *req.StrictToolChoice
|
|
}
|
|
if req.Thinking != nil {
|
|
thinkingMap := map[string]interface{}{
|
|
"type": string(req.Thinking.Type),
|
|
}
|
|
if req.Thinking.TokenBudget != nil {
|
|
thinkingMap["token_budget"] = *req.Thinking.TokenBudget
|
|
}
|
|
extraParams["thinking"] = thinkingMap
|
|
}
|
|
if len(extraParams) > 0 {
|
|
bifrostReq.Params.ExtraParams = extraParams
|
|
}
|
|
|
|
return bifrostReq
|
|
}
|
|
|
|
// ToBifrostChatResponse converts a Cohere v2 response to Bifrost format
|
|
func (response *CohereChatResponse) ToBifrostChatResponse(model string) *schemas.BifrostChatResponse {
|
|
if response == nil {
|
|
return nil
|
|
}
|
|
|
|
bifrostResponse := &schemas.BifrostChatResponse{
|
|
ID: response.ID,
|
|
Model: model,
|
|
Object: "chat.completion",
|
|
Choices: []schemas.BifrostResponseChoice{
|
|
{
|
|
Index: 0,
|
|
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{},
|
|
},
|
|
},
|
|
Created: int(time.Now().Unix()),
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
},
|
|
}
|
|
|
|
// Convert messages
|
|
if response.Message != nil {
|
|
bifrostMessage := response.Message.ToBifrostChatMessage()
|
|
bifrostResponse.Choices[0].ChatNonStreamResponseChoice.Message = bifrostMessage
|
|
}
|
|
|
|
// Convert finish reason
|
|
if response.FinishReason != nil {
|
|
finishReason := ConvertCohereFinishReasonToBifrost(*response.FinishReason)
|
|
bifrostResponse.Choices[0].FinishReason = schemas.Ptr(finishReason)
|
|
}
|
|
|
|
// Convert usage information
|
|
if response.Usage != nil {
|
|
usage := &schemas.BifrostLLMUsage{}
|
|
|
|
if response.Usage.Tokens != nil {
|
|
if response.Usage.Tokens.InputTokens != nil {
|
|
usage.PromptTokens = *response.Usage.Tokens.InputTokens
|
|
}
|
|
if response.Usage.Tokens.OutputTokens != nil {
|
|
usage.CompletionTokens = *response.Usage.Tokens.OutputTokens
|
|
}
|
|
if response.Usage.CachedTokens != nil {
|
|
usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{
|
|
CachedReadTokens: *response.Usage.CachedTokens,
|
|
}
|
|
}
|
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
}
|
|
|
|
bifrostResponse.Usage = usage
|
|
}
|
|
|
|
return bifrostResponse
|
|
}
|
|
|
|
func (chunk *CohereStreamEvent) ToBifrostChatCompletionStream() (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) {
|
|
switch chunk.Type {
|
|
case StreamEventMessageStart:
|
|
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.Role != nil {
|
|
// Create streaming response for this delta
|
|
streamResponse := &schemas.BifrostChatResponse{
|
|
Object: "chat.completion.chunk",
|
|
Choices: []schemas.BifrostResponseChoice{
|
|
{
|
|
Index: 0,
|
|
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
|
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
|
Role: chunk.Delta.Message.Role,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
return streamResponse, nil, false
|
|
}
|
|
|
|
case StreamEventContentDelta:
|
|
if chunk.Delta != nil &&
|
|
chunk.Delta.Message != nil &&
|
|
chunk.Delta.Message.Content != nil &&
|
|
chunk.Delta.Message.Content.CohereStreamContentObject != nil {
|
|
if chunk.Delta.Message.Content.CohereStreamContentObject.Text != nil {
|
|
// Try to cast content to CohereStreamContent
|
|
streamResponse := &schemas.BifrostChatResponse{
|
|
Object: "chat.completion.chunk",
|
|
Choices: []schemas.BifrostResponseChoice{
|
|
{
|
|
Index: 0,
|
|
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
|
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
|
Content: chunk.Delta.Message.Content.CohereStreamContentObject.Text,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
return streamResponse, nil, false
|
|
} else if chunk.Delta.Message.Content.CohereStreamContentObject.Thinking != nil {
|
|
thinkingText := *chunk.Delta.Message.Content.CohereStreamContentObject.Thinking
|
|
streamResponse := &schemas.BifrostChatResponse{
|
|
Object: "chat.completion.chunk",
|
|
Choices: []schemas.BifrostResponseChoice{
|
|
{
|
|
Index: 0,
|
|
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
|
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
|
Reasoning: schemas.Ptr(thinkingText),
|
|
ReasoningDetails: []schemas.ChatReasoningDetails{
|
|
{
|
|
Index: 0,
|
|
Type: schemas.BifrostReasoningDetailsTypeText,
|
|
Text: schemas.Ptr(thinkingText),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
return streamResponse, nil, false
|
|
}
|
|
}
|
|
|
|
case StreamEventToolPlanDelta:
|
|
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolPlan != nil {
|
|
streamResponse := &schemas.BifrostChatResponse{
|
|
Object: "chat.completion.chunk",
|
|
Choices: []schemas.BifrostResponseChoice{
|
|
{
|
|
Index: 0,
|
|
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
|
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
|
Reasoning: chunk.Delta.Message.ToolPlan,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
return streamResponse, nil, false
|
|
}
|
|
|
|
case StreamEventContentStart:
|
|
// Content start event - just continue, actual content comes in content-delta
|
|
return nil, nil, false
|
|
|
|
case StreamEventToolCallStart, StreamEventToolCallDelta:
|
|
if chunk.Delta != nil && chunk.Delta.Message != nil && chunk.Delta.Message.ToolCalls != nil && chunk.Delta.Message.ToolCalls.CohereToolCallObject != nil {
|
|
// Handle single tool call object (tool-call-start/delta events)
|
|
cohereToolCall := chunk.Delta.Message.ToolCalls.CohereToolCallObject
|
|
toolCall := schemas.ChatAssistantMessageToolCall{}
|
|
|
|
if chunk.Index != nil {
|
|
toolCall.Index = uint16(*chunk.Index)
|
|
}
|
|
|
|
if cohereToolCall.ID != nil {
|
|
toolCall.ID = cohereToolCall.ID
|
|
}
|
|
|
|
if cohereToolCall.Function != nil {
|
|
if cohereToolCall.Function.Name != nil {
|
|
toolCall.Function.Name = cohereToolCall.Function.Name
|
|
}
|
|
toolCall.Function.Arguments = cohereToolCall.Function.Arguments
|
|
}
|
|
|
|
streamResponse := &schemas.BifrostChatResponse{
|
|
Object: "chat.completion.chunk",
|
|
Choices: []schemas.BifrostResponseChoice{
|
|
{
|
|
Index: 0,
|
|
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
|
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
|
ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
return streamResponse, nil, false
|
|
}
|
|
|
|
case StreamEventToolCallEnd:
|
|
return nil, nil, false
|
|
|
|
case StreamEventContentEnd:
|
|
return nil, nil, false
|
|
|
|
case StreamEventMessageEnd:
|
|
if chunk.Delta != nil {
|
|
var finishReason string
|
|
usage := &schemas.BifrostLLMUsage{}
|
|
// Set finish reason
|
|
if chunk.Delta.FinishReason != nil {
|
|
finishReason = ConvertCohereFinishReasonToBifrost(*chunk.Delta.FinishReason)
|
|
}
|
|
|
|
// Set usage information
|
|
if chunk.Delta.Usage != nil {
|
|
if chunk.Delta.Usage.Tokens != nil {
|
|
if chunk.Delta.Usage.Tokens.InputTokens != nil {
|
|
usage.PromptTokens = *chunk.Delta.Usage.Tokens.InputTokens
|
|
}
|
|
if chunk.Delta.Usage.Tokens.OutputTokens != nil {
|
|
usage.CompletionTokens = *chunk.Delta.Usage.Tokens.OutputTokens
|
|
}
|
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
}
|
|
}
|
|
|
|
streamResponse := &schemas.BifrostChatResponse{
|
|
Object: "chat.completion.chunk",
|
|
Choices: []schemas.BifrostResponseChoice{
|
|
{
|
|
Index: 0,
|
|
FinishReason: &finishReason,
|
|
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
|
Delta: &schemas.ChatStreamResponseChoiceDelta{},
|
|
},
|
|
},
|
|
},
|
|
Usage: usage,
|
|
}
|
|
|
|
return streamResponse, nil, true
|
|
}
|
|
return nil, nil, false
|
|
}
|
|
|
|
return nil, nil, false
|
|
}
|
|
|
|
func (cm *CohereMessage) ToBifrostChatMessage() *schemas.ChatMessage {
|
|
if cm == nil {
|
|
return nil
|
|
}
|
|
|
|
var content *string
|
|
var contentBlocks []schemas.ChatContentBlock
|
|
var toolCalls []schemas.ChatAssistantMessageToolCall
|
|
var reasoningDetails []schemas.ChatReasoningDetails
|
|
var reasoningText string
|
|
|
|
// Convert message content
|
|
if cm.Content != nil {
|
|
if cm.Content.IsString() ||
|
|
(cm.Content.IsBlocks() &&
|
|
len(cm.Content.GetBlocks()) == 1 &&
|
|
cm.Content.GetBlocks()[0].Type == CohereContentBlockTypeText) {
|
|
if cm.Content.IsString() {
|
|
content = cm.Content.GetString()
|
|
} else {
|
|
content = cm.Content.GetBlocks()[0].Text
|
|
}
|
|
} else if cm.Content.IsBlocks() {
|
|
for _, block := range cm.Content.GetBlocks() {
|
|
if block.Type == CohereContentBlockTypeText && block.Text != nil {
|
|
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
|
Type: schemas.ChatContentBlockTypeText,
|
|
Text: block.Text,
|
|
})
|
|
} else if block.Type == CohereContentBlockTypeImage && block.ImageURL != nil {
|
|
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
|
Type: schemas.ChatContentBlockTypeImage,
|
|
ImageURLStruct: &schemas.ChatInputImage{
|
|
URL: block.ImageURL.URL,
|
|
},
|
|
})
|
|
} else if block.Type == CohereContentBlockTypeThinking && block.Thinking != nil {
|
|
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
|
Index: len(reasoningDetails),
|
|
Type: schemas.BifrostReasoningDetailsTypeText,
|
|
Text: block.Thinking,
|
|
})
|
|
if len(reasoningText) > 0 {
|
|
reasoningText += "\n"
|
|
}
|
|
reasoningText += *block.Thinking
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText {
|
|
content = contentBlocks[0].Text
|
|
contentBlocks = nil
|
|
}
|
|
|
|
// Create the message content
|
|
messageContent := &schemas.ChatMessageContent{
|
|
ContentStr: content,
|
|
ContentBlocks: contentBlocks,
|
|
}
|
|
|
|
// Convert tool calls
|
|
if cm.ToolCalls != nil {
|
|
for _, toolCall := range cm.ToolCalls {
|
|
// Check if Function is nil to avoid nil pointer dereference
|
|
if toolCall.Function == nil {
|
|
// Skip this tool call if Function is nil
|
|
continue
|
|
}
|
|
|
|
// Safely extract function name and arguments
|
|
var functionName *string
|
|
var functionArguments string
|
|
|
|
if toolCall.Function.Name != nil {
|
|
functionName = toolCall.Function.Name
|
|
} else {
|
|
// Use empty string if Name is nil
|
|
functionName = schemas.Ptr("")
|
|
}
|
|
|
|
// Arguments is a string, not a pointer, so it's safe to access directly
|
|
functionArguments = toolCall.Function.Arguments
|
|
|
|
bifrostToolCall := schemas.ChatAssistantMessageToolCall{
|
|
Index: uint16(len(toolCalls)),
|
|
ID: toolCall.ID,
|
|
Function: schemas.ChatAssistantMessageToolCallFunction{
|
|
Name: functionName,
|
|
Arguments: functionArguments,
|
|
},
|
|
}
|
|
toolCalls = append(toolCalls, bifrostToolCall)
|
|
}
|
|
}
|
|
|
|
// Create assistant message if we have tool calls
|
|
var assistantMessage *schemas.ChatAssistantMessage
|
|
if len(toolCalls) > 0 {
|
|
assistantMessage = &schemas.ChatAssistantMessage{
|
|
ToolCalls: toolCalls,
|
|
}
|
|
}
|
|
|
|
if len(reasoningDetails) > 0 {
|
|
if assistantMessage == nil {
|
|
assistantMessage = &schemas.ChatAssistantMessage{}
|
|
}
|
|
assistantMessage.ReasoningDetails = reasoningDetails
|
|
assistantMessage.Reasoning = schemas.Ptr(reasoningText)
|
|
}
|
|
|
|
bifrostMessage := &schemas.ChatMessage{
|
|
Role: schemas.ChatMessageRole(cm.Role),
|
|
Content: messageContent,
|
|
ChatAssistantMessage: assistantMessage,
|
|
}
|
|
|
|
if cm.Role == "tool" {
|
|
bifrostMessage.ChatToolMessage = &schemas.ChatToolMessage{
|
|
ToolCallID: cm.ToolCallID,
|
|
}
|
|
}
|
|
return bifrostMessage
|
|
}
|