Files
bifrost/core/providers/cohere/count_tokens.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

126 lines
3.5 KiB
Go

package cohere
import (
"fmt"
"strings"
"unicode/utf8"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBifrostResponsesRequest converts a Cohere count tokens request to Bifrost format.
func (req *CohereCountTokensRequest) ToBifrostResponsesRequest(ctx *schemas.BifrostContext) *schemas.BifrostResponsesRequest {
if req == nil {
return nil
}
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
userRole := schemas.ResponsesInputMessageRoleUser
return &schemas.BifrostResponsesRequest{
Provider: provider,
Model: model,
Input: []schemas.ResponsesMessage{
{
Role: &userRole,
Content: &schemas.ResponsesMessageContent{
ContentStr: &req.Text,
},
},
},
}
}
// ToCohereCountTokensRequest converts a Bifrost count tokens request to Cohere's tokenize payload.
func ToCohereCountTokensRequest(bifrostReq *schemas.BifrostResponsesRequest) (*CohereCountTokensRequest, error) {
if bifrostReq == nil {
return nil, nil
}
if bifrostReq.Input == nil {
return nil, fmt.Errorf("count tokens input is not provided")
}
text := buildCohereCountTokensText(bifrostReq.Input)
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return nil, fmt.Errorf("count tokens text is empty after conversion")
}
runeCount := utf8.RuneCountInString(trimmed)
if runeCount < cohereTokenizeMinTextLength || runeCount > cohereTokenizeMaxTextLength {
return nil, fmt.Errorf("count tokens text length must be between %d and %d characters", cohereTokenizeMinTextLength, cohereTokenizeMaxTextLength)
}
cohereReq := &CohereCountTokensRequest{
Model: bifrostReq.Model,
Text: trimmed,
}
if bifrostReq.Params != nil {
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
}
return cohereReq, nil
}
// ToBifrostCountTokensResponse converts a Cohere tokenize response to Bifrost format.
func (resp *CohereCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
if resp == nil {
return nil
}
inputTokens := len(resp.Tokens)
if inputTokens == 0 && len(resp.TokenStrings) > 0 {
inputTokens = len(resp.TokenStrings)
}
totalTokens := inputTokens
return &schemas.BifrostCountTokensResponse{
Model: model,
InputTokens: inputTokens,
TotalTokens: &totalTokens,
TokenStrings: resp.TokenStrings,
Tokens: resp.Tokens,
Object: "response.input_tokens",
}
}
// buildCohereCountTokensText flattens Responses messages into a plain text payload for tokenization.
func buildCohereCountTokensText(messages []schemas.ResponsesMessage) string {
var parts []string
for _, msg := range messages {
var contentParts []string
if msg.Content != nil {
if msg.Content.ContentStr != nil {
contentParts = append(contentParts, *msg.Content.ContentStr)
}
for _, block := range msg.Content.ContentBlocks {
if block.Text != nil {
contentParts = append(contentParts, *block.Text)
}
if block.ResponsesOutputMessageContentRefusal != nil && block.ResponsesOutputMessageContentRefusal.Refusal != "" {
contentParts = append(contentParts, block.ResponsesOutputMessageContentRefusal.Refusal)
}
}
}
if msg.ResponsesReasoning != nil {
for _, summary := range msg.ResponsesReasoning.Summary {
if summary.Text != "" {
contentParts = append(contentParts, summary.Text)
}
}
}
if len(contentParts) == 0 {
continue
}
parts = append(parts, strings.Join(contentParts, "\n"))
}
return strings.TrimSpace(strings.Join(parts, "\n"))
}