package cohere import ( "sort" "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" "gopkg.in/yaml.v3" ) // ToCohereRerankRequest converts a Bifrost rerank request to Cohere format func ToCohereRerankRequest(bifrostReq *schemas.BifrostRerankRequest) *CohereRerankRequest { if bifrostReq == nil { return nil } cohereReq := &CohereRerankRequest{ Model: bifrostReq.Model, Query: bifrostReq.Query, } // Cohere v2 expects documents as a list of strings. documents := make([]string, len(bifrostReq.Documents)) for i, doc := range bifrostReq.Documents { documents[i] = formatCohereRerankDocument(doc) } cohereReq.Documents = documents if bifrostReq.Params != nil { cohereReq.TopN = bifrostReq.Params.TopN cohereReq.MaxTokensPerDoc = bifrostReq.Params.MaxTokensPerDoc cohereReq.Priority = bifrostReq.Params.Priority cohereReq.ExtraParams = bifrostReq.Params.ExtraParams } return cohereReq } // ToBifrostRerankRequest converts a Cohere rerank request to Bifrost format func (req *CohereRerankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest { if req == nil { return nil } provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere)) bifrostReq := &schemas.BifrostRerankRequest{ Provider: provider, Model: model, Query: req.Query, Params: &schemas.RerankParameters{}, } // Convert documents for _, doc := range req.Documents { bifrostReq.Documents = append(bifrostReq.Documents, schemas.RerankDocument{ Text: doc, }) } if req.TopN != nil { bifrostReq.Params.TopN = req.TopN } if req.MaxTokensPerDoc != nil { bifrostReq.Params.MaxTokensPerDoc = req.MaxTokensPerDoc } if req.Priority != nil { bifrostReq.Params.Priority = req.Priority } if req.ExtraParams != nil { bifrostReq.Params.ExtraParams = req.ExtraParams } return bifrostReq } // ToBifrostRerankResponse converts a Cohere rerank response to Bifrost format. func (response *CohereRerankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) *schemas.BifrostRerankResponse { if response == nil { return nil } bifrostResponse := &schemas.BifrostRerankResponse{ ID: response.ID, } // Convert results for _, result := range response.Results { rerankResult := schemas.RerankResult{ Index: result.Index, RelevanceScore: result.RelevanceScore, } // Convert document if present if len(result.Document) > 0 { var docMap map[string]interface{} if err := sonic.Unmarshal(result.Document, &docMap); err == nil { doc := &schemas.RerankDocument{} populated := false if text, ok := docMap["text"].(string); ok { doc.Text = text populated = true } if id, ok := docMap["id"].(string); ok { doc.ID = &id populated = true } // Collect metadata: unwrap "metadata"/"meta" keys to avoid nesting meta := make(map[string]interface{}) if rawMeta, ok := docMap["metadata"].(map[string]interface{}); ok { for k, v := range rawMeta { meta[k] = v } } else if rawMeta, ok := docMap["meta"].(map[string]interface{}); ok { for k, v := range rawMeta { meta[k] = v } } for k, v := range docMap { if k != "text" && k != "id" && k != "metadata" && k != "meta" { meta[k] = v } } if len(meta) > 0 { doc.Meta = meta populated = true } if populated { rerankResult.Document = doc } } } bifrostResponse.Results = append(bifrostResponse.Results, rerankResult) } sort.SliceStable(bifrostResponse.Results, func(i, j int) bool { if bifrostResponse.Results[i].RelevanceScore == bifrostResponse.Results[j].RelevanceScore { return bifrostResponse.Results[i].Index < bifrostResponse.Results[j].Index } return bifrostResponse.Results[i].RelevanceScore > bifrostResponse.Results[j].RelevanceScore }) if returnDocuments { for i := range bifrostResponse.Results { resultIndex := bifrostResponse.Results[i].Index if resultIndex >= 0 && resultIndex < len(documents) { bifrostResponse.Results[i].Document = schemas.Ptr(documents[resultIndex]) } } } // Convert usage information if response.Meta != nil { promptTokens := 0 completionTokens := 0 hasTokenUsage := false if response.Meta.Tokens != nil { if response.Meta.Tokens.InputTokens != nil { promptTokens = int(*response.Meta.Tokens.InputTokens) hasTokenUsage = true } if response.Meta.Tokens.OutputTokens != nil { completionTokens = int(*response.Meta.Tokens.OutputTokens) hasTokenUsage = true } } else if response.Meta.BilledUnits != nil { if response.Meta.BilledUnits.InputTokens != nil { promptTokens = int(*response.Meta.BilledUnits.InputTokens) hasTokenUsage = true } if response.Meta.BilledUnits.OutputTokens != nil { completionTokens = int(*response.Meta.BilledUnits.OutputTokens) hasTokenUsage = true } } if hasTokenUsage { bifrostResponse.Usage = &schemas.BifrostLLMUsage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, } } } return bifrostResponse } func formatCohereRerankDocument(doc schemas.RerankDocument) string { if doc.ID == nil && len(doc.Meta) == 0 { return doc.Text } // Keep metadata/id available by encoding a structured string document. documentPayload := map[string]interface{}{ "text": doc.Text, } if doc.ID != nil { documentPayload["id"] = *doc.ID } if len(doc.Meta) > 0 { documentPayload["metadata"] = doc.Meta } encoded, err := yaml.Marshal(documentPayload) if err != nil { return doc.Text } return string(encoded) }