first commit
This commit is contained in:
290
core/providers/vertex/rerank.go
Normal file
290
core/providers/vertex/rerank.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func buildVertexRankingConfig(projectID, rankingConfigOverride string) (string, error) {
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
return "", fmt.Errorf("project ID is required for ranking config")
|
||||
}
|
||||
|
||||
override := strings.TrimSpace(rankingConfigOverride)
|
||||
if override == "" {
|
||||
return fmt.Sprintf("projects/%s/locations/global/rankingConfigs/%s", projectID, vertexDefaultRankingConfigID), nil
|
||||
}
|
||||
|
||||
override = strings.TrimSuffix(override, ":rank")
|
||||
if strings.HasPrefix(override, "projects/") {
|
||||
return override, nil
|
||||
}
|
||||
if strings.Contains(override, "/") {
|
||||
return "", fmt.Errorf("invalid ranking_config %q: must be resource name or config ID", rankingConfigOverride)
|
||||
}
|
||||
return fmt.Sprintf("projects/%s/locations/global/rankingConfigs/%s", projectID, override), nil
|
||||
}
|
||||
|
||||
func getVertexRerankOptions(projectID string, params *schemas.RerankParameters) (*vertexRerankOptions, error) {
|
||||
options := &vertexRerankOptions{
|
||||
IgnoreRecordDetailsInResponse: true,
|
||||
}
|
||||
|
||||
if params == nil || params.ExtraParams == nil {
|
||||
rankingConfig, err := buildVertexRankingConfig(projectID, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
options.RankingConfig = rankingConfig
|
||||
return options, nil
|
||||
}
|
||||
|
||||
extraParams := params.ExtraParams
|
||||
|
||||
rankingConfigOverride := ""
|
||||
if rawRankingConfig, exists := extraParams["ranking_config"]; exists {
|
||||
rankingConfig, ok := schemas.SafeExtractString(rawRankingConfig)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid ranking_config: expected string")
|
||||
}
|
||||
rankingConfigOverride = rankingConfig
|
||||
}
|
||||
|
||||
rankingConfig, err := buildVertexRankingConfig(projectID, rankingConfigOverride)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
options.RankingConfig = rankingConfig
|
||||
|
||||
if rawIgnoreRecordDetails, exists := extraParams["ignore_record_details_in_response"]; exists {
|
||||
ignoreRecordDetailsInResponse, ok := schemas.SafeExtractBool(rawIgnoreRecordDetails)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid ignore_record_details_in_response: expected bool")
|
||||
}
|
||||
options.IgnoreRecordDetailsInResponse = ignoreRecordDetailsInResponse
|
||||
}
|
||||
|
||||
if rawUserLabels, exists := extraParams["user_labels"]; exists {
|
||||
userLabels, ok := schemas.SafeExtractStringMap(rawUserLabels)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid user_labels: expected map[string]string")
|
||||
}
|
||||
options.UserLabels = userLabels
|
||||
}
|
||||
|
||||
return options, nil
|
||||
}
|
||||
|
||||
// ToVertexRankRequest converts a Bifrost rerank request to Discovery Engine rank API format.
|
||||
func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, options *vertexRerankOptions) (*VertexRankRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrost rerank request is nil")
|
||||
}
|
||||
if options == nil {
|
||||
return nil, fmt.Errorf("vertex rerank options are nil")
|
||||
}
|
||||
if len(bifrostReq.Documents) == 0 {
|
||||
return nil, fmt.Errorf("documents are required for rerank request")
|
||||
}
|
||||
if len(bifrostReq.Documents) > vertexMaxRerankRecordsPerQuery {
|
||||
return nil, fmt.Errorf("vertex rerank supports up to %d records per request", vertexMaxRerankRecordsPerQuery)
|
||||
}
|
||||
|
||||
rankRequest := &VertexRankRequest{
|
||||
Query: bifrostReq.Query,
|
||||
Records: make([]VertexRankRecord, len(bifrostReq.Documents)),
|
||||
}
|
||||
|
||||
for i, doc := range bifrostReq.Documents {
|
||||
recordID := fmt.Sprintf("%s%d", vertexSyntheticRecordPrefix, i)
|
||||
content := doc.Text
|
||||
record := VertexRankRecord{
|
||||
ID: recordID,
|
||||
Content: &content,
|
||||
}
|
||||
|
||||
if doc.Meta != nil {
|
||||
if rawTitle, exists := doc.Meta["title"]; exists {
|
||||
if title, ok := schemas.SafeExtractString(rawTitle); ok && strings.TrimSpace(title) != "" {
|
||||
record.Title = &title
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rankRequest.Records[i] = record
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.TopN != nil {
|
||||
topN := *bifrostReq.Params.TopN
|
||||
if topN < 1 {
|
||||
return nil, fmt.Errorf("top_n must be at least 1")
|
||||
}
|
||||
if topN > len(bifrostReq.Documents) {
|
||||
topN = len(bifrostReq.Documents)
|
||||
}
|
||||
rankRequest.TopN = &topN
|
||||
}
|
||||
|
||||
trimmedModel := strings.TrimSpace(bifrostReq.Model)
|
||||
if trimmedModel == "" {
|
||||
trimmedModel = vertexDefaultRerankModel
|
||||
}
|
||||
rankRequest.Model = &trimmedModel
|
||||
|
||||
ignoreRecordDetailsInResponse := options.IgnoreRecordDetailsInResponse
|
||||
rankRequest.IgnoreRecordDetailsInResponse = &ignoreRecordDetailsInResponse
|
||||
|
||||
if len(options.UserLabels) > 0 {
|
||||
rankRequest.UserLabels = options.UserLabels
|
||||
}
|
||||
|
||||
return rankRequest, nil
|
||||
}
|
||||
|
||||
// ToBifrostRerankRequest converts a Discovery Engine rank request to Bifrost format.
|
||||
func (req *VertexRankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var provider schemas.ModelProvider
|
||||
var model string
|
||||
if req.Model != nil {
|
||||
provider, model = schemas.ParseModelString(*req.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Vertex))
|
||||
} else {
|
||||
provider = providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Vertex)
|
||||
}
|
||||
|
||||
bifrostReq := &schemas.BifrostRerankRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Query: req.Query,
|
||||
Params: &schemas.RerankParameters{},
|
||||
}
|
||||
|
||||
// Convert records to documents
|
||||
for _, record := range req.Records {
|
||||
doc := schemas.RerankDocument{
|
||||
ID: &record.ID,
|
||||
}
|
||||
if record.Content != nil {
|
||||
doc.Text = *record.Content
|
||||
}
|
||||
if record.Title != nil {
|
||||
doc.Meta = map[string]interface{}{
|
||||
"title": *record.Title,
|
||||
}
|
||||
}
|
||||
bifrostReq.Documents = append(bifrostReq.Documents, doc)
|
||||
}
|
||||
|
||||
// Extract TopN
|
||||
if req.TopN != nil {
|
||||
bifrostReq.Params.TopN = req.TopN
|
||||
}
|
||||
|
||||
// Pass extra fields as ExtraParams
|
||||
extraParams := make(map[string]interface{})
|
||||
if req.IgnoreRecordDetailsInResponse != nil {
|
||||
extraParams["ignore_record_details_in_response"] = *req.IgnoreRecordDetailsInResponse
|
||||
}
|
||||
if len(req.UserLabels) > 0 {
|
||||
extraParams["user_labels"] = req.UserLabels
|
||||
}
|
||||
if len(extraParams) > 0 {
|
||||
bifrostReq.Params.ExtraParams = extraParams
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
func parseVertexSyntheticRecordIndex(recordID string, maxDocs int) (int, error) {
|
||||
if !strings.HasPrefix(recordID, vertexSyntheticRecordPrefix) {
|
||||
return 0, fmt.Errorf("invalid record id %q: expected prefix %q", recordID, vertexSyntheticRecordPrefix)
|
||||
}
|
||||
indexStr := strings.TrimPrefix(recordID, vertexSyntheticRecordPrefix)
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid record id %q: %w", recordID, err)
|
||||
}
|
||||
if index < 0 || index >= maxDocs {
|
||||
return 0, fmt.Errorf("record id %q maps to out-of-range index %d", recordID, index)
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// ToBifrostRerankResponse converts a Discovery Engine rank response to Bifrost format.
|
||||
func (response *VertexRankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) (*schemas.BifrostRerankResponse, error) {
|
||||
if response == nil {
|
||||
return nil, fmt.Errorf("vertex rerank response is nil")
|
||||
}
|
||||
|
||||
results := make([]schemas.RerankResult, 0, len(response.Records))
|
||||
seenIndices := make(map[int]struct{}, len(response.Records))
|
||||
|
||||
for _, record := range response.Records {
|
||||
index, err := parseVertexSyntheticRecordIndex(record.ID, len(documents))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, seen := seenIndices[index]; seen {
|
||||
return nil, fmt.Errorf("duplicate record id mapping for index %d", index)
|
||||
}
|
||||
seenIndices[index] = struct{}{}
|
||||
|
||||
result := schemas.RerankResult{
|
||||
Index: index,
|
||||
RelevanceScore: record.Score,
|
||||
}
|
||||
|
||||
if returnDocuments {
|
||||
doc := documents[index]
|
||||
result.Document = &doc
|
||||
}
|
||||
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
sort.SliceStable(results, func(i, j int) bool {
|
||||
if results[i].RelevanceScore == results[j].RelevanceScore {
|
||||
return results[i].Index < results[j].Index
|
||||
}
|
||||
return results[i].RelevanceScore > results[j].RelevanceScore
|
||||
})
|
||||
|
||||
return &schemas.BifrostRerankResponse{
|
||||
Results: results,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseDiscoveryEngineErrorMessage(responseBody []byte) string {
|
||||
if len(responseBody) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var errorResponse map[string]interface{}
|
||||
if err := sonic.Unmarshal(responseBody, &errorResponse); err == nil {
|
||||
if rawError, exists := errorResponse["error"]; exists {
|
||||
if errorMap, ok := rawError.(map[string]interface{}); ok {
|
||||
if message, ok := schemas.SafeExtractString(errorMap["message"]); ok && strings.TrimSpace(message) != "" {
|
||||
return message
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rawString := strings.TrimSpace(string(responseBody))
|
||||
if rawString == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return rawString
|
||||
}
|
||||
Reference in New Issue
Block a user