first commit
This commit is contained in:
209
core/providers/cohere/rerank.go
Normal file
209
core/providers/cohere/rerank.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package cohere
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ToCohereRerankRequest converts a Bifrost rerank request to Cohere format
|
||||
func ToCohereRerankRequest(bifrostReq *schemas.BifrostRerankRequest) *CohereRerankRequest {
|
||||
if bifrostReq == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cohereReq := &CohereRerankRequest{
|
||||
Model: bifrostReq.Model,
|
||||
Query: bifrostReq.Query,
|
||||
}
|
||||
|
||||
// Cohere v2 expects documents as a list of strings.
|
||||
documents := make([]string, len(bifrostReq.Documents))
|
||||
for i, doc := range bifrostReq.Documents {
|
||||
documents[i] = formatCohereRerankDocument(doc)
|
||||
}
|
||||
cohereReq.Documents = documents
|
||||
|
||||
if bifrostReq.Params != nil {
|
||||
cohereReq.TopN = bifrostReq.Params.TopN
|
||||
cohereReq.MaxTokensPerDoc = bifrostReq.Params.MaxTokensPerDoc
|
||||
cohereReq.Priority = bifrostReq.Params.Priority
|
||||
cohereReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
return cohereReq
|
||||
}
|
||||
|
||||
// ToBifrostRerankRequest converts a Cohere rerank request to Bifrost format
|
||||
func (req *CohereRerankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(req.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Cohere))
|
||||
|
||||
bifrostReq := &schemas.BifrostRerankRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Query: req.Query,
|
||||
Params: &schemas.RerankParameters{},
|
||||
}
|
||||
|
||||
// Convert documents
|
||||
for _, doc := range req.Documents {
|
||||
bifrostReq.Documents = append(bifrostReq.Documents, schemas.RerankDocument{
|
||||
Text: doc,
|
||||
})
|
||||
}
|
||||
|
||||
if req.TopN != nil {
|
||||
bifrostReq.Params.TopN = req.TopN
|
||||
}
|
||||
if req.MaxTokensPerDoc != nil {
|
||||
bifrostReq.Params.MaxTokensPerDoc = req.MaxTokensPerDoc
|
||||
}
|
||||
if req.Priority != nil {
|
||||
bifrostReq.Params.Priority = req.Priority
|
||||
}
|
||||
if req.ExtraParams != nil {
|
||||
bifrostReq.Params.ExtraParams = req.ExtraParams
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToBifrostRerankResponse converts a Cohere rerank response to Bifrost format.
|
||||
func (response *CohereRerankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) *schemas.BifrostRerankResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostRerankResponse{
|
||||
ID: response.ID,
|
||||
}
|
||||
|
||||
// Convert results
|
||||
for _, result := range response.Results {
|
||||
rerankResult := schemas.RerankResult{
|
||||
Index: result.Index,
|
||||
RelevanceScore: result.RelevanceScore,
|
||||
}
|
||||
|
||||
// Convert document if present
|
||||
if len(result.Document) > 0 {
|
||||
var docMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(result.Document, &docMap); err == nil {
|
||||
doc := &schemas.RerankDocument{}
|
||||
populated := false
|
||||
if text, ok := docMap["text"].(string); ok {
|
||||
doc.Text = text
|
||||
populated = true
|
||||
}
|
||||
if id, ok := docMap["id"].(string); ok {
|
||||
doc.ID = &id
|
||||
populated = true
|
||||
}
|
||||
// Collect metadata: unwrap "metadata"/"meta" keys to avoid nesting
|
||||
meta := make(map[string]interface{})
|
||||
if rawMeta, ok := docMap["metadata"].(map[string]interface{}); ok {
|
||||
for k, v := range rawMeta {
|
||||
meta[k] = v
|
||||
}
|
||||
} else if rawMeta, ok := docMap["meta"].(map[string]interface{}); ok {
|
||||
for k, v := range rawMeta {
|
||||
meta[k] = v
|
||||
}
|
||||
}
|
||||
for k, v := range docMap {
|
||||
if k != "text" && k != "id" && k != "metadata" && k != "meta" {
|
||||
meta[k] = v
|
||||
}
|
||||
}
|
||||
if len(meta) > 0 {
|
||||
doc.Meta = meta
|
||||
populated = true
|
||||
}
|
||||
if populated {
|
||||
rerankResult.Document = doc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Results = append(bifrostResponse.Results, rerankResult)
|
||||
}
|
||||
sort.SliceStable(bifrostResponse.Results, func(i, j int) bool {
|
||||
if bifrostResponse.Results[i].RelevanceScore == bifrostResponse.Results[j].RelevanceScore {
|
||||
return bifrostResponse.Results[i].Index < bifrostResponse.Results[j].Index
|
||||
}
|
||||
return bifrostResponse.Results[i].RelevanceScore > bifrostResponse.Results[j].RelevanceScore
|
||||
})
|
||||
if returnDocuments {
|
||||
for i := range bifrostResponse.Results {
|
||||
resultIndex := bifrostResponse.Results[i].Index
|
||||
if resultIndex >= 0 && resultIndex < len(documents) {
|
||||
bifrostResponse.Results[i].Document = schemas.Ptr(documents[resultIndex])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert usage information
|
||||
if response.Meta != nil {
|
||||
promptTokens := 0
|
||||
completionTokens := 0
|
||||
hasTokenUsage := false
|
||||
if response.Meta.Tokens != nil {
|
||||
if response.Meta.Tokens.InputTokens != nil {
|
||||
promptTokens = int(*response.Meta.Tokens.InputTokens)
|
||||
hasTokenUsage = true
|
||||
}
|
||||
if response.Meta.Tokens.OutputTokens != nil {
|
||||
completionTokens = int(*response.Meta.Tokens.OutputTokens)
|
||||
hasTokenUsage = true
|
||||
}
|
||||
} else if response.Meta.BilledUnits != nil {
|
||||
if response.Meta.BilledUnits.InputTokens != nil {
|
||||
promptTokens = int(*response.Meta.BilledUnits.InputTokens)
|
||||
hasTokenUsage = true
|
||||
}
|
||||
if response.Meta.BilledUnits.OutputTokens != nil {
|
||||
completionTokens = int(*response.Meta.BilledUnits.OutputTokens)
|
||||
hasTokenUsage = true
|
||||
}
|
||||
}
|
||||
if hasTokenUsage {
|
||||
bifrostResponse.Usage = &schemas.BifrostLLMUsage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
func formatCohereRerankDocument(doc schemas.RerankDocument) string {
|
||||
if doc.ID == nil && len(doc.Meta) == 0 {
|
||||
return doc.Text
|
||||
}
|
||||
|
||||
// Keep metadata/id available by encoding a structured string document.
|
||||
documentPayload := map[string]interface{}{
|
||||
"text": doc.Text,
|
||||
}
|
||||
if doc.ID != nil {
|
||||
documentPayload["id"] = *doc.ID
|
||||
}
|
||||
if len(doc.Meta) > 0 {
|
||||
documentPayload["metadata"] = doc.Meta
|
||||
}
|
||||
|
||||
encoded, err := yaml.Marshal(documentPayload)
|
||||
if err != nil {
|
||||
return doc.Text
|
||||
}
|
||||
return string(encoded)
|
||||
}
|
||||
Reference in New Issue
Block a user