first commit
This commit is contained in:
230
core/providers/bedrock/rerank_test.go
Normal file
230
core/providers/bedrock/rerank_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package bedrock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToBedrockRerankRequest(t *testing.T) {
|
||||
topN := 10
|
||||
maxTokensPerDoc := 512
|
||||
priority := 3
|
||||
|
||||
req, err := ToBedrockRerankRequest(&schemas.BifrostRerankRequest{
|
||||
Model: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
||||
Query: "capital of france",
|
||||
Documents: []schemas.RerankDocument{
|
||||
{Text: "Paris is the capital of France."},
|
||||
{Text: "Berlin is the capital of Germany."},
|
||||
},
|
||||
Params: &schemas.RerankParameters{
|
||||
TopN: schemas.Ptr(topN),
|
||||
MaxTokensPerDoc: schemas.Ptr(maxTokensPerDoc),
|
||||
Priority: schemas.Ptr(priority),
|
||||
ExtraParams: map[string]interface{}{
|
||||
"truncate": "END",
|
||||
},
|
||||
},
|
||||
}, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req)
|
||||
|
||||
require.Len(t, req.Queries, 1)
|
||||
assert.Equal(t, "TEXT", req.Queries[0].Type)
|
||||
assert.Equal(t, "capital of france", req.Queries[0].TextQuery.Text)
|
||||
require.Len(t, req.Sources, 2)
|
||||
|
||||
require.NotNil(t, req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults)
|
||||
assert.Equal(t, 2, *req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults, "top_n must be clamped to source count")
|
||||
|
||||
fields := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields
|
||||
require.NotNil(t, fields)
|
||||
assert.Equal(t, maxTokensPerDoc, fields["max_tokens_per_doc"])
|
||||
assert.Equal(t, priority, fields["priority"])
|
||||
assert.Equal(t, "END", fields["truncate"])
|
||||
}
|
||||
|
||||
func TestBedrockRerankResponseToBifrostRerankResponse(t *testing.T) {
|
||||
response := (&BedrockRerankResponse{
|
||||
Results: []BedrockRerankResult{
|
||||
{
|
||||
Index: 2,
|
||||
RelevanceScore: 0.21,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "doc-2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 1,
|
||||
RelevanceScore: 0.95,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "doc-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 0,
|
||||
RelevanceScore: 0.95,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "doc-0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostRerankResponse(nil, false)
|
||||
|
||||
require.NotNil(t, response)
|
||||
require.Len(t, response.Results, 3)
|
||||
|
||||
assert.Equal(t, 0, response.Results[0].Index)
|
||||
assert.Equal(t, 1, response.Results[1].Index)
|
||||
assert.Equal(t, 2, response.Results[2].Index)
|
||||
assert.Equal(t, "doc-0", response.Results[0].Document.Text)
|
||||
assert.Equal(t, "doc-1", response.Results[1].Document.Text)
|
||||
}
|
||||
|
||||
func TestBedrockRerankResponseToBifrostRerankResponseReturnDocuments(t *testing.T) {
|
||||
requestDocs := []schemas.RerankDocument{
|
||||
{Text: "request-doc-0"},
|
||||
{Text: "request-doc-1"},
|
||||
{Text: "request-doc-2"},
|
||||
}
|
||||
|
||||
response := (&BedrockRerankResponse{
|
||||
Results: []BedrockRerankResult{
|
||||
{
|
||||
Index: 2,
|
||||
RelevanceScore: 0.21,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 1,
|
||||
RelevanceScore: 0.95,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 0,
|
||||
RelevanceScore: 0.95,
|
||||
Document: &BedrockRerankResponseDocument{
|
||||
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}).ToBifrostRerankResponse(requestDocs, true)
|
||||
|
||||
require.NotNil(t, response)
|
||||
require.Len(t, response.Results, 3)
|
||||
require.NotNil(t, response.Results[0].Document)
|
||||
require.NotNil(t, response.Results[1].Document)
|
||||
require.NotNil(t, response.Results[2].Document)
|
||||
|
||||
assert.Equal(t, 0, response.Results[0].Index)
|
||||
assert.Equal(t, 1, response.Results[1].Index)
|
||||
assert.Equal(t, 2, response.Results[2].Index)
|
||||
assert.Equal(t, "request-doc-0", response.Results[0].Document.Text)
|
||||
assert.Equal(t, "request-doc-1", response.Results[1].Document.Text)
|
||||
assert.Equal(t, "request-doc-2", response.Results[2].Document.Text)
|
||||
}
|
||||
|
||||
func TestBedrockRerankRequestToBifrostRerankRequest(t *testing.T) {
|
||||
topN := 3
|
||||
bedrockReq := &BedrockRerankRequest{
|
||||
Queries: []BedrockRerankQuery{
|
||||
{
|
||||
Type: bedrockRerankQueryTypeText,
|
||||
TextQuery: BedrockRerankTextRef{Text: "capital of france"},
|
||||
},
|
||||
},
|
||||
Sources: []BedrockRerankSource{
|
||||
{
|
||||
Type: bedrockRerankSourceTypeInline,
|
||||
InlineDocumentSource: BedrockRerankInlineSource{
|
||||
Type: bedrockRerankInlineDocumentTypeText,
|
||||
TextDocument: BedrockRerankTextValue{Text: "Paris is the capital of France."},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: bedrockRerankSourceTypeInline,
|
||||
InlineDocumentSource: BedrockRerankInlineSource{
|
||||
Type: bedrockRerankInlineDocumentTypeText,
|
||||
TextDocument: BedrockRerankTextValue{Text: "Berlin is the capital of Germany."},
|
||||
},
|
||||
},
|
||||
},
|
||||
RerankingConfiguration: BedrockRerankingConfiguration{
|
||||
Type: bedrockRerankConfigurationTypeBedrock,
|
||||
BedrockRerankingConfiguration: BedrockRerankingModelConfiguration{
|
||||
NumberOfResults: &topN,
|
||||
ModelConfiguration: BedrockRerankModelConfiguration{
|
||||
ModelARN: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
||||
AdditionalModelRequestFields: map[string]interface{}{
|
||||
"truncate": "END",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
result := bedrockReq.ToBifrostRerankRequest(bifrostCtx)
|
||||
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, schemas.Bedrock, result.Provider)
|
||||
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", result.Model)
|
||||
assert.Equal(t, "capital of france", result.Query)
|
||||
require.Len(t, result.Documents, 2)
|
||||
assert.Equal(t, "Paris is the capital of France.", result.Documents[0].Text)
|
||||
assert.Equal(t, "Berlin is the capital of Germany.", result.Documents[1].Text)
|
||||
require.NotNil(t, result.Params)
|
||||
require.NotNil(t, result.Params.TopN)
|
||||
assert.Equal(t, 3, *result.Params.TopN)
|
||||
require.NotNil(t, result.Params.ExtraParams)
|
||||
assert.Equal(t, "END", result.Params.ExtraParams["truncate"])
|
||||
}
|
||||
|
||||
func TestBedrockRerankRequestToBifrostRerankRequestNil(t *testing.T) {
|
||||
var req *BedrockRerankRequest
|
||||
assert.Nil(t, req.ToBifrostRerankRequest(nil))
|
||||
}
|
||||
|
||||
func TestResolveBedrockDeployment(t *testing.T) {
|
||||
key := schemas.Key{
|
||||
Aliases: schemas.KeyAliases{
|
||||
"cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
||||
},
|
||||
}
|
||||
|
||||
deployment := key.Aliases.Resolve("cohere-rerank")
|
||||
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", deployment)
|
||||
assert.Equal(t, "cohere.rerank-v3-5:0", key.Aliases.Resolve("cohere.rerank-v3-5:0"))
|
||||
assert.Equal(t, "", key.Aliases.Resolve(""))
|
||||
}
|
||||
|
||||
func TestBedrockRerankRequiresARNModelIdentifier(t *testing.T) {
|
||||
provider := &BedrockProvider{}
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
key := schemas.Key{
|
||||
Aliases: schemas.KeyAliases{
|
||||
"cohere-rerank": "cohere.rerank-v3-5:0",
|
||||
},
|
||||
}
|
||||
|
||||
response, bifrostErr := provider.Rerank(ctx, key, &schemas.BifrostRerankRequest{
|
||||
Model: "cohere-rerank",
|
||||
Query: "capital of france",
|
||||
Documents: []schemas.RerankDocument{
|
||||
{Text: "Paris is the capital of France."},
|
||||
},
|
||||
})
|
||||
|
||||
require.Nil(t, response)
|
||||
require.NotNil(t, bifrostErr)
|
||||
require.NotNil(t, bifrostErr.Error)
|
||||
assert.Contains(t, bifrostErr.Error.Message, "requires an ARN")
|
||||
}
|
||||
Reference in New Issue
Block a user