210 lines
5.7 KiB
Go
210 lines
5.7 KiB
Go
package cohere
|
|
|
|
import (
|
|
"sort"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/maximhq/bifrost/core/providers/utils"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
// ToCohereRerankRequest converts a Bifrost rerank request to Cohere format
|
|
func ToCohereRerankRequest(bifrostReq *schemas.BifrostRerankRequest) *CohereRerankRequest {
|
|
if bifrostReq == nil {
|
|
return nil
|
|
}
|
|
|
|
cohereReq := &CohereRerankRequest{
|
|
Model: bifrostReq.Model,
|
|
Query: bifrostReq.Query,
|
|
}
|
|
|
|
// Cohere v2 expects documents as a list of strings.
|
|
documents := make([]string, len(bifrostReq.Documents))
|
|
for i, doc := range bifrostReq.Documents {
|
|
documents[i] = formatCohereRerankDocument(doc)
|
|
}
|
|
cohereReq.Documents = documents
|
|
|
|
if bifrostReq.Params != nil {
|
|
cohereReq.TopN = bifrostReq.Params.TopN
|
|
cohereReq.MaxTokensPerDoc = bifrostReq.Params.MaxTokensPerDoc
|
|
cohereReq.Priority = bifrostReq.Params.Priority
|
|
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
|
|
}
|
|
|
|
return cohereReq
|
|
}
|
|
|
|
// ToBifrostRerankRequest converts a Cohere rerank request to Bifrost format
|
|
func (req *CohereRerankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
|
|
if req == nil {
|
|
return nil
|
|
}
|
|
|
|
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
|
|
|
|
bifrostReq := &schemas.BifrostRerankRequest{
|
|
Provider: provider,
|
|
Model: model,
|
|
Query: req.Query,
|
|
Params: &schemas.RerankParameters{},
|
|
}
|
|
|
|
// Convert documents
|
|
for _, doc := range req.Documents {
|
|
bifrostReq.Documents = append(bifrostReq.Documents, schemas.RerankDocument{
|
|
Text: doc,
|
|
})
|
|
}
|
|
|
|
if req.TopN != nil {
|
|
bifrostReq.Params.TopN = req.TopN
|
|
}
|
|
if req.MaxTokensPerDoc != nil {
|
|
bifrostReq.Params.MaxTokensPerDoc = req.MaxTokensPerDoc
|
|
}
|
|
if req.Priority != nil {
|
|
bifrostReq.Params.Priority = req.Priority
|
|
}
|
|
if req.ExtraParams != nil {
|
|
bifrostReq.Params.ExtraParams = req.ExtraParams
|
|
}
|
|
|
|
return bifrostReq
|
|
}
|
|
|
|
// ToBifrostRerankResponse converts a Cohere rerank response to Bifrost format.
|
|
func (response *CohereRerankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) *schemas.BifrostRerankResponse {
|
|
if response == nil {
|
|
return nil
|
|
}
|
|
|
|
bifrostResponse := &schemas.BifrostRerankResponse{
|
|
ID: response.ID,
|
|
}
|
|
|
|
// Convert results
|
|
for _, result := range response.Results {
|
|
rerankResult := schemas.RerankResult{
|
|
Index: result.Index,
|
|
RelevanceScore: result.RelevanceScore,
|
|
}
|
|
|
|
// Convert document if present
|
|
if len(result.Document) > 0 {
|
|
var docMap map[string]interface{}
|
|
if err := sonic.Unmarshal(result.Document, &docMap); err == nil {
|
|
doc := &schemas.RerankDocument{}
|
|
populated := false
|
|
if text, ok := docMap["text"].(string); ok {
|
|
doc.Text = text
|
|
populated = true
|
|
}
|
|
if id, ok := docMap["id"].(string); ok {
|
|
doc.ID = &id
|
|
populated = true
|
|
}
|
|
// Collect metadata: unwrap "metadata"/"meta" keys to avoid nesting
|
|
meta := make(map[string]interface{})
|
|
if rawMeta, ok := docMap["metadata"].(map[string]interface{}); ok {
|
|
for k, v := range rawMeta {
|
|
meta[k] = v
|
|
}
|
|
} else if rawMeta, ok := docMap["meta"].(map[string]interface{}); ok {
|
|
for k, v := range rawMeta {
|
|
meta[k] = v
|
|
}
|
|
}
|
|
for k, v := range docMap {
|
|
if k != "text" && k != "id" && k != "metadata" && k != "meta" {
|
|
meta[k] = v
|
|
}
|
|
}
|
|
if len(meta) > 0 {
|
|
doc.Meta = meta
|
|
populated = true
|
|
}
|
|
if populated {
|
|
rerankResult.Document = doc
|
|
}
|
|
}
|
|
}
|
|
|
|
bifrostResponse.Results = append(bifrostResponse.Results, rerankResult)
|
|
}
|
|
sort.SliceStable(bifrostResponse.Results, func(i, j int) bool {
|
|
if bifrostResponse.Results[i].RelevanceScore == bifrostResponse.Results[j].RelevanceScore {
|
|
return bifrostResponse.Results[i].Index < bifrostResponse.Results[j].Index
|
|
}
|
|
return bifrostResponse.Results[i].RelevanceScore > bifrostResponse.Results[j].RelevanceScore
|
|
})
|
|
if returnDocuments {
|
|
for i := range bifrostResponse.Results {
|
|
resultIndex := bifrostResponse.Results[i].Index
|
|
if resultIndex >= 0 && resultIndex < len(documents) {
|
|
bifrostResponse.Results[i].Document = schemas.Ptr(documents[resultIndex])
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert usage information
|
|
if response.Meta != nil {
|
|
promptTokens := 0
|
|
completionTokens := 0
|
|
hasTokenUsage := false
|
|
if response.Meta.Tokens != nil {
|
|
if response.Meta.Tokens.InputTokens != nil {
|
|
promptTokens = int(*response.Meta.Tokens.InputTokens)
|
|
hasTokenUsage = true
|
|
}
|
|
if response.Meta.Tokens.OutputTokens != nil {
|
|
completionTokens = int(*response.Meta.Tokens.OutputTokens)
|
|
hasTokenUsage = true
|
|
}
|
|
} else if response.Meta.BilledUnits != nil {
|
|
if response.Meta.BilledUnits.InputTokens != nil {
|
|
promptTokens = int(*response.Meta.BilledUnits.InputTokens)
|
|
hasTokenUsage = true
|
|
}
|
|
if response.Meta.BilledUnits.OutputTokens != nil {
|
|
completionTokens = int(*response.Meta.BilledUnits.OutputTokens)
|
|
hasTokenUsage = true
|
|
}
|
|
}
|
|
if hasTokenUsage {
|
|
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
|
|
PromptTokens: promptTokens,
|
|
CompletionTokens: completionTokens,
|
|
TotalTokens: promptTokens + completionTokens,
|
|
}
|
|
}
|
|
}
|
|
|
|
return bifrostResponse
|
|
}
|
|
|
|
func formatCohereRerankDocument(doc schemas.RerankDocument) string {
|
|
if doc.ID == nil && len(doc.Meta) == 0 {
|
|
return doc.Text
|
|
}
|
|
|
|
// Keep metadata/id available by encoding a structured string document.
|
|
documentPayload := map[string]interface{}{
|
|
"text": doc.Text,
|
|
}
|
|
if doc.ID != nil {
|
|
documentPayload["id"] = *doc.ID
|
|
}
|
|
if len(doc.Meta) > 0 {
|
|
documentPayload["metadata"] = doc.Meta
|
|
}
|
|
|
|
encoded, err := yaml.Marshal(documentPayload)
|
|
if err != nil {
|
|
return doc.Text
|
|
}
|
|
return string(encoded)
|
|
}
|