Files
bifrost/core/providers/vllm/rerank.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

150 lines
4.3 KiB
Go

package vllm
import (
"fmt"
"sort"
schemas "github.com/maximhq/bifrost/core/schemas"
)
// ToVLLMRerankRequest converts a Bifrost rerank request to vLLM format.
func ToVLLMRerankRequest(bifrostReq *schemas.BifrostRerankRequest) *vLLMRerankRequest {
if bifrostReq == nil {
return nil
}
vllmReq := &vLLMRerankRequest{
Model: bifrostReq.Model,
Query: bifrostReq.Query,
Documents: make([]string, len(bifrostReq.Documents)),
}
for i, doc := range bifrostReq.Documents {
vllmReq.Documents[i] = doc.Text
}
if bifrostReq.Params != nil {
vllmReq.TopN = bifrostReq.Params.TopN
vllmReq.MaxTokensPerDoc = bifrostReq.Params.MaxTokensPerDoc
vllmReq.Priority = bifrostReq.Params.Priority
vllmReq.ExtraParams = bifrostReq.Params.ExtraParams
}
return vllmReq
}
// ToBifrostRerankResponse converts a vLLM rerank response payload to Bifrost format.
func ToBifrostRerankResponse(payload map[string]interface{}, documents []schemas.RerankDocument, returnDocuments bool) (*schemas.BifrostRerankResponse, error) {
if payload == nil {
return nil, fmt.Errorf("vllm rerank response is nil")
}
response := &schemas.BifrostRerankResponse{}
if id, ok := schemas.SafeExtractString(payload["id"]); ok {
response.ID = id
}
if model, ok := schemas.SafeExtractString(payload["model"]); ok {
response.Model = model
}
if usage, ok := parseVLLMUsage(payload["usage"]); ok {
response.Usage = usage
}
resultsRaw := payload["results"]
if resultsRaw == nil {
return nil, fmt.Errorf("invalid vllm rerank response: missing results")
}
resultItems, ok := resultsRaw.([]interface{})
if !ok {
return nil, fmt.Errorf("invalid vllm rerank response: results must be an array")
}
seenIndices := make(map[int]struct{}, len(resultItems))
response.Results = make([]schemas.RerankResult, 0, len(resultItems))
for _, item := range resultItems {
itemMap, ok := item.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("invalid vllm rerank response: result item must be an object")
}
index, ok := schemas.SafeExtractInt(itemMap["index"])
if !ok {
return nil, fmt.Errorf("invalid vllm rerank response: result index is required")
}
if index < 0 || index >= len(documents) {
return nil, fmt.Errorf("invalid vllm rerank response: result index %d out of range", index)
}
if _, exists := seenIndices[index]; exists {
return nil, fmt.Errorf("invalid vllm rerank response: duplicate index %d", index)
}
seenIndices[index] = struct{}{}
relevanceScore, ok := schemas.SafeExtractFloat64(itemMap["relevance_score"])
if !ok {
relevanceScore, ok = schemas.SafeExtractFloat64(itemMap["score"])
}
if !ok {
return nil, fmt.Errorf("invalid vllm rerank response: relevance_score/score is required")
}
result := schemas.RerankResult{
Index: index,
RelevanceScore: relevanceScore,
}
if returnDocuments {
doc := documents[index]
result.Document = &doc
}
response.Results = append(response.Results, result)
}
sort.SliceStable(response.Results, func(i, j int) bool {
if response.Results[i].RelevanceScore == response.Results[j].RelevanceScore {
return response.Results[i].Index < response.Results[j].Index
}
return response.Results[i].RelevanceScore > response.Results[j].RelevanceScore
})
return response, nil
}
func parseVLLMUsage(rawUsage interface{}) (*schemas.BifrostLLMUsage, bool) {
usageMap, ok := rawUsage.(map[string]interface{})
if !ok {
return nil, false
}
promptTokens := 0
if _, hasPromptTokens := usageMap["prompt_tokens"]; hasPromptTokens {
promptTokens, _ = schemas.SafeExtractInt(usageMap["prompt_tokens"])
} else {
promptTokens, _ = schemas.SafeExtractInt(usageMap["input_tokens"])
}
completionTokens := 0
if _, hasCompletionTokens := usageMap["completion_tokens"]; hasCompletionTokens {
completionTokens, _ = schemas.SafeExtractInt(usageMap["completion_tokens"])
} else {
completionTokens, _ = schemas.SafeExtractInt(usageMap["output_tokens"])
}
totalTokens, ok := schemas.SafeExtractInt(usageMap["total_tokens"])
if !ok {
totalTokens = promptTokens + completionTokens
}
if promptTokens == 0 && completionTokens == 0 && totalTokens == 0 {
return nil, false
}
return &schemas.BifrostLLMUsage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: totalTokens,
}, true
}