Files
bifrost/core/providers/bedrock/rerank_test.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

231 lines
7.3 KiB
Go

package bedrock
import (
"context"
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestToBedrockRerankRequest(t *testing.T) {
topN := 10
maxTokensPerDoc := 512
priority := 3
req, err := ToBedrockRerankRequest(&schemas.BifrostRerankRequest{
Model: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
Query: "capital of france",
Documents: []schemas.RerankDocument{
{Text: "Paris is the capital of France."},
{Text: "Berlin is the capital of Germany."},
},
Params: &schemas.RerankParameters{
TopN: schemas.Ptr(topN),
MaxTokensPerDoc: schemas.Ptr(maxTokensPerDoc),
Priority: schemas.Ptr(priority),
ExtraParams: map[string]interface{}{
"truncate": "END",
},
},
}, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0")
require.NoError(t, err)
require.NotNil(t, req)
require.Len(t, req.Queries, 1)
assert.Equal(t, "TEXT", req.Queries[0].Type)
assert.Equal(t, "capital of france", req.Queries[0].TextQuery.Text)
require.Len(t, req.Sources, 2)
require.NotNil(t, req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults)
assert.Equal(t, 2, *req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults, "top_n must be clamped to source count")
fields := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields
require.NotNil(t, fields)
assert.Equal(t, maxTokensPerDoc, fields["max_tokens_per_doc"])
assert.Equal(t, priority, fields["priority"])
assert.Equal(t, "END", fields["truncate"])
}
func TestBedrockRerankResponseToBifrostRerankResponse(t *testing.T) {
response := (&BedrockRerankResponse{
Results: []BedrockRerankResult{
{
Index: 2,
RelevanceScore: 0.21,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "doc-2"},
},
},
{
Index: 1,
RelevanceScore: 0.95,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "doc-1"},
},
},
{
Index: 0,
RelevanceScore: 0.95,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "doc-0"},
},
},
},
}).ToBifrostRerankResponse(nil, false)
require.NotNil(t, response)
require.Len(t, response.Results, 3)
assert.Equal(t, 0, response.Results[0].Index)
assert.Equal(t, 1, response.Results[1].Index)
assert.Equal(t, 2, response.Results[2].Index)
assert.Equal(t, "doc-0", response.Results[0].Document.Text)
assert.Equal(t, "doc-1", response.Results[1].Document.Text)
}
func TestBedrockRerankResponseToBifrostRerankResponseReturnDocuments(t *testing.T) {
requestDocs := []schemas.RerankDocument{
{Text: "request-doc-0"},
{Text: "request-doc-1"},
{Text: "request-doc-2"},
}
response := (&BedrockRerankResponse{
Results: []BedrockRerankResult{
{
Index: 2,
RelevanceScore: 0.21,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-2"},
},
},
{
Index: 1,
RelevanceScore: 0.95,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-1"},
},
},
{
Index: 0,
RelevanceScore: 0.95,
Document: &BedrockRerankResponseDocument{
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-0"},
},
},
},
}).ToBifrostRerankResponse(requestDocs, true)
require.NotNil(t, response)
require.Len(t, response.Results, 3)
require.NotNil(t, response.Results[0].Document)
require.NotNil(t, response.Results[1].Document)
require.NotNil(t, response.Results[2].Document)
assert.Equal(t, 0, response.Results[0].Index)
assert.Equal(t, 1, response.Results[1].Index)
assert.Equal(t, 2, response.Results[2].Index)
assert.Equal(t, "request-doc-0", response.Results[0].Document.Text)
assert.Equal(t, "request-doc-1", response.Results[1].Document.Text)
assert.Equal(t, "request-doc-2", response.Results[2].Document.Text)
}
func TestBedrockRerankRequestToBifrostRerankRequest(t *testing.T) {
topN := 3
bedrockReq := &BedrockRerankRequest{
Queries: []BedrockRerankQuery{
{
Type: bedrockRerankQueryTypeText,
TextQuery: BedrockRerankTextRef{Text: "capital of france"},
},
},
Sources: []BedrockRerankSource{
{
Type: bedrockRerankSourceTypeInline,
InlineDocumentSource: BedrockRerankInlineSource{
Type: bedrockRerankInlineDocumentTypeText,
TextDocument: BedrockRerankTextValue{Text: "Paris is the capital of France."},
},
},
{
Type: bedrockRerankSourceTypeInline,
InlineDocumentSource: BedrockRerankInlineSource{
Type: bedrockRerankInlineDocumentTypeText,
TextDocument: BedrockRerankTextValue{Text: "Berlin is the capital of Germany."},
},
},
},
RerankingConfiguration: BedrockRerankingConfiguration{
Type: bedrockRerankConfigurationTypeBedrock,
BedrockRerankingConfiguration: BedrockRerankingModelConfiguration{
NumberOfResults: &topN,
ModelConfiguration: BedrockRerankModelConfiguration{
ModelARN: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
AdditionalModelRequestFields: map[string]interface{}{
"truncate": "END",
},
},
},
},
}
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
result := bedrockReq.ToBifrostRerankRequest(bifrostCtx)
require.NotNil(t, result)
assert.Equal(t, schemas.Bedrock, result.Provider)
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", result.Model)
assert.Equal(t, "capital of france", result.Query)
require.Len(t, result.Documents, 2)
assert.Equal(t, "Paris is the capital of France.", result.Documents[0].Text)
assert.Equal(t, "Berlin is the capital of Germany.", result.Documents[1].Text)
require.NotNil(t, result.Params)
require.NotNil(t, result.Params.TopN)
assert.Equal(t, 3, *result.Params.TopN)
require.NotNil(t, result.Params.ExtraParams)
assert.Equal(t, "END", result.Params.ExtraParams["truncate"])
}
func TestBedrockRerankRequestToBifrostRerankRequestNil(t *testing.T) {
var req *BedrockRerankRequest
assert.Nil(t, req.ToBifrostRerankRequest(nil))
}
func TestResolveBedrockDeployment(t *testing.T) {
key := schemas.Key{
Aliases: schemas.KeyAliases{
"cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
},
}
deployment := key.Aliases.Resolve("cohere-rerank")
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", deployment)
assert.Equal(t, "cohere.rerank-v3-5:0", key.Aliases.Resolve("cohere.rerank-v3-5:0"))
assert.Equal(t, "", key.Aliases.Resolve(""))
}
func TestBedrockRerankRequiresARNModelIdentifier(t *testing.T) {
provider := &BedrockProvider{}
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
key := schemas.Key{
Aliases: schemas.KeyAliases{
"cohere-rerank": "cohere.rerank-v3-5:0",
},
}
response, bifrostErr := provider.Rerank(ctx, key, &schemas.BifrostRerankRequest{
Model: "cohere-rerank",
Query: "capital of france",
Documents: []schemas.RerankDocument{
{Text: "Paris is the capital of France."},
},
})
require.Nil(t, response)
require.NotNil(t, bifrostErr)
require.NotNil(t, bifrostErr.Error)
assert.Contains(t, bifrostErr.Error.Message, "requires an ARN")
}