first commit
This commit is contained in:
1081
transports/bifrost-http/integrations/anthropic.go
Normal file
1081
transports/bifrost-http/integrations/anthropic.go
Normal file
File diff suppressed because it is too large
Load Diff
87
transports/bifrost-http/integrations/anthropic_test.go
Normal file
87
transports/bifrost-http/integrations/anthropic_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFilterVertexUnsupportedBetaHeaders(t *testing.T) {
|
||||
t.Run("filters known exact header values", func(t *testing.T) {
|
||||
headers := map[string][]string{
|
||||
"anthropic-beta": {"advanced-tool-use-2025-11-20,structured-outputs-2025-11-13,mcp-client-2025-04-04,prompt-caching-scope-2026-01-05"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
_, ok := result["anthropic-beta"]
|
||||
assert.False(t, ok, "all unsupported beta headers should be removed, leaving no anthropic-beta key")
|
||||
})
|
||||
|
||||
t.Run("filters bumped date variants", func(t *testing.T) {
|
||||
// Simulate Anthropic bumping version dates in the future
|
||||
headers := map[string][]string{
|
||||
"anthropic-beta": {"structured-outputs-2025-12-15,advanced-tool-use-2026-03-01,mcp-client-2026-01-01,prompt-caching-scope-2027-06-30"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
_, ok := result["anthropic-beta"]
|
||||
assert.False(t, ok, "bumped-date variants of unsupported headers should also be filtered")
|
||||
})
|
||||
|
||||
t.Run("passes through unrelated beta headers", func(t *testing.T) {
|
||||
headers := map[string][]string{
|
||||
"anthropic-beta": {"interleaved-thinking-2025-05-14,files-api-2025-04-14"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
vals, ok := result["anthropic-beta"]
|
||||
assert.True(t, ok, "unrelated beta headers should be preserved")
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14,files-api-2025-04-14"}, vals)
|
||||
})
|
||||
|
||||
t.Run("filters unsupported and keeps supported in mixed list", func(t *testing.T) {
|
||||
headers := map[string][]string{
|
||||
"anthropic-beta": {"interleaved-thinking-2025-05-14,structured-outputs-2025-11-13,files-api-2025-04-14,mcp-client-2025-04-04"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
vals, ok := result["anthropic-beta"]
|
||||
assert.True(t, ok, "supported beta headers should be preserved")
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14,files-api-2025-04-14"}, vals)
|
||||
})
|
||||
|
||||
t.Run("filters bumped unsupported mixed with supported", func(t *testing.T) {
|
||||
// Future-proof: bumped dates should still be filtered
|
||||
headers := map[string][]string{
|
||||
"anthropic-beta": {"structured-outputs-2026-01-01,interleaved-thinking-2025-05-14,advanced-tool-use-2026-06-15"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
vals, ok := result["anthropic-beta"]
|
||||
assert.True(t, ok, "supported beta headers should be preserved even when mixed with bumped unsupported ones")
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, vals)
|
||||
})
|
||||
|
||||
t.Run("returns headers unchanged when no anthropic-beta key present", func(t *testing.T) {
|
||||
headers := map[string][]string{
|
||||
"content-type": {"application/json"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
assert.Equal(t, headers, result)
|
||||
})
|
||||
|
||||
t.Run("handles empty anthropic-beta value gracefully", func(t *testing.T) {
|
||||
headers := map[string][]string{
|
||||
"anthropic-beta": {""},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
// Empty string after trimming is not an unsupported header, but it is also empty — key should be removed
|
||||
_, ok := result["anthropic-beta"]
|
||||
assert.False(t, ok, "empty beta header list should result in key removal")
|
||||
})
|
||||
|
||||
t.Run("case-insensitive key matching for Anthropic-Beta header", func(t *testing.T) {
|
||||
headers := map[string][]string{
|
||||
"Anthropic-Beta": {"structured-outputs-2025-11-13,interleaved-thinking-2025-05-14"},
|
||||
}
|
||||
result := filterVertexUnsupportedBetaHeaders(headers)
|
||||
vals, ok := result["Anthropic-Beta"]
|
||||
assert.True(t, ok, "header key casing should be preserved and matching should be case-insensitive")
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, vals)
|
||||
})
|
||||
}
|
||||
1315
transports/bifrost-http/integrations/bedrock.go
Normal file
1315
transports/bifrost-http/integrations/bedrock.go
Normal file
File diff suppressed because it is too large
Load Diff
921
transports/bifrost-http/integrations/bedrock_test.go
Normal file
921
transports/bifrost-http/integrations/bedrock_test.go
Normal file
@@ -0,0 +1,921 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/bedrock"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/kvstore"
|
||||
"github.com/maximhq/bifrost/framework/logstore"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// mockHandlerStore implements lib.HandlerStore for testing
|
||||
type mockHandlerStore struct {
|
||||
allowDirectKeys bool
|
||||
headerMatcher *lib.HeaderMatcher
|
||||
availableProviders []schemas.ModelProvider
|
||||
mcpHeaderCombinedAllowlist schemas.WhiteList
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) ShouldAllowDirectKeys() bool {
|
||||
return m.allowDirectKeys
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher {
|
||||
return m.headerMatcher
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetAvailableProviders() []schemas.ModelProvider {
|
||||
return m.availableProviders
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetAsyncJobResultTTL() int {
|
||||
return 3600
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetKVStore() *kvstore.Store {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList {
|
||||
return m.mcpHeaderCombinedAllowlist
|
||||
}
|
||||
|
||||
// Ensure mockHandlerStore implements lib.HandlerStore
|
||||
var _ lib.HandlerStore = (*mockHandlerStore)(nil)
|
||||
|
||||
func Test_parseS3URI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
wantBucket string
|
||||
wantKey string
|
||||
}{
|
||||
{
|
||||
name: "full S3 URI with key",
|
||||
uri: "s3://my-bucket/path/to/file.jsonl",
|
||||
wantBucket: "my-bucket",
|
||||
wantKey: "path/to/file.jsonl",
|
||||
},
|
||||
{
|
||||
name: "S3 URI with bucket only",
|
||||
uri: "s3://my-bucket/",
|
||||
wantBucket: "my-bucket",
|
||||
wantKey: "",
|
||||
},
|
||||
{
|
||||
name: "S3 URI with bucket no trailing slash",
|
||||
uri: "s3://my-bucket",
|
||||
wantBucket: "my-bucket",
|
||||
wantKey: "",
|
||||
},
|
||||
{
|
||||
name: "plain bucket name",
|
||||
uri: "my-bucket",
|
||||
wantBucket: "my-bucket",
|
||||
wantKey: "",
|
||||
},
|
||||
{
|
||||
name: "S3 URI with nested key",
|
||||
uri: "s3://bucket-name/folder1/folder2/file.txt",
|
||||
wantBucket: "bucket-name",
|
||||
wantKey: "folder1/folder2/file.txt",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
uri: "",
|
||||
wantBucket: "",
|
||||
wantKey: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotBucket, gotKey := parseS3URI(tt.uri)
|
||||
assert.Equal(t, tt.wantBucket, gotBucket, "bucket mismatch")
|
||||
assert.Equal(t, tt.wantKey, gotKey, "key mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_createBedrockRouteConfigs(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
routes := CreateBedrockRouteConfigs("/bedrock", handlerStore)
|
||||
|
||||
assert.Len(t, routes, 6, "should have 6 bedrock routes")
|
||||
|
||||
expectedRoutes := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/bedrock/model/{modelId}/converse", "POST"},
|
||||
{"/bedrock/model/{modelId}/converse-stream", "POST"},
|
||||
{"/bedrock/model/{modelId}/invoke-with-response-stream", "POST"},
|
||||
{"/bedrock/model/{modelId}/invoke", "POST"},
|
||||
{"/bedrock/rerank", "POST"},
|
||||
{"/bedrock/model/{modelId}/count-tokens", "POST"},
|
||||
}
|
||||
|
||||
for i, expected := range expectedRoutes {
|
||||
assert.Equal(t, expected.path, routes[i].Path, "route %d path mismatch", i)
|
||||
assert.Equal(t, expected.method, routes[i].Method, "route %d method mismatch", i)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, routes[i].Type, "route %d type mismatch", i)
|
||||
assert.NotNil(t, routes[i].GetRequestTypeInstance, "route %d GetRequestTypeInstance should not be nil", i)
|
||||
assert.NotNil(t, routes[i].ErrorConverter, "route %d ErrorConverter should not be nil", i)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_createBedrockConverseRouteConfig(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockConverseRouteConfig("/bedrock", handlerStore)
|
||||
|
||||
assert.Equal(t, "/bedrock/model/{modelId}/converse", route.Path)
|
||||
assert.Equal(t, "POST", route.Method)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
|
||||
assert.NotNil(t, route.GetRequestTypeInstance)
|
||||
assert.NotNil(t, route.RequestConverter)
|
||||
assert.NotNil(t, route.ResponsesResponseConverter)
|
||||
assert.NotNil(t, route.ErrorConverter)
|
||||
assert.NotNil(t, route.PreCallback)
|
||||
|
||||
// Verify request instance type
|
||||
reqInstance := route.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*bedrock.BedrockConverseRequest)
|
||||
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockConverseRequest")
|
||||
}
|
||||
|
||||
func Test_createBedrockConverseStreamRouteConfig(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockConverseStreamRouteConfig("/bedrock", handlerStore)
|
||||
|
||||
assert.Equal(t, "/bedrock/model/{modelId}/converse-stream", route.Path)
|
||||
assert.Equal(t, "POST", route.Method)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
|
||||
assert.NotNil(t, route.StreamConfig)
|
||||
assert.NotNil(t, route.StreamConfig.ResponsesStreamResponseConverter)
|
||||
|
||||
// Verify request instance type
|
||||
reqInstance := route.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*bedrock.BedrockConverseRequest)
|
||||
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockConverseRequest")
|
||||
}
|
||||
|
||||
func Test_createBedrockInvokeRouteConfig(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockInvokeRouteConfig("/bedrock", handlerStore)
|
||||
|
||||
assert.Equal(t, "/bedrock/model/{modelId}/invoke", route.Path)
|
||||
assert.Equal(t, "POST", route.Method)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
|
||||
assert.NotNil(t, route.TextResponseConverter)
|
||||
assert.NotNil(t, route.ResponsesResponseConverter)
|
||||
|
||||
// Verify request instance type
|
||||
reqInstance := route.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*bedrock.BedrockInvokeRequest)
|
||||
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockInvokeRequest")
|
||||
}
|
||||
|
||||
func Test_createBedrockInvokeWithResponseStreamRouteConfig(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockInvokeWithResponseStreamRouteConfig("/bedrock", handlerStore)
|
||||
|
||||
assert.Equal(t, "/bedrock/model/{modelId}/invoke-with-response-stream", route.Path)
|
||||
assert.Equal(t, "POST", route.Method)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
|
||||
assert.NotNil(t, route.StreamConfig)
|
||||
assert.NotNil(t, route.StreamConfig.TextStreamResponseConverter)
|
||||
assert.NotNil(t, route.StreamConfig.ResponsesStreamResponseConverter)
|
||||
|
||||
// Verify request instance type
|
||||
reqInstance := route.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*bedrock.BedrockInvokeRequest)
|
||||
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockInvokeRequest")
|
||||
}
|
||||
|
||||
func Test_createBedrockRerankRouteConfig(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockRerankRouteConfig("/bedrock", handlerStore)
|
||||
|
||||
assert.Equal(t, "/bedrock/rerank", route.Path)
|
||||
assert.Equal(t, "POST", route.Method)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
|
||||
assert.NotNil(t, route.GetHTTPRequestType)
|
||||
assert.Equal(t, schemas.RerankRequest, route.GetHTTPRequestType(nil))
|
||||
assert.NotNil(t, route.GetRequestTypeInstance)
|
||||
assert.NotNil(t, route.RequestConverter)
|
||||
assert.NotNil(t, route.RerankResponseConverter)
|
||||
assert.NotNil(t, route.ErrorConverter)
|
||||
assert.NotNil(t, route.PreCallback)
|
||||
|
||||
// Verify request instance type
|
||||
reqInstance := route.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*bedrock.BedrockRerankRequest)
|
||||
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockRerankRequest")
|
||||
}
|
||||
|
||||
func Test_createBedrockRerankResponseConverterUsesRawResponse(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockRerankRouteConfig("/bedrock", handlerStore)
|
||||
require.NotNil(t, route.RerankResponseConverter)
|
||||
|
||||
raw := map[string]interface{}{"results": []interface{}{}}
|
||||
resp := &schemas.BifrostRerankResponse{
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Provider: schemas.Bedrock,
|
||||
RawResponse: raw,
|
||||
},
|
||||
}
|
||||
converted, err := route.RerankResponseConverter(nil, resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, raw, converted)
|
||||
}
|
||||
|
||||
func Test_createBedrockRerankRouteRequestConverter(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
route := createBedrockRerankRouteConfig("/bedrock", handlerStore)
|
||||
require.NotNil(t, route.RequestConverter)
|
||||
|
||||
topN := 1
|
||||
req := &bedrock.BedrockRerankRequest{
|
||||
Queries: []bedrock.BedrockRerankQuery{
|
||||
{
|
||||
Type: "TEXT",
|
||||
TextQuery: bedrock.BedrockRerankTextRef{Text: "capital of france"},
|
||||
},
|
||||
},
|
||||
Sources: []bedrock.BedrockRerankSource{
|
||||
{
|
||||
Type: "INLINE",
|
||||
InlineDocumentSource: bedrock.BedrockRerankInlineSource{
|
||||
Type: "TEXT",
|
||||
TextDocument: bedrock.BedrockRerankTextValue{Text: "Paris is capital of France"},
|
||||
},
|
||||
},
|
||||
},
|
||||
RerankingConfiguration: bedrock.BedrockRerankingConfiguration{
|
||||
Type: "BEDROCK_RERANKING_MODEL",
|
||||
BedrockRerankingConfiguration: bedrock.BedrockRerankingModelConfiguration{
|
||||
NumberOfResults: &topN,
|
||||
ModelConfiguration: bedrock.BedrockRerankModelConfiguration{
|
||||
ModelARN: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostReq, err := route.RequestConverter(bifrostCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, bifrostReq)
|
||||
require.NotNil(t, bifrostReq.RerankRequest)
|
||||
assert.Equal(t, schemas.Bedrock, bifrostReq.RerankRequest.Provider)
|
||||
assert.Equal(t, "capital of france", bifrostReq.RerankRequest.Query)
|
||||
require.Len(t, bifrostReq.RerankRequest.Documents, 1)
|
||||
assert.Equal(t, "Paris is capital of France", bifrostReq.RerankRequest.Documents[0].Text)
|
||||
require.NotNil(t, bifrostReq.RerankRequest.Params)
|
||||
require.NotNil(t, bifrostReq.RerankRequest.Params.TopN)
|
||||
assert.Equal(t, 1, *bifrostReq.RerankRequest.Params.TopN)
|
||||
}
|
||||
|
||||
func Test_createBedrockRouteConfigsIncludesRerankForCompositePrefixes(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
prefixes := []string{"/litellm", "/langchain", "/pydanticai"}
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
routes := CreateBedrockRouteConfigs(prefix, handlerStore)
|
||||
found := false
|
||||
for _, route := range routes {
|
||||
if route.Path == prefix+"/rerank" && route.Method == "POST" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Truef(t, found, "expected rerank route for prefix %s", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_createBedrockBatchRouteConfigs(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
routes := createBedrockBatchRouteConfigs("/bedrock", handlerStore)
|
||||
|
||||
assert.Len(t, routes, 4, "should have 4 batch routes")
|
||||
|
||||
expectedRoutes := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/bedrock/model-invocation-job", "POST"},
|
||||
{"/bedrock/model-invocation-jobs", "GET"},
|
||||
{"/bedrock/model-invocation-job/{job_arn}", "GET"},
|
||||
{"/bedrock/model-invocation-job/{job_arn}/stop", "POST"},
|
||||
}
|
||||
|
||||
for i, expected := range expectedRoutes {
|
||||
assert.Equal(t, expected.path, routes[i].Path, "batch route %d path mismatch", i)
|
||||
assert.Equal(t, expected.method, routes[i].Method, "batch route %d method mismatch", i)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, routes[i].Type, "batch route %d type mismatch", i)
|
||||
assert.NotNil(t, routes[i].GetRequestTypeInstance, "batch route %d GetRequestTypeInstance should not be nil", i)
|
||||
assert.NotNil(t, routes[i].BatchRequestConverter, "batch route %d BatchCreateRequestConverter should not be nil", i)
|
||||
assert.NotNil(t, routes[i].ErrorConverter, "batch route %d ErrorConverter should not be nil", i)
|
||||
assert.NotNil(t, routes[i].PreCallback, "batch route %d PreCallback should not be nil", i)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_createBedrockFilesRouteConfigs(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
routes := createBedrockFilesRouteConfigs("/bedrock/files", handlerStore)
|
||||
|
||||
assert.Len(t, routes, 5, "should have 5 file routes")
|
||||
|
||||
expectedRoutes := []struct {
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"/bedrock/files/{bucket}/{key:*}", "PUT"},
|
||||
{"/bedrock/files/{bucket}/{key:*}", "GET"},
|
||||
{"/bedrock/files/{bucket}/{key:*}", "HEAD"},
|
||||
{"/bedrock/files/{bucket}/{key:*}", "DELETE"},
|
||||
{"/bedrock/files/{bucket}", "GET"},
|
||||
}
|
||||
|
||||
for i, expected := range expectedRoutes {
|
||||
assert.Equal(t, expected.path, routes[i].Path, "file route %d path mismatch", i)
|
||||
assert.Equal(t, expected.method, routes[i].Method, "file route %d method mismatch", i)
|
||||
assert.Equal(t, RouteConfigTypeBedrock, routes[i].Type, "file route %d type mismatch", i)
|
||||
assert.NotNil(t, routes[i].GetRequestTypeInstance, "file route %d GetRequestTypeInstance should not be nil", i)
|
||||
assert.NotNil(t, routes[i].ErrorConverter, "file route %d ErrorConverter should not be nil", i)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseS3PutObjectRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bucket string
|
||||
key string
|
||||
body []byte
|
||||
wantErr bool
|
||||
wantBucket string
|
||||
wantKey string
|
||||
wantFilename string
|
||||
}{
|
||||
{
|
||||
name: "valid request",
|
||||
bucket: "my-bucket",
|
||||
key: "folder/file.jsonl",
|
||||
body: []byte(`{"test": "data"}`),
|
||||
wantErr: false,
|
||||
wantBucket: "my-bucket",
|
||||
wantKey: "folder/file.jsonl",
|
||||
wantFilename: "file.jsonl",
|
||||
},
|
||||
{
|
||||
name: "simple key without folder",
|
||||
bucket: "bucket",
|
||||
key: "file.txt",
|
||||
body: []byte("content"),
|
||||
wantErr: false,
|
||||
wantBucket: "bucket",
|
||||
wantKey: "file.txt",
|
||||
wantFilename: "file.txt",
|
||||
},
|
||||
{
|
||||
name: "missing bucket",
|
||||
bucket: "",
|
||||
key: "file.txt",
|
||||
body: []byte("content"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
bucket: "bucket",
|
||||
key: "",
|
||||
body: []byte("content"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetBody(tt.body)
|
||||
|
||||
if tt.bucket != "" {
|
||||
ctx.SetUserValue("bucket", tt.bucket)
|
||||
}
|
||||
if tt.key != "" {
|
||||
ctx.SetUserValue("key", tt.key)
|
||||
}
|
||||
|
||||
req := &bedrock.BedrockFileUploadRequest{}
|
||||
err := parseS3PutObjectRequest(ctx, req)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantBucket, req.Bucket)
|
||||
assert.Equal(t, tt.wantKey, req.Key)
|
||||
assert.Equal(t, tt.wantFilename, req.Filename)
|
||||
assert.Equal(t, "batch", req.Purpose)
|
||||
assert.Equal(t, tt.body, req.Body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseS3PutObjectRequest_invalidType(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue("bucket", "bucket")
|
||||
ctx.SetUserValue("key", "key")
|
||||
|
||||
// Pass wrong type
|
||||
err := parseS3PutObjectRequest(ctx, "invalid type")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid request type")
|
||||
}
|
||||
|
||||
func Test_s3PutObjectPostCallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
wantStatus int
|
||||
wantETag string
|
||||
}{
|
||||
{
|
||||
name: "valid response with ID",
|
||||
response: &schemas.BifrostFileUploadResponse{
|
||||
ID: "file-123",
|
||||
},
|
||||
wantStatus: 200,
|
||||
wantETag: "\"file-123\"",
|
||||
},
|
||||
{
|
||||
name: "nil response",
|
||||
response: nil,
|
||||
wantStatus: 200,
|
||||
wantETag: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
err := s3PutObjectPostCallback(ctx, nil, tt.response)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantStatus, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "application/xml", string(ctx.Response.Header.ContentType()))
|
||||
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
|
||||
|
||||
if tt.wantETag != "" {
|
||||
assert.Equal(t, tt.wantETag, string(ctx.Response.Header.Peek("ETag")))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_s3GetObjectPostCallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
wantContentType string
|
||||
wantLength string
|
||||
wantETag string
|
||||
}{
|
||||
{
|
||||
name: "valid response",
|
||||
response: &schemas.BifrostFileContentResponse{
|
||||
Content: []byte("test content"),
|
||||
ContentType: "application/json",
|
||||
FileID: "file-456",
|
||||
},
|
||||
wantContentType: "application/json",
|
||||
wantLength: "12",
|
||||
wantETag: "\"file-456\"",
|
||||
},
|
||||
{
|
||||
name: "nil response",
|
||||
response: nil,
|
||||
wantContentType: "",
|
||||
wantLength: "",
|
||||
wantETag: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
err := s3GetObjectPostCallback(ctx, nil, tt.response)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
if tt.wantContentType != "" {
|
||||
assert.Equal(t, tt.wantContentType, string(ctx.Response.Header.Peek("Content-Type")))
|
||||
assert.Equal(t, tt.wantLength, string(ctx.Response.Header.Peek("Content-Length")))
|
||||
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
|
||||
}
|
||||
|
||||
if tt.wantETag != "" {
|
||||
assert.Equal(t, tt.wantETag, string(ctx.Response.Header.Peek("ETag")))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_s3HeadObjectPostCallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
wantStatus int
|
||||
wantLength string
|
||||
wantETag string
|
||||
}{
|
||||
{
|
||||
name: "valid response",
|
||||
response: &schemas.BifrostFileRetrieveResponse{
|
||||
ID: "file-789",
|
||||
Bytes: 1024,
|
||||
},
|
||||
wantStatus: 200,
|
||||
wantLength: "1024",
|
||||
wantETag: "\"file-789\"",
|
||||
},
|
||||
{
|
||||
name: "nil response",
|
||||
response: nil,
|
||||
wantStatus: 200,
|
||||
wantLength: "",
|
||||
wantETag: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
err := s3HeadObjectPostCallback(ctx, nil, tt.response)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantStatus, ctx.Response.StatusCode())
|
||||
|
||||
if tt.wantLength != "" {
|
||||
assert.Equal(t, "application/octet-stream", string(ctx.Response.Header.Peek("Content-Type")))
|
||||
assert.Equal(t, tt.wantLength, string(ctx.Response.Header.Peek("Content-Length")))
|
||||
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
|
||||
assert.Equal(t, tt.wantETag, string(ctx.Response.Header.Peek("ETag")))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_s3DeleteObjectPostCallback(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
err := s3DeleteObjectPostCallback(ctx, nil, nil)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 204, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
|
||||
}
|
||||
|
||||
func Test_s3ListObjectsV2PostCallback(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
err := s3ListObjectsV2PostCallback(ctx, nil, nil)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "application/xml", string(ctx.Response.Header.ContentType()))
|
||||
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
|
||||
}
|
||||
|
||||
func Test_extractBedrockBatchListQueryParams(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: false}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams map[string]string
|
||||
wantMaxResults int
|
||||
wantNextToken string
|
||||
wantStatus string
|
||||
wantName string
|
||||
}{
|
||||
{
|
||||
name: "all params",
|
||||
queryParams: map[string]string{
|
||||
"maxResults": "50",
|
||||
"nextToken": "token123",
|
||||
"statusEquals": "InProgress",
|
||||
"nameContains": "test-job",
|
||||
},
|
||||
wantMaxResults: 50,
|
||||
wantNextToken: "token123",
|
||||
wantStatus: "InProgress",
|
||||
wantName: "test-job",
|
||||
},
|
||||
{
|
||||
name: "no params",
|
||||
queryParams: map[string]string{},
|
||||
wantMaxResults: 0,
|
||||
wantNextToken: "",
|
||||
wantStatus: "",
|
||||
wantName: "",
|
||||
},
|
||||
{
|
||||
name: "invalid maxResults",
|
||||
queryParams: map[string]string{
|
||||
"maxResults": "invalid",
|
||||
},
|
||||
wantMaxResults: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
for k, v := range tt.queryParams {
|
||||
ctx.QueryArgs().Add(k, v)
|
||||
}
|
||||
|
||||
req := &bedrock.BedrockBatchListRequest{}
|
||||
callback := extractBedrockBatchListQueryParams(handlerStore)
|
||||
|
||||
bifrostCtx := createTestBifrostContext()
|
||||
err := callback(ctx, bifrostCtx, req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantMaxResults, req.MaxResults)
|
||||
assert.Equal(t, tt.wantStatus, req.StatusEquals)
|
||||
assert.Equal(t, tt.wantName, req.NameContains)
|
||||
|
||||
if tt.wantNextToken != "" {
|
||||
assert.NotNil(t, req.NextToken)
|
||||
assert.Equal(t, tt.wantNextToken, *req.NextToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_extractBedrockJobArnFromPath(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: false}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
jobArn interface{}
|
||||
provider schemas.ModelProvider
|
||||
wantErr bool
|
||||
wantJobArn string
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid job ARN for Bedrock",
|
||||
jobArn: "arn:aws:bedrock:us-east-1:123456789012:batch:job-123",
|
||||
provider: schemas.Bedrock,
|
||||
wantErr: false,
|
||||
wantJobArn: "arn:aws:bedrock:us-east-1:123456789012:batch:job-123",
|
||||
},
|
||||
{
|
||||
name: "URL encoded job ARN",
|
||||
jobArn: "arn%3Aaws%3Abedrock%3Aus-east-1%3A123456789012%3Abatch%3Ajob-123",
|
||||
provider: schemas.Bedrock,
|
||||
wantErr: false,
|
||||
wantJobArn: "arn:aws:bedrock:us-east-1:123456789012:batch:job-123",
|
||||
},
|
||||
{
|
||||
name: "non-Bedrock provider strips ARN prefix",
|
||||
jobArn: "arn:aws:bedrock:us-east-1:444444444444:batch:job-456",
|
||||
provider: schemas.OpenAI,
|
||||
wantErr: false,
|
||||
wantJobArn: "job-456",
|
||||
},
|
||||
{
|
||||
name: "missing job_arn",
|
||||
jobArn: nil,
|
||||
provider: schemas.Bedrock,
|
||||
wantErr: true,
|
||||
errContains: "job_arn is required",
|
||||
},
|
||||
{
|
||||
name: "empty job_arn",
|
||||
jobArn: "",
|
||||
provider: schemas.Bedrock,
|
||||
wantErr: true,
|
||||
errContains: "job_arn must be a non-empty string",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
if tt.jobArn != nil {
|
||||
ctx.SetUserValue("job_arn", tt.jobArn)
|
||||
}
|
||||
|
||||
req := &bedrock.BedrockBatchRetrieveRequest{}
|
||||
callback := extractBedrockJobArnFromPath(handlerStore)
|
||||
|
||||
bifrostCtx := createTestBifrostContextWithProvider(tt.provider)
|
||||
err := callback(ctx, bifrostCtx, req)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantJobArn, req.JobIdentifier)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_extractS3ListObjectsV2Params(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: false}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
bucket string
|
||||
queryParams map[string]string
|
||||
wantErr bool
|
||||
wantBucket string
|
||||
wantPrefix string
|
||||
wantMaxKeys int
|
||||
wantContinuationToken string
|
||||
}{
|
||||
{
|
||||
name: "all params",
|
||||
bucket: "my-bucket",
|
||||
queryParams: map[string]string{
|
||||
"prefix": "folder/",
|
||||
"max-keys": "100",
|
||||
"continuation-token": "token-abc",
|
||||
},
|
||||
wantErr: false,
|
||||
wantBucket: "my-bucket",
|
||||
wantPrefix: "folder/",
|
||||
wantMaxKeys: 100,
|
||||
wantContinuationToken: "token-abc",
|
||||
},
|
||||
{
|
||||
name: "bucket only",
|
||||
bucket: "simple-bucket",
|
||||
queryParams: map[string]string{},
|
||||
wantErr: false,
|
||||
wantBucket: "simple-bucket",
|
||||
wantPrefix: "",
|
||||
wantMaxKeys: 1000,
|
||||
},
|
||||
{
|
||||
name: "missing bucket",
|
||||
bucket: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
if tt.bucket != "" {
|
||||
ctx.SetUserValue("bucket", tt.bucket)
|
||||
}
|
||||
for k, v := range tt.queryParams {
|
||||
ctx.QueryArgs().Add(k, v)
|
||||
}
|
||||
|
||||
req := &bedrock.BedrockFileListRequest{}
|
||||
callback := extractS3ListObjectsV2Params(handlerStore)
|
||||
|
||||
bifrostCtx := createTestBifrostContext()
|
||||
err := callback(ctx, bifrostCtx, req)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantBucket, req.Bucket)
|
||||
assert.Equal(t, tt.wantPrefix, req.Prefix)
|
||||
assert.Equal(t, tt.wantMaxKeys, req.MaxKeys)
|
||||
assert.Equal(t, tt.wantContinuationToken, req.ContinuationToken)
|
||||
|
||||
// Verify context values
|
||||
assert.Equal(t, tt.wantBucket, bifrostCtx.Value(s3ContextKeyBucket))
|
||||
assert.Equal(t, tt.wantPrefix, bifrostCtx.Value(s3ContextKeyPrefix))
|
||||
assert.Equal(t, tt.wantMaxKeys, bifrostCtx.Value(s3ContextKeyMaxKeys))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_extractS3BucketKeyFromPath(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: false}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
bucket string
|
||||
key string
|
||||
fileID string
|
||||
opType string
|
||||
wantErr bool
|
||||
wantBucket string
|
||||
wantKey string
|
||||
wantS3URI string
|
||||
}{
|
||||
{
|
||||
name: "content operation",
|
||||
bucket: "my-bucket",
|
||||
key: "path/to/file.txt",
|
||||
fileID: "file-123",
|
||||
opType: "content",
|
||||
wantErr: false,
|
||||
wantBucket: "my-bucket",
|
||||
wantKey: "path/to/file.txt",
|
||||
wantS3URI: "s3://my-bucket/path/to/file.txt",
|
||||
},
|
||||
{
|
||||
name: "missing bucket",
|
||||
bucket: "",
|
||||
key: "file.txt",
|
||||
opType: "content",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing key",
|
||||
bucket: "bucket",
|
||||
key: "",
|
||||
opType: "content",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
if tt.bucket != "" {
|
||||
ctx.SetUserValue("bucket", tt.bucket)
|
||||
}
|
||||
if tt.key != "" {
|
||||
ctx.SetUserValue("key", tt.key)
|
||||
}
|
||||
if tt.fileID != "" {
|
||||
ctx.Request.Header.Set("If-Match", tt.fileID)
|
||||
}
|
||||
|
||||
callback := extractS3BucketKeyFromPath(handlerStore, tt.opType)
|
||||
bifrostCtx := createTestBifrostContext()
|
||||
|
||||
var req interface{}
|
||||
switch tt.opType {
|
||||
case "content":
|
||||
req = &bedrock.BedrockFileContentRequest{}
|
||||
case "retrieve":
|
||||
req = &bedrock.BedrockFileRetrieveRequest{}
|
||||
case "delete":
|
||||
req = &bedrock.BedrockFileDeleteRequest{}
|
||||
}
|
||||
|
||||
err := callback(ctx, bifrostCtx, req)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
switch r := req.(type) {
|
||||
case *bedrock.BedrockFileContentRequest:
|
||||
assert.Equal(t, tt.wantBucket, r.Bucket)
|
||||
assert.Equal(t, tt.wantKey, r.Prefix)
|
||||
assert.Equal(t, tt.wantS3URI, r.S3Uri)
|
||||
assert.Equal(t, tt.fileID, r.ETag)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for creating test contexts
|
||||
|
||||
func createTestBifrostContext() *schemas.BifrostContext {
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock)
|
||||
return bifrostCtx
|
||||
}
|
||||
|
||||
func createTestBifrostContextWithProvider(provider schemas.ModelProvider) *schemas.BifrostContext {
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(bifrostContextKeyProvider, provider)
|
||||
return bifrostCtx
|
||||
}
|
||||
222
transports/bifrost-http/integrations/cohere.go
Normal file
222
transports/bifrost-http/integrations/cohere.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/providers/cohere"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// hydrateCohereRequestFromLargePayloadMetadata populates model + stream from
|
||||
// LargePayloadMetadata when body parsing is skipped under large payload mode.
|
||||
func hydrateCohereRequestFromLargePayloadMetadata(bifrostCtx *schemas.BifrostContext, req interface{}) {
|
||||
if bifrostCtx == nil {
|
||||
return
|
||||
}
|
||||
isLargePayload, _ := bifrostCtx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool)
|
||||
if !isLargePayload {
|
||||
return
|
||||
}
|
||||
metadata := resolveLargePayloadMetadata(bifrostCtx)
|
||||
if metadata == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch r := req.(type) {
|
||||
case *cohere.CohereChatRequest:
|
||||
if r.Model == "" {
|
||||
r.Model = metadata.Model
|
||||
}
|
||||
if metadata.StreamRequested != nil && r.Stream == nil {
|
||||
r.Stream = schemas.Ptr(*metadata.StreamRequested)
|
||||
}
|
||||
case *cohere.CohereEmbeddingRequest:
|
||||
if r.Model == "" {
|
||||
r.Model = metadata.Model
|
||||
}
|
||||
case *cohere.CohereRerankRequest:
|
||||
if r.Model == "" {
|
||||
r.Model = metadata.Model
|
||||
}
|
||||
case *cohere.CohereCountTokensRequest:
|
||||
if r.Model == "" {
|
||||
r.Model = metadata.Model
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cohereLargePayloadPreHook populates model + stream from LargePayloadMetadata
|
||||
// when body parsing is skipped under large payload mode.
|
||||
func cohereLargePayloadPreHook(_ *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error {
|
||||
hydrateCohereRequestFromLargePayloadMetadata(bifrostCtx, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CohereRouter holds route registrations for Cohere endpoints.
|
||||
// It supports Cohere's v2 chat, embeddings, and rerank APIs.
|
||||
type CohereRouter struct {
|
||||
*GenericRouter
|
||||
}
|
||||
|
||||
// NewCohereRouter creates a new CohereRouter with the given bifrost client.
|
||||
func NewCohereRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *CohereRouter {
|
||||
return &CohereRouter{
|
||||
GenericRouter: NewGenericRouter(client, handlerStore, CreateCohereRouteConfigs("/cohere"), nil, logger),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCohereRouteConfigs creates route configurations for Cohere API endpoints.
|
||||
func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig {
|
||||
var routes []RouteConfig
|
||||
|
||||
// Chat completions endpoint (v2/chat)
|
||||
routes = append(routes, RouteConfig{
|
||||
Type: RouteConfigTypeCohere,
|
||||
Path: pathPrefix + "/v2/chat",
|
||||
Method: "POST",
|
||||
PreCallback: cohereLargePayloadPreHook,
|
||||
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
|
||||
return schemas.ChatCompletionRequest
|
||||
},
|
||||
GetRequestTypeInstance: func(ctx context.Context) interface{} {
|
||||
return &cohere.CohereChatRequest{}
|
||||
},
|
||||
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
|
||||
if cohereReq, ok := req.(*cohere.CohereChatRequest); ok {
|
||||
return &schemas.BifrostRequest{
|
||||
ChatRequest: cohereReq.ToBifrostChatRequest(ctx),
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("invalid request type")
|
||||
},
|
||||
ChatResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostChatResponse) (interface{}, error) {
|
||||
if resp.ExtraFields.Provider == schemas.Cohere {
|
||||
if resp.ExtraFields.RawResponse != nil {
|
||||
return resp.ExtraFields.RawResponse, nil
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
StreamConfig: &StreamConfig{
|
||||
ChatStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostChatResponse) (string, interface{}, error) {
|
||||
if resp.ExtraFields.Provider == schemas.Cohere {
|
||||
if resp.ExtraFields.RawResponse != nil {
|
||||
return "", resp.ExtraFields.RawResponse, nil
|
||||
}
|
||||
}
|
||||
return "", resp, nil
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Embeddings endpoint (v2/embed)
|
||||
routes = append(routes, RouteConfig{
|
||||
Type: RouteConfigTypeCohere,
|
||||
Path: pathPrefix + "/v2/embed",
|
||||
Method: "POST",
|
||||
PreCallback: cohereLargePayloadPreHook,
|
||||
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
|
||||
return schemas.EmbeddingRequest
|
||||
},
|
||||
GetRequestTypeInstance: func(ctx context.Context) interface{} {
|
||||
return &cohere.CohereEmbeddingRequest{}
|
||||
},
|
||||
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
|
||||
if cohereReq, ok := req.(*cohere.CohereEmbeddingRequest); ok {
|
||||
return &schemas.BifrostRequest{
|
||||
EmbeddingRequest: cohereReq.ToBifrostEmbeddingRequest(ctx),
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("invalid embedding request type")
|
||||
},
|
||||
EmbeddingResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) {
|
||||
if resp.ExtraFields.Provider == schemas.Cohere {
|
||||
if resp.ExtraFields.RawResponse != nil {
|
||||
return resp.ExtraFields.RawResponse, nil
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
})
|
||||
|
||||
// Rerank endpoint (v2/rerank)
|
||||
routes = append(routes, RouteConfig{
|
||||
Type: RouteConfigTypeCohere,
|
||||
Path: pathPrefix + "/v2/rerank",
|
||||
Method: "POST",
|
||||
PreCallback: cohereLargePayloadPreHook,
|
||||
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
|
||||
return schemas.RerankRequest
|
||||
},
|
||||
GetRequestTypeInstance: func(ctx context.Context) interface{} {
|
||||
return &cohere.CohereRerankRequest{}
|
||||
},
|
||||
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
|
||||
if cohereReq, ok := req.(*cohere.CohereRerankRequest); ok {
|
||||
return &schemas.BifrostRequest{
|
||||
RerankRequest: cohereReq.ToBifrostRerankRequest(ctx),
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("invalid rerank request type")
|
||||
},
|
||||
RerankResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostRerankResponse) (interface{}, error) {
|
||||
if resp.ExtraFields.Provider == schemas.Cohere {
|
||||
if resp.ExtraFields.RawResponse != nil {
|
||||
return resp.ExtraFields.RawResponse, nil
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
})
|
||||
|
||||
// Tokenize endpoint (v1/tokenize)
|
||||
routes = append(routes, RouteConfig{
|
||||
Type: RouteConfigTypeCohere,
|
||||
Path: pathPrefix + "/v1/tokenize",
|
||||
Method: "POST",
|
||||
PreCallback: cohereLargePayloadPreHook,
|
||||
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
|
||||
return schemas.CountTokensRequest
|
||||
},
|
||||
GetRequestTypeInstance: func(ctx context.Context) interface{} {
|
||||
return &cohere.CohereCountTokensRequest{}
|
||||
},
|
||||
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
|
||||
if cohereReq, ok := req.(*cohere.CohereCountTokensRequest); ok {
|
||||
return &schemas.BifrostRequest{
|
||||
CountTokensRequest: cohereReq.ToBifrostResponsesRequest(ctx),
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("invalid count tokens request type")
|
||||
},
|
||||
CountTokensResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostCountTokensResponse) (interface{}, error) {
|
||||
if resp.ExtraFields.Provider == schemas.Cohere {
|
||||
if resp.ExtraFields.RawResponse != nil {
|
||||
return resp.ExtraFields.RawResponse, nil
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
})
|
||||
|
||||
return routes
|
||||
}
|
||||
102
transports/bifrost-http/integrations/cohere_test.go
Normal file
102
transports/bifrost-http/integrations/cohere_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/cohere"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateCohereRouteConfigsIncludesRerank(t *testing.T) {
|
||||
routes := CreateCohereRouteConfigs("/cohere")
|
||||
|
||||
assert.Len(t, routes, 4, "should have 4 cohere routes")
|
||||
|
||||
var rerankRoute *RouteConfig
|
||||
for i := range routes {
|
||||
if routes[i].Path == "/cohere/v2/rerank" && routes[i].Method == "POST" {
|
||||
rerankRoute = &routes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, rerankRoute, "rerank route should exist")
|
||||
assert.Equal(t, RouteConfigTypeCohere, rerankRoute.Type)
|
||||
assert.NotNil(t, rerankRoute.GetHTTPRequestType)
|
||||
assert.Equal(t, schemas.RerankRequest, rerankRoute.GetHTTPRequestType(nil))
|
||||
assert.NotNil(t, rerankRoute.GetRequestTypeInstance)
|
||||
assert.NotNil(t, rerankRoute.RequestConverter)
|
||||
assert.NotNil(t, rerankRoute.RerankResponseConverter)
|
||||
assert.NotNil(t, rerankRoute.ErrorConverter)
|
||||
|
||||
reqInstance := rerankRoute.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*cohere.CohereRerankRequest)
|
||||
assert.True(t, ok, "rerank request instance should be CohereRerankRequest")
|
||||
}
|
||||
|
||||
func TestCohereRerankRouteRequestConverter(t *testing.T) {
|
||||
routes := CreateCohereRouteConfigs("/cohere")
|
||||
|
||||
var rerankRoute *RouteConfig
|
||||
for i := range routes {
|
||||
if routes[i].Path == "/cohere/v2/rerank" {
|
||||
rerankRoute = &routes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, rerankRoute)
|
||||
require.NotNil(t, rerankRoute.RequestConverter)
|
||||
|
||||
topN := 1
|
||||
req := &cohere.CohereRerankRequest{
|
||||
Model: "rerank-v3.5",
|
||||
Query: "what is bifrost?",
|
||||
Documents: []string{"doc1", "doc2"},
|
||||
TopN: &topN,
|
||||
}
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostReq, err := rerankRoute.RequestConverter(bifrostCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, bifrostReq)
|
||||
require.NotNil(t, bifrostReq.RerankRequest)
|
||||
|
||||
assert.Equal(t, schemas.Cohere, bifrostReq.RerankRequest.Provider)
|
||||
assert.Equal(t, "rerank-v3.5", bifrostReq.RerankRequest.Model)
|
||||
assert.Equal(t, "what is bifrost?", bifrostReq.RerankRequest.Query)
|
||||
require.Len(t, bifrostReq.RerankRequest.Documents, 2)
|
||||
assert.Equal(t, "doc1", bifrostReq.RerankRequest.Documents[0].Text)
|
||||
assert.Equal(t, "doc2", bifrostReq.RerankRequest.Documents[1].Text)
|
||||
require.NotNil(t, bifrostReq.RerankRequest.Params)
|
||||
require.NotNil(t, bifrostReq.RerankRequest.Params.TopN)
|
||||
assert.Equal(t, 1, *bifrostReq.RerankRequest.Params.TopN)
|
||||
}
|
||||
|
||||
func TestCohereRerankResponseConverterUsesRawResponse(t *testing.T) {
|
||||
routes := CreateCohereRouteConfigs("/cohere")
|
||||
|
||||
var rerankRoute *RouteConfig
|
||||
for i := range routes {
|
||||
if routes[i].Path == "/cohere/v2/rerank" {
|
||||
rerankRoute = &routes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, rerankRoute)
|
||||
require.NotNil(t, rerankRoute.RerankResponseConverter)
|
||||
|
||||
raw := map[string]interface{}{"id": "r-123", "results": []interface{}{}}
|
||||
resp := &schemas.BifrostRerankResponse{
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Provider: schemas.Cohere,
|
||||
RawResponse: raw,
|
||||
},
|
||||
}
|
||||
|
||||
converted, err := rerankRoute.RerankResponseConverter(nil, resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, raw, converted)
|
||||
}
|
||||
1067
transports/bifrost-http/integrations/cursor.go
Normal file
1067
transports/bifrost-http/integrations/cursor.go
Normal file
File diff suppressed because it is too large
Load Diff
1347
transports/bifrost-http/integrations/genai.go
Normal file
1347
transports/bifrost-http/integrations/genai.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,47 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestExtractModelAndRequestType_LargePayloadUsesMetadataWithoutBodyParse(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue("model", "gemini-2.5-pro:generateContent")
|
||||
// Intentionally invalid JSON: detection must rely on large-payload metadata, not body parse.
|
||||
ctx.Request.SetBodyString(`{"contents":[INVALID`)
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyLargePayloadMode, true)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, &schemas.LargePayloadMetadata{
|
||||
ResponseModalities: []string{"AUDIO"},
|
||||
})
|
||||
ctx.SetUserValue(lib.FastHTTPUserValueBifrostContext, bifrostCtx)
|
||||
|
||||
model, reqType := extractModelAndRequestType(ctx)
|
||||
if model != "gemini-2.5-pro" {
|
||||
t.Fatalf("expected normalized model gemini-2.5-pro, got %q", model)
|
||||
}
|
||||
if reqType != schemas.SpeechRequest {
|
||||
t.Fatalf("expected speech request type from metadata, got %q", reqType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractModelAndRequestType_LargeBodyHeuristicSkipsParse(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue("model", "gemini-2.5-pro:generateContent")
|
||||
ctx.Request.SetBodyStream(strings.NewReader(`{"contents":[INVALID`), schemas.DefaultLargePayloadRequestThresholdBytes+1)
|
||||
|
||||
model, reqType := extractModelAndRequestType(ctx)
|
||||
if model != "gemini-2.5-pro" {
|
||||
t.Fatalf("expected normalized model gemini-2.5-pro, got %q", model)
|
||||
}
|
||||
if reqType != schemas.ResponsesRequest {
|
||||
t.Fatalf("expected responses request type from large-body heuristic, got %q", reqType)
|
||||
}
|
||||
}
|
||||
208
transports/bifrost-http/integrations/genai_test.go
Normal file
208
transports/bifrost-http/integrations/genai_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/gemini"
|
||||
"github.com/maximhq/bifrost/core/providers/vertex"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestCreateGenAIRerankRouteConfig(t *testing.T) {
|
||||
route := createGenAIRerankRouteConfig("/genai")
|
||||
|
||||
assert.Equal(t, "/genai/v1/rank", route.Path)
|
||||
assert.Equal(t, "POST", route.Method)
|
||||
assert.Equal(t, RouteConfigTypeGenAI, route.Type)
|
||||
assert.NotNil(t, route.GetHTTPRequestType)
|
||||
assert.Equal(t, schemas.RerankRequest, route.GetHTTPRequestType(nil))
|
||||
assert.NotNil(t, route.GetRequestTypeInstance)
|
||||
assert.NotNil(t, route.RequestConverter)
|
||||
assert.NotNil(t, route.RerankResponseConverter)
|
||||
assert.NotNil(t, route.ErrorConverter)
|
||||
assert.Nil(t, route.PreCallback)
|
||||
|
||||
// Verify request instance type
|
||||
reqInstance := route.GetRequestTypeInstance(context.Background())
|
||||
_, ok := reqInstance.(*vertex.VertexRankRequest)
|
||||
assert.True(t, ok, "GetRequestTypeInstance should return *vertex.VertexRankRequest")
|
||||
}
|
||||
|
||||
func TestCreateGenAIRouteConfigsIncludesRerank(t *testing.T) {
|
||||
routes := CreateGenAIRouteConfigs("/genai")
|
||||
|
||||
found := false
|
||||
for _, route := range routes {
|
||||
if route.Path == "/genai/v1/rank" && route.Method == "POST" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "expected rerank route in genai route configs")
|
||||
}
|
||||
|
||||
func TestCreateGenAIRouteConfigsIncludesRerankForCompositePrefixes(t *testing.T) {
|
||||
prefixes := []string{"/litellm", "/langchain", "/pydanticai"}
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
routes := CreateGenAIRouteConfigs(prefix)
|
||||
found := false
|
||||
for _, route := range routes {
|
||||
if route.Path == prefix+"/v1/rank" && route.Method == "POST" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Truef(t, found, "expected rerank route for prefix %s", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenAIRerankRequestConverter(t *testing.T) {
|
||||
route := createGenAIRerankRouteConfig("/genai")
|
||||
require.NotNil(t, route.RequestConverter)
|
||||
|
||||
model := "semantic-ranker-default@latest"
|
||||
topN := 2
|
||||
content1 := "Paris is capital of France"
|
||||
content2 := "Berlin is capital of Germany"
|
||||
req := &vertex.VertexRankRequest{
|
||||
Model: &model,
|
||||
Query: "capital of france",
|
||||
Records: []vertex.VertexRankRecord{
|
||||
{ID: "rec-1", Content: &content1},
|
||||
{ID: "rec-2", Content: &content2},
|
||||
},
|
||||
TopN: &topN,
|
||||
}
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostReq, err := route.RequestConverter(bifrostCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, bifrostReq)
|
||||
require.NotNil(t, bifrostReq.RerankRequest)
|
||||
assert.Equal(t, schemas.Vertex, bifrostReq.RerankRequest.Provider)
|
||||
assert.Equal(t, "semantic-ranker-default@latest", bifrostReq.RerankRequest.Model)
|
||||
assert.Equal(t, "capital of france", bifrostReq.RerankRequest.Query)
|
||||
require.Len(t, bifrostReq.RerankRequest.Documents, 2)
|
||||
assert.Equal(t, "Paris is capital of France", bifrostReq.RerankRequest.Documents[0].Text)
|
||||
assert.Equal(t, "Berlin is capital of Germany", bifrostReq.RerankRequest.Documents[1].Text)
|
||||
require.NotNil(t, bifrostReq.RerankRequest.Params)
|
||||
require.NotNil(t, bifrostReq.RerankRequest.Params.TopN)
|
||||
assert.Equal(t, 2, *bifrostReq.RerankRequest.Params.TopN)
|
||||
}
|
||||
|
||||
func TestGenAIRerankResponseConverterUsesRawResponse(t *testing.T) {
|
||||
route := createGenAIRerankRouteConfig("/genai")
|
||||
require.NotNil(t, route.RerankResponseConverter)
|
||||
|
||||
raw := map[string]interface{}{"records": []interface{}{}}
|
||||
resp := &schemas.BifrostRerankResponse{
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Provider: schemas.Vertex,
|
||||
RawResponse: raw,
|
||||
},
|
||||
}
|
||||
converted, err := route.RerankResponseConverter(nil, resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, raw, converted)
|
||||
}
|
||||
|
||||
func TestGenAIRerankResponseConverterFallsBackWhenNotVertex(t *testing.T) {
|
||||
route := createGenAIRerankRouteConfig("/genai")
|
||||
require.NotNil(t, route.RerankResponseConverter)
|
||||
|
||||
resp := &schemas.BifrostRerankResponse{
|
||||
Results: []schemas.RerankResult{
|
||||
{Index: 0, RelevanceScore: 0.9},
|
||||
},
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
Provider: schemas.Cohere,
|
||||
},
|
||||
}
|
||||
converted, err := route.RerankResponseConverter(nil, resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp, converted)
|
||||
}
|
||||
|
||||
func TestCreateGenAIRouteConfigsIncludesModelMetadataRoute(t *testing.T) {
|
||||
routes := CreateGenAIRouteConfigs("/genai")
|
||||
|
||||
found := false
|
||||
for _, route := range routes {
|
||||
if route.Path == "/genai/v1beta/models/{model}" && route.Method == "GET" {
|
||||
found = true
|
||||
assert.Equal(t, schemas.ListModelsRequest, route.GetHTTPRequestType(nil))
|
||||
require.NotNil(t, route.PreCallback)
|
||||
require.NotNil(t, route.ListModelsResponseConverter)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, found, "expected model metadata route in genai route configs")
|
||||
}
|
||||
|
||||
func TestExtractGeminiModelMetadataParams(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue("model", "models/gemini-3-pro-preview")
|
||||
|
||||
listReq := &schemas.BifrostListModelsRequest{}
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
|
||||
err := extractGeminiModelMetadataParams(ctx, bifrostCtx, listReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, schemas.Gemini, listReq.Provider)
|
||||
assert.Equal(t, "/models/gemini-3-pro-preview", bifrostCtx.Value(schemas.BifrostContextKeyURLPath))
|
||||
assert.Equal(t, "gemini-3-pro-preview", bifrostCtx.Value(requestedGeminiModelMetadataContextKey))
|
||||
}
|
||||
|
||||
func TestConvertGeminiModelMetadataResponse(t *testing.T) {
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(requestedGeminiModelMetadataContextKey, "gemini-2.5-pro")
|
||||
|
||||
resp := &schemas.BifrostListModelsResponse{
|
||||
Data: []schemas.Model{{ID: "gemini/gemini-2.5-pro", Name: schemas.Ptr("Gemini 2.5 Pro")}},
|
||||
}
|
||||
|
||||
converted, err := convertGeminiModelMetadataResponse(bifrostCtx, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, ok := converted.(gemini.GeminiModel)
|
||||
require.True(t, ok, "expected gemini.GeminiModel")
|
||||
assert.Equal(t, "models/gemini-2.5-pro", model.Name)
|
||||
assert.Equal(t, "Gemini 2.5 Pro", model.DisplayName)
|
||||
}
|
||||
|
||||
func TestConvertGeminiModelMetadataResponse_MatchesRequestedModelNotFirst(t *testing.T) {
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(requestedGeminiModelMetadataContextKey, "gemini-3-pro-preview")
|
||||
|
||||
resp := &schemas.BifrostListModelsResponse{
|
||||
Data: []schemas.Model{
|
||||
{ID: "gemini/gemini-1.5-pro", Name: schemas.Ptr("Gemini 1.5 Pro")},
|
||||
{ID: "gemini/gemini-3-pro-preview", Name: schemas.Ptr("Gemini 3 Pro Preview")},
|
||||
},
|
||||
}
|
||||
|
||||
converted, err := convertGeminiModelMetadataResponse(bifrostCtx, resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, ok := converted.(gemini.GeminiModel)
|
||||
require.True(t, ok, "expected gemini.GeminiModel")
|
||||
assert.Equal(t, "models/gemini-3-pro-preview", model.Name)
|
||||
assert.Equal(t, "Gemini 3 Pro Preview", model.DisplayName)
|
||||
}
|
||||
|
||||
func TestConvertGeminiModelMetadataResponse_EmptyReturnsMinimalModel(t *testing.T) {
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(requestedGeminiModelMetadataContextKey, "gemini-3-pro-preview")
|
||||
|
||||
converted, err := convertGeminiModelMetadataResponse(bifrostCtx, &schemas.BifrostListModelsResponse{Data: []schemas.Model{}})
|
||||
require.NoError(t, err)
|
||||
model, ok := converted.(gemini.GeminiModel)
|
||||
require.True(t, ok, "expected gemini.GeminiModel")
|
||||
assert.Equal(t, "models/gemini-3-pro-preview", model.Name)
|
||||
}
|
||||
42
transports/bifrost-http/integrations/langchain.go
Normal file
42
transports/bifrost-http/integrations/langchain.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// LangChainRouter holds route registrations for LangChain endpoints.
|
||||
// It supports standard chat completions and image-enabled vision capabilities.
|
||||
// LangChain is fully OpenAI-compatible, so we reuse OpenAI types
|
||||
// with aliases for clarity and minimal LangChain-specific extensions
|
||||
type LangChainRouter struct {
|
||||
*GenericRouter
|
||||
}
|
||||
|
||||
// NewLangChainRouter creates a new LangChainRouter with the given bifrost client.
|
||||
func NewLangChainRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *LangChainRouter {
|
||||
routes := []RouteConfig{}
|
||||
|
||||
// Add OpenAI routes to LangChain for OpenAI API compatibility
|
||||
routes = append(routes, CreateOpenAIRouteConfigs("/langchain", handlerStore)...)
|
||||
|
||||
// Add Anthropic routes to LangChain for Anthropic API compatibility
|
||||
routes = append(routes, CreateAnthropicRouteConfigs("/langchain", logger)...)
|
||||
|
||||
// Add Anthropic count tokens route for LangChain to ensure token counting uses the dedicated endpoint
|
||||
routes = append(routes, CreateAnthropicCountTokensRouteConfigs("/langchain", handlerStore)...)
|
||||
|
||||
// Add GenAI routes to LangChain for Vertex AI compatibility
|
||||
routes = append(routes, CreateGenAIRouteConfigs("/langchain")...)
|
||||
|
||||
// Add Bedrock routes to LangChain for AWS Bedrock API compatibility
|
||||
routes = append(routes, CreateBedrockRouteConfigs("/langchain", handlerStore)...)
|
||||
|
||||
// Add Cohere routes to LangChain for Cohere API compatibility
|
||||
routes = append(routes, CreateCohereRouteConfigs("/langchain")...)
|
||||
|
||||
return &LangChainRouter{
|
||||
GenericRouter: NewGenericRouter(client, handlerStore, routes, nil, logger),
|
||||
}
|
||||
}
|
||||
39
transports/bifrost-http/integrations/litellm.go
Normal file
39
transports/bifrost-http/integrations/litellm.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// LiteLLMRouter holds route registrations for LiteLLM endpoints.
|
||||
// It supports standard chat completions and image-enabled vision capabilities.
|
||||
// LiteLLM is fully OpenAI-compatible, so we reuse OpenAI types
|
||||
// with aliases for clarity and minimal LiteLLM-specific extensions
|
||||
type LiteLLMRouter struct {
|
||||
*GenericRouter
|
||||
}
|
||||
|
||||
// NewLiteLLMRouter creates a new LiteLLMRouter with the given bifrost client.
|
||||
func NewLiteLLMRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *LiteLLMRouter {
|
||||
routes := []RouteConfig{}
|
||||
|
||||
// Add OpenAI routes to LiteLLM for OpenAI API compatibility
|
||||
routes = append(routes, CreateOpenAIRouteConfigs("/litellm", handlerStore)...)
|
||||
|
||||
// Add Anthropic routes to LiteLLM for Anthropic API compatibility
|
||||
routes = append(routes, CreateAnthropicRouteConfigs("/litellm", logger)...)
|
||||
|
||||
// Add GenAI routes to LiteLLM for Vertex AI compatibility
|
||||
routes = append(routes, CreateGenAIRouteConfigs("/litellm")...)
|
||||
|
||||
// Add Bedrock routes to LiteLLM for AWS Bedrock API compatibility
|
||||
routes = append(routes, CreateBedrockRouteConfigs("/litellm", handlerStore)...)
|
||||
|
||||
// Add Cohere routes to LiteLLM for Cohere API compatibility
|
||||
routes = append(routes, CreateCohereRouteConfigs("/litellm")...)
|
||||
|
||||
return &LiteLLMRouter{
|
||||
GenericRouter: NewGenericRouter(client, handlerStore, routes, nil, logger),
|
||||
}
|
||||
}
|
||||
3349
transports/bifrost-http/integrations/openai.go
Normal file
3349
transports/bifrost-http/integrations/openai.go
Normal file
File diff suppressed because it is too large
Load Diff
72
transports/bifrost-http/integrations/passthrough.go
Normal file
72
transports/bifrost-http/integrations/passthrough.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// PassthroughRouter is a catch-all router that forwards all requests directly
|
||||
// to the provider without matching against known route patterns.
|
||||
type PassthroughRouter struct {
|
||||
*GenericRouter
|
||||
}
|
||||
|
||||
// NewPassthroughRouter creates a passthrough-only router for any prefix/provider combo.
|
||||
func NewPassthroughRouter(
|
||||
client *bifrost.Bifrost,
|
||||
handlerStore lib.HandlerStore,
|
||||
logger schemas.Logger,
|
||||
cfg *PassthroughConfig,
|
||||
) *PassthroughRouter {
|
||||
if cfg == nil {
|
||||
cfg = &PassthroughConfig{}
|
||||
}
|
||||
return &PassthroughRouter{
|
||||
GenericRouter: NewGenericRouter(client, handlerStore, nil, cfg, logger),
|
||||
}
|
||||
}
|
||||
|
||||
// NewAnthropicPassthroughRouter creates a passthrough router for /anthropic_passthrough.
|
||||
func NewAnthropicPassthroughRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PassthroughRouter {
|
||||
return NewPassthroughRouter(client, handlerStore, logger, &PassthroughConfig{
|
||||
Provider: schemas.Anthropic,
|
||||
StripPrefix: []string{
|
||||
"/anthropic_passthrough",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// NewOpenAIPassthroughRouter creates a passthrough router for /openai_passthrough.
|
||||
func NewOpenAIPassthroughRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PassthroughRouter {
|
||||
return NewPassthroughRouter(client, handlerStore, logger, &PassthroughConfig{
|
||||
Provider: schemas.OpenAI,
|
||||
StripPrefix: []string{
|
||||
"/openai_passthrough",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// NewAzurePassthroughRouter creates a passthrough router for /azure_passthrough.
|
||||
func NewAzurePassthroughRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PassthroughRouter {
|
||||
return NewPassthroughRouter(client, handlerStore, logger, &PassthroughConfig{
|
||||
Provider: schemas.Azure,
|
||||
StripPrefix: []string{
|
||||
"/azure_passthrough",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// NewGenAIPassthroughRouter creates a passthrough router for /genai_passthrough.
|
||||
func NewGenAIPassthroughRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PassthroughRouter {
|
||||
return NewPassthroughRouter(client, handlerStore, logger, &PassthroughConfig{
|
||||
Provider: schemas.Gemini,
|
||||
ProviderDetector: detectProviderFromGenAIRequest,
|
||||
StripPrefix: []string{
|
||||
"/genai_passthrough/v1beta1",
|
||||
"/genai_passthrough/v1beta",
|
||||
"/genai_passthrough/v1",
|
||||
"/genai_passthrough",
|
||||
},
|
||||
})
|
||||
}
|
||||
135
transports/bifrost-http/integrations/pydanticai.go
Normal file
135
transports/bifrost-http/integrations/pydanticai.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
)
|
||||
|
||||
// PydanticAIRouter holds route registrations for Pydantic AI endpoints.
|
||||
// It supports standard chat completions, tool calling, streaming, and multi-provider capabilities.
|
||||
// Pydantic AI uses standard provider SDKs (OpenAI, Anthropic, Google GenAI), so we reuse
|
||||
// existing route configurations with aliases for clarity and Pydantic AI-specific extensions.
|
||||
type PydanticAIRouter struct {
|
||||
*GenericRouter
|
||||
}
|
||||
|
||||
// NewPydanticAIRouter creates a new PydanticAIRouter with the given bifrost client.
|
||||
func NewPydanticAIRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PydanticAIRouter {
|
||||
routes := []RouteConfig{}
|
||||
// Add OpenAI routes to Pydantic AI for OpenAI API compatibility
|
||||
// Supports: chat completions, embeddings, speech, transcriptions, responses
|
||||
routes = append(routes, withPydanticResponsesNullNormalization(CreateOpenAIRouteConfigs("/pydanticai", handlerStore))...)
|
||||
// Add Anthropic routes to Pydantic AI for Anthropic API compatibility
|
||||
// Supports: messages API (Claude models)
|
||||
routes = append(routes, CreateAnthropicRouteConfigs("/pydanticai", logger)...)
|
||||
// Add GenAI routes to Pydantic AI for Google Gemini API compatibility
|
||||
// Supports: generateContent, streamGenerateContent, embedContent
|
||||
routes = append(routes, CreateGenAIRouteConfigs("/pydanticai")...)
|
||||
// Add Cohere routes to Pydantic AI for Cohere API compatibility
|
||||
// Supports: v2/chat (chat completions with streaming), v2/embed (embeddings)
|
||||
routes = append(routes, CreateCohereRouteConfigs("/pydanticai")...)
|
||||
// Add Bedrock routes to Pydantic AI for AWS Bedrock API compatibility
|
||||
// Supports: converse, converse-stream, invoke, invoke-with-response-stream
|
||||
routes = append(routes, CreateBedrockRouteConfigs("/pydanticai", handlerStore)...)
|
||||
return &PydanticAIRouter{
|
||||
GenericRouter: NewGenericRouter(client, handlerStore, routes, nil, logger),
|
||||
}
|
||||
}
|
||||
|
||||
func withPydanticResponsesNullNormalization(routes []RouteConfig) []RouteConfig {
|
||||
for i := range routes {
|
||||
if !strings.Contains(routes[i].Path, "/responses") {
|
||||
continue
|
||||
}
|
||||
|
||||
if routes[i].ResponsesResponseConverter != nil {
|
||||
routes[i].ResponsesResponseConverter = func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesResponse) (interface{}, error) {
|
||||
// For pydantic responses endpoint, prefer normalized bifrost output
|
||||
// instead of raw passthrough, to keep null handling consistent.
|
||||
return resp.WithDefaults(), nil
|
||||
}
|
||||
}
|
||||
|
||||
if routes[i].StreamConfig != nil && routes[i].StreamConfig.ResponsesStreamResponseConverter != nil {
|
||||
// Match non-stream behavior: prefer normalized output (raw->normalizePydanticResponsesRawStreamChunk, typed->resp.WithDefaults()+ensurePydanticResponsesStreamTextFields).
|
||||
routes[i].StreamConfig.ResponsesStreamResponseConverter = func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) {
|
||||
if resp == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
if resp.ExtraFields.RawResponse != nil {
|
||||
normalizedRaw := normalizePydanticResponsesRawStreamChunk(resp.ExtraFields.RawResponse)
|
||||
if normalizedRawString, ok := normalizedRaw.(string); ok {
|
||||
return string(resp.Type), normalizedRawString, nil
|
||||
}
|
||||
}
|
||||
|
||||
normalized := resp.WithDefaults()
|
||||
ensurePydanticResponsesStreamTextFields(normalized)
|
||||
return string(resp.Type), normalized, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func ensurePydanticResponsesStreamTextFields(resp *schemas.BifrostResponsesStreamResponse) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch resp.Type {
|
||||
case schemas.ResponsesStreamResponseTypeOutputTextDelta:
|
||||
if resp.Delta == nil {
|
||||
resp.Delta = bifrost.Ptr("")
|
||||
}
|
||||
case schemas.ResponsesStreamResponseTypeOutputTextDone:
|
||||
if resp.Text == nil {
|
||||
resp.Text = bifrost.Ptr("")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizePydanticResponsesRawStreamChunk(raw interface{}) interface{} {
|
||||
rawString, ok := raw.(string)
|
||||
if !ok {
|
||||
return raw
|
||||
}
|
||||
|
||||
var chunk map[string]interface{}
|
||||
if err := sonic.UnmarshalString(rawString, &chunk); err != nil {
|
||||
return raw
|
||||
}
|
||||
|
||||
changed := false
|
||||
if chunkType, ok := chunk["type"].(string); ok {
|
||||
switch schemas.ResponsesStreamResponseType(chunkType) {
|
||||
case schemas.ResponsesStreamResponseTypeOutputTextDelta:
|
||||
if value, exists := chunk["delta"]; exists && value == nil {
|
||||
chunk["delta"] = ""
|
||||
changed = true
|
||||
}
|
||||
case schemas.ResponsesStreamResponseTypeOutputTextDone:
|
||||
if value, exists := chunk["text"]; exists && value == nil {
|
||||
chunk["text"] = ""
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return raw
|
||||
}
|
||||
|
||||
normalized, err := sonic.MarshalString(chunk)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
2862
transports/bifrost-http/integrations/router.go
Normal file
2862
transports/bifrost-http/integrations/router.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,152 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestCreateHandler_SkipsRequestParserInLargePayloadMode(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
parserCalls := 0
|
||||
|
||||
route := RouteConfig{
|
||||
Type: RouteConfigTypeOpenAI,
|
||||
Path: "/openai/v1/chat/completions",
|
||||
Method: "POST",
|
||||
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
|
||||
return schemas.ChatCompletionRequest
|
||||
},
|
||||
GetRequestTypeInstance: func(ctx context.Context) interface{} {
|
||||
return &struct{}{}
|
||||
},
|
||||
RequestParser: func(ctx *fasthttp.RequestCtx, req interface{}) error {
|
||||
parserCalls++
|
||||
return nil
|
||||
},
|
||||
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
|
||||
return nil, errors.New("stop after parse phase")
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
router := NewGenericRouter(nil, handlerStore, nil, nil, nil)
|
||||
router.SetLargePayloadHook(func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, routeType RouteConfigType) (bool, error) {
|
||||
return true, nil
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
|
||||
ctx.Request.SetBodyString(`{"model":"openai/gpt-4o","messages":[]}`)
|
||||
ctx.SetUserValue(schemas.BifrostContextKeyHTTPRequestType, schemas.ChatCompletionRequest)
|
||||
|
||||
handler := router.createHandler(route)
|
||||
handler(ctx)
|
||||
|
||||
assert.Equal(t, 0, parserCalls)
|
||||
}
|
||||
|
||||
func TestCreateHandler_UsesRequestParserWhenNotInLargePayloadMode(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
parserCalls := 0
|
||||
|
||||
route := RouteConfig{
|
||||
Type: RouteConfigTypeOpenAI,
|
||||
Path: "/openai/v1/chat/completions",
|
||||
Method: "POST",
|
||||
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
|
||||
return schemas.ChatCompletionRequest
|
||||
},
|
||||
GetRequestTypeInstance: func(ctx context.Context) interface{} {
|
||||
return &struct{}{}
|
||||
},
|
||||
RequestParser: func(ctx *fasthttp.RequestCtx, req interface{}) error {
|
||||
parserCalls++
|
||||
return nil
|
||||
},
|
||||
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
|
||||
return nil, errors.New("stop after parse phase")
|
||||
},
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
router := NewGenericRouter(nil, handlerStore, nil, nil, nil)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
|
||||
ctx.Request.SetBodyString(`{"model":"openai/gpt-4o","messages":[]}`)
|
||||
ctx.SetUserValue(schemas.BifrostContextKeyHTTPRequestType, schemas.ChatCompletionRequest)
|
||||
|
||||
handler := router.createHandler(route)
|
||||
handler(ctx)
|
||||
|
||||
assert.Equal(t, 1, parserCalls)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// resolveLargePayloadMetadata tests
|
||||
// ============================================================================
|
||||
|
||||
func TestResolveLargePayloadMetadata_NilContext(t *testing.T) {
|
||||
assert.Nil(t, resolveLargePayloadMetadata(nil))
|
||||
}
|
||||
|
||||
func TestResolveLargePayloadMetadata_SyncPath(t *testing.T) {
|
||||
ctx := schemas.NewBifrostContext(nil, time.Time{})
|
||||
meta := &schemas.LargePayloadMetadata{Model: "gpt-4o"}
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, meta)
|
||||
|
||||
result := resolveLargePayloadMetadata(ctx)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "gpt-4o", result.Model)
|
||||
}
|
||||
|
||||
func TestResolveLargePayloadMetadata_DeferredReady(t *testing.T) {
|
||||
ctx := schemas.NewBifrostContext(nil, time.Time{})
|
||||
ch := make(chan *schemas.LargePayloadMetadata, 1)
|
||||
ch <- &schemas.LargePayloadMetadata{Model: "claude-4"}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDeferredLargePayloadMetadata, (<-chan *schemas.LargePayloadMetadata)(ch))
|
||||
|
||||
result := resolveLargePayloadMetadata(ctx)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "claude-4", result.Model)
|
||||
|
||||
// Verify it was cached in the sync key.
|
||||
cached, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadMetadata).(*schemas.LargePayloadMetadata)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "claude-4", cached.Model)
|
||||
}
|
||||
|
||||
func TestResolveLargePayloadMetadata_DeferredNotReady(t *testing.T) {
|
||||
ctx := schemas.NewBifrostContext(nil, time.Time{})
|
||||
ch := make(chan *schemas.LargePayloadMetadata, 1) // empty, not ready
|
||||
ctx.SetValue(schemas.BifrostContextKeyDeferredLargePayloadMetadata, (<-chan *schemas.LargePayloadMetadata)(ch))
|
||||
|
||||
// Non-blocking: should return nil when channel has no value yet.
|
||||
result := resolveLargePayloadMetadata(ctx)
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestResolveLargePayloadMetadata_SyncTakesPrecedence(t *testing.T) {
|
||||
ctx := schemas.NewBifrostContext(nil, time.Time{})
|
||||
syncMeta := &schemas.LargePayloadMetadata{Model: "sync-model"}
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, syncMeta)
|
||||
|
||||
ch := make(chan *schemas.LargePayloadMetadata, 1)
|
||||
ch <- &schemas.LargePayloadMetadata{Model: "deferred-model"}
|
||||
ctx.SetValue(schemas.BifrostContextKeyDeferredLargePayloadMetadata, (<-chan *schemas.LargePayloadMetadata)(ch))
|
||||
|
||||
result := resolveLargePayloadMetadata(ctx)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "sync-model", result.Model)
|
||||
}
|
||||
373
transports/bifrost-http/integrations/router_test.go
Normal file
373
transports/bifrost-http/integrations/router_test.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"mime/multipart"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/providers/openai"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParsePassthroughBody_MultipartExtractsModelAfterFilePart(t *testing.T) {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
|
||||
fileWriter, err := writer.CreateFormFile("file", "sample.mp3")
|
||||
require.NoError(t, err)
|
||||
_, err = fileWriter.Write([]byte("audio-bytes"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, writer.WriteField("model", "openai/whisper-1"))
|
||||
require.NoError(t, writer.WriteField("stream", "true"))
|
||||
require.NoError(t, writer.Close())
|
||||
|
||||
model, stream := parsePassthroughBody(writer.FormDataContentType(), body.Bytes())
|
||||
assert.Equal(t, "openai/whisper-1", model)
|
||||
assert.True(t, stream)
|
||||
}
|
||||
|
||||
func TestRequestWithSettableExtraParams_OpenAIChatRequest(t *testing.T) {
|
||||
t.Run("SetExtraParams populates both standalone and embedded ExtraParams", func(t *testing.T) {
|
||||
req := &openai.OpenAIChatRequest{}
|
||||
extra := map[string]interface{}{
|
||||
"guardrailConfig": map[string]interface{}{
|
||||
"guardrailIdentifier": "xxx",
|
||||
"guardrailVersion": "1",
|
||||
},
|
||||
}
|
||||
|
||||
rws, ok := interface{}(req).(RequestWithSettableExtraParams)
|
||||
require.True(t, ok, "OpenAIChatRequest should implement RequestWithSettableExtraParams")
|
||||
|
||||
rws.SetExtraParams(extra)
|
||||
|
||||
assert.Equal(t, extra, req.GetExtraParams())
|
||||
assert.Equal(t, extra, req.ChatParameters.ExtraParams, "embedded ChatParameters.ExtraParams should also be set")
|
||||
})
|
||||
|
||||
t.Run("extra params propagate through ToBifrostChatRequest", func(t *testing.T) {
|
||||
req := &openai.OpenAIChatRequest{
|
||||
Model: "bedrock/claude-4-5-sonnet-global",
|
||||
Messages: []openai.OpenAIMessage{},
|
||||
}
|
||||
extra := map[string]interface{}{
|
||||
"guardrailConfig": map[string]interface{}{
|
||||
"guardrailIdentifier": "test-id",
|
||||
"guardrailVersion": "1",
|
||||
},
|
||||
}
|
||||
|
||||
rws := interface{}(req).(RequestWithSettableExtraParams)
|
||||
rws.SetExtraParams(extra)
|
||||
|
||||
ctx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
bifrostReq := req.ToBifrostChatRequest(ctx)
|
||||
|
||||
require.NotNil(t, bifrostReq)
|
||||
require.NotNil(t, bifrostReq.Params)
|
||||
assert.Contains(t, bifrostReq.Params.ExtraParams, "guardrailConfig")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestWithSettableExtraParams_AllOpenAIRequestTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req interface{}
|
||||
}{
|
||||
{"OpenAIChatRequest", &openai.OpenAIChatRequest{}},
|
||||
{"OpenAITextCompletionRequest", &openai.OpenAITextCompletionRequest{}},
|
||||
{"OpenAIResponsesRequest", &openai.OpenAIResponsesRequest{}},
|
||||
{"OpenAIEmbeddingRequest", &openai.OpenAIEmbeddingRequest{}},
|
||||
{"OpenAISpeechRequest", &openai.OpenAISpeechRequest{}},
|
||||
{"OpenAIImageGenerationRequest", &openai.OpenAIImageGenerationRequest{}},
|
||||
{"OpenAIImageEditRequest", &openai.OpenAIImageEditRequest{}},
|
||||
{"OpenAIImageVariationRequest", &openai.OpenAIImageVariationRequest{}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name+" implements RequestWithSettableExtraParams", func(t *testing.T) {
|
||||
rws, ok := tt.req.(RequestWithSettableExtraParams)
|
||||
require.True(t, ok, "%s should implement RequestWithSettableExtraParams", tt.name)
|
||||
|
||||
extra := map[string]interface{}{"test_key": "test_value"}
|
||||
rws.SetExtraParams(extra)
|
||||
|
||||
getter, ok := tt.req.(interface{ GetExtraParams() map[string]interface{} })
|
||||
require.True(t, ok, "%s should implement GetExtraParams", tt.name)
|
||||
assert.Equal(t, extra, getter.GetExtraParams())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraParamsRequiresPassthroughHeader(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
routes := CreateOpenAIRouteConfigs("/openai", handlerStore)
|
||||
|
||||
var chatRoute *RouteConfig
|
||||
for i := range routes {
|
||||
if routes[i].Path == "/openai/v1/chat/completions" {
|
||||
chatRoute = &routes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, chatRoute, "should find /openai/v1/chat/completions route")
|
||||
|
||||
rawBody := []byte(`{
|
||||
"model": "bedrock/claude-4-5-sonnet-global",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
|
||||
"extra_params": {
|
||||
"guardrailConfig": {
|
||||
"guardrailIdentifier": "my-guardrail",
|
||||
"guardrailVersion": "1",
|
||||
"trace": "disabled"
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
t.Run("extra_params NOT extracted without passthrough header", func(t *testing.T) {
|
||||
req := chatRoute.GetRequestTypeInstance(context.Background())
|
||||
err := sonic.Unmarshal(rawBody, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
// Header not set -- simulate router logic
|
||||
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
|
||||
if rws, ok := req.(RequestWithSettableExtraParams); ok {
|
||||
var wrapper struct {
|
||||
ExtraParams map[string]interface{} `json:"extra_params"`
|
||||
}
|
||||
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
|
||||
rws.SetExtraParams(wrapper.ExtraParams)
|
||||
}
|
||||
_ = rws
|
||||
}
|
||||
}
|
||||
|
||||
openaiReq, ok := req.(*openai.OpenAIChatRequest)
|
||||
require.True(t, ok)
|
||||
assert.Empty(t, openaiReq.ChatParameters.ExtraParams,
|
||||
"ExtraParams should be empty when passthrough header is not set")
|
||||
})
|
||||
|
||||
t.Run("extra_params extracted with passthrough header", func(t *testing.T) {
|
||||
req := chatRoute.GetRequestTypeInstance(context.Background())
|
||||
err := sonic.Unmarshal(rawBody, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
|
||||
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
|
||||
if rws, ok := req.(RequestWithSettableExtraParams); ok {
|
||||
var wrapper struct {
|
||||
ExtraParams map[string]interface{} `json:"extra_params"`
|
||||
}
|
||||
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
|
||||
rws.SetExtraParams(wrapper.ExtraParams)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
openaiReq, ok := req.(*openai.OpenAIChatRequest)
|
||||
require.True(t, ok)
|
||||
require.Contains(t, openaiReq.ChatParameters.ExtraParams, "guardrailConfig",
|
||||
"guardrailConfig should be in ExtraParams when passthrough header is set")
|
||||
|
||||
gc, ok := openaiReq.ChatParameters.ExtraParams["guardrailConfig"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "my-guardrail", gc["guardrailIdentifier"])
|
||||
assert.Equal(t, "1", gc["guardrailVersion"])
|
||||
assert.Equal(t, "disabled", gc["trace"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtraParamsPassthrough_NestedStructures(t *testing.T) {
|
||||
rawBody := []byte(`{
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
|
||||
"extra_params": {
|
||||
"custom_param": "value",
|
||||
"another_param": 123,
|
||||
"nested": {
|
||||
"deep_field": "deep_value",
|
||||
"deeper": {"level": 3}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
req := &openai.OpenAIChatRequest{}
|
||||
err := sonic.Unmarshal(rawBody, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
|
||||
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
|
||||
if rws, ok := interface{}(req).(RequestWithSettableExtraParams); ok {
|
||||
var wrapper struct {
|
||||
ExtraParams map[string]interface{} `json:"extra_params"`
|
||||
}
|
||||
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
|
||||
rws.SetExtraParams(wrapper.ExtraParams)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.Len(t, req.ChatParameters.ExtraParams, 3)
|
||||
assert.Equal(t, "value", req.ChatParameters.ExtraParams["custom_param"])
|
||||
assert.Equal(t, float64(123), req.ChatParameters.ExtraParams["another_param"])
|
||||
|
||||
nested, ok := req.ChatParameters.ExtraParams["nested"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "deep_value", nested["deep_field"])
|
||||
}
|
||||
|
||||
func TestExtraParamsPassthrough_EndToEnd(t *testing.T) {
|
||||
rawJSON := []byte(`{
|
||||
"model": "bedrock/claude-4-5-sonnet-global",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
|
||||
"stream": false,
|
||||
"temperature": 0.7,
|
||||
"extra_params": {
|
||||
"guardrailConfig": {
|
||||
"guardrailIdentifier": "my-guardrail",
|
||||
"guardrailVersion": "1",
|
||||
"trace": "disabled"
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
req := &openai.OpenAIChatRequest{}
|
||||
err := sonic.Unmarshal(rawJSON, req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "bedrock/claude-4-5-sonnet-global", req.Model)
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
|
||||
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
|
||||
if rws, ok := interface{}(req).(RequestWithSettableExtraParams); ok {
|
||||
var wrapper struct {
|
||||
ExtraParams map[string]interface{} `json:"extra_params"`
|
||||
}
|
||||
if err := sonic.Unmarshal(rawJSON, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
|
||||
rws.SetExtraParams(wrapper.ExtraParams)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bifrostReq := req.ToBifrostChatRequest(bifrostCtx)
|
||||
|
||||
require.NotNil(t, bifrostReq)
|
||||
require.NotNil(t, bifrostReq.Params)
|
||||
require.Contains(t, bifrostReq.Params.ExtraParams, "guardrailConfig")
|
||||
|
||||
gc, ok := bifrostReq.Params.ExtraParams["guardrailConfig"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "my-guardrail", gc["guardrailIdentifier"])
|
||||
assert.Equal(t, "1", gc["guardrailVersion"])
|
||||
assert.Equal(t, "disabled", gc["trace"])
|
||||
|
||||
assert.NotContains(t, bifrostReq.Params.ExtraParams, "model")
|
||||
assert.NotContains(t, bifrostReq.Params.ExtraParams, "messages")
|
||||
assert.NotContains(t, bifrostReq.Params.ExtraParams, "stream")
|
||||
assert.NotContains(t, bifrostReq.Params.ExtraParams, "temperature")
|
||||
}
|
||||
|
||||
func TestExtraParamsPassthrough_NoExtraParamsKey(t *testing.T) {
|
||||
rawBody := []byte(`{
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]
|
||||
}`)
|
||||
|
||||
req := &openai.OpenAIChatRequest{}
|
||||
err := sonic.Unmarshal(rawBody, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
|
||||
|
||||
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
|
||||
if rws, ok := interface{}(req).(RequestWithSettableExtraParams); ok {
|
||||
var wrapper struct {
|
||||
ExtraParams map[string]interface{} `json:"extra_params"`
|
||||
}
|
||||
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
|
||||
rws.SetExtraParams(wrapper.ExtraParams)
|
||||
}
|
||||
_ = rws
|
||||
}
|
||||
}
|
||||
|
||||
assert.Empty(t, req.ChatParameters.ExtraParams,
|
||||
"ExtraParams should be empty when extra_params key is absent from JSON")
|
||||
}
|
||||
|
||||
// TestExtraParamsSetViaInterfaceMutatesOriginalReq verifies that setting extra
|
||||
// params through the RequestWithSettableExtraParams interface assertion mutates
|
||||
// the original req (interface{}) value. This matters because createHandler
|
||||
// passes req to config.RequestConverter after the extra params block -- both
|
||||
// variables must reference the same underlying struct via pointer semantics.
|
||||
func TestExtraParamsSetViaInterfaceMutatesOriginalReq(t *testing.T) {
|
||||
handlerStore := &mockHandlerStore{allowDirectKeys: true}
|
||||
routes := CreateOpenAIRouteConfigs("/openai", handlerStore)
|
||||
|
||||
var chatRoute *RouteConfig
|
||||
for i := range routes {
|
||||
if routes[i].Path == "/openai/v1/chat/completions" {
|
||||
chatRoute = &routes[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, chatRoute)
|
||||
|
||||
rawBody := []byte(`{
|
||||
"model": "bedrock/claude-4-5-sonnet-global",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
|
||||
"extra_params": {
|
||||
"guardrailConfig": {
|
||||
"guardrailIdentifier": "my-guardrail",
|
||||
"guardrailVersion": "1"
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
// Simulate the exact flow in createHandler:
|
||||
// 1. req is created via GetRequestTypeInstance (returns interface{})
|
||||
// 2. JSON is unmarshalled into req
|
||||
// 3. rws type assertion is used to call SetExtraParams
|
||||
// 4. req (not rws) is passed to RequestConverter downstream
|
||||
req := chatRoute.GetRequestTypeInstance(context.Background()) // returns interface{}
|
||||
err := sonic.Unmarshal(rawBody, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Type-assert and set extra params (same as router code)
|
||||
if rws, ok := req.(RequestWithSettableExtraParams); ok {
|
||||
var wrapper struct {
|
||||
ExtraParams map[string]interface{} `json:"extra_params"`
|
||||
}
|
||||
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
|
||||
rws.SetExtraParams(wrapper.ExtraParams)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that req (the original interface{} variable) was mutated
|
||||
openaiReq, ok := req.(*openai.OpenAIChatRequest)
|
||||
require.True(t, ok)
|
||||
require.Contains(t, openaiReq.ChatParameters.ExtraParams, "guardrailConfig",
|
||||
"original req should be mutated via pointer semantics")
|
||||
|
||||
// Verify the full downstream path: RequestConverter uses req
|
||||
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
|
||||
bifrostReq, err := chatRoute.RequestConverter(bifrostCtx, req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, bifrostReq)
|
||||
require.NotNil(t, bifrostReq.ChatRequest)
|
||||
require.NotNil(t, bifrostReq.ChatRequest.Params)
|
||||
assert.Contains(t, bifrostReq.ChatRequest.Params.ExtraParams, "guardrailConfig",
|
||||
"extra params should propagate through RequestConverter to BifrostChatRequest")
|
||||
}
|
||||
502
transports/bifrost-http/integrations/utils.go
Normal file
502
transports/bifrost-http/integrations/utils.go
Normal file
@@ -0,0 +1,502 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/providers/gemini"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/maximhq/bifrost/framework/kvstore"
|
||||
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
var bifrostContextKeyProvider = schemas.BifrostContextKey("provider")
|
||||
|
||||
var availableIntegrations = []string{
|
||||
"openai",
|
||||
"anthropic",
|
||||
"genai",
|
||||
"litellm",
|
||||
"langchain",
|
||||
"bedrock",
|
||||
"pydantic",
|
||||
"cohere",
|
||||
}
|
||||
|
||||
// newBifrostErrorWithCode is like newBifrostError but sets an explicit HTTP status code.
|
||||
func newBifrostErrorWithCode(err error, message string, statusCode int) *schemas.BifrostError {
|
||||
e := newBifrostError(err, message)
|
||||
e.StatusCode = &statusCode
|
||||
return e
|
||||
}
|
||||
|
||||
// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false.
|
||||
// This helper function reduces code duplication when handling non-Bifrost errors.
|
||||
func newBifrostError(err error, message string) *schemas.BifrostError {
|
||||
if err == nil {
|
||||
return &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: message,
|
||||
Error: err,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// safeGetRequestType safely obtains the request type from a BifrostStreamChunk chunk.
|
||||
// It checks multiple sources in order of preference:
|
||||
// 1. Response ExtraFields if any response is available
|
||||
// 2. BifrostError ExtraFields if error is available and not nil
|
||||
// 3. Falls back to "unknown" if no source is available
|
||||
func safeGetRequestType(chunk *schemas.BifrostStreamChunk) string {
|
||||
if chunk == nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Try to get RequestType from response ExtraFields (preferred source)
|
||||
switch {
|
||||
case chunk.BifrostTextCompletionResponse != nil:
|
||||
return string(chunk.BifrostTextCompletionResponse.ExtraFields.RequestType)
|
||||
case chunk.BifrostChatResponse != nil:
|
||||
return string(chunk.BifrostChatResponse.ExtraFields.RequestType)
|
||||
case chunk.BifrostResponsesStreamResponse != nil:
|
||||
return string(chunk.BifrostResponsesStreamResponse.ExtraFields.RequestType)
|
||||
case chunk.BifrostSpeechStreamResponse != nil:
|
||||
return string(chunk.BifrostSpeechStreamResponse.ExtraFields.RequestType)
|
||||
case chunk.BifrostTranscriptionStreamResponse != nil:
|
||||
return string(chunk.BifrostTranscriptionStreamResponse.ExtraFields.RequestType)
|
||||
}
|
||||
|
||||
// Try to get RequestType from error ExtraFields (fallback)
|
||||
if chunk.BifrostError != nil && chunk.BifrostError.ExtraFields.RequestType != "" {
|
||||
return string(chunk.BifrostError.ExtraFields.RequestType)
|
||||
}
|
||||
|
||||
// Final fallback
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// extractHeadersFromRequest extracts headers from the request and returns them as a map.
|
||||
// It uses the fasthttp.RequestCtx.Header.All() method to iterate over all headers.
|
||||
func extractHeadersFromRequest(ctx *fasthttp.RequestCtx) map[string][]string {
|
||||
headers := make(map[string][]string)
|
||||
|
||||
for key, value := range ctx.Request.Header.All() {
|
||||
keyStr := string(key)
|
||||
headers[keyStr] = append(headers[keyStr], string(value))
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// extractExactPath returns the request path *after* the integration prefix,
|
||||
// preserving the original query string exactly as sent by the client.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// /openai/v1/chat/completions?model=gpt-4o -> v1/chat/completions?model=gpt-4o
|
||||
func extractExactPath(ctx *fasthttp.RequestCtx) string {
|
||||
// ctx.Path() returns only the path (no query) as a []byte backed by fasthttp’s internal buffers.
|
||||
// Treat it as read-only; don’t append to it directly.
|
||||
path := ctx.Path() // e.g. "/openai/v1/chat/completions"
|
||||
|
||||
// Strip the integration prefix only if it’s at the start.
|
||||
for _, integration := range availableIntegrations {
|
||||
if bytes.HasPrefix(path, []byte("/"+integration+"/")) {
|
||||
path = path[len("/"+integration+"/"):]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Raw query string as sent by client (unparsed, preserves ordering/duplicates/encoding).
|
||||
q := ctx.URI().QueryString() // e.g. "model=gpt-4o&stream=true"
|
||||
|
||||
if len(q) == 0 {
|
||||
// No query → just return the (possibly trimmed) path.
|
||||
return string(path)
|
||||
}
|
||||
|
||||
// --- Build "<path>?<query>" efficiently and safely ---
|
||||
//
|
||||
// Why not do: return string(path) + "?" + string(q) ?
|
||||
// - That allocates multiple temporary strings and may copy data more than necessary.
|
||||
//
|
||||
// Why not append into 'path' directly?
|
||||
// - 'path' may alias fasthttp’s internal buffers; mutating/expanding it could corrupt request state.
|
||||
//
|
||||
// We instead allocate a new buffer with exact capacity and copy into it,
|
||||
// staying in []byte until the final string conversion (1 allocation for the new slice).
|
||||
out := make([]byte, 0, len(path)+1+len(q)) // pre-size: path + "?" + query
|
||||
out = append(out, path...) // copy path bytes
|
||||
out = append(out, '?') // separator
|
||||
out = append(out, q...) // copy raw query bytes
|
||||
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// sendStreamError sends an error response for a streaming request that failed before streaming started.
|
||||
// It propagates the provider's HTTP status code and returns a JSON error body (not SSE format),
|
||||
// since no streaming has begun and clients should receive a standard error response.
|
||||
func (g *GenericRouter) sendStreamError(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, config RouteConfig, bifrostErr *schemas.BifrostError) {
|
||||
// Forward provider response headers from context so streaming error responses include them
|
||||
if bifrostCtx != nil {
|
||||
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
|
||||
for key, value := range headers {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set the HTTP status code from the provider error
|
||||
if bifrostErr.StatusCode != nil {
|
||||
ctx.SetStatusCode(*bifrostErr.StatusCode)
|
||||
} else {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
}
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
// Always use the route-level ErrorConverter (not StreamConfig.ErrorConverter) because
|
||||
// sendStreamError returns JSON, not SSE. StreamConfig.ErrorConverter is designed for
|
||||
// in-stream SSE errors (e.g., Anthropic's returns a raw SSE string that would be
|
||||
// double-escaped by JSON marshaling).
|
||||
errorResponse := config.ErrorConverter(bifrostCtx, bifrostErr)
|
||||
|
||||
errorJSON, err := sonic.Marshal(errorResponse)
|
||||
if err != nil {
|
||||
g.logger.Error("failed to marshal error response", "err", err, "path", extractExactPath(ctx))
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetContentType("text/plain; charset=utf-8")
|
||||
ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetBody(errorJSON)
|
||||
}
|
||||
|
||||
// sendError sends an error response with the appropriate status code and JSON body.
|
||||
// It handles different error types (string, error interface, or arbitrary objects).
|
||||
func (g *GenericRouter) sendError(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, errorConverter ErrorConverter, bifrostErr *schemas.BifrostError) {
|
||||
// Forward provider response headers from context so error responses include them
|
||||
if bifrostCtx != nil {
|
||||
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
|
||||
for key, value := range headers {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bifrostErr.StatusCode != nil {
|
||||
ctx.SetStatusCode(*bifrostErr.StatusCode)
|
||||
} else {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
}
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
// Marshal the error for response and log the error for diagnostics
|
||||
responseObj := errorConverter(bifrostCtx, bifrostErr)
|
||||
errorBody, err := sonic.Marshal(responseObj)
|
||||
if err != nil {
|
||||
// Log the marshal failure and return a plain text error
|
||||
g.logger.Error("failed to marshal error response", "err", err, "path", extractExactPath(ctx))
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetContentType("text/plain; charset=utf-8")
|
||||
ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetBody(errorBody)
|
||||
}
|
||||
|
||||
// sendSuccess sends a successful response with HTTP 200 status and JSON body.
|
||||
func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, errorConverter ErrorConverter, response interface{}, extraHeaders map[string]string) {
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetContentType("application/json")
|
||||
|
||||
if extraHeaders != nil {
|
||||
for key, value := range extraHeaders {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
responseBody, err := sonic.Marshal(response)
|
||||
if err != nil {
|
||||
g.sendError(ctx, bifrostCtx, errorConverter, newBifrostError(err, "failed to encode response"))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.SetBody(responseBody)
|
||||
}
|
||||
|
||||
// tryStreamLargeResponse checks if large response mode was activated by the provider,
|
||||
// sets the transport marker, and streams the response directly to the client.
|
||||
// Returns true if the response was handled (caller should return).
|
||||
func (g *GenericRouter) tryStreamLargeResponse(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
|
||||
isLargeResponse, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool)
|
||||
if !ok || !isLargeResponse {
|
||||
return false
|
||||
}
|
||||
// Forward provider response headers before streaming — providers store them in
|
||||
// context via BifrostContextKeyProviderResponseHeaders, but some early-return
|
||||
// branches in the router skip the common footer that normally forwards them.
|
||||
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
|
||||
for key, value := range headers {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
if g.streamLargeResponse(ctx, bifrostCtx) {
|
||||
ctx.SetUserValue(lib.FastHTTPUserValueLargeResponseMode, true)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// streamLargeResponse streams the large response body directly from the upstream provider to the client.
|
||||
// This bypasses the normal serialize → set body path, piping the response bytes unchanged.
|
||||
func (g *GenericRouter) streamLargeResponse(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
|
||||
// Enterprise hook: wrap the reader with Phase B scanning (e.g., usage extraction
|
||||
// from the full response stream) before streaming to client.
|
||||
if g.largeResponseHook != nil {
|
||||
g.largeResponseHook(ctx, bifrostCtx)
|
||||
}
|
||||
|
||||
if !lib.StreamLargeResponseBody(ctx, bifrostCtx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
|
||||
ctx.SetBodyString("large response reader not available")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// extractAndParseFallbacks extracts fallbacks from the integration request and adds them to the BifrostRequest
|
||||
func (g *GenericRouter) extractAndParseFallbacks(req interface{}, bifrostReq *schemas.BifrostRequest) error {
|
||||
// Check if the request has a fallbacks field ([]string)
|
||||
fallbacks, err := g.extractFallbacksFromRequest(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract fallbacks: %w", err)
|
||||
}
|
||||
|
||||
if len(fallbacks) == 0 {
|
||||
return nil // No fallbacks to process
|
||||
}
|
||||
|
||||
provider, _, _ := bifrostReq.GetRequestFields()
|
||||
|
||||
// Parse fallbacks from strings to Fallback structs
|
||||
parsedFallbacks := make([]schemas.Fallback, 0, len(fallbacks))
|
||||
for _, fallbackStr := range fallbacks {
|
||||
if fallbackStr == "" {
|
||||
continue // Skip empty strings
|
||||
}
|
||||
|
||||
// Use ParseModelString to extract provider and model
|
||||
provider, model := schemas.ParseModelString(fallbackStr, provider)
|
||||
|
||||
parsedFallback := schemas.Fallback{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}
|
||||
parsedFallbacks = append(parsedFallbacks, parsedFallback)
|
||||
}
|
||||
|
||||
if len(parsedFallbacks) == 0 {
|
||||
return nil // No valid fallbacks found
|
||||
}
|
||||
|
||||
// Add fallbacks to the main BifrostRequest
|
||||
bifrostReq.SetFallbacks(parsedFallbacks)
|
||||
|
||||
// Also add fallbacks to the specific request type if it exists
|
||||
switch bifrostReq.RequestType {
|
||||
case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest:
|
||||
if bifrostReq.TextCompletionRequest != nil {
|
||||
bifrostReq.TextCompletionRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
|
||||
if bifrostReq.ChatRequest != nil {
|
||||
bifrostReq.ChatRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest:
|
||||
if bifrostReq.ResponsesRequest != nil {
|
||||
bifrostReq.ResponsesRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.EmbeddingRequest:
|
||||
if bifrostReq.EmbeddingRequest != nil {
|
||||
bifrostReq.EmbeddingRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.RerankRequest:
|
||||
if bifrostReq.RerankRequest != nil {
|
||||
bifrostReq.RerankRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.SpeechRequest, schemas.SpeechStreamRequest:
|
||||
if bifrostReq.SpeechRequest != nil {
|
||||
bifrostReq.SpeechRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest:
|
||||
if bifrostReq.TranscriptionRequest != nil {
|
||||
bifrostReq.TranscriptionRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest:
|
||||
if bifrostReq.ImageGenerationRequest != nil {
|
||||
bifrostReq.ImageGenerationRequest.Fallbacks = parsedFallbacks
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractFallbacksFromRequest uses reflection to extract fallbacks field from any request type
|
||||
func (g *GenericRouter) extractFallbacksFromRequest(req interface{}) ([]string, error) {
|
||||
if req == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Try to use reflection to find a "fallbacks" field
|
||||
reqValue := reflect.ValueOf(req)
|
||||
if reqValue.Kind() == reflect.Ptr {
|
||||
reqValue = reqValue.Elem()
|
||||
}
|
||||
|
||||
if reqValue.Kind() != reflect.Struct {
|
||||
return nil, nil // Not a struct, no fallbacks
|
||||
}
|
||||
|
||||
// Look for the "fallbacks" field
|
||||
fallbacksField := reqValue.FieldByName("fallbacks")
|
||||
if !fallbacksField.IsValid() {
|
||||
return nil, nil // No fallbacks field found
|
||||
}
|
||||
|
||||
// Handle different types of fallbacks field
|
||||
switch fallbacksField.Kind() {
|
||||
case reflect.Slice:
|
||||
if fallbacksField.Type().Elem().Kind() == reflect.String {
|
||||
// []string case
|
||||
fallbacks := make([]string, fallbacksField.Len())
|
||||
for i := 0; i < fallbacksField.Len(); i++ {
|
||||
fallbacks[i] = fallbacksField.Index(i).String()
|
||||
}
|
||||
return fallbacks, nil
|
||||
}
|
||||
case reflect.String:
|
||||
// Single string case - treat as one fallback
|
||||
return []string{fallbacksField.String()}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// getVirtualKeyFromBifrostContext extracts the virtual key value from bifrost context.
|
||||
// Returns nil if no VK is present (e.g., direct key mode or no governance).
|
||||
func getVirtualKeyFromBifrostContext(ctx *schemas.BifrostContext) *string {
|
||||
vkValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
|
||||
if vkValue == "" {
|
||||
return nil
|
||||
}
|
||||
return &vkValue
|
||||
}
|
||||
|
||||
// getResultTTLFromHeaderWithDefault extracts the result TTL from the x-bf-async-job-result-ttl header.
|
||||
// Returns the default TTL if the header is not present or invalid.
|
||||
func getResultTTLFromHeaderWithDefault(ctx *fasthttp.RequestCtx, defaultTTL int) int {
|
||||
resultTTL := string(ctx.Request.Header.Peek(schemas.AsyncHeaderResultTTL))
|
||||
if resultTTL == "" {
|
||||
return defaultTTL
|
||||
}
|
||||
resultTTLInt, err := strconv.Atoi(resultTTL)
|
||||
if err != nil || resultTTLInt < 0 {
|
||||
return defaultTTL
|
||||
}
|
||||
return resultTTLInt
|
||||
}
|
||||
|
||||
// isAnthropicAPIKeyAuth checks if the request uses standard API key authentication.
|
||||
// Returns true for API key auth (x-api-key header), false for OAuth (Bearer sk-ant-oat*).
|
||||
// This is required for Claude Code specifically, which may use OAuth authentication.
|
||||
// Default behavior is to assume API mode when neither x-api-key nor OAuth token is present.
|
||||
func isAnthropicAPIKeyAuth(ctx *fasthttp.RequestCtx) bool {
|
||||
// If x-api-key header is present - this is definitely API mode
|
||||
if apiKey := string(ctx.Request.Header.Peek("x-api-key")); apiKey != "" {
|
||||
return true
|
||||
}
|
||||
// Check for OAuth token in Authorization header
|
||||
if authHeader := string(ctx.Request.Header.Peek("Authorization")); authHeader != "" {
|
||||
if strings.HasPrefix(strings.ToLower(authHeader), "bearer sk-ant-oat") {
|
||||
return false // OAuth mode, NOT API
|
||||
}
|
||||
}
|
||||
// Default to API mode
|
||||
return true
|
||||
}
|
||||
|
||||
// resolveLargePayloadMetadata returns metadata from the sync context key,
|
||||
// falling back to a non-blocking read from the deferred channel.
|
||||
// If deferred metadata is resolved, it is cached in the sync key for later readers.
|
||||
func resolveLargePayloadMetadata(bifrostCtx *schemas.BifrostContext) *schemas.LargePayloadMetadata {
|
||||
if bifrostCtx == nil {
|
||||
return nil
|
||||
}
|
||||
if metadata, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargePayloadMetadata).(*schemas.LargePayloadMetadata); ok && metadata != nil {
|
||||
return metadata
|
||||
}
|
||||
ch, ok := bifrostCtx.Value(schemas.BifrostContextKeyDeferredLargePayloadMetadata).(<-chan *schemas.LargePayloadMetadata)
|
||||
if !ok || ch == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case metadata := <-ch:
|
||||
if metadata != nil {
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, metadata)
|
||||
}
|
||||
return metadata
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ParseProviderScopedVideoID parses a provider-scoped video ID in the form "id:provider".
|
||||
// The ID portion is automatically URL-decoded to restore the original ID.
|
||||
func ParseProviderScopedVideoID(videoID string) (schemas.ModelProvider, string, error) {
|
||||
parts := strings.SplitN(videoID, ":", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return "", "", fmt.Errorf("video_id must be in id:provider format")
|
||||
}
|
||||
provider := schemas.ModelProvider(parts[1])
|
||||
rawID := parts[0]
|
||||
|
||||
// URL decode the ID to restore original characters (e.g., %2F -> /)
|
||||
// This handles IDs from all providers that may contain special characters
|
||||
if decoded, err := url.PathUnescape(rawID); err == nil {
|
||||
rawID = decoded
|
||||
}
|
||||
|
||||
return provider, rawID, nil
|
||||
}
|
||||
|
||||
func getProviderFromHeader(ctx *fasthttp.RequestCtx, defaultProvider schemas.ModelProvider) schemas.ModelProvider {
|
||||
providerHeader := string(ctx.Request.Header.Peek("x-model-provider"))
|
||||
if providerHeader == "" {
|
||||
return defaultProvider
|
||||
}
|
||||
return schemas.ModelProvider(providerHeader)
|
||||
}
|
||||
|
||||
func RegisterKVDecoders(store *kvstore.Store) {
|
||||
store.RegisterDecoder("genai_upload_session:", func(data []byte) (any, error) {
|
||||
var v gemini.GeminiResumableUploadSession
|
||||
return &v, sonic.Unmarshal(data, &v)
|
||||
})
|
||||
}
|
||||
277
transports/bifrost-http/integrations/utils_test.go
Normal file
277
transports/bifrost-http/integrations/utils_test.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package integrations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/providers/anthropic"
|
||||
"github.com/maximhq/bifrost/core/providers/bedrock"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// testLogger implements schemas.Logger for testing (all no-ops)
|
||||
type testLogger struct{}
|
||||
|
||||
func (t *testLogger) Debug(msg string, args ...any) {}
|
||||
func (t *testLogger) Info(msg string, args ...any) {}
|
||||
func (t *testLogger) Warn(msg string, args ...any) {}
|
||||
func (t *testLogger) Error(msg string, args ...any) {}
|
||||
func (t *testLogger) Fatal(msg string, args ...any) {}
|
||||
func (t *testLogger) SetLevel(level schemas.LogLevel) {}
|
||||
func (t *testLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
|
||||
func (t *testLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
var _ schemas.Logger = (*testLogger)(nil)
|
||||
|
||||
func ptr(i int) *int {
|
||||
return &i
|
||||
}
|
||||
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func newTestGenericRouter() *GenericRouter {
|
||||
return NewGenericRouter(nil, &mockHandlerStore{}, nil, nil, &testLogger{})
|
||||
}
|
||||
|
||||
func newTestBifrostContext() *schemas.BifrostContext {
|
||||
return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
}
|
||||
|
||||
// TestSendStreamError_PropagatesProviderStatusCode verifies that sendStreamError
|
||||
// sets the HTTP status code from the provider's BifrostError.StatusCode field.
|
||||
// All three providers (OpenAI, Anthropic, Bedrock) return actual HTTP error codes
|
||||
// for pre-stream errors, so Bifrost must propagate them faithfully.
|
||||
func TestSendStreamError_PropagatesProviderStatusCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode *int
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "provider 400 - Bedrock ValidationException / OpenAI invalid_request_error",
|
||||
statusCode: ptr(400),
|
||||
expectedStatusCode: 400,
|
||||
},
|
||||
{
|
||||
name: "provider 429 - rate limiting (all providers)",
|
||||
statusCode: ptr(429),
|
||||
expectedStatusCode: 429,
|
||||
},
|
||||
{
|
||||
name: "provider 503 - Bedrock ServiceUnavailableException",
|
||||
statusCode: ptr(503),
|
||||
expectedStatusCode: 503,
|
||||
},
|
||||
{
|
||||
name: "provider 529 - Anthropic overloaded_error",
|
||||
statusCode: ptr(529),
|
||||
expectedStatusCode: 529,
|
||||
},
|
||||
{
|
||||
name: "nil StatusCode defaults to 500",
|
||||
statusCode: nil,
|
||||
expectedStatusCode: 500,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := newTestGenericRouter()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
bifrostCtx := newTestBifrostContext()
|
||||
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
StatusCode: tt.statusCode,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "test error",
|
||||
},
|
||||
}
|
||||
|
||||
config := RouteConfig{
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
|
||||
|
||||
assert.Equal(t, tt.expectedStatusCode, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "application/json", string(ctx.Response.Header.ContentType()))
|
||||
|
||||
body := string(ctx.Response.Body())
|
||||
assert.True(t, sonic.Valid(ctx.Response.Body()), "response body should be valid JSON, got: %s", body)
|
||||
assert.False(t, strings.HasPrefix(body, "data: "), "response should not be SSE format")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendStreamError_OpenAIErrorFormat verifies the response body matches the
|
||||
// OpenAI error format. OpenAI's ErrorConverter returns *schemas.BifrostError directly,
|
||||
// which serializes to {"is_bifrost_error":false,"status_code":400,"error":{...}}.
|
||||
func TestSendStreamError_OpenAIErrorFormat(t *testing.T) {
|
||||
router := newTestGenericRouter()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
bifrostCtx := newTestBifrostContext()
|
||||
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: ptr(400),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: strPtr("invalid_request_error"),
|
||||
Message: "content is empty",
|
||||
},
|
||||
}
|
||||
|
||||
config := RouteConfig{
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
|
||||
|
||||
assert.Equal(t, 400, ctx.Response.StatusCode())
|
||||
|
||||
// Unmarshal and verify the structure
|
||||
var result map[string]interface{}
|
||||
err := sonic.Unmarshal(ctx.Response.Body(), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, result, "is_bifrost_error")
|
||||
assert.Contains(t, result, "status_code")
|
||||
assert.Contains(t, result, "error")
|
||||
assert.Equal(t, false, result["is_bifrost_error"])
|
||||
|
||||
errorObj, ok := result["error"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "invalid_request_error", errorObj["type"])
|
||||
assert.Equal(t, "content is empty", errorObj["message"])
|
||||
}
|
||||
|
||||
// TestSendStreamError_AnthropicErrorFormat verifies the response body matches the
|
||||
// Anthropic error format: {"type":"error","error":{"type":"...","message":"..."}}.
|
||||
// Critically, it also verifies that the StreamConfig.ErrorConverter (which returns
|
||||
// raw SSE strings) is NOT used — sendStreamError must use the route-level ErrorConverter.
|
||||
func TestSendStreamError_AnthropicErrorFormat(t *testing.T) {
|
||||
router := newTestGenericRouter()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
bifrostCtx := newTestBifrostContext()
|
||||
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
StatusCode: ptr(429),
|
||||
Error: &schemas.ErrorField{
|
||||
Type: strPtr("rate_limit_error"),
|
||||
Message: "rate limited",
|
||||
},
|
||||
}
|
||||
|
||||
config := RouteConfig{
|
||||
// Route-level: returns JSON-marshallable *AnthropicMessageError
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return anthropic.ToAnthropicChatCompletionError(err)
|
||||
},
|
||||
// Stream-level: returns raw SSE string — should NOT be used by sendStreamError
|
||||
StreamConfig: &StreamConfig{
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return anthropic.ToAnthropicResponsesStreamError(err)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
|
||||
|
||||
assert.Equal(t, 429, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "application/json", string(ctx.Response.Header.ContentType()))
|
||||
|
||||
body := string(ctx.Response.Body())
|
||||
|
||||
// Must NOT contain SSE markers — that would mean StreamConfig.ErrorConverter was used
|
||||
assert.NotContains(t, body, "event: error", "response should not contain SSE event markers")
|
||||
|
||||
// Unmarshal and verify Anthropic error structure
|
||||
var result anthropic.AnthropicMessageError
|
||||
err := sonic.Unmarshal(ctx.Response.Body(), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "error", result.Type)
|
||||
assert.Equal(t, "rate_limit_error", result.Error.Type)
|
||||
assert.Equal(t, "rate limited", result.Error.Message)
|
||||
}
|
||||
|
||||
// TestSendStreamError_BedrockErrorFormat verifies the response body matches the
|
||||
// Bedrock error format: {"__type":"ValidationException","message":"..."}.
|
||||
func TestSendStreamError_BedrockErrorFormat(t *testing.T) {
|
||||
router := newTestGenericRouter()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
bifrostCtx := newTestBifrostContext()
|
||||
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
StatusCode: ptr(400),
|
||||
Error: &schemas.ErrorField{
|
||||
Code: strPtr("ValidationException"),
|
||||
Message: "validation error",
|
||||
},
|
||||
}
|
||||
|
||||
config := RouteConfig{
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return bedrock.ToBedrockError(err)
|
||||
},
|
||||
}
|
||||
|
||||
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
|
||||
|
||||
assert.Equal(t, 400, ctx.Response.StatusCode())
|
||||
|
||||
// Unmarshal and verify Bedrock error structure
|
||||
var result bedrock.BedrockError
|
||||
err := sonic.Unmarshal(ctx.Response.Body(), &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "ValidationException", result.Type)
|
||||
assert.Equal(t, "validation error", result.Message)
|
||||
}
|
||||
|
||||
// TestSendStreamError_ForwardsProviderHeaders verifies that provider response headers
|
||||
// stored in the BifrostContext are forwarded to the HTTP response. This ensures
|
||||
// clients receive provider-specific headers (e.g., x-amzn-requestid for Bedrock,
|
||||
// x-request-id for Anthropic) even in error scenarios.
|
||||
func TestSendStreamError_ForwardsProviderHeaders(t *testing.T) {
|
||||
router := newTestGenericRouter()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
bifrostCtx := newTestBifrostContext()
|
||||
|
||||
// Set provider response headers on the context
|
||||
bifrostCtx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, map[string]string{
|
||||
"x-amzn-requestid": "req-123",
|
||||
"x-amzn-errortype": "ValidationException",
|
||||
})
|
||||
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
StatusCode: ptr(400),
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "validation error",
|
||||
},
|
||||
}
|
||||
|
||||
config := RouteConfig{
|
||||
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
|
||||
|
||||
assert.Equal(t, 400, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "req-123", string(ctx.Response.Header.Peek("x-amzn-requestid")))
|
||||
assert.Equal(t, "ValidationException", string(ctx.Response.Header.Peek("x-amzn-errortype")))
|
||||
}
|
||||
Reference in New Issue
Block a user