231 lines
7.3 KiB
Go
231 lines
7.3 KiB
Go
package bedrock
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestToBedrockRerankRequest(t *testing.T) {
|
|
topN := 10
|
|
maxTokensPerDoc := 512
|
|
priority := 3
|
|
|
|
req, err := ToBedrockRerankRequest(&schemas.BifrostRerankRequest{
|
|
Model: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
|
Query: "capital of france",
|
|
Documents: []schemas.RerankDocument{
|
|
{Text: "Paris is the capital of France."},
|
|
{Text: "Berlin is the capital of Germany."},
|
|
},
|
|
Params: &schemas.RerankParameters{
|
|
TopN: schemas.Ptr(topN),
|
|
MaxTokensPerDoc: schemas.Ptr(maxTokensPerDoc),
|
|
Priority: schemas.Ptr(priority),
|
|
ExtraParams: map[string]interface{}{
|
|
"truncate": "END",
|
|
},
|
|
},
|
|
}, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, req)
|
|
|
|
require.Len(t, req.Queries, 1)
|
|
assert.Equal(t, "TEXT", req.Queries[0].Type)
|
|
assert.Equal(t, "capital of france", req.Queries[0].TextQuery.Text)
|
|
require.Len(t, req.Sources, 2)
|
|
|
|
require.NotNil(t, req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults)
|
|
assert.Equal(t, 2, *req.RerankingConfiguration.BedrockRerankingConfiguration.NumberOfResults, "top_n must be clamped to source count")
|
|
|
|
fields := req.RerankingConfiguration.BedrockRerankingConfiguration.ModelConfiguration.AdditionalModelRequestFields
|
|
require.NotNil(t, fields)
|
|
assert.Equal(t, maxTokensPerDoc, fields["max_tokens_per_doc"])
|
|
assert.Equal(t, priority, fields["priority"])
|
|
assert.Equal(t, "END", fields["truncate"])
|
|
}
|
|
|
|
func TestBedrockRerankResponseToBifrostRerankResponse(t *testing.T) {
|
|
response := (&BedrockRerankResponse{
|
|
Results: []BedrockRerankResult{
|
|
{
|
|
Index: 2,
|
|
RelevanceScore: 0.21,
|
|
Document: &BedrockRerankResponseDocument{
|
|
TextDocument: &BedrockRerankTextValue{Text: "doc-2"},
|
|
},
|
|
},
|
|
{
|
|
Index: 1,
|
|
RelevanceScore: 0.95,
|
|
Document: &BedrockRerankResponseDocument{
|
|
TextDocument: &BedrockRerankTextValue{Text: "doc-1"},
|
|
},
|
|
},
|
|
{
|
|
Index: 0,
|
|
RelevanceScore: 0.95,
|
|
Document: &BedrockRerankResponseDocument{
|
|
TextDocument: &BedrockRerankTextValue{Text: "doc-0"},
|
|
},
|
|
},
|
|
},
|
|
}).ToBifrostRerankResponse(nil, false)
|
|
|
|
require.NotNil(t, response)
|
|
require.Len(t, response.Results, 3)
|
|
|
|
assert.Equal(t, 0, response.Results[0].Index)
|
|
assert.Equal(t, 1, response.Results[1].Index)
|
|
assert.Equal(t, 2, response.Results[2].Index)
|
|
assert.Equal(t, "doc-0", response.Results[0].Document.Text)
|
|
assert.Equal(t, "doc-1", response.Results[1].Document.Text)
|
|
}
|
|
|
|
func TestBedrockRerankResponseToBifrostRerankResponseReturnDocuments(t *testing.T) {
|
|
requestDocs := []schemas.RerankDocument{
|
|
{Text: "request-doc-0"},
|
|
{Text: "request-doc-1"},
|
|
{Text: "request-doc-2"},
|
|
}
|
|
|
|
response := (&BedrockRerankResponse{
|
|
Results: []BedrockRerankResult{
|
|
{
|
|
Index: 2,
|
|
RelevanceScore: 0.21,
|
|
Document: &BedrockRerankResponseDocument{
|
|
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-2"},
|
|
},
|
|
},
|
|
{
|
|
Index: 1,
|
|
RelevanceScore: 0.95,
|
|
Document: &BedrockRerankResponseDocument{
|
|
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-1"},
|
|
},
|
|
},
|
|
{
|
|
Index: 0,
|
|
RelevanceScore: 0.95,
|
|
Document: &BedrockRerankResponseDocument{
|
|
TextDocument: &BedrockRerankTextValue{Text: "provider-doc-0"},
|
|
},
|
|
},
|
|
},
|
|
}).ToBifrostRerankResponse(requestDocs, true)
|
|
|
|
require.NotNil(t, response)
|
|
require.Len(t, response.Results, 3)
|
|
require.NotNil(t, response.Results[0].Document)
|
|
require.NotNil(t, response.Results[1].Document)
|
|
require.NotNil(t, response.Results[2].Document)
|
|
|
|
assert.Equal(t, 0, response.Results[0].Index)
|
|
assert.Equal(t, 1, response.Results[1].Index)
|
|
assert.Equal(t, 2, response.Results[2].Index)
|
|
assert.Equal(t, "request-doc-0", response.Results[0].Document.Text)
|
|
assert.Equal(t, "request-doc-1", response.Results[1].Document.Text)
|
|
assert.Equal(t, "request-doc-2", response.Results[2].Document.Text)
|
|
}
|
|
|
|
func TestBedrockRerankRequestToBifrostRerankRequest(t *testing.T) {
|
|
topN := 3
|
|
bedrockReq := &BedrockRerankRequest{
|
|
Queries: []BedrockRerankQuery{
|
|
{
|
|
Type: bedrockRerankQueryTypeText,
|
|
TextQuery: BedrockRerankTextRef{Text: "capital of france"},
|
|
},
|
|
},
|
|
Sources: []BedrockRerankSource{
|
|
{
|
|
Type: bedrockRerankSourceTypeInline,
|
|
InlineDocumentSource: BedrockRerankInlineSource{
|
|
Type: bedrockRerankInlineDocumentTypeText,
|
|
TextDocument: BedrockRerankTextValue{Text: "Paris is the capital of France."},
|
|
},
|
|
},
|
|
{
|
|
Type: bedrockRerankSourceTypeInline,
|
|
InlineDocumentSource: BedrockRerankInlineSource{
|
|
Type: bedrockRerankInlineDocumentTypeText,
|
|
TextDocument: BedrockRerankTextValue{Text: "Berlin is the capital of Germany."},
|
|
},
|
|
},
|
|
},
|
|
RerankingConfiguration: BedrockRerankingConfiguration{
|
|
Type: bedrockRerankConfigurationTypeBedrock,
|
|
BedrockRerankingConfiguration: BedrockRerankingModelConfiguration{
|
|
NumberOfResults: &topN,
|
|
ModelConfiguration: BedrockRerankModelConfiguration{
|
|
ModelARN: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
|
AdditionalModelRequestFields: map[string]interface{}{
|
|
"truncate": "END",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
result := bedrockReq.ToBifrostRerankRequest(bifrostCtx)
|
|
|
|
require.NotNil(t, result)
|
|
assert.Equal(t, schemas.Bedrock, result.Provider)
|
|
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", result.Model)
|
|
assert.Equal(t, "capital of france", result.Query)
|
|
require.Len(t, result.Documents, 2)
|
|
assert.Equal(t, "Paris is the capital of France.", result.Documents[0].Text)
|
|
assert.Equal(t, "Berlin is the capital of Germany.", result.Documents[1].Text)
|
|
require.NotNil(t, result.Params)
|
|
require.NotNil(t, result.Params.TopN)
|
|
assert.Equal(t, 3, *result.Params.TopN)
|
|
require.NotNil(t, result.Params.ExtraParams)
|
|
assert.Equal(t, "END", result.Params.ExtraParams["truncate"])
|
|
}
|
|
|
|
func TestBedrockRerankRequestToBifrostRerankRequestNil(t *testing.T) {
|
|
var req *BedrockRerankRequest
|
|
assert.Nil(t, req.ToBifrostRerankRequest(nil))
|
|
}
|
|
|
|
func TestResolveBedrockDeployment(t *testing.T) {
|
|
key := schemas.Key{
|
|
Aliases: schemas.KeyAliases{
|
|
"cohere-rerank": "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
|
},
|
|
}
|
|
|
|
deployment := key.Aliases.Resolve("cohere-rerank")
|
|
assert.Equal(t, "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0", deployment)
|
|
assert.Equal(t, "cohere.rerank-v3-5:0", key.Aliases.Resolve("cohere.rerank-v3-5:0"))
|
|
assert.Equal(t, "", key.Aliases.Resolve(""))
|
|
}
|
|
|
|
func TestBedrockRerankRequiresARNModelIdentifier(t *testing.T) {
|
|
provider := &BedrockProvider{}
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
key := schemas.Key{
|
|
Aliases: schemas.KeyAliases{
|
|
"cohere-rerank": "cohere.rerank-v3-5:0",
|
|
},
|
|
}
|
|
|
|
response, bifrostErr := provider.Rerank(ctx, key, &schemas.BifrostRerankRequest{
|
|
Model: "cohere-rerank",
|
|
Query: "capital of france",
|
|
Documents: []schemas.RerankDocument{
|
|
{Text: "Paris is the capital of France."},
|
|
},
|
|
})
|
|
|
|
require.Nil(t, response)
|
|
require.NotNil(t, bifrostErr)
|
|
require.NotNil(t, bifrostErr.Error)
|
|
assert.Contains(t, bifrostErr.Error.Message, "requires an ARN")
|
|
}
|