150 lines
4.3 KiB
Go
150 lines
4.3 KiB
Go
package vllm
|
|
|
|
import (
|
|
"fmt"
|
|
"sort"
|
|
|
|
schemas "github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// ToVLLMRerankRequest converts a Bifrost rerank request to vLLM format.
|
|
func ToVLLMRerankRequest(bifrostReq *schemas.BifrostRerankRequest) *vLLMRerankRequest {
|
|
if bifrostReq == nil {
|
|
return nil
|
|
}
|
|
|
|
vllmReq := &vLLMRerankRequest{
|
|
Model: bifrostReq.Model,
|
|
Query: bifrostReq.Query,
|
|
Documents: make([]string, len(bifrostReq.Documents)),
|
|
}
|
|
|
|
for i, doc := range bifrostReq.Documents {
|
|
vllmReq.Documents[i] = doc.Text
|
|
}
|
|
|
|
if bifrostReq.Params != nil {
|
|
vllmReq.TopN = bifrostReq.Params.TopN
|
|
vllmReq.MaxTokensPerDoc = bifrostReq.Params.MaxTokensPerDoc
|
|
vllmReq.Priority = bifrostReq.Params.Priority
|
|
vllmReq.ExtraParams = bifrostReq.Params.ExtraParams
|
|
}
|
|
|
|
return vllmReq
|
|
}
|
|
|
|
// ToBifrostRerankResponse converts a vLLM rerank response payload to Bifrost format.
|
|
func ToBifrostRerankResponse(payload map[string]interface{}, documents []schemas.RerankDocument, returnDocuments bool) (*schemas.BifrostRerankResponse, error) {
|
|
if payload == nil {
|
|
return nil, fmt.Errorf("vllm rerank response is nil")
|
|
}
|
|
|
|
response := &schemas.BifrostRerankResponse{}
|
|
|
|
if id, ok := schemas.SafeExtractString(payload["id"]); ok {
|
|
response.ID = id
|
|
}
|
|
if model, ok := schemas.SafeExtractString(payload["model"]); ok {
|
|
response.Model = model
|
|
}
|
|
if usage, ok := parseVLLMUsage(payload["usage"]); ok {
|
|
response.Usage = usage
|
|
}
|
|
|
|
resultsRaw := payload["results"]
|
|
if resultsRaw == nil {
|
|
return nil, fmt.Errorf("invalid vllm rerank response: missing results")
|
|
}
|
|
|
|
resultItems, ok := resultsRaw.([]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid vllm rerank response: results must be an array")
|
|
}
|
|
|
|
seenIndices := make(map[int]struct{}, len(resultItems))
|
|
response.Results = make([]schemas.RerankResult, 0, len(resultItems))
|
|
|
|
for _, item := range resultItems {
|
|
itemMap, ok := item.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid vllm rerank response: result item must be an object")
|
|
}
|
|
|
|
index, ok := schemas.SafeExtractInt(itemMap["index"])
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid vllm rerank response: result index is required")
|
|
}
|
|
if index < 0 || index >= len(documents) {
|
|
return nil, fmt.Errorf("invalid vllm rerank response: result index %d out of range", index)
|
|
}
|
|
if _, exists := seenIndices[index]; exists {
|
|
return nil, fmt.Errorf("invalid vllm rerank response: duplicate index %d", index)
|
|
}
|
|
seenIndices[index] = struct{}{}
|
|
|
|
relevanceScore, ok := schemas.SafeExtractFloat64(itemMap["relevance_score"])
|
|
if !ok {
|
|
relevanceScore, ok = schemas.SafeExtractFloat64(itemMap["score"])
|
|
}
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid vllm rerank response: relevance_score/score is required")
|
|
}
|
|
|
|
result := schemas.RerankResult{
|
|
Index: index,
|
|
RelevanceScore: relevanceScore,
|
|
}
|
|
|
|
if returnDocuments {
|
|
doc := documents[index]
|
|
result.Document = &doc
|
|
}
|
|
|
|
response.Results = append(response.Results, result)
|
|
}
|
|
|
|
sort.SliceStable(response.Results, func(i, j int) bool {
|
|
if response.Results[i].RelevanceScore == response.Results[j].RelevanceScore {
|
|
return response.Results[i].Index < response.Results[j].Index
|
|
}
|
|
return response.Results[i].RelevanceScore > response.Results[j].RelevanceScore
|
|
})
|
|
|
|
return response, nil
|
|
}
|
|
|
|
func parseVLLMUsage(rawUsage interface{}) (*schemas.BifrostLLMUsage, bool) {
|
|
usageMap, ok := rawUsage.(map[string]interface{})
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
|
|
promptTokens := 0
|
|
if _, hasPromptTokens := usageMap["prompt_tokens"]; hasPromptTokens {
|
|
promptTokens, _ = schemas.SafeExtractInt(usageMap["prompt_tokens"])
|
|
} else {
|
|
promptTokens, _ = schemas.SafeExtractInt(usageMap["input_tokens"])
|
|
}
|
|
|
|
completionTokens := 0
|
|
if _, hasCompletionTokens := usageMap["completion_tokens"]; hasCompletionTokens {
|
|
completionTokens, _ = schemas.SafeExtractInt(usageMap["completion_tokens"])
|
|
} else {
|
|
completionTokens, _ = schemas.SafeExtractInt(usageMap["output_tokens"])
|
|
}
|
|
|
|
totalTokens, ok := schemas.SafeExtractInt(usageMap["total_tokens"])
|
|
if !ok {
|
|
totalTokens = promptTokens + completionTokens
|
|
}
|
|
if promptTokens == 0 && completionTokens == 0 && totalTokens == 0 {
|
|
return nil, false
|
|
}
|
|
|
|
return &schemas.BifrostLLMUsage{
|
|
PromptTokens: promptTokens,
|
|
CompletionTokens: completionTokens,
|
|
TotalTokens: totalTokens,
|
|
}, true
|
|
}
|