Files
bifrost/core/providers/cohere/rerank.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

210 lines
5.7 KiB
Go

package cohere
import (
"sort"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"gopkg.in/yaml.v3"
)
// ToCohereRerankRequest converts a Bifrost rerank request to Cohere format
func ToCohereRerankRequest(bifrostReq *schemas.BifrostRerankRequest) *CohereRerankRequest {
if bifrostReq == nil {
return nil
}
cohereReq := &CohereRerankRequest{
Model: bifrostReq.Model,
Query: bifrostReq.Query,
}
// Cohere v2 expects documents as a list of strings.
documents := make([]string, len(bifrostReq.Documents))
for i, doc := range bifrostReq.Documents {
documents[i] = formatCohereRerankDocument(doc)
}
cohereReq.Documents = documents
if bifrostReq.Params != nil {
cohereReq.TopN = bifrostReq.Params.TopN
cohereReq.MaxTokensPerDoc = bifrostReq.Params.MaxTokensPerDoc
cohereReq.Priority = bifrostReq.Params.Priority
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
}
return cohereReq
}
// ToBifrostRerankRequest converts a Cohere rerank request to Bifrost format
func (req *CohereRerankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
if req == nil {
return nil
}
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
bifrostReq := &schemas.BifrostRerankRequest{
Provider: provider,
Model: model,
Query: req.Query,
Params: &schemas.RerankParameters{},
}
// Convert documents
for _, doc := range req.Documents {
bifrostReq.Documents = append(bifrostReq.Documents, schemas.RerankDocument{
Text: doc,
})
}
if req.TopN != nil {
bifrostReq.Params.TopN = req.TopN
}
if req.MaxTokensPerDoc != nil {
bifrostReq.Params.MaxTokensPerDoc = req.MaxTokensPerDoc
}
if req.Priority != nil {
bifrostReq.Params.Priority = req.Priority
}
if req.ExtraParams != nil {
bifrostReq.Params.ExtraParams = req.ExtraParams
}
return bifrostReq
}
// ToBifrostRerankResponse converts a Cohere rerank response to Bifrost format.
func (response *CohereRerankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) *schemas.BifrostRerankResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostRerankResponse{
ID: response.ID,
}
// Convert results
for _, result := range response.Results {
rerankResult := schemas.RerankResult{
Index: result.Index,
RelevanceScore: result.RelevanceScore,
}
// Convert document if present
if len(result.Document) > 0 {
var docMap map[string]interface{}
if err := sonic.Unmarshal(result.Document, &docMap); err == nil {
doc := &schemas.RerankDocument{}
populated := false
if text, ok := docMap["text"].(string); ok {
doc.Text = text
populated = true
}
if id, ok := docMap["id"].(string); ok {
doc.ID = &id
populated = true
}
// Collect metadata: unwrap "metadata"/"meta" keys to avoid nesting
meta := make(map[string]interface{})
if rawMeta, ok := docMap["metadata"].(map[string]interface{}); ok {
for k, v := range rawMeta {
meta[k] = v
}
} else if rawMeta, ok := docMap["meta"].(map[string]interface{}); ok {
for k, v := range rawMeta {
meta[k] = v
}
}
for k, v := range docMap {
if k != "text" && k != "id" && k != "metadata" && k != "meta" {
meta[k] = v
}
}
if len(meta) > 0 {
doc.Meta = meta
populated = true
}
if populated {
rerankResult.Document = doc
}
}
}
bifrostResponse.Results = append(bifrostResponse.Results, rerankResult)
}
sort.SliceStable(bifrostResponse.Results, func(i, j int) bool {
if bifrostResponse.Results[i].RelevanceScore == bifrostResponse.Results[j].RelevanceScore {
return bifrostResponse.Results[i].Index < bifrostResponse.Results[j].Index
}
return bifrostResponse.Results[i].RelevanceScore > bifrostResponse.Results[j].RelevanceScore
})
if returnDocuments {
for i := range bifrostResponse.Results {
resultIndex := bifrostResponse.Results[i].Index
if resultIndex >= 0 && resultIndex < len(documents) {
bifrostResponse.Results[i].Document = schemas.Ptr(documents[resultIndex])
}
}
}
// Convert usage information
if response.Meta != nil {
promptTokens := 0
completionTokens := 0
hasTokenUsage := false
if response.Meta.Tokens != nil {
if response.Meta.Tokens.InputTokens != nil {
promptTokens = int(*response.Meta.Tokens.InputTokens)
hasTokenUsage = true
}
if response.Meta.Tokens.OutputTokens != nil {
completionTokens = int(*response.Meta.Tokens.OutputTokens)
hasTokenUsage = true
}
} else if response.Meta.BilledUnits != nil {
if response.Meta.BilledUnits.InputTokens != nil {
promptTokens = int(*response.Meta.BilledUnits.InputTokens)
hasTokenUsage = true
}
if response.Meta.BilledUnits.OutputTokens != nil {
completionTokens = int(*response.Meta.BilledUnits.OutputTokens)
hasTokenUsage = true
}
}
if hasTokenUsage {
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
}
return bifrostResponse
}
func formatCohereRerankDocument(doc schemas.RerankDocument) string {
if doc.ID == nil && len(doc.Meta) == 0 {
return doc.Text
}
// Keep metadata/id available by encoding a structured string document.
documentPayload := map[string]interface{}{
"text": doc.Text,
}
if doc.ID != nil {
documentPayload["id"] = *doc.ID
}
if len(doc.Meta) > 0 {
documentPayload["metadata"] = doc.Meta
}
encoded, err := yaml.Marshal(documentPayload)
if err != nil {
return doc.Text
}
return string(encoded)
}