Files
bifrost/core/providers/vertex/rerank_test.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

206 lines
6.4 KiB
Go

package vertex
import (
"context"
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBuildVertexRankingConfig(t *testing.T) {
t.Parallel()
config, err := buildVertexRankingConfig("demo-project", "")
require.NoError(t, err)
assert.Equal(t, "projects/demo-project/locations/global/rankingConfigs/default_ranking_config", config)
config, err = buildVertexRankingConfig("demo-project", "custom_rank")
require.NoError(t, err)
assert.Equal(t, "projects/demo-project/locations/global/rankingConfigs/custom_rank", config)
config, err = buildVertexRankingConfig("demo-project", "projects/other/locations/global/rankingConfigs/custom_rank:rank")
require.NoError(t, err)
assert.Equal(t, "projects/other/locations/global/rankingConfigs/custom_rank", config)
_, err = buildVertexRankingConfig("demo-project", "locations/global/rankingConfigs/custom_rank")
require.Error(t, err)
}
func TestToVertexRankRequest(t *testing.T) {
t.Parallel()
req, err := ToVertexRankRequest(
&schemas.BifrostRerankRequest{
Query: "capital of france",
Documents: []schemas.RerankDocument{
{Text: "Paris is the capital of France.", Meta: map[string]interface{}{"title": "Doc A"}},
{Text: "Berlin is the capital of Germany."},
},
Params: &schemas.RerankParameters{
TopN: schemas.Ptr(10),
},
},
&vertexRerankOptions{
RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config",
IgnoreRecordDetailsInResponse: true,
},
)
require.NoError(t, err)
require.NotNil(t, req)
require.NotNil(t, req.Model)
assert.Equal(t, "semantic-ranker-default@latest", *req.Model)
require.Len(t, req.Records, 2)
assert.Equal(t, "idx:0", req.Records[0].ID)
assert.Equal(t, "idx:1", req.Records[1].ID)
require.NotNil(t, req.Records[0].Title)
assert.Equal(t, "Doc A", *req.Records[0].Title)
require.NotNil(t, req.TopN)
assert.Equal(t, 2, *req.TopN, "topN should be clamped to document count")
require.NotNil(t, req.IgnoreRecordDetailsInResponse)
assert.True(t, *req.IgnoreRecordDetailsInResponse)
}
func TestToVertexRankRequestTooManyRecords(t *testing.T) {
t.Parallel()
docs := make([]schemas.RerankDocument, 201)
for i := range docs {
docs[i] = schemas.RerankDocument{Text: "doc"}
}
_, err := ToVertexRankRequest(
&schemas.BifrostRerankRequest{
Query: "q",
Documents: docs,
},
&vertexRerankOptions{
RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config",
IgnoreRecordDetailsInResponse: true,
},
)
require.Error(t, err)
assert.Contains(t, err.Error(), "supports up to")
}
func TestGetVertexRerankOptions(t *testing.T) {
t.Parallel()
options, err := getVertexRerankOptions("project-x", &schemas.RerankParameters{
ExtraParams: map[string]interface{}{
"ranking_config": "custom_rank",
"ignore_record_details_in_response": false,
"user_labels": map[string]interface{}{
"env": "test",
},
},
})
require.NoError(t, err)
assert.Equal(t, "projects/project-x/locations/global/rankingConfigs/custom_rank", options.RankingConfig)
assert.False(t, options.IgnoreRecordDetailsInResponse)
assert.Equal(t, map[string]string{"env": "test"}, options.UserLabels)
}
func TestVertexRankResponseToBifrostRerankResponse(t *testing.T) {
t.Parallel()
docs := []schemas.RerankDocument{
{Text: "doc-0"},
{Text: "doc-1"},
{Text: "doc-2"},
}
response, err := (&VertexRankResponse{
Records: []VertexRankedRecord{
{ID: "idx:2", Score: 0.12},
{ID: "idx:1", Score: 0.91},
{ID: "idx:0", Score: 0.91},
},
}).ToBifrostRerankResponse(docs, true)
require.NoError(t, err)
require.NotNil(t, response)
require.Len(t, response.Results, 3)
assert.Equal(t, 0, response.Results[0].Index)
assert.Equal(t, 1, response.Results[1].Index)
assert.Equal(t, 2, response.Results[2].Index)
require.NotNil(t, response.Results[0].Document)
assert.Equal(t, "doc-0", response.Results[0].Document.Text)
}
func TestVertexRankRequestToBifrostRerankRequest(t *testing.T) {
t.Parallel()
topN := 5
model := "semantic-ranker-default@latest"
ignoreDetails := true
title := "Doc A"
content1 := "Paris is the capital of France."
content2 := "Berlin is the capital of Germany."
req := &VertexRankRequest{
Model: &model,
Query: "capital of france",
Records: []VertexRankRecord{
{ID: "rec-1", Content: &content1, Title: &title},
{ID: "rec-2", Content: &content2},
},
TopN: &topN,
IgnoreRecordDetailsInResponse: &ignoreDetails,
UserLabels: map[string]string{"env": "test"},
}
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
result := req.ToBifrostRerankRequest(bifrostCtx)
require.NotNil(t, result)
assert.Equal(t, schemas.Vertex, result.Provider)
assert.Equal(t, "semantic-ranker-default@latest", result.Model)
assert.Equal(t, "capital of france", result.Query)
require.Len(t, result.Documents, 2)
// First document has ID, content, and title in meta
require.NotNil(t, result.Documents[0].ID)
assert.Equal(t, "rec-1", *result.Documents[0].ID)
assert.Equal(t, "Paris is the capital of France.", result.Documents[0].Text)
require.NotNil(t, result.Documents[0].Meta)
assert.Equal(t, "Doc A", result.Documents[0].Meta["title"])
// Second document has no title
require.NotNil(t, result.Documents[1].ID)
assert.Equal(t, "rec-2", *result.Documents[1].ID)
assert.Equal(t, "Berlin is the capital of Germany.", result.Documents[1].Text)
assert.Nil(t, result.Documents[1].Meta)
// TopN
require.NotNil(t, result.Params)
require.NotNil(t, result.Params.TopN)
assert.Equal(t, 5, *result.Params.TopN)
// ExtraParams
require.NotNil(t, result.Params.ExtraParams)
assert.Equal(t, true, result.Params.ExtraParams["ignore_record_details_in_response"])
assert.Equal(t, map[string]string{"env": "test"}, result.Params.ExtraParams["user_labels"])
}
func TestVertexRankRequestToBifrostRerankRequestNil(t *testing.T) {
t.Parallel()
var req *VertexRankRequest
assert.Nil(t, req.ToBifrostRerankRequest(nil))
}
func TestVertexRankResponseToBifrostRerankResponseInvalidID(t *testing.T) {
t.Parallel()
_, err := (&VertexRankResponse{
Records: []VertexRankedRecord{
{ID: "bad-id", Score: 0.9},
},
}).ToBifrostRerankResponse([]schemas.RerankDocument{{Text: "doc"}}, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid record id")
}