first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,87 @@
package integrations
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFilterVertexUnsupportedBetaHeaders(t *testing.T) {
t.Run("filters known exact header values", func(t *testing.T) {
headers := map[string][]string{
"anthropic-beta": {"advanced-tool-use-2025-11-20,structured-outputs-2025-11-13,mcp-client-2025-04-04,prompt-caching-scope-2026-01-05"},
}
result := filterVertexUnsupportedBetaHeaders(headers)
_, ok := result["anthropic-beta"]
assert.False(t, ok, "all unsupported beta headers should be removed, leaving no anthropic-beta key")
})
t.Run("filters bumped date variants", func(t *testing.T) {
// Simulate Anthropic bumping version dates in the future
headers := map[string][]string{
"anthropic-beta": {"structured-outputs-2025-12-15,advanced-tool-use-2026-03-01,mcp-client-2026-01-01,prompt-caching-scope-2027-06-30"},
}
result := filterVertexUnsupportedBetaHeaders(headers)
_, ok := result["anthropic-beta"]
assert.False(t, ok, "bumped-date variants of unsupported headers should also be filtered")
})
t.Run("passes through unrelated beta headers", func(t *testing.T) {
headers := map[string][]string{
"anthropic-beta": {"interleaved-thinking-2025-05-14,files-api-2025-04-14"},
}
result := filterVertexUnsupportedBetaHeaders(headers)
vals, ok := result["anthropic-beta"]
assert.True(t, ok, "unrelated beta headers should be preserved")
assert.Equal(t, []string{"interleaved-thinking-2025-05-14,files-api-2025-04-14"}, vals)
})
t.Run("filters unsupported and keeps supported in mixed list", func(t *testing.T) {
headers := map[string][]string{
"anthropic-beta": {"interleaved-thinking-2025-05-14,structured-outputs-2025-11-13,files-api-2025-04-14,mcp-client-2025-04-04"},
}
result := filterVertexUnsupportedBetaHeaders(headers)
vals, ok := result["anthropic-beta"]
assert.True(t, ok, "supported beta headers should be preserved")
assert.Equal(t, []string{"interleaved-thinking-2025-05-14,files-api-2025-04-14"}, vals)
})
t.Run("filters bumped unsupported mixed with supported", func(t *testing.T) {
// Future-proof: bumped dates should still be filtered
headers := map[string][]string{
"anthropic-beta": {"structured-outputs-2026-01-01,interleaved-thinking-2025-05-14,advanced-tool-use-2026-06-15"},
}
result := filterVertexUnsupportedBetaHeaders(headers)
vals, ok := result["anthropic-beta"]
assert.True(t, ok, "supported beta headers should be preserved even when mixed with bumped unsupported ones")
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, vals)
})
t.Run("returns headers unchanged when no anthropic-beta key present", func(t *testing.T) {
headers := map[string][]string{
"content-type": {"application/json"},
}
result := filterVertexUnsupportedBetaHeaders(headers)
assert.Equal(t, headers, result)
})
t.Run("handles empty anthropic-beta value gracefully", func(t *testing.T) {
headers := map[string][]string{
"anthropic-beta": {""},
}
result := filterVertexUnsupportedBetaHeaders(headers)
// Empty string after trimming is not an unsupported header, but it is also empty — key should be removed
_, ok := result["anthropic-beta"]
assert.False(t, ok, "empty beta header list should result in key removal")
})
t.Run("case-insensitive key matching for Anthropic-Beta header", func(t *testing.T) {
headers := map[string][]string{
"Anthropic-Beta": {"structured-outputs-2025-11-13,interleaved-thinking-2025-05-14"},
}
result := filterVertexUnsupportedBetaHeaders(headers)
vals, ok := result["Anthropic-Beta"]
assert.True(t, ok, "header key casing should be preserved and matching should be case-insensitive")
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, vals)
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,921 @@
package integrations
import (
"context"
"testing"
"github.com/maximhq/bifrost/core/providers/bedrock"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/kvstore"
"github.com/maximhq/bifrost/framework/logstore"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
// mockHandlerStore implements lib.HandlerStore for testing
type mockHandlerStore struct {
allowDirectKeys bool
headerMatcher *lib.HeaderMatcher
availableProviders []schemas.ModelProvider
mcpHeaderCombinedAllowlist schemas.WhiteList
}
func (m *mockHandlerStore) ShouldAllowDirectKeys() bool {
return m.allowDirectKeys
}
func (m *mockHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher {
return m.headerMatcher
}
func (m *mockHandlerStore) GetAvailableProviders() []schemas.ModelProvider {
return m.availableProviders
}
func (m *mockHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor {
return nil
}
func (m *mockHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor {
return nil
}
func (m *mockHandlerStore) GetAsyncJobResultTTL() int {
return 3600
}
func (m *mockHandlerStore) GetKVStore() *kvstore.Store {
return nil
}
func (m *mockHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList {
return m.mcpHeaderCombinedAllowlist
}
// Ensure mockHandlerStore implements lib.HandlerStore
var _ lib.HandlerStore = (*mockHandlerStore)(nil)
func Test_parseS3URI(t *testing.T) {
tests := []struct {
name string
uri string
wantBucket string
wantKey string
}{
{
name: "full S3 URI with key",
uri: "s3://my-bucket/path/to/file.jsonl",
wantBucket: "my-bucket",
wantKey: "path/to/file.jsonl",
},
{
name: "S3 URI with bucket only",
uri: "s3://my-bucket/",
wantBucket: "my-bucket",
wantKey: "",
},
{
name: "S3 URI with bucket no trailing slash",
uri: "s3://my-bucket",
wantBucket: "my-bucket",
wantKey: "",
},
{
name: "plain bucket name",
uri: "my-bucket",
wantBucket: "my-bucket",
wantKey: "",
},
{
name: "S3 URI with nested key",
uri: "s3://bucket-name/folder1/folder2/file.txt",
wantBucket: "bucket-name",
wantKey: "folder1/folder2/file.txt",
},
{
name: "empty string",
uri: "",
wantBucket: "",
wantKey: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotBucket, gotKey := parseS3URI(tt.uri)
assert.Equal(t, tt.wantBucket, gotBucket, "bucket mismatch")
assert.Equal(t, tt.wantKey, gotKey, "key mismatch")
})
}
}
func Test_createBedrockRouteConfigs(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
routes := CreateBedrockRouteConfigs("/bedrock", handlerStore)
assert.Len(t, routes, 6, "should have 6 bedrock routes")
expectedRoutes := []struct {
path string
method string
}{
{"/bedrock/model/{modelId}/converse", "POST"},
{"/bedrock/model/{modelId}/converse-stream", "POST"},
{"/bedrock/model/{modelId}/invoke-with-response-stream", "POST"},
{"/bedrock/model/{modelId}/invoke", "POST"},
{"/bedrock/rerank", "POST"},
{"/bedrock/model/{modelId}/count-tokens", "POST"},
}
for i, expected := range expectedRoutes {
assert.Equal(t, expected.path, routes[i].Path, "route %d path mismatch", i)
assert.Equal(t, expected.method, routes[i].Method, "route %d method mismatch", i)
assert.Equal(t, RouteConfigTypeBedrock, routes[i].Type, "route %d type mismatch", i)
assert.NotNil(t, routes[i].GetRequestTypeInstance, "route %d GetRequestTypeInstance should not be nil", i)
assert.NotNil(t, routes[i].ErrorConverter, "route %d ErrorConverter should not be nil", i)
}
}
func Test_createBedrockConverseRouteConfig(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
route := createBedrockConverseRouteConfig("/bedrock", handlerStore)
assert.Equal(t, "/bedrock/model/{modelId}/converse", route.Path)
assert.Equal(t, "POST", route.Method)
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
assert.NotNil(t, route.GetRequestTypeInstance)
assert.NotNil(t, route.RequestConverter)
assert.NotNil(t, route.ResponsesResponseConverter)
assert.NotNil(t, route.ErrorConverter)
assert.NotNil(t, route.PreCallback)
// Verify request instance type
reqInstance := route.GetRequestTypeInstance(context.Background())
_, ok := reqInstance.(*bedrock.BedrockConverseRequest)
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockConverseRequest")
}
func Test_createBedrockConverseStreamRouteConfig(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
route := createBedrockConverseStreamRouteConfig("/bedrock", handlerStore)
assert.Equal(t, "/bedrock/model/{modelId}/converse-stream", route.Path)
assert.Equal(t, "POST", route.Method)
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
assert.NotNil(t, route.StreamConfig)
assert.NotNil(t, route.StreamConfig.ResponsesStreamResponseConverter)
// Verify request instance type
reqInstance := route.GetRequestTypeInstance(context.Background())
_, ok := reqInstance.(*bedrock.BedrockConverseRequest)
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockConverseRequest")
}
func Test_createBedrockInvokeRouteConfig(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
route := createBedrockInvokeRouteConfig("/bedrock", handlerStore)
assert.Equal(t, "/bedrock/model/{modelId}/invoke", route.Path)
assert.Equal(t, "POST", route.Method)
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
assert.NotNil(t, route.TextResponseConverter)
assert.NotNil(t, route.ResponsesResponseConverter)
// Verify request instance type
reqInstance := route.GetRequestTypeInstance(context.Background())
_, ok := reqInstance.(*bedrock.BedrockInvokeRequest)
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockInvokeRequest")
}
func Test_createBedrockInvokeWithResponseStreamRouteConfig(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
route := createBedrockInvokeWithResponseStreamRouteConfig("/bedrock", handlerStore)
assert.Equal(t, "/bedrock/model/{modelId}/invoke-with-response-stream", route.Path)
assert.Equal(t, "POST", route.Method)
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
assert.NotNil(t, route.StreamConfig)
assert.NotNil(t, route.StreamConfig.TextStreamResponseConverter)
assert.NotNil(t, route.StreamConfig.ResponsesStreamResponseConverter)
// Verify request instance type
reqInstance := route.GetRequestTypeInstance(context.Background())
_, ok := reqInstance.(*bedrock.BedrockInvokeRequest)
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockInvokeRequest")
}
func Test_createBedrockRerankRouteConfig(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
route := createBedrockRerankRouteConfig("/bedrock", handlerStore)
assert.Equal(t, "/bedrock/rerank", route.Path)
assert.Equal(t, "POST", route.Method)
assert.Equal(t, RouteConfigTypeBedrock, route.Type)
assert.NotNil(t, route.GetHTTPRequestType)
assert.Equal(t, schemas.RerankRequest, route.GetHTTPRequestType(nil))
assert.NotNil(t, route.GetRequestTypeInstance)
assert.NotNil(t, route.RequestConverter)
assert.NotNil(t, route.RerankResponseConverter)
assert.NotNil(t, route.ErrorConverter)
assert.NotNil(t, route.PreCallback)
// Verify request instance type
reqInstance := route.GetRequestTypeInstance(context.Background())
_, ok := reqInstance.(*bedrock.BedrockRerankRequest)
assert.True(t, ok, "GetRequestTypeInstance should return *bedrock.BedrockRerankRequest")
}
func Test_createBedrockRerankResponseConverterUsesRawResponse(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
route := createBedrockRerankRouteConfig("/bedrock", handlerStore)
require.NotNil(t, route.RerankResponseConverter)
raw := map[string]interface{}{"results": []interface{}{}}
resp := &schemas.BifrostRerankResponse{
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Bedrock,
RawResponse: raw,
},
}
converted, err := route.RerankResponseConverter(nil, resp)
require.NoError(t, err)
assert.Equal(t, raw, converted)
}
func Test_createBedrockRerankRouteRequestConverter(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
route := createBedrockRerankRouteConfig("/bedrock", handlerStore)
require.NotNil(t, route.RequestConverter)
topN := 1
req := &bedrock.BedrockRerankRequest{
Queries: []bedrock.BedrockRerankQuery{
{
Type: "TEXT",
TextQuery: bedrock.BedrockRerankTextRef{Text: "capital of france"},
},
},
Sources: []bedrock.BedrockRerankSource{
{
Type: "INLINE",
InlineDocumentSource: bedrock.BedrockRerankInlineSource{
Type: "TEXT",
TextDocument: bedrock.BedrockRerankTextValue{Text: "Paris is capital of France"},
},
},
},
RerankingConfiguration: bedrock.BedrockRerankingConfiguration{
Type: "BEDROCK_RERANKING_MODEL",
BedrockRerankingConfiguration: bedrock.BedrockRerankingModelConfiguration{
NumberOfResults: &topN,
ModelConfiguration: bedrock.BedrockRerankModelConfiguration{
ModelARN: "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0",
},
},
},
}
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bifrostReq, err := route.RequestConverter(bifrostCtx, req)
require.NoError(t, err)
require.NotNil(t, bifrostReq)
require.NotNil(t, bifrostReq.RerankRequest)
assert.Equal(t, schemas.Bedrock, bifrostReq.RerankRequest.Provider)
assert.Equal(t, "capital of france", bifrostReq.RerankRequest.Query)
require.Len(t, bifrostReq.RerankRequest.Documents, 1)
assert.Equal(t, "Paris is capital of France", bifrostReq.RerankRequest.Documents[0].Text)
require.NotNil(t, bifrostReq.RerankRequest.Params)
require.NotNil(t, bifrostReq.RerankRequest.Params.TopN)
assert.Equal(t, 1, *bifrostReq.RerankRequest.Params.TopN)
}
func Test_createBedrockRouteConfigsIncludesRerankForCompositePrefixes(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
prefixes := []string{"/litellm", "/langchain", "/pydanticai"}
for _, prefix := range prefixes {
routes := CreateBedrockRouteConfigs(prefix, handlerStore)
found := false
for _, route := range routes {
if route.Path == prefix+"/rerank" && route.Method == "POST" {
found = true
break
}
}
assert.Truef(t, found, "expected rerank route for prefix %s", prefix)
}
}
func Test_createBedrockBatchRouteConfigs(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
routes := createBedrockBatchRouteConfigs("/bedrock", handlerStore)
assert.Len(t, routes, 4, "should have 4 batch routes")
expectedRoutes := []struct {
path string
method string
}{
{"/bedrock/model-invocation-job", "POST"},
{"/bedrock/model-invocation-jobs", "GET"},
{"/bedrock/model-invocation-job/{job_arn}", "GET"},
{"/bedrock/model-invocation-job/{job_arn}/stop", "POST"},
}
for i, expected := range expectedRoutes {
assert.Equal(t, expected.path, routes[i].Path, "batch route %d path mismatch", i)
assert.Equal(t, expected.method, routes[i].Method, "batch route %d method mismatch", i)
assert.Equal(t, RouteConfigTypeBedrock, routes[i].Type, "batch route %d type mismatch", i)
assert.NotNil(t, routes[i].GetRequestTypeInstance, "batch route %d GetRequestTypeInstance should not be nil", i)
assert.NotNil(t, routes[i].BatchRequestConverter, "batch route %d BatchCreateRequestConverter should not be nil", i)
assert.NotNil(t, routes[i].ErrorConverter, "batch route %d ErrorConverter should not be nil", i)
assert.NotNil(t, routes[i].PreCallback, "batch route %d PreCallback should not be nil", i)
}
}
func Test_createBedrockFilesRouteConfigs(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
routes := createBedrockFilesRouteConfigs("/bedrock/files", handlerStore)
assert.Len(t, routes, 5, "should have 5 file routes")
expectedRoutes := []struct {
path string
method string
}{
{"/bedrock/files/{bucket}/{key:*}", "PUT"},
{"/bedrock/files/{bucket}/{key:*}", "GET"},
{"/bedrock/files/{bucket}/{key:*}", "HEAD"},
{"/bedrock/files/{bucket}/{key:*}", "DELETE"},
{"/bedrock/files/{bucket}", "GET"},
}
for i, expected := range expectedRoutes {
assert.Equal(t, expected.path, routes[i].Path, "file route %d path mismatch", i)
assert.Equal(t, expected.method, routes[i].Method, "file route %d method mismatch", i)
assert.Equal(t, RouteConfigTypeBedrock, routes[i].Type, "file route %d type mismatch", i)
assert.NotNil(t, routes[i].GetRequestTypeInstance, "file route %d GetRequestTypeInstance should not be nil", i)
assert.NotNil(t, routes[i].ErrorConverter, "file route %d ErrorConverter should not be nil", i)
}
}
func Test_parseS3PutObjectRequest(t *testing.T) {
tests := []struct {
name string
bucket string
key string
body []byte
wantErr bool
wantBucket string
wantKey string
wantFilename string
}{
{
name: "valid request",
bucket: "my-bucket",
key: "folder/file.jsonl",
body: []byte(`{"test": "data"}`),
wantErr: false,
wantBucket: "my-bucket",
wantKey: "folder/file.jsonl",
wantFilename: "file.jsonl",
},
{
name: "simple key without folder",
bucket: "bucket",
key: "file.txt",
body: []byte("content"),
wantErr: false,
wantBucket: "bucket",
wantKey: "file.txt",
wantFilename: "file.txt",
},
{
name: "missing bucket",
bucket: "",
key: "file.txt",
body: []byte("content"),
wantErr: true,
},
{
name: "missing key",
bucket: "bucket",
key: "",
body: []byte("content"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetBody(tt.body)
if tt.bucket != "" {
ctx.SetUserValue("bucket", tt.bucket)
}
if tt.key != "" {
ctx.SetUserValue("key", tt.key)
}
req := &bedrock.BedrockFileUploadRequest{}
err := parseS3PutObjectRequest(ctx, req)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantBucket, req.Bucket)
assert.Equal(t, tt.wantKey, req.Key)
assert.Equal(t, tt.wantFilename, req.Filename)
assert.Equal(t, "batch", req.Purpose)
assert.Equal(t, tt.body, req.Body)
})
}
}
func Test_parseS3PutObjectRequest_invalidType(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue("bucket", "bucket")
ctx.SetUserValue("key", "key")
// Pass wrong type
err := parseS3PutObjectRequest(ctx, "invalid type")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid request type")
}
func Test_s3PutObjectPostCallback(t *testing.T) {
tests := []struct {
name string
response interface{}
wantStatus int
wantETag string
}{
{
name: "valid response with ID",
response: &schemas.BifrostFileUploadResponse{
ID: "file-123",
},
wantStatus: 200,
wantETag: "\"file-123\"",
},
{
name: "nil response",
response: nil,
wantStatus: 200,
wantETag: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
err := s3PutObjectPostCallback(ctx, nil, tt.response)
assert.NoError(t, err)
assert.Equal(t, tt.wantStatus, ctx.Response.StatusCode())
assert.Equal(t, "application/xml", string(ctx.Response.Header.ContentType()))
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
if tt.wantETag != "" {
assert.Equal(t, tt.wantETag, string(ctx.Response.Header.Peek("ETag")))
}
})
}
}
func Test_s3GetObjectPostCallback(t *testing.T) {
tests := []struct {
name string
response interface{}
wantContentType string
wantLength string
wantETag string
}{
{
name: "valid response",
response: &schemas.BifrostFileContentResponse{
Content: []byte("test content"),
ContentType: "application/json",
FileID: "file-456",
},
wantContentType: "application/json",
wantLength: "12",
wantETag: "\"file-456\"",
},
{
name: "nil response",
response: nil,
wantContentType: "",
wantLength: "",
wantETag: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
err := s3GetObjectPostCallback(ctx, nil, tt.response)
assert.NoError(t, err)
if tt.wantContentType != "" {
assert.Equal(t, tt.wantContentType, string(ctx.Response.Header.Peek("Content-Type")))
assert.Equal(t, tt.wantLength, string(ctx.Response.Header.Peek("Content-Length")))
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
}
if tt.wantETag != "" {
assert.Equal(t, tt.wantETag, string(ctx.Response.Header.Peek("ETag")))
}
})
}
}
func Test_s3HeadObjectPostCallback(t *testing.T) {
tests := []struct {
name string
response interface{}
wantStatus int
wantLength string
wantETag string
}{
{
name: "valid response",
response: &schemas.BifrostFileRetrieveResponse{
ID: "file-789",
Bytes: 1024,
},
wantStatus: 200,
wantLength: "1024",
wantETag: "\"file-789\"",
},
{
name: "nil response",
response: nil,
wantStatus: 200,
wantLength: "",
wantETag: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
err := s3HeadObjectPostCallback(ctx, nil, tt.response)
assert.NoError(t, err)
assert.Equal(t, tt.wantStatus, ctx.Response.StatusCode())
if tt.wantLength != "" {
assert.Equal(t, "application/octet-stream", string(ctx.Response.Header.Peek("Content-Type")))
assert.Equal(t, tt.wantLength, string(ctx.Response.Header.Peek("Content-Length")))
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
assert.Equal(t, tt.wantETag, string(ctx.Response.Header.Peek("ETag")))
}
})
}
}
func Test_s3DeleteObjectPostCallback(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
err := s3DeleteObjectPostCallback(ctx, nil, nil)
assert.NoError(t, err)
assert.Equal(t, 204, ctx.Response.StatusCode())
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
}
func Test_s3ListObjectsV2PostCallback(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
err := s3ListObjectsV2PostCallback(ctx, nil, nil)
assert.NoError(t, err)
assert.Equal(t, "application/xml", string(ctx.Response.Header.ContentType()))
assert.Equal(t, "bifrost", string(ctx.Response.Header.Peek("x-amz-request-id")))
}
func Test_extractBedrockBatchListQueryParams(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: false}
tests := []struct {
name string
queryParams map[string]string
wantMaxResults int
wantNextToken string
wantStatus string
wantName string
}{
{
name: "all params",
queryParams: map[string]string{
"maxResults": "50",
"nextToken": "token123",
"statusEquals": "InProgress",
"nameContains": "test-job",
},
wantMaxResults: 50,
wantNextToken: "token123",
wantStatus: "InProgress",
wantName: "test-job",
},
{
name: "no params",
queryParams: map[string]string{},
wantMaxResults: 0,
wantNextToken: "",
wantStatus: "",
wantName: "",
},
{
name: "invalid maxResults",
queryParams: map[string]string{
"maxResults": "invalid",
},
wantMaxResults: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
for k, v := range tt.queryParams {
ctx.QueryArgs().Add(k, v)
}
req := &bedrock.BedrockBatchListRequest{}
callback := extractBedrockBatchListQueryParams(handlerStore)
bifrostCtx := createTestBifrostContext()
err := callback(ctx, bifrostCtx, req)
assert.NoError(t, err)
assert.Equal(t, tt.wantMaxResults, req.MaxResults)
assert.Equal(t, tt.wantStatus, req.StatusEquals)
assert.Equal(t, tt.wantName, req.NameContains)
if tt.wantNextToken != "" {
assert.NotNil(t, req.NextToken)
assert.Equal(t, tt.wantNextToken, *req.NextToken)
}
})
}
}
func Test_extractBedrockJobArnFromPath(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: false}
tests := []struct {
name string
jobArn interface{}
provider schemas.ModelProvider
wantErr bool
wantJobArn string
errContains string
}{
{
name: "valid job ARN for Bedrock",
jobArn: "arn:aws:bedrock:us-east-1:123456789012:batch:job-123",
provider: schemas.Bedrock,
wantErr: false,
wantJobArn: "arn:aws:bedrock:us-east-1:123456789012:batch:job-123",
},
{
name: "URL encoded job ARN",
jobArn: "arn%3Aaws%3Abedrock%3Aus-east-1%3A123456789012%3Abatch%3Ajob-123",
provider: schemas.Bedrock,
wantErr: false,
wantJobArn: "arn:aws:bedrock:us-east-1:123456789012:batch:job-123",
},
{
name: "non-Bedrock provider strips ARN prefix",
jobArn: "arn:aws:bedrock:us-east-1:444444444444:batch:job-456",
provider: schemas.OpenAI,
wantErr: false,
wantJobArn: "job-456",
},
{
name: "missing job_arn",
jobArn: nil,
provider: schemas.Bedrock,
wantErr: true,
errContains: "job_arn is required",
},
{
name: "empty job_arn",
jobArn: "",
provider: schemas.Bedrock,
wantErr: true,
errContains: "job_arn must be a non-empty string",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
if tt.jobArn != nil {
ctx.SetUserValue("job_arn", tt.jobArn)
}
req := &bedrock.BedrockBatchRetrieveRequest{}
callback := extractBedrockJobArnFromPath(handlerStore)
bifrostCtx := createTestBifrostContextWithProvider(tt.provider)
err := callback(ctx, bifrostCtx, req)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantJobArn, req.JobIdentifier)
})
}
}
func Test_extractS3ListObjectsV2Params(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: false}
tests := []struct {
name string
bucket string
queryParams map[string]string
wantErr bool
wantBucket string
wantPrefix string
wantMaxKeys int
wantContinuationToken string
}{
{
name: "all params",
bucket: "my-bucket",
queryParams: map[string]string{
"prefix": "folder/",
"max-keys": "100",
"continuation-token": "token-abc",
},
wantErr: false,
wantBucket: "my-bucket",
wantPrefix: "folder/",
wantMaxKeys: 100,
wantContinuationToken: "token-abc",
},
{
name: "bucket only",
bucket: "simple-bucket",
queryParams: map[string]string{},
wantErr: false,
wantBucket: "simple-bucket",
wantPrefix: "",
wantMaxKeys: 1000,
},
{
name: "missing bucket",
bucket: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
if tt.bucket != "" {
ctx.SetUserValue("bucket", tt.bucket)
}
for k, v := range tt.queryParams {
ctx.QueryArgs().Add(k, v)
}
req := &bedrock.BedrockFileListRequest{}
callback := extractS3ListObjectsV2Params(handlerStore)
bifrostCtx := createTestBifrostContext()
err := callback(ctx, bifrostCtx, req)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.wantBucket, req.Bucket)
assert.Equal(t, tt.wantPrefix, req.Prefix)
assert.Equal(t, tt.wantMaxKeys, req.MaxKeys)
assert.Equal(t, tt.wantContinuationToken, req.ContinuationToken)
// Verify context values
assert.Equal(t, tt.wantBucket, bifrostCtx.Value(s3ContextKeyBucket))
assert.Equal(t, tt.wantPrefix, bifrostCtx.Value(s3ContextKeyPrefix))
assert.Equal(t, tt.wantMaxKeys, bifrostCtx.Value(s3ContextKeyMaxKeys))
})
}
}
func Test_extractS3BucketKeyFromPath(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: false}
tests := []struct {
name string
bucket string
key string
fileID string
opType string
wantErr bool
wantBucket string
wantKey string
wantS3URI string
}{
{
name: "content operation",
bucket: "my-bucket",
key: "path/to/file.txt",
fileID: "file-123",
opType: "content",
wantErr: false,
wantBucket: "my-bucket",
wantKey: "path/to/file.txt",
wantS3URI: "s3://my-bucket/path/to/file.txt",
},
{
name: "missing bucket",
bucket: "",
key: "file.txt",
opType: "content",
wantErr: true,
},
{
name: "missing key",
bucket: "bucket",
key: "",
opType: "content",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
if tt.bucket != "" {
ctx.SetUserValue("bucket", tt.bucket)
}
if tt.key != "" {
ctx.SetUserValue("key", tt.key)
}
if tt.fileID != "" {
ctx.Request.Header.Set("If-Match", tt.fileID)
}
callback := extractS3BucketKeyFromPath(handlerStore, tt.opType)
bifrostCtx := createTestBifrostContext()
var req interface{}
switch tt.opType {
case "content":
req = &bedrock.BedrockFileContentRequest{}
case "retrieve":
req = &bedrock.BedrockFileRetrieveRequest{}
case "delete":
req = &bedrock.BedrockFileDeleteRequest{}
}
err := callback(ctx, bifrostCtx, req)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
switch r := req.(type) {
case *bedrock.BedrockFileContentRequest:
assert.Equal(t, tt.wantBucket, r.Bucket)
assert.Equal(t, tt.wantKey, r.Prefix)
assert.Equal(t, tt.wantS3URI, r.S3Uri)
assert.Equal(t, tt.fileID, r.ETag)
}
})
}
}
// Helper functions for creating test contexts
func createTestBifrostContext() *schemas.BifrostContext {
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bifrostCtx.SetValue(bifrostContextKeyProvider, schemas.Bedrock)
return bifrostCtx
}
func createTestBifrostContextWithProvider(provider schemas.ModelProvider) *schemas.BifrostContext {
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bifrostCtx.SetValue(bifrostContextKeyProvider, provider)
return bifrostCtx
}

View File

@@ -0,0 +1,222 @@
package integrations
import (
"context"
"errors"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/providers/cohere"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// hydrateCohereRequestFromLargePayloadMetadata populates model + stream from
// LargePayloadMetadata when body parsing is skipped under large payload mode.
func hydrateCohereRequestFromLargePayloadMetadata(bifrostCtx *schemas.BifrostContext, req interface{}) {
if bifrostCtx == nil {
return
}
isLargePayload, _ := bifrostCtx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool)
if !isLargePayload {
return
}
metadata := resolveLargePayloadMetadata(bifrostCtx)
if metadata == nil {
return
}
switch r := req.(type) {
case *cohere.CohereChatRequest:
if r.Model == "" {
r.Model = metadata.Model
}
if metadata.StreamRequested != nil && r.Stream == nil {
r.Stream = schemas.Ptr(*metadata.StreamRequested)
}
case *cohere.CohereEmbeddingRequest:
if r.Model == "" {
r.Model = metadata.Model
}
case *cohere.CohereRerankRequest:
if r.Model == "" {
r.Model = metadata.Model
}
case *cohere.CohereCountTokensRequest:
if r.Model == "" {
r.Model = metadata.Model
}
}
}
// cohereLargePayloadPreHook populates model + stream from LargePayloadMetadata
// when body parsing is skipped under large payload mode.
func cohereLargePayloadPreHook(_ *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, req interface{}) error {
hydrateCohereRequestFromLargePayloadMetadata(bifrostCtx, req)
return nil
}
// CohereRouter holds route registrations for Cohere endpoints.
// It supports Cohere's v2 chat, embeddings, and rerank APIs.
type CohereRouter struct {
*GenericRouter
}
// NewCohereRouter creates a new CohereRouter with the given bifrost client.
func NewCohereRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *CohereRouter {
return &CohereRouter{
GenericRouter: NewGenericRouter(client, handlerStore, CreateCohereRouteConfigs("/cohere"), nil, logger),
}
}
// CreateCohereRouteConfigs creates route configurations for Cohere API endpoints.
func CreateCohereRouteConfigs(pathPrefix string) []RouteConfig {
var routes []RouteConfig
// Chat completions endpoint (v2/chat)
routes = append(routes, RouteConfig{
Type: RouteConfigTypeCohere,
Path: pathPrefix + "/v2/chat",
Method: "POST",
PreCallback: cohereLargePayloadPreHook,
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
return schemas.ChatCompletionRequest
},
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &cohere.CohereChatRequest{}
},
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if cohereReq, ok := req.(*cohere.CohereChatRequest); ok {
return &schemas.BifrostRequest{
ChatRequest: cohereReq.ToBifrostChatRequest(ctx),
}, nil
}
return nil, errors.New("invalid request type")
},
ChatResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostChatResponse) (interface{}, error) {
if resp.ExtraFields.Provider == schemas.Cohere {
if resp.ExtraFields.RawResponse != nil {
return resp.ExtraFields.RawResponse, nil
}
}
return resp, nil
},
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return err
},
StreamConfig: &StreamConfig{
ChatStreamResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostChatResponse) (string, interface{}, error) {
if resp.ExtraFields.Provider == schemas.Cohere {
if resp.ExtraFields.RawResponse != nil {
return "", resp.ExtraFields.RawResponse, nil
}
}
return "", resp, nil
},
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return err
},
},
})
// Embeddings endpoint (v2/embed)
routes = append(routes, RouteConfig{
Type: RouteConfigTypeCohere,
Path: pathPrefix + "/v2/embed",
Method: "POST",
PreCallback: cohereLargePayloadPreHook,
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
return schemas.EmbeddingRequest
},
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &cohere.CohereEmbeddingRequest{}
},
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if cohereReq, ok := req.(*cohere.CohereEmbeddingRequest); ok {
return &schemas.BifrostRequest{
EmbeddingRequest: cohereReq.ToBifrostEmbeddingRequest(ctx),
}, nil
}
return nil, errors.New("invalid embedding request type")
},
EmbeddingResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostEmbeddingResponse) (interface{}, error) {
if resp.ExtraFields.Provider == schemas.Cohere {
if resp.ExtraFields.RawResponse != nil {
return resp.ExtraFields.RawResponse, nil
}
}
return resp, nil
},
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return err
},
})
// Rerank endpoint (v2/rerank)
routes = append(routes, RouteConfig{
Type: RouteConfigTypeCohere,
Path: pathPrefix + "/v2/rerank",
Method: "POST",
PreCallback: cohereLargePayloadPreHook,
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
return schemas.RerankRequest
},
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &cohere.CohereRerankRequest{}
},
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if cohereReq, ok := req.(*cohere.CohereRerankRequest); ok {
return &schemas.BifrostRequest{
RerankRequest: cohereReq.ToBifrostRerankRequest(ctx),
}, nil
}
return nil, errors.New("invalid rerank request type")
},
RerankResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostRerankResponse) (interface{}, error) {
if resp.ExtraFields.Provider == schemas.Cohere {
if resp.ExtraFields.RawResponse != nil {
return resp.ExtraFields.RawResponse, nil
}
}
return resp, nil
},
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return err
},
})
// Tokenize endpoint (v1/tokenize)
routes = append(routes, RouteConfig{
Type: RouteConfigTypeCohere,
Path: pathPrefix + "/v1/tokenize",
Method: "POST",
PreCallback: cohereLargePayloadPreHook,
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
return schemas.CountTokensRequest
},
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &cohere.CohereCountTokensRequest{}
},
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
if cohereReq, ok := req.(*cohere.CohereCountTokensRequest); ok {
return &schemas.BifrostRequest{
CountTokensRequest: cohereReq.ToBifrostResponsesRequest(ctx),
}, nil
}
return nil, errors.New("invalid count tokens request type")
},
CountTokensResponseConverter: func(ctx *schemas.BifrostContext, resp *schemas.BifrostCountTokensResponse) (interface{}, error) {
if resp.ExtraFields.Provider == schemas.Cohere {
if resp.ExtraFields.RawResponse != nil {
return resp.ExtraFields.RawResponse, nil
}
}
return resp, nil
},
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return err
},
})
return routes
}

View File

@@ -0,0 +1,102 @@
package integrations
import (
"context"
"testing"
"github.com/maximhq/bifrost/core/providers/cohere"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateCohereRouteConfigsIncludesRerank(t *testing.T) {
routes := CreateCohereRouteConfigs("/cohere")
assert.Len(t, routes, 4, "should have 4 cohere routes")
var rerankRoute *RouteConfig
for i := range routes {
if routes[i].Path == "/cohere/v2/rerank" && routes[i].Method == "POST" {
rerankRoute = &routes[i]
break
}
}
require.NotNil(t, rerankRoute, "rerank route should exist")
assert.Equal(t, RouteConfigTypeCohere, rerankRoute.Type)
assert.NotNil(t, rerankRoute.GetHTTPRequestType)
assert.Equal(t, schemas.RerankRequest, rerankRoute.GetHTTPRequestType(nil))
assert.NotNil(t, rerankRoute.GetRequestTypeInstance)
assert.NotNil(t, rerankRoute.RequestConverter)
assert.NotNil(t, rerankRoute.RerankResponseConverter)
assert.NotNil(t, rerankRoute.ErrorConverter)
reqInstance := rerankRoute.GetRequestTypeInstance(context.Background())
_, ok := reqInstance.(*cohere.CohereRerankRequest)
assert.True(t, ok, "rerank request instance should be CohereRerankRequest")
}
func TestCohereRerankRouteRequestConverter(t *testing.T) {
routes := CreateCohereRouteConfigs("/cohere")
var rerankRoute *RouteConfig
for i := range routes {
if routes[i].Path == "/cohere/v2/rerank" {
rerankRoute = &routes[i]
break
}
}
require.NotNil(t, rerankRoute)
require.NotNil(t, rerankRoute.RequestConverter)
topN := 1
req := &cohere.CohereRerankRequest{
Model: "rerank-v3.5",
Query: "what is bifrost?",
Documents: []string{"doc1", "doc2"},
TopN: &topN,
}
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bifrostReq, err := rerankRoute.RequestConverter(bifrostCtx, req)
require.NoError(t, err)
require.NotNil(t, bifrostReq)
require.NotNil(t, bifrostReq.RerankRequest)
assert.Equal(t, schemas.Cohere, bifrostReq.RerankRequest.Provider)
assert.Equal(t, "rerank-v3.5", bifrostReq.RerankRequest.Model)
assert.Equal(t, "what is bifrost?", bifrostReq.RerankRequest.Query)
require.Len(t, bifrostReq.RerankRequest.Documents, 2)
assert.Equal(t, "doc1", bifrostReq.RerankRequest.Documents[0].Text)
assert.Equal(t, "doc2", bifrostReq.RerankRequest.Documents[1].Text)
require.NotNil(t, bifrostReq.RerankRequest.Params)
require.NotNil(t, bifrostReq.RerankRequest.Params.TopN)
assert.Equal(t, 1, *bifrostReq.RerankRequest.Params.TopN)
}
func TestCohereRerankResponseConverterUsesRawResponse(t *testing.T) {
routes := CreateCohereRouteConfigs("/cohere")
var rerankRoute *RouteConfig
for i := range routes {
if routes[i].Path == "/cohere/v2/rerank" {
rerankRoute = &routes[i]
break
}
}
require.NotNil(t, rerankRoute)
require.NotNil(t, rerankRoute.RerankResponseConverter)
raw := map[string]interface{}{"id": "r-123", "results": []interface{}{}}
resp := &schemas.BifrostRerankResponse{
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Cohere,
RawResponse: raw,
},
}
converted, err := rerankRoute.RerankResponseConverter(nil, resp)
require.NoError(t, err)
assert.Equal(t, raw, converted)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,47 @@
package integrations
import (
"context"
"strings"
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
func TestExtractModelAndRequestType_LargePayloadUsesMetadataWithoutBodyParse(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue("model", "gemini-2.5-pro:generateContent")
// Intentionally invalid JSON: detection must rely on large-payload metadata, not body parse.
ctx.Request.SetBodyString(`{"contents":[INVALID`)
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bifrostCtx.SetValue(schemas.BifrostContextKeyLargePayloadMode, true)
bifrostCtx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, &schemas.LargePayloadMetadata{
ResponseModalities: []string{"AUDIO"},
})
ctx.SetUserValue(lib.FastHTTPUserValueBifrostContext, bifrostCtx)
model, reqType := extractModelAndRequestType(ctx)
if model != "gemini-2.5-pro" {
t.Fatalf("expected normalized model gemini-2.5-pro, got %q", model)
}
if reqType != schemas.SpeechRequest {
t.Fatalf("expected speech request type from metadata, got %q", reqType)
}
}
func TestExtractModelAndRequestType_LargeBodyHeuristicSkipsParse(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue("model", "gemini-2.5-pro:generateContent")
ctx.Request.SetBodyStream(strings.NewReader(`{"contents":[INVALID`), schemas.DefaultLargePayloadRequestThresholdBytes+1)
model, reqType := extractModelAndRequestType(ctx)
if model != "gemini-2.5-pro" {
t.Fatalf("expected normalized model gemini-2.5-pro, got %q", model)
}
if reqType != schemas.ResponsesRequest {
t.Fatalf("expected responses request type from large-body heuristic, got %q", reqType)
}
}

View File

@@ -0,0 +1,208 @@
package integrations
import (
"context"
"testing"
"github.com/maximhq/bifrost/core/providers/gemini"
"github.com/maximhq/bifrost/core/providers/vertex"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func TestCreateGenAIRerankRouteConfig(t *testing.T) {
route := createGenAIRerankRouteConfig("/genai")
assert.Equal(t, "/genai/v1/rank", route.Path)
assert.Equal(t, "POST", route.Method)
assert.Equal(t, RouteConfigTypeGenAI, route.Type)
assert.NotNil(t, route.GetHTTPRequestType)
assert.Equal(t, schemas.RerankRequest, route.GetHTTPRequestType(nil))
assert.NotNil(t, route.GetRequestTypeInstance)
assert.NotNil(t, route.RequestConverter)
assert.NotNil(t, route.RerankResponseConverter)
assert.NotNil(t, route.ErrorConverter)
assert.Nil(t, route.PreCallback)
// Verify request instance type
reqInstance := route.GetRequestTypeInstance(context.Background())
_, ok := reqInstance.(*vertex.VertexRankRequest)
assert.True(t, ok, "GetRequestTypeInstance should return *vertex.VertexRankRequest")
}
func TestCreateGenAIRouteConfigsIncludesRerank(t *testing.T) {
routes := CreateGenAIRouteConfigs("/genai")
found := false
for _, route := range routes {
if route.Path == "/genai/v1/rank" && route.Method == "POST" {
found = true
break
}
}
assert.True(t, found, "expected rerank route in genai route configs")
}
func TestCreateGenAIRouteConfigsIncludesRerankForCompositePrefixes(t *testing.T) {
prefixes := []string{"/litellm", "/langchain", "/pydanticai"}
for _, prefix := range prefixes {
routes := CreateGenAIRouteConfigs(prefix)
found := false
for _, route := range routes {
if route.Path == prefix+"/v1/rank" && route.Method == "POST" {
found = true
break
}
}
assert.Truef(t, found, "expected rerank route for prefix %s", prefix)
}
}
func TestGenAIRerankRequestConverter(t *testing.T) {
route := createGenAIRerankRouteConfig("/genai")
require.NotNil(t, route.RequestConverter)
model := "semantic-ranker-default@latest"
topN := 2
content1 := "Paris is capital of France"
content2 := "Berlin is capital of Germany"
req := &vertex.VertexRankRequest{
Model: &model,
Query: "capital of france",
Records: []vertex.VertexRankRecord{
{ID: "rec-1", Content: &content1},
{ID: "rec-2", Content: &content2},
},
TopN: &topN,
}
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bifrostReq, err := route.RequestConverter(bifrostCtx, req)
require.NoError(t, err)
require.NotNil(t, bifrostReq)
require.NotNil(t, bifrostReq.RerankRequest)
assert.Equal(t, schemas.Vertex, bifrostReq.RerankRequest.Provider)
assert.Equal(t, "semantic-ranker-default@latest", bifrostReq.RerankRequest.Model)
assert.Equal(t, "capital of france", bifrostReq.RerankRequest.Query)
require.Len(t, bifrostReq.RerankRequest.Documents, 2)
assert.Equal(t, "Paris is capital of France", bifrostReq.RerankRequest.Documents[0].Text)
assert.Equal(t, "Berlin is capital of Germany", bifrostReq.RerankRequest.Documents[1].Text)
require.NotNil(t, bifrostReq.RerankRequest.Params)
require.NotNil(t, bifrostReq.RerankRequest.Params.TopN)
assert.Equal(t, 2, *bifrostReq.RerankRequest.Params.TopN)
}
func TestGenAIRerankResponseConverterUsesRawResponse(t *testing.T) {
route := createGenAIRerankRouteConfig("/genai")
require.NotNil(t, route.RerankResponseConverter)
raw := map[string]interface{}{"records": []interface{}{}}
resp := &schemas.BifrostRerankResponse{
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Vertex,
RawResponse: raw,
},
}
converted, err := route.RerankResponseConverter(nil, resp)
require.NoError(t, err)
assert.Equal(t, raw, converted)
}
func TestGenAIRerankResponseConverterFallsBackWhenNotVertex(t *testing.T) {
route := createGenAIRerankRouteConfig("/genai")
require.NotNil(t, route.RerankResponseConverter)
resp := &schemas.BifrostRerankResponse{
Results: []schemas.RerankResult{
{Index: 0, RelevanceScore: 0.9},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: schemas.Cohere,
},
}
converted, err := route.RerankResponseConverter(nil, resp)
require.NoError(t, err)
assert.Equal(t, resp, converted)
}
func TestCreateGenAIRouteConfigsIncludesModelMetadataRoute(t *testing.T) {
routes := CreateGenAIRouteConfigs("/genai")
found := false
for _, route := range routes {
if route.Path == "/genai/v1beta/models/{model}" && route.Method == "GET" {
found = true
assert.Equal(t, schemas.ListModelsRequest, route.GetHTTPRequestType(nil))
require.NotNil(t, route.PreCallback)
require.NotNil(t, route.ListModelsResponseConverter)
break
}
}
assert.True(t, found, "expected model metadata route in genai route configs")
}
func TestExtractGeminiModelMetadataParams(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.SetUserValue("model", "models/gemini-3-pro-preview")
listReq := &schemas.BifrostListModelsRequest{}
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
err := extractGeminiModelMetadataParams(ctx, bifrostCtx, listReq)
require.NoError(t, err)
assert.Equal(t, schemas.Gemini, listReq.Provider)
assert.Equal(t, "/models/gemini-3-pro-preview", bifrostCtx.Value(schemas.BifrostContextKeyURLPath))
assert.Equal(t, "gemini-3-pro-preview", bifrostCtx.Value(requestedGeminiModelMetadataContextKey))
}
func TestConvertGeminiModelMetadataResponse(t *testing.T) {
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bifrostCtx.SetValue(requestedGeminiModelMetadataContextKey, "gemini-2.5-pro")
resp := &schemas.BifrostListModelsResponse{
Data: []schemas.Model{{ID: "gemini/gemini-2.5-pro", Name: schemas.Ptr("Gemini 2.5 Pro")}},
}
converted, err := convertGeminiModelMetadataResponse(bifrostCtx, resp)
require.NoError(t, err)
model, ok := converted.(gemini.GeminiModel)
require.True(t, ok, "expected gemini.GeminiModel")
assert.Equal(t, "models/gemini-2.5-pro", model.Name)
assert.Equal(t, "Gemini 2.5 Pro", model.DisplayName)
}
func TestConvertGeminiModelMetadataResponse_MatchesRequestedModelNotFirst(t *testing.T) {
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bifrostCtx.SetValue(requestedGeminiModelMetadataContextKey, "gemini-3-pro-preview")
resp := &schemas.BifrostListModelsResponse{
Data: []schemas.Model{
{ID: "gemini/gemini-1.5-pro", Name: schemas.Ptr("Gemini 1.5 Pro")},
{ID: "gemini/gemini-3-pro-preview", Name: schemas.Ptr("Gemini 3 Pro Preview")},
},
}
converted, err := convertGeminiModelMetadataResponse(bifrostCtx, resp)
require.NoError(t, err)
model, ok := converted.(gemini.GeminiModel)
require.True(t, ok, "expected gemini.GeminiModel")
assert.Equal(t, "models/gemini-3-pro-preview", model.Name)
assert.Equal(t, "Gemini 3 Pro Preview", model.DisplayName)
}
func TestConvertGeminiModelMetadataResponse_EmptyReturnsMinimalModel(t *testing.T) {
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bifrostCtx.SetValue(requestedGeminiModelMetadataContextKey, "gemini-3-pro-preview")
converted, err := convertGeminiModelMetadataResponse(bifrostCtx, &schemas.BifrostListModelsResponse{Data: []schemas.Model{}})
require.NoError(t, err)
model, ok := converted.(gemini.GeminiModel)
require.True(t, ok, "expected gemini.GeminiModel")
assert.Equal(t, "models/gemini-3-pro-preview", model.Name)
}

View File

@@ -0,0 +1,42 @@
package integrations
import (
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
)
// LangChainRouter holds route registrations for LangChain endpoints.
// It supports standard chat completions and image-enabled vision capabilities.
// LangChain is fully OpenAI-compatible, so we reuse OpenAI types
// with aliases for clarity and minimal LangChain-specific extensions
type LangChainRouter struct {
*GenericRouter
}
// NewLangChainRouter creates a new LangChainRouter with the given bifrost client.
func NewLangChainRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *LangChainRouter {
routes := []RouteConfig{}
// Add OpenAI routes to LangChain for OpenAI API compatibility
routes = append(routes, CreateOpenAIRouteConfigs("/langchain", handlerStore)...)
// Add Anthropic routes to LangChain for Anthropic API compatibility
routes = append(routes, CreateAnthropicRouteConfigs("/langchain", logger)...)
// Add Anthropic count tokens route for LangChain to ensure token counting uses the dedicated endpoint
routes = append(routes, CreateAnthropicCountTokensRouteConfigs("/langchain", handlerStore)...)
// Add GenAI routes to LangChain for Vertex AI compatibility
routes = append(routes, CreateGenAIRouteConfigs("/langchain")...)
// Add Bedrock routes to LangChain for AWS Bedrock API compatibility
routes = append(routes, CreateBedrockRouteConfigs("/langchain", handlerStore)...)
// Add Cohere routes to LangChain for Cohere API compatibility
routes = append(routes, CreateCohereRouteConfigs("/langchain")...)
return &LangChainRouter{
GenericRouter: NewGenericRouter(client, handlerStore, routes, nil, logger),
}
}

View File

@@ -0,0 +1,39 @@
package integrations
import (
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
)
// LiteLLMRouter holds route registrations for LiteLLM endpoints.
// It supports standard chat completions and image-enabled vision capabilities.
// LiteLLM is fully OpenAI-compatible, so we reuse OpenAI types
// with aliases for clarity and minimal LiteLLM-specific extensions
type LiteLLMRouter struct {
*GenericRouter
}
// NewLiteLLMRouter creates a new LiteLLMRouter with the given bifrost client.
func NewLiteLLMRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *LiteLLMRouter {
routes := []RouteConfig{}
// Add OpenAI routes to LiteLLM for OpenAI API compatibility
routes = append(routes, CreateOpenAIRouteConfigs("/litellm", handlerStore)...)
// Add Anthropic routes to LiteLLM for Anthropic API compatibility
routes = append(routes, CreateAnthropicRouteConfigs("/litellm", logger)...)
// Add GenAI routes to LiteLLM for Vertex AI compatibility
routes = append(routes, CreateGenAIRouteConfigs("/litellm")...)
// Add Bedrock routes to LiteLLM for AWS Bedrock API compatibility
routes = append(routes, CreateBedrockRouteConfigs("/litellm", handlerStore)...)
// Add Cohere routes to LiteLLM for Cohere API compatibility
routes = append(routes, CreateCohereRouteConfigs("/litellm")...)
return &LiteLLMRouter{
GenericRouter: NewGenericRouter(client, handlerStore, routes, nil, logger),
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,72 @@
package integrations
import (
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
)
// PassthroughRouter is a catch-all router that forwards all requests directly
// to the provider without matching against known route patterns.
type PassthroughRouter struct {
*GenericRouter
}
// NewPassthroughRouter creates a passthrough-only router for any prefix/provider combo.
func NewPassthroughRouter(
client *bifrost.Bifrost,
handlerStore lib.HandlerStore,
logger schemas.Logger,
cfg *PassthroughConfig,
) *PassthroughRouter {
if cfg == nil {
cfg = &PassthroughConfig{}
}
return &PassthroughRouter{
GenericRouter: NewGenericRouter(client, handlerStore, nil, cfg, logger),
}
}
// NewAnthropicPassthroughRouter creates a passthrough router for /anthropic_passthrough.
func NewAnthropicPassthroughRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PassthroughRouter {
return NewPassthroughRouter(client, handlerStore, logger, &PassthroughConfig{
Provider: schemas.Anthropic,
StripPrefix: []string{
"/anthropic_passthrough",
},
})
}
// NewOpenAIPassthroughRouter creates a passthrough router for /openai_passthrough.
func NewOpenAIPassthroughRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PassthroughRouter {
return NewPassthroughRouter(client, handlerStore, logger, &PassthroughConfig{
Provider: schemas.OpenAI,
StripPrefix: []string{
"/openai_passthrough",
},
})
}
// NewAzurePassthroughRouter creates a passthrough router for /azure_passthrough.
func NewAzurePassthroughRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PassthroughRouter {
return NewPassthroughRouter(client, handlerStore, logger, &PassthroughConfig{
Provider: schemas.Azure,
StripPrefix: []string{
"/azure_passthrough",
},
})
}
// NewGenAIPassthroughRouter creates a passthrough router for /genai_passthrough.
func NewGenAIPassthroughRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PassthroughRouter {
return NewPassthroughRouter(client, handlerStore, logger, &PassthroughConfig{
Provider: schemas.Gemini,
ProviderDetector: detectProviderFromGenAIRequest,
StripPrefix: []string{
"/genai_passthrough/v1beta1",
"/genai_passthrough/v1beta",
"/genai_passthrough/v1",
"/genai_passthrough",
},
})
}

View File

@@ -0,0 +1,135 @@
package integrations
import (
"strings"
"github.com/bytedance/sonic"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
)
// PydanticAIRouter holds route registrations for Pydantic AI endpoints.
// It supports standard chat completions, tool calling, streaming, and multi-provider capabilities.
// Pydantic AI uses standard provider SDKs (OpenAI, Anthropic, Google GenAI), so we reuse
// existing route configurations with aliases for clarity and Pydantic AI-specific extensions.
type PydanticAIRouter struct {
*GenericRouter
}
// NewPydanticAIRouter creates a new PydanticAIRouter with the given bifrost client.
func NewPydanticAIRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *PydanticAIRouter {
routes := []RouteConfig{}
// Add OpenAI routes to Pydantic AI for OpenAI API compatibility
// Supports: chat completions, embeddings, speech, transcriptions, responses
routes = append(routes, withPydanticResponsesNullNormalization(CreateOpenAIRouteConfigs("/pydanticai", handlerStore))...)
// Add Anthropic routes to Pydantic AI for Anthropic API compatibility
// Supports: messages API (Claude models)
routes = append(routes, CreateAnthropicRouteConfigs("/pydanticai", logger)...)
// Add GenAI routes to Pydantic AI for Google Gemini API compatibility
// Supports: generateContent, streamGenerateContent, embedContent
routes = append(routes, CreateGenAIRouteConfigs("/pydanticai")...)
// Add Cohere routes to Pydantic AI for Cohere API compatibility
// Supports: v2/chat (chat completions with streaming), v2/embed (embeddings)
routes = append(routes, CreateCohereRouteConfigs("/pydanticai")...)
// Add Bedrock routes to Pydantic AI for AWS Bedrock API compatibility
// Supports: converse, converse-stream, invoke, invoke-with-response-stream
routes = append(routes, CreateBedrockRouteConfigs("/pydanticai", handlerStore)...)
return &PydanticAIRouter{
GenericRouter: NewGenericRouter(client, handlerStore, routes, nil, logger),
}
}
func withPydanticResponsesNullNormalization(routes []RouteConfig) []RouteConfig {
for i := range routes {
if !strings.Contains(routes[i].Path, "/responses") {
continue
}
if routes[i].ResponsesResponseConverter != nil {
routes[i].ResponsesResponseConverter = func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesResponse) (interface{}, error) {
// For pydantic responses endpoint, prefer normalized bifrost output
// instead of raw passthrough, to keep null handling consistent.
return resp.WithDefaults(), nil
}
}
if routes[i].StreamConfig != nil && routes[i].StreamConfig.ResponsesStreamResponseConverter != nil {
// Match non-stream behavior: prefer normalized output (raw->normalizePydanticResponsesRawStreamChunk, typed->resp.WithDefaults()+ensurePydanticResponsesStreamTextFields).
routes[i].StreamConfig.ResponsesStreamResponseConverter = func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponsesStreamResponse) (string, interface{}, error) {
if resp == nil {
return "", nil, nil
}
if resp.ExtraFields.RawResponse != nil {
normalizedRaw := normalizePydanticResponsesRawStreamChunk(resp.ExtraFields.RawResponse)
if normalizedRawString, ok := normalizedRaw.(string); ok {
return string(resp.Type), normalizedRawString, nil
}
}
normalized := resp.WithDefaults()
ensurePydanticResponsesStreamTextFields(normalized)
return string(resp.Type), normalized, nil
}
}
}
return routes
}
func ensurePydanticResponsesStreamTextFields(resp *schemas.BifrostResponsesStreamResponse) {
if resp == nil {
return
}
switch resp.Type {
case schemas.ResponsesStreamResponseTypeOutputTextDelta:
if resp.Delta == nil {
resp.Delta = bifrost.Ptr("")
}
case schemas.ResponsesStreamResponseTypeOutputTextDone:
if resp.Text == nil {
resp.Text = bifrost.Ptr("")
}
}
}
func normalizePydanticResponsesRawStreamChunk(raw interface{}) interface{} {
rawString, ok := raw.(string)
if !ok {
return raw
}
var chunk map[string]interface{}
if err := sonic.UnmarshalString(rawString, &chunk); err != nil {
return raw
}
changed := false
if chunkType, ok := chunk["type"].(string); ok {
switch schemas.ResponsesStreamResponseType(chunkType) {
case schemas.ResponsesStreamResponseTypeOutputTextDelta:
if value, exists := chunk["delta"]; exists && value == nil {
chunk["delta"] = ""
changed = true
}
case schemas.ResponsesStreamResponseTypeOutputTextDone:
if value, exists := chunk["text"]; exists && value == nil {
chunk["text"] = ""
changed = true
}
}
}
if !changed {
return raw
}
normalized, err := sonic.MarshalString(chunk)
if err != nil {
return raw
}
return normalized
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,152 @@
package integrations
import (
"context"
"errors"
"testing"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func TestCreateHandler_SkipsRequestParserInLargePayloadMode(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
parserCalls := 0
route := RouteConfig{
Type: RouteConfigTypeOpenAI,
Path: "/openai/v1/chat/completions",
Method: "POST",
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
return schemas.ChatCompletionRequest
},
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &struct{}{}
},
RequestParser: func(ctx *fasthttp.RequestCtx, req interface{}) error {
parserCalls++
return nil
},
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
return nil, errors.New("stop after parse phase")
},
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return err
},
}
router := NewGenericRouter(nil, handlerStore, nil, nil, nil)
router.SetLargePayloadHook(func(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, routeType RouteConfigType) (bool, error) {
return true, nil
})
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
ctx.Request.SetBodyString(`{"model":"openai/gpt-4o","messages":[]}`)
ctx.SetUserValue(schemas.BifrostContextKeyHTTPRequestType, schemas.ChatCompletionRequest)
handler := router.createHandler(route)
handler(ctx)
assert.Equal(t, 0, parserCalls)
}
func TestCreateHandler_UsesRequestParserWhenNotInLargePayloadMode(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
parserCalls := 0
route := RouteConfig{
Type: RouteConfigTypeOpenAI,
Path: "/openai/v1/chat/completions",
Method: "POST",
GetHTTPRequestType: func(ctx *fasthttp.RequestCtx) schemas.RequestType {
return schemas.ChatCompletionRequest
},
GetRequestTypeInstance: func(ctx context.Context) interface{} {
return &struct{}{}
},
RequestParser: func(ctx *fasthttp.RequestCtx, req interface{}) error {
parserCalls++
return nil
},
RequestConverter: func(ctx *schemas.BifrostContext, req interface{}) (*schemas.BifrostRequest, error) {
return nil, errors.New("stop after parse phase")
},
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return err
},
}
router := NewGenericRouter(nil, handlerStore, nil, nil, nil)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fasthttp.MethodPost)
ctx.Request.SetBodyString(`{"model":"openai/gpt-4o","messages":[]}`)
ctx.SetUserValue(schemas.BifrostContextKeyHTTPRequestType, schemas.ChatCompletionRequest)
handler := router.createHandler(route)
handler(ctx)
assert.Equal(t, 1, parserCalls)
}
// ============================================================================
// resolveLargePayloadMetadata tests
// ============================================================================
func TestResolveLargePayloadMetadata_NilContext(t *testing.T) {
assert.Nil(t, resolveLargePayloadMetadata(nil))
}
func TestResolveLargePayloadMetadata_SyncPath(t *testing.T) {
ctx := schemas.NewBifrostContext(nil, time.Time{})
meta := &schemas.LargePayloadMetadata{Model: "gpt-4o"}
ctx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, meta)
result := resolveLargePayloadMetadata(ctx)
require.NotNil(t, result)
assert.Equal(t, "gpt-4o", result.Model)
}
func TestResolveLargePayloadMetadata_DeferredReady(t *testing.T) {
ctx := schemas.NewBifrostContext(nil, time.Time{})
ch := make(chan *schemas.LargePayloadMetadata, 1)
ch <- &schemas.LargePayloadMetadata{Model: "claude-4"}
ctx.SetValue(schemas.BifrostContextKeyDeferredLargePayloadMetadata, (<-chan *schemas.LargePayloadMetadata)(ch))
result := resolveLargePayloadMetadata(ctx)
require.NotNil(t, result)
assert.Equal(t, "claude-4", result.Model)
// Verify it was cached in the sync key.
cached, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadMetadata).(*schemas.LargePayloadMetadata)
require.True(t, ok)
assert.Equal(t, "claude-4", cached.Model)
}
func TestResolveLargePayloadMetadata_DeferredNotReady(t *testing.T) {
ctx := schemas.NewBifrostContext(nil, time.Time{})
ch := make(chan *schemas.LargePayloadMetadata, 1) // empty, not ready
ctx.SetValue(schemas.BifrostContextKeyDeferredLargePayloadMetadata, (<-chan *schemas.LargePayloadMetadata)(ch))
// Non-blocking: should return nil when channel has no value yet.
result := resolveLargePayloadMetadata(ctx)
assert.Nil(t, result)
}
func TestResolveLargePayloadMetadata_SyncTakesPrecedence(t *testing.T) {
ctx := schemas.NewBifrostContext(nil, time.Time{})
syncMeta := &schemas.LargePayloadMetadata{Model: "sync-model"}
ctx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, syncMeta)
ch := make(chan *schemas.LargePayloadMetadata, 1)
ch <- &schemas.LargePayloadMetadata{Model: "deferred-model"}
ctx.SetValue(schemas.BifrostContextKeyDeferredLargePayloadMetadata, (<-chan *schemas.LargePayloadMetadata)(ch))
result := resolveLargePayloadMetadata(ctx)
require.NotNil(t, result)
assert.Equal(t, "sync-model", result.Model)
}

View File

@@ -0,0 +1,373 @@
package integrations
import (
"bytes"
"context"
"mime/multipart"
"testing"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/providers/openai"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParsePassthroughBody_MultipartExtractsModelAfterFilePart(t *testing.T) {
var body bytes.Buffer
writer := multipart.NewWriter(&body)
fileWriter, err := writer.CreateFormFile("file", "sample.mp3")
require.NoError(t, err)
_, err = fileWriter.Write([]byte("audio-bytes"))
require.NoError(t, err)
require.NoError(t, writer.WriteField("model", "openai/whisper-1"))
require.NoError(t, writer.WriteField("stream", "true"))
require.NoError(t, writer.Close())
model, stream := parsePassthroughBody(writer.FormDataContentType(), body.Bytes())
assert.Equal(t, "openai/whisper-1", model)
assert.True(t, stream)
}
func TestRequestWithSettableExtraParams_OpenAIChatRequest(t *testing.T) {
t.Run("SetExtraParams populates both standalone and embedded ExtraParams", func(t *testing.T) {
req := &openai.OpenAIChatRequest{}
extra := map[string]interface{}{
"guardrailConfig": map[string]interface{}{
"guardrailIdentifier": "xxx",
"guardrailVersion": "1",
},
}
rws, ok := interface{}(req).(RequestWithSettableExtraParams)
require.True(t, ok, "OpenAIChatRequest should implement RequestWithSettableExtraParams")
rws.SetExtraParams(extra)
assert.Equal(t, extra, req.GetExtraParams())
assert.Equal(t, extra, req.ChatParameters.ExtraParams, "embedded ChatParameters.ExtraParams should also be set")
})
t.Run("extra params propagate through ToBifrostChatRequest", func(t *testing.T) {
req := &openai.OpenAIChatRequest{
Model: "bedrock/claude-4-5-sonnet-global",
Messages: []openai.OpenAIMessage{},
}
extra := map[string]interface{}{
"guardrailConfig": map[string]interface{}{
"guardrailIdentifier": "test-id",
"guardrailVersion": "1",
},
}
rws := interface{}(req).(RequestWithSettableExtraParams)
rws.SetExtraParams(extra)
ctx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
bifrostReq := req.ToBifrostChatRequest(ctx)
require.NotNil(t, bifrostReq)
require.NotNil(t, bifrostReq.Params)
assert.Contains(t, bifrostReq.Params.ExtraParams, "guardrailConfig")
})
}
func TestRequestWithSettableExtraParams_AllOpenAIRequestTypes(t *testing.T) {
tests := []struct {
name string
req interface{}
}{
{"OpenAIChatRequest", &openai.OpenAIChatRequest{}},
{"OpenAITextCompletionRequest", &openai.OpenAITextCompletionRequest{}},
{"OpenAIResponsesRequest", &openai.OpenAIResponsesRequest{}},
{"OpenAIEmbeddingRequest", &openai.OpenAIEmbeddingRequest{}},
{"OpenAISpeechRequest", &openai.OpenAISpeechRequest{}},
{"OpenAIImageGenerationRequest", &openai.OpenAIImageGenerationRequest{}},
{"OpenAIImageEditRequest", &openai.OpenAIImageEditRequest{}},
{"OpenAIImageVariationRequest", &openai.OpenAIImageVariationRequest{}},
}
for _, tt := range tests {
t.Run(tt.name+" implements RequestWithSettableExtraParams", func(t *testing.T) {
rws, ok := tt.req.(RequestWithSettableExtraParams)
require.True(t, ok, "%s should implement RequestWithSettableExtraParams", tt.name)
extra := map[string]interface{}{"test_key": "test_value"}
rws.SetExtraParams(extra)
getter, ok := tt.req.(interface{ GetExtraParams() map[string]interface{} })
require.True(t, ok, "%s should implement GetExtraParams", tt.name)
assert.Equal(t, extra, getter.GetExtraParams())
})
}
}
func TestExtraParamsRequiresPassthroughHeader(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
routes := CreateOpenAIRouteConfigs("/openai", handlerStore)
var chatRoute *RouteConfig
for i := range routes {
if routes[i].Path == "/openai/v1/chat/completions" {
chatRoute = &routes[i]
break
}
}
require.NotNil(t, chatRoute, "should find /openai/v1/chat/completions route")
rawBody := []byte(`{
"model": "bedrock/claude-4-5-sonnet-global",
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
"extra_params": {
"guardrailConfig": {
"guardrailIdentifier": "my-guardrail",
"guardrailVersion": "1",
"trace": "disabled"
}
}
}`)
t.Run("extra_params NOT extracted without passthrough header", func(t *testing.T) {
req := chatRoute.GetRequestTypeInstance(context.Background())
err := sonic.Unmarshal(rawBody, req)
require.NoError(t, err)
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
// Header not set -- simulate router logic
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
if rws, ok := req.(RequestWithSettableExtraParams); ok {
var wrapper struct {
ExtraParams map[string]interface{} `json:"extra_params"`
}
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
rws.SetExtraParams(wrapper.ExtraParams)
}
_ = rws
}
}
openaiReq, ok := req.(*openai.OpenAIChatRequest)
require.True(t, ok)
assert.Empty(t, openaiReq.ChatParameters.ExtraParams,
"ExtraParams should be empty when passthrough header is not set")
})
t.Run("extra_params extracted with passthrough header", func(t *testing.T) {
req := chatRoute.GetRequestTypeInstance(context.Background())
err := sonic.Unmarshal(rawBody, req)
require.NoError(t, err)
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
if rws, ok := req.(RequestWithSettableExtraParams); ok {
var wrapper struct {
ExtraParams map[string]interface{} `json:"extra_params"`
}
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
rws.SetExtraParams(wrapper.ExtraParams)
}
}
}
openaiReq, ok := req.(*openai.OpenAIChatRequest)
require.True(t, ok)
require.Contains(t, openaiReq.ChatParameters.ExtraParams, "guardrailConfig",
"guardrailConfig should be in ExtraParams when passthrough header is set")
gc, ok := openaiReq.ChatParameters.ExtraParams["guardrailConfig"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "my-guardrail", gc["guardrailIdentifier"])
assert.Equal(t, "1", gc["guardrailVersion"])
assert.Equal(t, "disabled", gc["trace"])
})
}
func TestExtraParamsPassthrough_NestedStructures(t *testing.T) {
rawBody := []byte(`{
"model": "openai/gpt-4o-mini",
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
"extra_params": {
"custom_param": "value",
"another_param": 123,
"nested": {
"deep_field": "deep_value",
"deeper": {"level": 3}
}
}
}`)
req := &openai.OpenAIChatRequest{}
err := sonic.Unmarshal(rawBody, req)
require.NoError(t, err)
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
if rws, ok := interface{}(req).(RequestWithSettableExtraParams); ok {
var wrapper struct {
ExtraParams map[string]interface{} `json:"extra_params"`
}
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
rws.SetExtraParams(wrapper.ExtraParams)
}
}
}
require.Len(t, req.ChatParameters.ExtraParams, 3)
assert.Equal(t, "value", req.ChatParameters.ExtraParams["custom_param"])
assert.Equal(t, float64(123), req.ChatParameters.ExtraParams["another_param"])
nested, ok := req.ChatParameters.ExtraParams["nested"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "deep_value", nested["deep_field"])
}
func TestExtraParamsPassthrough_EndToEnd(t *testing.T) {
rawJSON := []byte(`{
"model": "bedrock/claude-4-5-sonnet-global",
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
"stream": false,
"temperature": 0.7,
"extra_params": {
"guardrailConfig": {
"guardrailIdentifier": "my-guardrail",
"guardrailVersion": "1",
"trace": "disabled"
}
}
}`)
req := &openai.OpenAIChatRequest{}
err := sonic.Unmarshal(rawJSON, req)
require.NoError(t, err)
assert.Equal(t, "bedrock/claude-4-5-sonnet-global", req.Model)
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
if rws, ok := interface{}(req).(RequestWithSettableExtraParams); ok {
var wrapper struct {
ExtraParams map[string]interface{} `json:"extra_params"`
}
if err := sonic.Unmarshal(rawJSON, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
rws.SetExtraParams(wrapper.ExtraParams)
}
}
}
bifrostReq := req.ToBifrostChatRequest(bifrostCtx)
require.NotNil(t, bifrostReq)
require.NotNil(t, bifrostReq.Params)
require.Contains(t, bifrostReq.Params.ExtraParams, "guardrailConfig")
gc, ok := bifrostReq.Params.ExtraParams["guardrailConfig"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "my-guardrail", gc["guardrailIdentifier"])
assert.Equal(t, "1", gc["guardrailVersion"])
assert.Equal(t, "disabled", gc["trace"])
assert.NotContains(t, bifrostReq.Params.ExtraParams, "model")
assert.NotContains(t, bifrostReq.Params.ExtraParams, "messages")
assert.NotContains(t, bifrostReq.Params.ExtraParams, "stream")
assert.NotContains(t, bifrostReq.Params.ExtraParams, "temperature")
}
func TestExtraParamsPassthrough_NoExtraParamsKey(t *testing.T) {
rawBody := []byte(`{
"model": "openai/gpt-4o-mini",
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}]
}`)
req := &openai.OpenAIChatRequest{}
err := sonic.Unmarshal(rawBody, req)
require.NoError(t, err)
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
bifrostCtx.SetValue(schemas.BifrostContextKeyPassthroughExtraParams, true)
if bifrostCtx.Value(schemas.BifrostContextKeyPassthroughExtraParams) == true {
if rws, ok := interface{}(req).(RequestWithSettableExtraParams); ok {
var wrapper struct {
ExtraParams map[string]interface{} `json:"extra_params"`
}
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
rws.SetExtraParams(wrapper.ExtraParams)
}
_ = rws
}
}
assert.Empty(t, req.ChatParameters.ExtraParams,
"ExtraParams should be empty when extra_params key is absent from JSON")
}
// TestExtraParamsSetViaInterfaceMutatesOriginalReq verifies that setting extra
// params through the RequestWithSettableExtraParams interface assertion mutates
// the original req (interface{}) value. This matters because createHandler
// passes req to config.RequestConverter after the extra params block -- both
// variables must reference the same underlying struct via pointer semantics.
func TestExtraParamsSetViaInterfaceMutatesOriginalReq(t *testing.T) {
handlerStore := &mockHandlerStore{allowDirectKeys: true}
routes := CreateOpenAIRouteConfigs("/openai", handlerStore)
var chatRoute *RouteConfig
for i := range routes {
if routes[i].Path == "/openai/v1/chat/completions" {
chatRoute = &routes[i]
break
}
}
require.NotNil(t, chatRoute)
rawBody := []byte(`{
"model": "bedrock/claude-4-5-sonnet-global",
"messages": [{"role": "user", "content": [{"type": "text", "text": "hello"}]}],
"extra_params": {
"guardrailConfig": {
"guardrailIdentifier": "my-guardrail",
"guardrailVersion": "1"
}
}
}`)
// Simulate the exact flow in createHandler:
// 1. req is created via GetRequestTypeInstance (returns interface{})
// 2. JSON is unmarshalled into req
// 3. rws type assertion is used to call SetExtraParams
// 4. req (not rws) is passed to RequestConverter downstream
req := chatRoute.GetRequestTypeInstance(context.Background()) // returns interface{}
err := sonic.Unmarshal(rawBody, req)
require.NoError(t, err)
// Type-assert and set extra params (same as router code)
if rws, ok := req.(RequestWithSettableExtraParams); ok {
var wrapper struct {
ExtraParams map[string]interface{} `json:"extra_params"`
}
if err := sonic.Unmarshal(rawBody, &wrapper); err == nil && len(wrapper.ExtraParams) > 0 {
rws.SetExtraParams(wrapper.ExtraParams)
}
}
// Verify that req (the original interface{} variable) was mutated
openaiReq, ok := req.(*openai.OpenAIChatRequest)
require.True(t, ok)
require.Contains(t, openaiReq.ChatParameters.ExtraParams, "guardrailConfig",
"original req should be mutated via pointer semantics")
// Verify the full downstream path: RequestConverter uses req
bifrostCtx := schemas.NewBifrostContext(nil, schemas.NoDeadline)
bifrostReq, err := chatRoute.RequestConverter(bifrostCtx, req)
require.NoError(t, err)
require.NotNil(t, bifrostReq)
require.NotNil(t, bifrostReq.ChatRequest)
require.NotNil(t, bifrostReq.ChatRequest.Params)
assert.Contains(t, bifrostReq.ChatRequest.Params.ExtraParams, "guardrailConfig",
"extra params should propagate through RequestConverter to BifrostChatRequest")
}

View File

@@ -0,0 +1,502 @@
package integrations
import (
"bytes"
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
"github.com/bytedance/sonic"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/providers/gemini"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/kvstore"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
var bifrostContextKeyProvider = schemas.BifrostContextKey("provider")
var availableIntegrations = []string{
"openai",
"anthropic",
"genai",
"litellm",
"langchain",
"bedrock",
"pydantic",
"cohere",
}
// newBifrostErrorWithCode is like newBifrostError but sets an explicit HTTP status code.
func newBifrostErrorWithCode(err error, message string, statusCode int) *schemas.BifrostError {
e := newBifrostError(err, message)
e.StatusCode = &statusCode
return e
}
// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false.
// This helper function reduces code duplication when handling non-Bifrost errors.
func newBifrostError(err error, message string) *schemas.BifrostError {
if err == nil {
return &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: message,
},
}
}
return &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: message,
Error: err,
},
}
}
// safeGetRequestType safely obtains the request type from a BifrostStreamChunk chunk.
// It checks multiple sources in order of preference:
// 1. Response ExtraFields if any response is available
// 2. BifrostError ExtraFields if error is available and not nil
// 3. Falls back to "unknown" if no source is available
func safeGetRequestType(chunk *schemas.BifrostStreamChunk) string {
if chunk == nil {
return "unknown"
}
// Try to get RequestType from response ExtraFields (preferred source)
switch {
case chunk.BifrostTextCompletionResponse != nil:
return string(chunk.BifrostTextCompletionResponse.ExtraFields.RequestType)
case chunk.BifrostChatResponse != nil:
return string(chunk.BifrostChatResponse.ExtraFields.RequestType)
case chunk.BifrostResponsesStreamResponse != nil:
return string(chunk.BifrostResponsesStreamResponse.ExtraFields.RequestType)
case chunk.BifrostSpeechStreamResponse != nil:
return string(chunk.BifrostSpeechStreamResponse.ExtraFields.RequestType)
case chunk.BifrostTranscriptionStreamResponse != nil:
return string(chunk.BifrostTranscriptionStreamResponse.ExtraFields.RequestType)
}
// Try to get RequestType from error ExtraFields (fallback)
if chunk.BifrostError != nil && chunk.BifrostError.ExtraFields.RequestType != "" {
return string(chunk.BifrostError.ExtraFields.RequestType)
}
// Final fallback
return "unknown"
}
// extractHeadersFromRequest extracts headers from the request and returns them as a map.
// It uses the fasthttp.RequestCtx.Header.All() method to iterate over all headers.
func extractHeadersFromRequest(ctx *fasthttp.RequestCtx) map[string][]string {
headers := make(map[string][]string)
for key, value := range ctx.Request.Header.All() {
keyStr := string(key)
headers[keyStr] = append(headers[keyStr], string(value))
}
return headers
}
// extractExactPath returns the request path *after* the integration prefix,
// preserving the original query string exactly as sent by the client.
//
// Example:
//
// /openai/v1/chat/completions?model=gpt-4o -> v1/chat/completions?model=gpt-4o
func extractExactPath(ctx *fasthttp.RequestCtx) string {
// ctx.Path() returns only the path (no query) as a []byte backed by fasthttps internal buffers.
// Treat it as read-only; dont append to it directly.
path := ctx.Path() // e.g. "/openai/v1/chat/completions"
// Strip the integration prefix only if its at the start.
for _, integration := range availableIntegrations {
if bytes.HasPrefix(path, []byte("/"+integration+"/")) {
path = path[len("/"+integration+"/"):]
break
}
}
// Raw query string as sent by client (unparsed, preserves ordering/duplicates/encoding).
q := ctx.URI().QueryString() // e.g. "model=gpt-4o&stream=true"
if len(q) == 0 {
// No query → just return the (possibly trimmed) path.
return string(path)
}
// --- Build "<path>?<query>" efficiently and safely ---
//
// Why not do: return string(path) + "?" + string(q) ?
// - That allocates multiple temporary strings and may copy data more than necessary.
//
// Why not append into 'path' directly?
// - 'path' may alias fasthttps internal buffers; mutating/expanding it could corrupt request state.
//
// We instead allocate a new buffer with exact capacity and copy into it,
// staying in []byte until the final string conversion (1 allocation for the new slice).
out := make([]byte, 0, len(path)+1+len(q)) // pre-size: path + "?" + query
out = append(out, path...) // copy path bytes
out = append(out, '?') // separator
out = append(out, q...) // copy raw query bytes
return string(out)
}
// sendStreamError sends an error response for a streaming request that failed before streaming started.
// It propagates the provider's HTTP status code and returns a JSON error body (not SSE format),
// since no streaming has begun and clients should receive a standard error response.
func (g *GenericRouter) sendStreamError(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, config RouteConfig, bifrostErr *schemas.BifrostError) {
// Forward provider response headers from context so streaming error responses include them
if bifrostCtx != nil {
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
for key, value := range headers {
ctx.Response.Header.Set(key, value)
}
}
}
// Set the HTTP status code from the provider error
if bifrostErr.StatusCode != nil {
ctx.SetStatusCode(*bifrostErr.StatusCode)
} else {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
}
ctx.SetContentType("application/json")
// Always use the route-level ErrorConverter (not StreamConfig.ErrorConverter) because
// sendStreamError returns JSON, not SSE. StreamConfig.ErrorConverter is designed for
// in-stream SSE errors (e.g., Anthropic's returns a raw SSE string that would be
// double-escaped by JSON marshaling).
errorResponse := config.ErrorConverter(bifrostCtx, bifrostErr)
errorJSON, err := sonic.Marshal(errorResponse)
if err != nil {
g.logger.Error("failed to marshal error response", "err", err, "path", extractExactPath(ctx))
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetContentType("text/plain; charset=utf-8")
ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", err))
return
}
ctx.SetBody(errorJSON)
}
// sendError sends an error response with the appropriate status code and JSON body.
// It handles different error types (string, error interface, or arbitrary objects).
func (g *GenericRouter) sendError(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, errorConverter ErrorConverter, bifrostErr *schemas.BifrostError) {
// Forward provider response headers from context so error responses include them
if bifrostCtx != nil {
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
for key, value := range headers {
ctx.Response.Header.Set(key, value)
}
}
}
if bifrostErr.StatusCode != nil {
ctx.SetStatusCode(*bifrostErr.StatusCode)
} else {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
}
ctx.SetContentType("application/json")
// Marshal the error for response and log the error for diagnostics
responseObj := errorConverter(bifrostCtx, bifrostErr)
errorBody, err := sonic.Marshal(responseObj)
if err != nil {
// Log the marshal failure and return a plain text error
g.logger.Error("failed to marshal error response", "err", err, "path", extractExactPath(ctx))
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetContentType("text/plain; charset=utf-8")
ctx.SetBodyString(fmt.Sprintf("failed to encode error response: %v", err))
return
}
ctx.SetBody(errorBody)
}
// sendSuccess sends a successful response with HTTP 200 status and JSON body.
func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext, errorConverter ErrorConverter, response interface{}, extraHeaders map[string]string) {
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetContentType("application/json")
if extraHeaders != nil {
for key, value := range extraHeaders {
ctx.Response.Header.Set(key, value)
}
}
responseBody, err := sonic.Marshal(response)
if err != nil {
g.sendError(ctx, bifrostCtx, errorConverter, newBifrostError(err, "failed to encode response"))
return
}
ctx.SetBody(responseBody)
}
// tryStreamLargeResponse checks if large response mode was activated by the provider,
// sets the transport marker, and streams the response directly to the client.
// Returns true if the response was handled (caller should return).
func (g *GenericRouter) tryStreamLargeResponse(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
isLargeResponse, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargeResponseMode).(bool)
if !ok || !isLargeResponse {
return false
}
// Forward provider response headers before streaming — providers store them in
// context via BifrostContextKeyProviderResponseHeaders, but some early-return
// branches in the router skip the common footer that normally forwards them.
if headers, ok := bifrostCtx.Value(schemas.BifrostContextKeyProviderResponseHeaders).(map[string]string); ok {
for key, value := range headers {
ctx.Response.Header.Set(key, value)
}
}
if g.streamLargeResponse(ctx, bifrostCtx) {
ctx.SetUserValue(lib.FastHTTPUserValueLargeResponseMode, true)
}
return true
}
// streamLargeResponse streams the large response body directly from the upstream provider to the client.
// This bypasses the normal serialize → set body path, piping the response bytes unchanged.
func (g *GenericRouter) streamLargeResponse(ctx *fasthttp.RequestCtx, bifrostCtx *schemas.BifrostContext) bool {
// Enterprise hook: wrap the reader with Phase B scanning (e.g., usage extraction
// from the full response stream) before streaming to client.
if g.largeResponseHook != nil {
g.largeResponseHook(ctx, bifrostCtx)
}
if !lib.StreamLargeResponseBody(ctx, bifrostCtx) {
ctx.SetStatusCode(fasthttp.StatusInternalServerError)
ctx.SetBodyString("large response reader not available")
return false
}
return true
}
// extractAndParseFallbacks extracts fallbacks from the integration request and adds them to the BifrostRequest
func (g *GenericRouter) extractAndParseFallbacks(req interface{}, bifrostReq *schemas.BifrostRequest) error {
// Check if the request has a fallbacks field ([]string)
fallbacks, err := g.extractFallbacksFromRequest(req)
if err != nil {
return fmt.Errorf("failed to extract fallbacks: %w", err)
}
if len(fallbacks) == 0 {
return nil // No fallbacks to process
}
provider, _, _ := bifrostReq.GetRequestFields()
// Parse fallbacks from strings to Fallback structs
parsedFallbacks := make([]schemas.Fallback, 0, len(fallbacks))
for _, fallbackStr := range fallbacks {
if fallbackStr == "" {
continue // Skip empty strings
}
// Use ParseModelString to extract provider and model
provider, model := schemas.ParseModelString(fallbackStr, provider)
parsedFallback := schemas.Fallback{
Provider: provider,
Model: model,
}
parsedFallbacks = append(parsedFallbacks, parsedFallback)
}
if len(parsedFallbacks) == 0 {
return nil // No valid fallbacks found
}
// Add fallbacks to the main BifrostRequest
bifrostReq.SetFallbacks(parsedFallbacks)
// Also add fallbacks to the specific request type if it exists
switch bifrostReq.RequestType {
case schemas.TextCompletionRequest, schemas.TextCompletionStreamRequest:
if bifrostReq.TextCompletionRequest != nil {
bifrostReq.TextCompletionRequest.Fallbacks = parsedFallbacks
}
case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest:
if bifrostReq.ChatRequest != nil {
bifrostReq.ChatRequest.Fallbacks = parsedFallbacks
}
case schemas.ResponsesRequest, schemas.ResponsesStreamRequest:
if bifrostReq.ResponsesRequest != nil {
bifrostReq.ResponsesRequest.Fallbacks = parsedFallbacks
}
case schemas.EmbeddingRequest:
if bifrostReq.EmbeddingRequest != nil {
bifrostReq.EmbeddingRequest.Fallbacks = parsedFallbacks
}
case schemas.RerankRequest:
if bifrostReq.RerankRequest != nil {
bifrostReq.RerankRequest.Fallbacks = parsedFallbacks
}
case schemas.SpeechRequest, schemas.SpeechStreamRequest:
if bifrostReq.SpeechRequest != nil {
bifrostReq.SpeechRequest.Fallbacks = parsedFallbacks
}
case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest:
if bifrostReq.TranscriptionRequest != nil {
bifrostReq.TranscriptionRequest.Fallbacks = parsedFallbacks
}
case schemas.ImageGenerationRequest, schemas.ImageGenerationStreamRequest:
if bifrostReq.ImageGenerationRequest != nil {
bifrostReq.ImageGenerationRequest.Fallbacks = parsedFallbacks
}
}
return nil
}
// extractFallbacksFromRequest uses reflection to extract fallbacks field from any request type
func (g *GenericRouter) extractFallbacksFromRequest(req interface{}) ([]string, error) {
if req == nil {
return nil, nil
}
// Try to use reflection to find a "fallbacks" field
reqValue := reflect.ValueOf(req)
if reqValue.Kind() == reflect.Ptr {
reqValue = reqValue.Elem()
}
if reqValue.Kind() != reflect.Struct {
return nil, nil // Not a struct, no fallbacks
}
// Look for the "fallbacks" field
fallbacksField := reqValue.FieldByName("fallbacks")
if !fallbacksField.IsValid() {
return nil, nil // No fallbacks field found
}
// Handle different types of fallbacks field
switch fallbacksField.Kind() {
case reflect.Slice:
if fallbacksField.Type().Elem().Kind() == reflect.String {
// []string case
fallbacks := make([]string, fallbacksField.Len())
for i := 0; i < fallbacksField.Len(); i++ {
fallbacks[i] = fallbacksField.Index(i).String()
}
return fallbacks, nil
}
case reflect.String:
// Single string case - treat as one fallback
return []string{fallbacksField.String()}, nil
}
return nil, nil
}
// getVirtualKeyFromBifrostContext extracts the virtual key value from bifrost context.
// Returns nil if no VK is present (e.g., direct key mode or no governance).
func getVirtualKeyFromBifrostContext(ctx *schemas.BifrostContext) *string {
vkValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)
if vkValue == "" {
return nil
}
return &vkValue
}
// getResultTTLFromHeaderWithDefault extracts the result TTL from the x-bf-async-job-result-ttl header.
// Returns the default TTL if the header is not present or invalid.
func getResultTTLFromHeaderWithDefault(ctx *fasthttp.RequestCtx, defaultTTL int) int {
resultTTL := string(ctx.Request.Header.Peek(schemas.AsyncHeaderResultTTL))
if resultTTL == "" {
return defaultTTL
}
resultTTLInt, err := strconv.Atoi(resultTTL)
if err != nil || resultTTLInt < 0 {
return defaultTTL
}
return resultTTLInt
}
// isAnthropicAPIKeyAuth checks if the request uses standard API key authentication.
// Returns true for API key auth (x-api-key header), false for OAuth (Bearer sk-ant-oat*).
// This is required for Claude Code specifically, which may use OAuth authentication.
// Default behavior is to assume API mode when neither x-api-key nor OAuth token is present.
func isAnthropicAPIKeyAuth(ctx *fasthttp.RequestCtx) bool {
// If x-api-key header is present - this is definitely API mode
if apiKey := string(ctx.Request.Header.Peek("x-api-key")); apiKey != "" {
return true
}
// Check for OAuth token in Authorization header
if authHeader := string(ctx.Request.Header.Peek("Authorization")); authHeader != "" {
if strings.HasPrefix(strings.ToLower(authHeader), "bearer sk-ant-oat") {
return false // OAuth mode, NOT API
}
}
// Default to API mode
return true
}
// resolveLargePayloadMetadata returns metadata from the sync context key,
// falling back to a non-blocking read from the deferred channel.
// If deferred metadata is resolved, it is cached in the sync key for later readers.
func resolveLargePayloadMetadata(bifrostCtx *schemas.BifrostContext) *schemas.LargePayloadMetadata {
if bifrostCtx == nil {
return nil
}
if metadata, ok := bifrostCtx.Value(schemas.BifrostContextKeyLargePayloadMetadata).(*schemas.LargePayloadMetadata); ok && metadata != nil {
return metadata
}
ch, ok := bifrostCtx.Value(schemas.BifrostContextKeyDeferredLargePayloadMetadata).(<-chan *schemas.LargePayloadMetadata)
if !ok || ch == nil {
return nil
}
select {
case metadata := <-ch:
if metadata != nil {
bifrostCtx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, metadata)
}
return metadata
default:
return nil
}
}
// ParseProviderScopedVideoID parses a provider-scoped video ID in the form "id:provider".
// The ID portion is automatically URL-decoded to restore the original ID.
func ParseProviderScopedVideoID(videoID string) (schemas.ModelProvider, string, error) {
parts := strings.SplitN(videoID, ":", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", fmt.Errorf("video_id must be in id:provider format")
}
provider := schemas.ModelProvider(parts[1])
rawID := parts[0]
// URL decode the ID to restore original characters (e.g., %2F -> /)
// This handles IDs from all providers that may contain special characters
if decoded, err := url.PathUnescape(rawID); err == nil {
rawID = decoded
}
return provider, rawID, nil
}
func getProviderFromHeader(ctx *fasthttp.RequestCtx, defaultProvider schemas.ModelProvider) schemas.ModelProvider {
providerHeader := string(ctx.Request.Header.Peek("x-model-provider"))
if providerHeader == "" {
return defaultProvider
}
return schemas.ModelProvider(providerHeader)
}
func RegisterKVDecoders(store *kvstore.Store) {
store.RegisterDecoder("genai_upload_session:", func(data []byte) (any, error) {
var v gemini.GeminiResumableUploadSession
return &v, sonic.Unmarshal(data, &v)
})
}

View File

@@ -0,0 +1,277 @@
package integrations
import (
"context"
"strings"
"testing"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/providers/anthropic"
"github.com/maximhq/bifrost/core/providers/bedrock"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
// testLogger implements schemas.Logger for testing (all no-ops)
type testLogger struct{}
func (t *testLogger) Debug(msg string, args ...any) {}
func (t *testLogger) Info(msg string, args ...any) {}
func (t *testLogger) Warn(msg string, args ...any) {}
func (t *testLogger) Error(msg string, args ...any) {}
func (t *testLogger) Fatal(msg string, args ...any) {}
func (t *testLogger) SetLevel(level schemas.LogLevel) {}
func (t *testLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
func (t *testLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
var _ schemas.Logger = (*testLogger)(nil)
func ptr(i int) *int {
return &i
}
func strPtr(s string) *string {
return &s
}
func newTestGenericRouter() *GenericRouter {
return NewGenericRouter(nil, &mockHandlerStore{}, nil, nil, &testLogger{})
}
func newTestBifrostContext() *schemas.BifrostContext {
return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
}
// TestSendStreamError_PropagatesProviderStatusCode verifies that sendStreamError
// sets the HTTP status code from the provider's BifrostError.StatusCode field.
// All three providers (OpenAI, Anthropic, Bedrock) return actual HTTP error codes
// for pre-stream errors, so Bifrost must propagate them faithfully.
func TestSendStreamError_PropagatesProviderStatusCode(t *testing.T) {
tests := []struct {
name string
statusCode *int
expectedStatusCode int
}{
{
name: "provider 400 - Bedrock ValidationException / OpenAI invalid_request_error",
statusCode: ptr(400),
expectedStatusCode: 400,
},
{
name: "provider 429 - rate limiting (all providers)",
statusCode: ptr(429),
expectedStatusCode: 429,
},
{
name: "provider 503 - Bedrock ServiceUnavailableException",
statusCode: ptr(503),
expectedStatusCode: 503,
},
{
name: "provider 529 - Anthropic overloaded_error",
statusCode: ptr(529),
expectedStatusCode: 529,
},
{
name: "nil StatusCode defaults to 500",
statusCode: nil,
expectedStatusCode: 500,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := newTestGenericRouter()
ctx := &fasthttp.RequestCtx{}
bifrostCtx := newTestBifrostContext()
bifrostErr := &schemas.BifrostError{
StatusCode: tt.statusCode,
Error: &schemas.ErrorField{
Message: "test error",
},
}
config := RouteConfig{
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return err
},
}
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
assert.Equal(t, tt.expectedStatusCode, ctx.Response.StatusCode())
assert.Equal(t, "application/json", string(ctx.Response.Header.ContentType()))
body := string(ctx.Response.Body())
assert.True(t, sonic.Valid(ctx.Response.Body()), "response body should be valid JSON, got: %s", body)
assert.False(t, strings.HasPrefix(body, "data: "), "response should not be SSE format")
})
}
}
// TestSendStreamError_OpenAIErrorFormat verifies the response body matches the
// OpenAI error format. OpenAI's ErrorConverter returns *schemas.BifrostError directly,
// which serializes to {"is_bifrost_error":false,"status_code":400,"error":{...}}.
func TestSendStreamError_OpenAIErrorFormat(t *testing.T) {
router := newTestGenericRouter()
ctx := &fasthttp.RequestCtx{}
bifrostCtx := newTestBifrostContext()
bifrostErr := &schemas.BifrostError{
IsBifrostError: false,
StatusCode: ptr(400),
Error: &schemas.ErrorField{
Type: strPtr("invalid_request_error"),
Message: "content is empty",
},
}
config := RouteConfig{
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return err
},
}
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
assert.Equal(t, 400, ctx.Response.StatusCode())
// Unmarshal and verify the structure
var result map[string]interface{}
err := sonic.Unmarshal(ctx.Response.Body(), &result)
require.NoError(t, err)
assert.Contains(t, result, "is_bifrost_error")
assert.Contains(t, result, "status_code")
assert.Contains(t, result, "error")
assert.Equal(t, false, result["is_bifrost_error"])
errorObj, ok := result["error"].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, "invalid_request_error", errorObj["type"])
assert.Equal(t, "content is empty", errorObj["message"])
}
// TestSendStreamError_AnthropicErrorFormat verifies the response body matches the
// Anthropic error format: {"type":"error","error":{"type":"...","message":"..."}}.
// Critically, it also verifies that the StreamConfig.ErrorConverter (which returns
// raw SSE strings) is NOT used — sendStreamError must use the route-level ErrorConverter.
func TestSendStreamError_AnthropicErrorFormat(t *testing.T) {
router := newTestGenericRouter()
ctx := &fasthttp.RequestCtx{}
bifrostCtx := newTestBifrostContext()
bifrostErr := &schemas.BifrostError{
StatusCode: ptr(429),
Error: &schemas.ErrorField{
Type: strPtr("rate_limit_error"),
Message: "rate limited",
},
}
config := RouteConfig{
// Route-level: returns JSON-marshallable *AnthropicMessageError
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return anthropic.ToAnthropicChatCompletionError(err)
},
// Stream-level: returns raw SSE string — should NOT be used by sendStreamError
StreamConfig: &StreamConfig{
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return anthropic.ToAnthropicResponsesStreamError(err)
},
},
}
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
assert.Equal(t, 429, ctx.Response.StatusCode())
assert.Equal(t, "application/json", string(ctx.Response.Header.ContentType()))
body := string(ctx.Response.Body())
// Must NOT contain SSE markers — that would mean StreamConfig.ErrorConverter was used
assert.NotContains(t, body, "event: error", "response should not contain SSE event markers")
// Unmarshal and verify Anthropic error structure
var result anthropic.AnthropicMessageError
err := sonic.Unmarshal(ctx.Response.Body(), &result)
require.NoError(t, err)
assert.Equal(t, "error", result.Type)
assert.Equal(t, "rate_limit_error", result.Error.Type)
assert.Equal(t, "rate limited", result.Error.Message)
}
// TestSendStreamError_BedrockErrorFormat verifies the response body matches the
// Bedrock error format: {"__type":"ValidationException","message":"..."}.
func TestSendStreamError_BedrockErrorFormat(t *testing.T) {
router := newTestGenericRouter()
ctx := &fasthttp.RequestCtx{}
bifrostCtx := newTestBifrostContext()
bifrostErr := &schemas.BifrostError{
StatusCode: ptr(400),
Error: &schemas.ErrorField{
Code: strPtr("ValidationException"),
Message: "validation error",
},
}
config := RouteConfig{
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return bedrock.ToBedrockError(err)
},
}
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
assert.Equal(t, 400, ctx.Response.StatusCode())
// Unmarshal and verify Bedrock error structure
var result bedrock.BedrockError
err := sonic.Unmarshal(ctx.Response.Body(), &result)
require.NoError(t, err)
assert.Equal(t, "ValidationException", result.Type)
assert.Equal(t, "validation error", result.Message)
}
// TestSendStreamError_ForwardsProviderHeaders verifies that provider response headers
// stored in the BifrostContext are forwarded to the HTTP response. This ensures
// clients receive provider-specific headers (e.g., x-amzn-requestid for Bedrock,
// x-request-id for Anthropic) even in error scenarios.
func TestSendStreamError_ForwardsProviderHeaders(t *testing.T) {
router := newTestGenericRouter()
ctx := &fasthttp.RequestCtx{}
bifrostCtx := newTestBifrostContext()
// Set provider response headers on the context
bifrostCtx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, map[string]string{
"x-amzn-requestid": "req-123",
"x-amzn-errortype": "ValidationException",
})
bifrostErr := &schemas.BifrostError{
StatusCode: ptr(400),
Error: &schemas.ErrorField{
Message: "validation error",
},
}
config := RouteConfig{
ErrorConverter: func(ctx *schemas.BifrostContext, err *schemas.BifrostError) interface{} {
return err
},
}
router.sendStreamError(ctx, bifrostCtx, config, bifrostErr)
assert.Equal(t, 400, ctx.Response.StatusCode())
assert.Equal(t, "req-123", string(ctx.Response.Header.Peek("x-amzn-requestid")))
assert.Equal(t, "ValidationException", string(ctx.Response.Header.Peek("x-amzn-errortype")))
}