Files
bifrost/core/providers/anthropic/utils_test.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

1985 lines
66 KiB
Go

package anthropic
import (
"context"
"encoding/json"
"fmt"
"reflect"
"slices"
"sort"
"testing"
"time"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
func TestExtractTypesFromValue(t *testing.T) {
tests := []struct {
name string
input interface{}
expected []string
}{
{
name: "string type",
input: "string",
expected: []string{"string"},
},
{
name: "[]string array",
input: []string{"string", "null"},
expected: []string{"string", "null"},
},
{
name: "[]interface{} array",
input: []interface{}{"string", "integer", "null"},
expected: []string{"string", "integer", "null"},
},
{
name: "[]interface{} with non-string items (filtered out)",
input: []interface{}{"string", 123, "null"},
expected: []string{"string", "null"},
},
{
name: "unsupported type returns nil",
input: 123,
expected: nil,
},
{
name: "nil returns nil",
input: nil,
expected: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractTypesFromValue(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("extractTypesFromValue() mismatch:\ngot: %+v\nwant: %+v", result, tt.expected)
}
})
}
}
func TestNormalizeSchemaForAnthropic(t *testing.T) {
tests := []struct {
name string
input map[string]interface{}
expected map[string]interface{}
}{
{
name: "type array with string and null - converts to anyOf",
input: map[string]interface{}{
"type": []interface{}{"string", "null"},
"description": "A nullable string field",
"enum": []string{"value1", "value2", ""},
},
expected: map[string]interface{}{
"description": "A nullable string field",
"anyOf": []interface{}{
map[string]interface{}{
"type": "string",
"enum": []string{"value1", "value2", ""},
},
map[string]interface{}{"type": "null"},
},
},
},
{
name: "type array with null and string - converts to anyOf",
input: map[string]interface{}{
"type": []interface{}{"null", "string"},
"description": "A nullable string field",
"enum": []string{"NODE-0", "NODE-1", ""},
},
expected: map[string]interface{}{
"description": "A nullable string field",
"anyOf": []interface{}{
map[string]interface{}{
"type": "string",
"enum": []string{"NODE-0", "NODE-1", ""},
},
map[string]interface{}{"type": "null"},
},
},
},
{
name: "type array as []string format with null - converts to anyOf",
input: map[string]interface{}{
"type": []string{"string", "null"},
"enum": []string{"option1", "option2"},
},
expected: map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{
"type": "string",
"enum": []string{"option1", "option2"},
},
map[string]interface{}{"type": "null"},
},
},
},
{
name: "type array with single type (no null) - keeps as simple type",
input: map[string]interface{}{
"type": []string{"string"},
"enum": []string{"option1", "option2"},
},
expected: map[string]interface{}{
"type": "string",
"enum": []string{"option1", "option2"},
},
},
{
name: "regular string type - no change",
input: map[string]interface{}{
"type": "string",
"description": "A regular string field",
},
expected: map[string]interface{}{
"type": "string",
"description": "A regular string field",
},
},
{
name: "nested properties with nullable type arrays - converts to anyOf",
input: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"field1": map[string]interface{}{
"type": []interface{}{"string", "null"},
"enum": []string{"a", "b"},
},
"field2": map[string]interface{}{
"type": "number",
},
},
},
expected: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"field1": map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{
"type": "string",
"enum": []string{"a", "b"},
},
map[string]interface{}{"type": "null"},
},
},
"field2": map[string]interface{}{
"type": "number",
},
},
},
},
{
name: "array items with nullable type array - converts to anyOf",
input: map[string]interface{}{
"type": "array",
"items": map[string]interface{}{
"type": []interface{}{"string", "null"},
"enum": []string{"x", "y", "z"},
},
},
expected: map[string]interface{}{
"type": "array",
"items": map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{
"type": "string",
"enum": []string{"x", "y", "z"},
},
map[string]interface{}{"type": "null"},
},
},
},
},
{
name: "anyOf with type arrays - nested anyOf gets flattened conceptually",
input: map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{
"type": []interface{}{"string", "null"},
},
map[string]interface{}{
"type": "number",
},
},
},
expected: map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{"type": "string"},
map[string]interface{}{"type": "null"},
},
},
map[string]interface{}{
"type": "number",
},
},
},
},
{
name: "oneOf with nullable type arrays",
input: map[string]interface{}{
"oneOf": []interface{}{
map[string]interface{}{
"type": []interface{}{"string", "null"},
},
},
},
expected: map[string]interface{}{
"oneOf": []interface{}{
map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{"type": "string"},
map[string]interface{}{"type": "null"},
},
},
},
},
},
{
name: "allOf with nullable type arrays",
input: map[string]interface{}{
"allOf": []interface{}{
map[string]interface{}{
"type": []interface{}{"string", "null"},
},
},
},
expected: map[string]interface{}{
"allOf": []interface{}{
map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{"type": "string"},
map[string]interface{}{"type": "null"},
},
},
},
},
},
{
name: "definitions with nullable type arrays",
input: map[string]interface{}{
"definitions": map[string]interface{}{
"myDef": map[string]interface{}{
"type": []interface{}{"string", "null"},
},
},
},
expected: map[string]interface{}{
"definitions": map[string]interface{}{
"myDef": map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{"type": "string"},
map[string]interface{}{"type": "null"},
},
},
},
},
},
{
name: "$defs with nullable type arrays",
input: map[string]interface{}{
"$defs": map[string]interface{}{
"myDef": map[string]interface{}{
"type": []interface{}{"string", "null"},
},
},
},
expected: map[string]interface{}{
"$defs": map[string]interface{}{
"myDef": map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{"type": "string"},
map[string]interface{}{"type": "null"},
},
},
},
},
},
{
name: "complex nested schema - real world example with nullable enum",
input: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"enum": []string{"continue", "transition"},
},
"target_node_id": map[string]interface{}{
"type": []interface{}{"string", "null"},
"description": "The ID of the node to transition to. Required when action is 'transition', null when action is 'continue'",
"enum": []string{"NODE-0", "NODE-1", "NODE-2", ""},
},
},
"required": []string{"action"},
},
expected: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"enum": []string{"continue", "transition"},
},
"target_node_id": map[string]interface{}{
"description": "The ID of the node to transition to. Required when action is 'transition', null when action is 'continue'",
"anyOf": []interface{}{
map[string]interface{}{
"type": "string",
"enum": []string{"NODE-0", "NODE-1", "NODE-2", ""},
},
map[string]interface{}{"type": "null"},
},
},
},
"required": []string{"action"},
},
},
{
name: "nil schema - returns nil",
input: nil,
expected: nil,
},
{
name: "empty schema - returns empty",
input: map[string]interface{}{},
expected: map[string]interface{}{},
},
{
name: "type array with multiple non-null types - converts to anyOf",
input: map[string]interface{}{
"type": []interface{}{"string", "integer"},
"description": "A field that can be string or integer",
},
expected: map[string]interface{}{
"description": "A field that can be string or integer",
"anyOf": []interface{}{
map[string]interface{}{"type": "string"},
map[string]interface{}{"type": "integer"},
},
},
},
{
name: "type array with multiple types including null - converts to anyOf with null",
input: map[string]interface{}{
"type": []interface{}{"string", "integer", "null"},
"description": "A nullable field that can be string or integer",
},
expected: map[string]interface{}{
"description": "A nullable field that can be string or integer",
"anyOf": []interface{}{
map[string]interface{}{"type": "string"},
map[string]interface{}{"type": "integer"},
map[string]interface{}{"type": "null"},
},
},
},
{
name: "type array with multiple types and enum - filters enum values by type in anyOf branches",
input: map[string]interface{}{
"type": []interface{}{"string", "integer"},
"enum": []interface{}{"value1", 123},
},
expected: map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{
"type": "string",
"enum": []interface{}{"value1"},
},
map[string]interface{}{
"type": "integer",
"enum": []interface{}{123},
},
},
},
},
{
name: "nested properties with multi-type arrays - all convert to anyOf",
input: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"field1": map[string]interface{}{
"type": []interface{}{"string", "number"},
},
"field2": map[string]interface{}{
"type": []interface{}{"boolean", "null"},
},
},
},
expected: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"field1": map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{"type": "string"},
map[string]interface{}{"type": "number"},
},
},
"field2": map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{"type": "boolean"},
map[string]interface{}{"type": "null"},
},
},
},
},
},
{
name: "real world priority field with mixed string and integer enum - filters correctly",
input: map[string]interface{}{
"type": []interface{}{"string", "integer"},
"description": "Priority level - can be a number (1-10) or a string label (low/medium/high)",
"enum": []interface{}{"low", "medium", "high", 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
expected: map[string]interface{}{
"description": "Priority level - can be a number (1-10) or a string label (low/medium/high)",
"anyOf": []interface{}{
map[string]interface{}{
"type": "string",
"enum": []interface{}{"low", "medium", "high"},
},
map[string]interface{}{
"type": "integer",
"enum": []interface{}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := normalizeSchemaForAnthropic(tt.input)
// Compare using JSON marshaling to handle []string vs []interface{} differences
// Marshal both to JSON, then unmarshal back to normalized form for comparison
// This ensures we compare actual structure, not field ordering
gotJSON, err1 := sonic.Marshal(result)
wantJSON, err2 := sonic.Marshal(tt.expected)
if err1 != nil || err2 != nil {
t.Fatalf("Failed to marshal for comparison: got err=%v, want err=%v", err1, err2)
}
// Unmarshal both back to interface{} to normalize the comparison
// This handles both field ordering and []string vs []interface{} differences
var gotNormalized, wantNormalized interface{}
if err := sonic.Unmarshal(gotJSON, &gotNormalized); err != nil {
t.Fatalf("Failed to unmarshal got JSON: %v", err)
}
if err := sonic.Unmarshal(wantJSON, &wantNormalized); err != nil {
t.Fatalf("Failed to unmarshal want JSON: %v", err)
}
// Now compare the unmarshaled structures
if !reflect.DeepEqual(gotNormalized, wantNormalized) {
// Pretty print for error message
gotJSONPretty, _ := sonic.MarshalIndent(result, "", " ")
wantJSONPretty, _ := sonic.MarshalIndent(tt.expected, "", " ")
t.Errorf("normalizeSchemaForAnthropic() mismatch:\ngot: %s\nwant: %s", gotJSONPretty, wantJSONPretty)
}
})
}
}
func TestConvertChatResponseFormatToAnthropicOutputFormat(t *testing.T) {
tests := []struct {
name string
input *interface{}
expected interface{}
}{
{
name: "chat format with nullable enum gets normalized to anyOf",
input: func() *interface{} {
val := interface{}(map[string]interface{}{
"type": "json_schema",
"json_schema": map[string]interface{}{
"name": "TestSchema",
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"field": map[string]interface{}{
"type": []interface{}{"string", "null"},
"enum": []string{"value1", "value2"},
},
},
},
},
})
return &val
}(),
expected: map[string]interface{}{
"type": "json_schema",
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"field": map[string]interface{}{
"anyOf": []interface{}{
map[string]interface{}{
"type": "string",
"enum": []string{"value1", "value2"},
},
map[string]interface{}{"type": "null"},
},
},
},
},
},
},
{
name: "nil input returns nil",
input: nil,
expected: nil,
},
{
name: "non-json_schema type returns nil",
input: func() *interface{} {
val := interface{}(map[string]interface{}{
"type": "json",
})
return &val
}(),
expected: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertChatResponseFormatToAnthropicOutputFormat(tt.input)
// Compare using JSON marshaling to handle field ordering differences
resultJSON, err1 := sonic.Marshal(result)
expectedJSON, err2 := sonic.Marshal(tt.expected)
if err1 != nil || err2 != nil {
t.Fatalf("Failed to marshal for comparison: result err=%v, expected err=%v", err1, err2)
}
// Unmarshal both back to interface{} to normalize the comparison
var resultNormalized, expectedNormalized interface{}
if err := sonic.Unmarshal(resultJSON, &resultNormalized); err != nil {
t.Fatalf("Failed to unmarshal result JSON: %v", err)
}
if err := sonic.Unmarshal(expectedJSON, &expectedNormalized); err != nil {
t.Fatalf("Failed to unmarshal expected JSON: %v", err)
}
if !reflect.DeepEqual(resultNormalized, expectedNormalized) {
t.Errorf("convertChatResponseFormatToAnthropicOutputFormat() mismatch:\ngot: %+v\nwant: %+v", result, tt.expected)
}
})
}
}
func TestValidateToolsForProvider(t *testing.T) {
tests := []struct {
name string
tools []schemas.ResponsesTool
provider schemas.ModelProvider
expectErr bool
}{
{
name: "Anthropic allows web_search",
tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}},
provider: schemas.Anthropic,
expectErr: false,
},
{
name: "Anthropic allows web_fetch",
tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}},
provider: schemas.Anthropic,
expectErr: false,
},
{
name: "Vertex allows web_search",
tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}},
provider: schemas.Vertex,
expectErr: false,
},
{
name: "Vertex rejects web_fetch",
tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}},
provider: schemas.Vertex,
expectErr: true,
},
{
name: "Vertex rejects code_interpreter",
tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeCodeInterpreter}},
provider: schemas.Vertex,
expectErr: true,
},
{
name: "Vertex rejects MCP",
tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeMCP}},
provider: schemas.Vertex,
expectErr: true,
},
{
name: "Bedrock rejects web_search",
tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}},
provider: schemas.Bedrock,
expectErr: true,
},
{
name: "Bedrock rejects web_fetch",
tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}},
provider: schemas.Bedrock,
expectErr: true,
},
{
name: "Bedrock allows computer_use",
tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeComputerUsePreview}},
provider: schemas.Bedrock,
expectErr: false,
},
{
name: "Azure allows everything",
tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}, {Type: schemas.ResponsesToolTypeCodeInterpreter}, {Type: schemas.ResponsesToolTypeMCP}},
provider: schemas.Azure,
expectErr: false,
},
{
name: "Unknown provider allows all",
tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}},
provider: "custom_provider",
expectErr: false,
},
{
name: "Function tools always allowed",
tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeFunction}},
provider: schemas.Bedrock,
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateToolsForProvider(tt.tools, tt.provider)
if tt.expectErr && err == nil {
t.Errorf("expected error but got nil")
}
if !tt.expectErr && err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestAddMissingBetaHeadersToContext_PerProvider(t *testing.T) {
tests := []struct {
name string
provider schemas.ModelProvider
req *AnthropicMessageRequest
expectHeaders []string
unexpectHeaders []string
}{
{
name: "Anthropic gets structured outputs header",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{
OutputFormat: json.RawMessage(`{"type":"json_schema"}`),
},
expectHeaders: []string{AnthropicStructuredOutputsBetaHeader},
},
{
name: "Vertex skips structured outputs header",
provider: schemas.Vertex,
req: &AnthropicMessageRequest{
OutputFormat: json.RawMessage(`{"type":"json_schema"}`),
},
unexpectHeaders: []string{AnthropicStructuredOutputsBetaHeader},
},
{
name: "Vertex skips MCP header",
provider: schemas.Vertex,
req: &AnthropicMessageRequest{
MCPServers: []AnthropicMCPServerV2{{URL: "http://example.com"}},
},
unexpectHeaders: []string{AnthropicMCPClientBetaHeader},
},
{
name: "Anthropic gets MCP header",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{
MCPServers: []AnthropicMCPServerV2{{URL: "http://example.com"}},
},
expectHeaders: []string{AnthropicMCPClientBetaHeader},
},
{
name: "Vertex gets compaction header",
provider: schemas.Vertex,
req: &AnthropicMessageRequest{
ContextManagement: &ContextManagement{
Edits: []ContextManagementEdit{{Type: ContextManagementEditTypeCompact}},
},
},
expectHeaders: []string{AnthropicCompactionBetaHeader},
},
{
name: "Bedrock gets compaction header",
provider: schemas.Bedrock,
req: &AnthropicMessageRequest{
ContextManagement: &ContextManagement{
Edits: []ContextManagementEdit{{Type: ContextManagementEditTypeCompact}},
},
},
expectHeaders: []string{AnthropicCompactionBetaHeader},
},
// Interleaved thinking tests
{
name: "Anthropic gets interleaved thinking header for enabled",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{
Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: schemas.Ptr(2048)},
},
expectHeaders: []string{AnthropicInterleavedThinkingBetaHeader},
},
{
name: "Anthropic does not get interleaved thinking header for adaptive",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{
Thinking: &AnthropicThinking{Type: "adaptive"},
},
unexpectHeaders: []string{AnthropicInterleavedThinkingBetaHeader},
},
{
name: "Vertex gets interleaved thinking header",
provider: schemas.Vertex,
req: &AnthropicMessageRequest{
Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: schemas.Ptr(2048)},
},
expectHeaders: []string{AnthropicInterleavedThinkingBetaHeader},
},
{
name: "Bedrock gets interleaved thinking header",
provider: schemas.Bedrock,
req: &AnthropicMessageRequest{
Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: schemas.Ptr(2048)},
},
expectHeaders: []string{AnthropicInterleavedThinkingBetaHeader},
},
{
name: "Disabled thinking does not get interleaved thinking header",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{
Thinking: &AnthropicThinking{Type: "disabled"},
},
unexpectHeaders: []string{AnthropicInterleavedThinkingBetaHeader},
},
// Fast mode tests — fast mode is Opus 4.6 only (research preview),
// so tests must set Model to exercise the path. Non-Opus-4.6 models
// are model-gated out regardless of provider flag.
{
name: "Anthropic gets fast mode header",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{
Model: "claude-opus-4-6",
Speed: schemas.Ptr("fast"),
},
expectHeaders: []string{AnthropicFastModeBetaHeader},
},
{
name: "Anthropic skips fast mode header on non-Opus-4.6 model",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{
Model: "claude-sonnet-4-6",
Speed: schemas.Ptr("fast"),
},
unexpectHeaders: []string{AnthropicFastModeBetaHeader},
},
{
name: "Bedrock skips fast mode header",
provider: schemas.Bedrock,
req: &AnthropicMessageRequest{
Model: "claude-opus-4-6", // fast mode is model-gated; set a supporting model so the test actually exercises provider suppression
Speed: schemas.Ptr("fast"),
},
unexpectHeaders: []string{AnthropicFastModeBetaHeader},
},
{
name: "Azure skips fast mode header",
provider: schemas.Azure,
req: &AnthropicMessageRequest{
Model: "claude-opus-4-6", // fast mode is model-gated; set a supporting model so the test actually exercises provider suppression
Speed: schemas.Ptr("fast"),
},
unexpectHeaders: []string{AnthropicFastModeBetaHeader},
},
// Fine-grained tool streaming (eager_input_streaming) — per Table 20:
// GA on Anthropic / Bedrock / Vertex, Beta on Azure. All four should
// auto-inject fine-grained-tool-streaming-2025-05-14 when a tool has
// eager_input_streaming: true.
{
name: "Anthropic gets eager_input_streaming header",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{
Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(true)}},
},
expectHeaders: []string{AnthropicEagerInputStreamingBetaHeader},
},
{
name: "Bedrock gets eager_input_streaming header",
provider: schemas.Bedrock,
req: &AnthropicMessageRequest{
Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(true)}},
},
expectHeaders: []string{AnthropicEagerInputStreamingBetaHeader},
},
{
name: "Vertex gets eager_input_streaming header",
provider: schemas.Vertex,
req: &AnthropicMessageRequest{
Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(true)}},
},
expectHeaders: []string{AnthropicEagerInputStreamingBetaHeader},
},
{
name: "Azure gets eager_input_streaming header",
provider: schemas.Azure,
req: &AnthropicMessageRequest{
Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(true)}},
},
expectHeaders: []string{AnthropicEagerInputStreamingBetaHeader},
},
{
name: "eager_input_streaming header absent when flag is false",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{
Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(false)}},
},
unexpectHeaders: []string{AnthropicEagerInputStreamingBetaHeader},
},
{
name: "eager_input_streaming header absent when unset",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{
Tools: []AnthropicTool{{Name: "t1"}},
},
unexpectHeaders: []string{AnthropicEagerInputStreamingBetaHeader},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := schemas.NewBifrostContext(nil, time.Time{})
AddMissingBetaHeadersToContext(ctx, tt.req, tt.provider)
var headers []string
if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok {
headers = extraHeaders[AnthropicBetaHeader]
}
for _, expected := range tt.expectHeaders {
found := false
for _, h := range headers {
if h == expected {
found = true
break
}
}
if !found {
t.Errorf("expected header %q not found in %v", expected, headers)
}
}
for _, unexpected := range tt.unexpectHeaders {
for _, h := range headers {
if h == unexpected {
t.Errorf("unexpected header %q found in %v", unexpected, headers)
}
}
}
})
}
}
func TestAddMissingBetaHeadersToContext_PassthroughWins(t *testing.T) {
// When a same-prefix header is already set from passthrough, auto-injection should NOT add a second version.
t.Run("passthrough_mcp_header_prevents_auto_inject", func(t *testing.T) {
ctx := schemas.NewBifrostContext(nil, time.Time{})
// Simulate passthrough setting an old MCP header
ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, map[string][]string{
"anthropic-beta": {AnthropicMCPClientBetaHeaderDeprecated},
})
// Request has MCP servers, which would normally auto-inject the new header
req := &AnthropicMessageRequest{
MCPServers: []AnthropicMCPServerV2{{URL: "http://example.com"}},
}
AddMissingBetaHeadersToContext(ctx, req, schemas.Anthropic)
extraHeaders := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
betaHeaders := extraHeaders[AnthropicBetaHeader]
// Should only have the old header, not both
if len(betaHeaders) != 1 {
t.Errorf("expected 1 header, got %d: %v", len(betaHeaders), betaHeaders)
}
if betaHeaders[0] != AnthropicMCPClientBetaHeaderDeprecated {
t.Errorf("expected passthrough header %q, got %q", AnthropicMCPClientBetaHeaderDeprecated, betaHeaders[0])
}
})
t.Run("passthrough_computer_use_header_prevents_auto_inject", func(t *testing.T) {
ctx := schemas.NewBifrostContext(nil, time.Time{})
// Simulate passthrough setting an older computer-use header
ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, map[string][]string{
"anthropic-beta": {AnthropicComputerUseBetaHeader20250124},
})
req := &AnthropicMessageRequest{
Tools: []AnthropicTool{{
Type: schemas.Ptr(AnthropicToolTypeComputer20251124),
Name: string(AnthropicToolNameComputer),
}},
}
AddMissingBetaHeadersToContext(ctx, req, schemas.Anthropic)
extraHeaders := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
betaHeaders := extraHeaders[AnthropicBetaHeader]
if len(betaHeaders) != 1 {
t.Errorf("expected 1 header, got %d: %v", len(betaHeaders), betaHeaders)
}
if betaHeaders[0] != AnthropicComputerUseBetaHeader20250124 {
t.Errorf("expected passthrough header %q, got %q", AnthropicComputerUseBetaHeader20250124, betaHeaders[0])
}
})
t.Run("no_passthrough_allows_auto_inject", func(t *testing.T) {
ctx := schemas.NewBifrostContext(nil, time.Time{})
req := &AnthropicMessageRequest{
MCPServers: []AnthropicMCPServerV2{{URL: "http://example.com"}},
}
AddMissingBetaHeadersToContext(ctx, req, schemas.Anthropic)
extraHeaders := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string)
betaHeaders := extraHeaders[AnthropicBetaHeader]
if len(betaHeaders) != 1 || betaHeaders[0] != AnthropicMCPClientBetaHeader {
t.Errorf("expected [%q], got %v", AnthropicMCPClientBetaHeader, betaHeaders)
}
})
}
func TestMergeBetaHeaders(t *testing.T) {
t.Run("context_extra_headers_case_insensitive_key", func(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, map[string][]string{
"Anthropic-Beta": {"structured-outputs-2025-11-13"},
})
got := MergeBetaHeaders(nil, ctx)
want := []string{"structured-outputs-2025-11-13"}
if !slices.Equal(got, want) {
t.Fatalf("got %v, want %v", got, want)
}
})
t.Run("provider_extra_headers_case_insensitive_key", func(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
got := MergeBetaHeaders(map[string]string{
"Anthropic-Beta": "mcp-client-2025-04-04",
}, ctx)
want := []string{"mcp-client-2025-04-04"}
if !slices.Equal(got, want) {
t.Fatalf("got %v, want %v", got, want)
}
})
t.Run("merges_provider_then_context_deduping_tokens", func(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), time.Time{})
ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, map[string][]string{
"ANTHROPIC-BETA": {"foo,bar", "bar,baz"},
})
got := MergeBetaHeaders(map[string]string{
"anthropic-beta": "foo",
}, ctx)
sort.Strings(got)
wantSorted := []string{"bar", "baz", "foo"}
if !slices.Equal(got, wantSorted) {
t.Fatalf("got %v, want %v", got, wantSorted)
}
})
}
func TestFilterBetaHeadersForProvider(t *testing.T) {
allHeaders := []string{
AnthropicComputerUseBetaHeader20251124,
AnthropicStructuredOutputsBetaHeader,
AnthropicMCPClientBetaHeader,
AnthropicPromptCachingScopeBetaHeader,
AnthropicCompactionBetaHeader,
AnthropicContextManagementBetaHeader,
AnthropicAdvancedToolUseBetaHeader,
AnthropicFilesAPIBetaHeader,
AnthropicInterleavedThinkingBetaHeader,
AnthropicSkillsBetaHeader,
AnthropicContext1MBetaHeader,
AnthropicFastModeBetaHeader,
AnthropicRedactThinkingBetaHeader,
}
containsHeader := func(result []string, h string) bool {
for _, r := range result {
if r == h {
return true
}
}
return false
}
t.Run("Anthropic/keeps_all_headers", func(t *testing.T) {
result := FilterBetaHeadersForProvider(allHeaders, schemas.Anthropic)
for _, h := range allHeaders {
if !containsHeader(result, h) {
t.Errorf("expected header %q to be kept for Anthropic, got %v", h, result)
}
}
})
t.Run("Vertex/drops_unsupported_headers", func(t *testing.T) {
unsupported := []string{
AnthropicStructuredOutputsBetaHeader,
AnthropicMCPClientBetaHeader,
AnthropicPromptCachingScopeBetaHeader,
AnthropicAdvancedToolUseBetaHeader,
AnthropicFilesAPIBetaHeader,
AnthropicSkillsBetaHeader,
AnthropicFastModeBetaHeader,
AnthropicRedactThinkingBetaHeader,
}
for _, h := range unsupported {
result := FilterBetaHeadersForProvider([]string{h}, schemas.Vertex)
if len(result) != 0 {
t.Errorf("expected header %q to be dropped for Vertex, got %v", h, result)
}
}
})
t.Run("Vertex/keeps_supported_headers", func(t *testing.T) {
supported := []string{
AnthropicComputerUseBetaHeader20251124,
AnthropicCompactionBetaHeader,
AnthropicContextManagementBetaHeader,
AnthropicInterleavedThinkingBetaHeader,
AnthropicContext1MBetaHeader,
AnthropicEagerInputStreamingBetaHeader,
}
result := FilterBetaHeadersForProvider(supported, schemas.Vertex)
if len(result) != len(supported) {
t.Errorf("expected %d headers, got %d: %v", len(supported), len(result), result)
}
})
t.Run("Bedrock/drops_unsupported_headers", func(t *testing.T) {
unsupported := []string{
AnthropicMCPClientBetaHeader,
AnthropicPromptCachingScopeBetaHeader,
AnthropicAdvancedToolUseBetaHeader,
AnthropicFilesAPIBetaHeader,
AnthropicSkillsBetaHeader,
AnthropicFastModeBetaHeader,
AnthropicRedactThinkingBetaHeader,
}
for _, h := range unsupported {
result := FilterBetaHeadersForProvider([]string{h}, schemas.Bedrock)
if len(result) != 0 {
t.Errorf("expected header %q to be dropped for Bedrock, got %v", h, result)
}
}
})
t.Run("Azure/drops_unsupported_headers", func(t *testing.T) {
unsupported := []string{
AnthropicFastModeBetaHeader,
}
for _, h := range unsupported {
result := FilterBetaHeadersForProvider([]string{h}, schemas.Azure)
if len(result) != 0 {
t.Errorf("expected header %q to be dropped for Azure, got %v", h, result)
}
}
})
t.Run("Azure/keeps_supported_headers", func(t *testing.T) {
supported := []string{
AnthropicComputerUseBetaHeader20251124,
AnthropicStructuredOutputsBetaHeader,
AnthropicMCPClientBetaHeader,
AnthropicPromptCachingScopeBetaHeader,
AnthropicCompactionBetaHeader,
AnthropicContextManagementBetaHeader,
AnthropicAdvancedToolUseBetaHeader,
AnthropicFilesAPIBetaHeader,
AnthropicInterleavedThinkingBetaHeader,
AnthropicSkillsBetaHeader,
AnthropicContext1MBetaHeader,
AnthropicRedactThinkingBetaHeader,
AnthropicEagerInputStreamingBetaHeader,
}
result := FilterBetaHeadersForProvider(supported, schemas.Azure)
if len(result) != len(supported) {
t.Errorf("expected %d headers, got %d: %v", len(supported), len(result), result)
}
})
t.Run("Bedrock/keeps_supported_headers", func(t *testing.T) {
supported := []string{
AnthropicComputerUseBetaHeader20251124,
AnthropicStructuredOutputsBetaHeader,
AnthropicCompactionBetaHeader,
AnthropicContextManagementBetaHeader,
AnthropicInterleavedThinkingBetaHeader,
AnthropicContext1MBetaHeader,
AnthropicEagerInputStreamingBetaHeader,
}
result := FilterBetaHeadersForProvider(supported, schemas.Bedrock)
if len(result) != len(supported) {
t.Errorf("expected %d headers, got %d: %v", len(supported), len(result), result)
}
})
t.Run("unknown_headers_dropped_for_non_anthropic", func(t *testing.T) {
result := FilterBetaHeadersForProvider([]string{"some-future-beta-2025"}, schemas.Vertex)
if len(result) != 0 {
t.Errorf("expected unknown header to be dropped for Vertex, got %v", result)
}
})
t.Run("unknown_headers_forwarded_for_anthropic", func(t *testing.T) {
headers := []string{"some-future-beta-2025"}
result := FilterBetaHeadersForProvider(headers, schemas.Anthropic)
if len(result) != len(headers) {
t.Errorf("expected unknown header to be forwarded for Anthropic, got %v", result)
}
})
t.Run("unknown_provider_allows_all", func(t *testing.T) {
result := FilterBetaHeadersForProvider(allHeaders, schemas.ModelProvider("custom-provider"))
if len(result) != len(allHeaders) {
t.Errorf("expected all headers for unknown provider, got %v", result)
}
})
t.Run("override_enables_unsupported_header", func(t *testing.T) {
// redact-thinking is not supported on Vertex by default
overrides := map[string]bool{AnthropicRedactThinkingBetaHeaderPrefix: true}
result := FilterBetaHeadersForProvider([]string{AnthropicRedactThinkingBetaHeader}, schemas.Vertex, overrides)
if len(result) != 1 || result[0] != AnthropicRedactThinkingBetaHeader {
t.Errorf("expected override to allow header, got %v", result)
}
})
t.Run("override_disables_supported_header", func(t *testing.T) {
// compaction is supported on Vertex by default; override to false should drop it silently
overrides := map[string]bool{"compact-": false}
result := FilterBetaHeadersForProvider([]string{AnthropicCompactionBetaHeader}, schemas.Vertex, overrides)
if len(result) != 0 {
t.Errorf("expected override false to drop supported header, got %v", result)
}
})
t.Run("override_nil_uses_defaults", func(t *testing.T) {
// Passing nil overrides should behave identically to no overrides
result := FilterBetaHeadersForProvider([]string{AnthropicCompactionBetaHeader}, schemas.Vertex, nil)
if len(result) != 1 {
t.Errorf("expected default behavior with nil overrides, got %v", result)
}
})
// Custom override tests for all providers
customOverrideProviders := []struct {
provider schemas.ModelProvider
expectForwardNoOverride bool // unknown headers forwarded without override?
}{
{schemas.Anthropic, true},
{schemas.Vertex, false},
{schemas.Bedrock, false},
{schemas.Azure, false},
}
for _, tc := range customOverrideProviders {
tc := tc
t.Run(fmt.Sprintf("%s/custom_override_enables_unknown_header", tc.provider), func(t *testing.T) {
overrides := map[string]bool{"new-feature-": true}
result := FilterBetaHeadersForProvider([]string{"new-feature-2026-01-01"}, tc.provider, overrides)
if len(result) != 1 || result[0] != "new-feature-2026-01-01" {
t.Errorf("expected custom override to allow header on %s, got %v", tc.provider, result)
}
})
t.Run(fmt.Sprintf("%s/custom_override_disables_unknown_header", tc.provider), func(t *testing.T) {
overrides := map[string]bool{"new-feature-": false}
result := FilterBetaHeadersForProvider([]string{"new-feature-2026-01-01"}, tc.provider, overrides)
if len(result) != 0 {
t.Errorf("expected custom override false to drop header on %s, got %v", tc.provider, result)
}
})
t.Run(fmt.Sprintf("%s/custom_override_no_match_still_handled_correctly", tc.provider), func(t *testing.T) {
overrides := map[string]bool{"new-feature-": true}
result := FilterBetaHeadersForProvider([]string{"other-thing-2026"}, tc.provider, overrides)
if tc.expectForwardNoOverride {
if len(result) != 1 {
t.Errorf("expected unknown header forwarded to %s, got %v", tc.provider, result)
}
} else {
if len(result) != 0 {
t.Errorf("expected unknown header dropped for %s, got %v", tc.provider, result)
}
}
})
t.Run(fmt.Sprintf("%s/custom_override_with_multiple_prefixes", tc.provider), func(t *testing.T) {
overrides := map[string]bool{
"alpha-": true,
"beta-": false,
"gamma-": true,
}
result := FilterBetaHeadersForProvider([]string{"alpha-2026-01"}, tc.provider, overrides)
if len(result) != 1 {
t.Errorf("expected alpha- allowed on %s, got %v", tc.provider, result)
}
result = FilterBetaHeadersForProvider([]string{"beta-2026-01"}, tc.provider, overrides)
if len(result) != 0 {
t.Errorf("expected beta- dropped on %s, got %v", tc.provider, result)
}
result = FilterBetaHeadersForProvider([]string{"gamma-2026-01"}, tc.provider, overrides)
if len(result) != 1 {
t.Errorf("expected gamma- allowed on %s, got %v", tc.provider, result)
}
})
}
}
func TestStripUnsupportedFieldsFromRawBody(t *testing.T) {
t.Run("bedrock_strips_new_request_level_fields", func(t *testing.T) {
// Raw body with every new typed field. Targeting Bedrock: speed (no FastMode),
// inference_geo (no InferenceGeo), mcp_servers (no MCP), container.skills
// (no Skills), top-level cache_control.scope (no PromptCachingScope),
// output_config.task_budget (no TaskBudgets). All should be stripped.
input := []byte(`{
"model":"claude-opus-4-6",
"speed":"fast",
"inference_geo":"us-east-1",
"mcp_servers":[{"type":"url","url":"https://example.com","name":"x"}],
"container":{"id":"c-1","skills":[{"skill_id":"s","type":"anthropic"}]},
"cache_control":{"type":"ephemeral","ttl":"5m","scope":"user"},
"output_config":{"task_budget":{"type":"tokens","total":20000}}
}`)
result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
for _, path := range []string{"speed", "inference_geo", "mcp_servers", "container", "cache_control.scope", "output_config.task_budget"} {
if providerUtils.JSONFieldExists(result, path) {
t.Errorf("expected %q to be stripped for Bedrock, got: %s", path, string(result))
}
}
// Confirm non-scope cache_control fields are retained.
if !providerUtils.JSONFieldExists(result, "cache_control.ttl") {
t.Errorf("expected cache_control.ttl to survive, got: %s", string(result))
}
})
t.Run("vertex_strips_mcp_strict_and_input_examples_via_feature_check", func(t *testing.T) {
// Vertex: no MCP, no InputExamples, no StructuredOutputs.
// tool.strict stripped; tool.input_examples stripped; mcp_servers stripped.
// tool.cache_control.scope stripped (Vertex has no PromptCachingScope).
input := []byte(`{
"model":"claude-sonnet-4-6",
"mcp_servers":[{"type":"url","url":"u","name":"n"}],
"tools":[{"name":"t1","strict":true,"input_examples":[{"input":{"a":1}}],"cache_control":{"type":"ephemeral","scope":"user"}}]
}`)
result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Vertex, "claude-sonnet-4-6")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
for _, path := range []string{"mcp_servers", "tools.0.strict", "tools.0.input_examples", "tools.0.cache_control.scope"} {
if providerUtils.JSONFieldExists(result, path) {
t.Errorf("expected %q to be stripped for Vertex, got: %s", path, string(result))
}
}
if !providerUtils.JSONFieldExists(result, "tools.0.name") {
t.Errorf("expected tool name to survive")
}
})
t.Run("bedrock_keeps_input_examples_via_standalone_flag", func(t *testing.T) {
// Bedrock has InputExamples=true via tool-examples-2025-10-29 but
// AdvancedToolUse=false. input_examples should be KEPT; defer_loading
// and allowed_callers (bundle-only) should be STRIPPED.
input := []byte(`{
"model":"claude-opus-4-6",
"tools":[{"name":"t1","input_examples":[{"input":{"a":1}}],"defer_loading":true,"allowed_callers":["direct"]}]
}`)
result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !providerUtils.JSONFieldExists(result, "tools.0.input_examples") {
t.Errorf("expected tools[0].input_examples to survive on Bedrock, got: %s", string(result))
}
for _, path := range []string{"tools.0.defer_loading", "tools.0.allowed_callers"} {
if providerUtils.JSONFieldExists(result, path) {
t.Errorf("expected %q to be stripped for Bedrock (AdvancedToolUse bundle unsupported), got: %s", path, string(result))
}
}
})
t.Run("speed_stripped_on_non_opus_46_even_on_anthropic", func(t *testing.T) {
// Model gate: fast-mode is Opus 4.6 only per docs. Even on Anthropic
// direct where FastMode=true, targeting a different model must strip.
input := []byte(`{"model":"claude-sonnet-4-6","speed":"fast"}`)
result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Anthropic, "claude-sonnet-4-6")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if providerUtils.JSONFieldExists(result, "speed") {
t.Errorf("expected speed stripped for non-Opus-4.6 model on Anthropic, got: %s", string(result))
}
})
t.Run("anthropic_direct_is_noop", func(t *testing.T) {
// Anthropic supports everything — body should survive untouched.
input := []byte(`{"model":"claude-opus-4-6","speed":"fast","mcp_servers":[{"type":"url","url":"u","name":"n"}],"container":{"id":"c"},"tools":[{"name":"t","defer_loading":true,"input_examples":[{"input":{"a":1}}]}]}`)
result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Anthropic, "claude-opus-4-6")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
for _, path := range []string{"speed", "mcp_servers", "container", "tools.0.defer_loading", "tools.0.input_examples"} {
if !providerUtils.JSONFieldExists(result, path) {
t.Errorf("expected %q preserved on Anthropic direct, got: %s", path, string(result))
}
}
})
t.Run("nested_scope_stripped_on_messages_and_system", func(t *testing.T) {
// Nested scope on system blocks and message blocks must also be stripped
// when the provider lacks PromptCachingScope.
input := []byte(`{
"model":"claude-opus-4-6",
"system":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"user"}}],
"messages":[{"role":"user","content":[{"type":"text","text":"q","cache_control":{"type":"ephemeral","scope":"global"}}]}]
}`)
result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
for _, path := range []string{"system.0.cache_control.scope", "messages.0.content.0.cache_control.scope"} {
if providerUtils.JSONFieldExists(result, path) {
t.Errorf("expected nested %q stripped, got: %s", path, string(result))
}
}
})
t.Run("unknown_provider_is_safe_noop", func(t *testing.T) {
input := []byte(`{"model":"claude-opus-4-6","speed":"fast"}`)
result, err := stripUnsupportedFieldsFromRawBody(input, schemas.ModelProvider("custom"), "claude-opus-4-6")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !providerUtils.JSONFieldExists(result, "speed") {
t.Errorf("expected speed preserved for unknown provider (safe default), got: %s", string(result))
}
})
t.Run("container_empty_skills_stripped_but_container_preserved", func(t *testing.T) {
// Skills=false provider (Bedrock), ContainerBasic=true.
// skills:[] is a caller oversight — strip the empty key, preserve container.
input := []byte(`{"model":"claude-opus-4-6","container":{"id":"c-1","skills":[]}}`)
result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if providerUtils.JSONFieldExists(result, "container.skills") {
t.Errorf("expected empty container.skills stripped on Skills=false provider, got: %s", string(result))
}
if !providerUtils.JSONFieldExists(result, "container.id") {
t.Errorf("expected container.id preserved (bare form still valid), got: %s", string(result))
}
})
t.Run("container_nonempty_skills_drops_whole_container", func(t *testing.T) {
// Non-empty skills signals caller intent; provider doesn't support — drop container.
input := []byte(`{"model":"claude-opus-4-6","container":{"id":"c-1","skills":[{"skill_id":"s","type":"anthropic"}]}}`)
result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if providerUtils.JSONFieldExists(result, "container") {
t.Errorf("expected whole container dropped for non-empty skills on Skills=false, got: %s", string(result))
}
})
t.Run("container_empty_skills_on_skills_capable_provider_preserved", func(t *testing.T) {
// On Anthropic direct (Skills=true), the empty skills array must be preserved
// as-is — our strip logic only fires when !features.Skills.
input := []byte(`{"model":"claude-opus-4-6","container":{"id":"c-1","skills":[]}}`)
result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Anthropic, "claude-opus-4-6")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !providerUtils.JSONFieldExists(result, "container.skills") {
t.Errorf("expected container.skills preserved on Skills=true provider, got: %s", string(result))
}
})
}
// TestStripUnsupportedAnthropicFields_ContainerSkillsGating mirrors the raw-path
// tests above on the typed path — ensures the typed sanitizer treats explicit
// empty skills arrays as a stripable (not drop-triggering) signal.
func TestStripUnsupportedAnthropicFields_ContainerSkillsGating(t *testing.T) {
t.Run("empty_skills_on_skills_false_provider_strips_skills_keeps_container", func(t *testing.T) {
req := &AnthropicMessageRequest{
Model: "claude-opus-4-6",
Container: &AnthropicContainer{
ContainerObject: &AnthropicContainerObject{
ID: schemas.Ptr("c-1"),
Skills: []AnthropicContainerSkill{}, // explicit empty
},
},
}
stripUnsupportedAnthropicFields(req, schemas.Bedrock, "claude-opus-4-6")
if req.Container == nil {
t.Fatalf("expected container preserved (bare form valid with empty skills), got nil")
}
if req.Container.ContainerObject == nil || req.Container.ContainerObject.Skills != nil {
t.Errorf("expected skills cleared on Skills=false, got %v", req.Container.ContainerObject)
}
})
t.Run("nonempty_skills_on_skills_false_provider_drops_container", func(t *testing.T) {
req := &AnthropicMessageRequest{
Model: "claude-opus-4-6",
Container: &AnthropicContainer{
ContainerObject: &AnthropicContainerObject{
ID: schemas.Ptr("c-1"),
Skills: []AnthropicContainerSkill{{SkillID: "s", Type: "anthropic"}},
},
},
}
stripUnsupportedAnthropicFields(req, schemas.Bedrock, "claude-opus-4-6")
if req.Container != nil {
t.Errorf("expected whole container dropped for non-empty skills on Skills=false, got %v", req.Container)
}
})
t.Run("empty_skills_on_skills_true_provider_preserved", func(t *testing.T) {
req := &AnthropicMessageRequest{
Model: "claude-opus-4-6",
Container: &AnthropicContainer{
ContainerObject: &AnthropicContainerObject{
ID: schemas.Ptr("c-1"),
Skills: []AnthropicContainerSkill{},
},
},
}
stripUnsupportedAnthropicFields(req, schemas.Anthropic, "claude-opus-4-6")
if req.Container == nil || req.Container.ContainerObject == nil {
t.Fatalf("expected container preserved on Skills=true provider, got %v", req.Container)
}
if req.Container.ContainerObject.Skills == nil {
t.Errorf("expected empty skills preserved on Skills=true provider (not nilled)")
}
})
}
func TestStripAutoInjectableTools(t *testing.T) {
t.Run("code_execution_without_web_search_preserved", func(t *testing.T) {
// code_execution alone should NOT be stripped (no web_search/web_fetch to trigger auto-injection)
input := []byte(`{"model":"claude-opus-4-6","tools":[{"type":"custom","name":"my_tool"},{"type":"code_execution_20250825","name":"code_execution"}]}`)
result, err := StripAutoInjectableTools(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tools := providerUtils.GetJSONField(result, "tools")
arr := tools.Array()
if len(arr) != 2 {
t.Fatalf("expected 2 tools (preserved), got %d", len(arr))
}
})
t.Run("code_execution_with_web_search_stripped", func(t *testing.T) {
// code_execution should be stripped when web_search is present (auto-injection conflict)
input := []byte(`{"tools":[{"type":"code_execution_20250825","name":"code_execution"},{"type":"web_search_20260209","name":"web_search"},{"type":"custom","name":"my_tool"}]}`)
result, err := StripAutoInjectableTools(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tools := providerUtils.GetJSONField(result, "tools")
arr := tools.Array()
if len(arr) != 2 {
t.Fatalf("expected 2 tools, got %d", len(arr))
}
if arr[0].Get("name").String() != "web_search" {
t.Errorf("expected first tool to be 'web_search', got '%s'", arr[0].Get("name").String())
}
if arr[1].Get("name").String() != "my_tool" {
t.Errorf("expected second tool to be 'my_tool', got '%s'", arr[1].Get("name").String())
}
})
t.Run("code_execution_with_web_fetch_stripped", func(t *testing.T) {
// code_execution should be stripped when web_fetch is present
input := []byte(`{"tools":[{"type":"code_execution_20250825","name":"code_execution"},{"type":"web_fetch_20250305","name":"web_fetch"},{"type":"custom","name":"my_tool"}]}`)
result, err := StripAutoInjectableTools(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tools := providerUtils.GetJSONField(result, "tools")
arr := tools.Array()
if len(arr) != 2 {
t.Fatalf("expected 2 tools, got %d", len(arr))
}
if arr[0].Get("name").String() != "web_fetch" {
t.Errorf("expected first tool to be 'web_fetch', got '%s'", arr[0].Get("name").String())
}
if arr[1].Get("name").String() != "my_tool" {
t.Errorf("expected second tool to be 'my_tool', got '%s'", arr[1].Get("name").String())
}
})
t.Run("web_search_alone_preserved", func(t *testing.T) {
// web_search without code_execution should be preserved entirely
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"type":"custom","name":"search"}]}`)
result, err := StripAutoInjectableTools(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tools := providerUtils.GetJSONField(result, "tools")
arr := tools.Array()
if len(arr) != 2 {
t.Fatalf("expected 2 tools (preserved), got %d", len(arr))
}
})
t.Run("web_fetch_alone_preserved", func(t *testing.T) {
// web_fetch without code_execution should be preserved entirely
input := []byte(`{"tools":[{"type":"web_fetch_20250305","name":"web_fetch"},{"type":"custom","name":"fetch"}]}`)
result, err := StripAutoInjectableTools(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tools := providerUtils.GetJSONField(result, "tools")
arr := tools.Array()
if len(arr) != 2 {
t.Fatalf("expected 2 tools (preserved), got %d", len(arr))
}
})
t.Run("preserves_custom_tools_only", func(t *testing.T) {
input := []byte(`{"tools":[{"type":"custom","name":"tool_a"},{"type":"custom","name":"tool_b"}]}`)
result, err := StripAutoInjectableTools(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tools := providerUtils.GetJSONField(result, "tools")
arr := tools.Array()
if len(arr) != 2 {
t.Fatalf("expected 2 tools, got %d", len(arr))
}
})
t.Run("no_tools_key", func(t *testing.T) {
input := []byte(`{"model":"claude-opus-4-6","messages":[]}`)
result, err := StripAutoInjectableTools(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(result) != string(input) {
t.Errorf("expected body unchanged, got %s", string(result))
}
})
t.Run("empty_tools_array", func(t *testing.T) {
input := []byte(`{"tools":[]}`)
result, err := StripAutoInjectableTools(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(result) != string(input) {
t.Errorf("expected body unchanged, got %s", string(result))
}
})
t.Run("code_execution_and_web_search_only_strips_code_execution", func(t *testing.T) {
// When only code_execution + web_search, strip code_execution, keep web_search
input := []byte(`{"model":"test","tools":[{"type":"code_execution_20250825","name":"code_execution"},{"type":"web_search_20250305","name":"web_search"}]}`)
result, err := StripAutoInjectableTools(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tools := providerUtils.GetJSONField(result, "tools")
arr := tools.Array()
if len(arr) != 1 {
t.Fatalf("expected 1 tool, got %d", len(arr))
}
if arr[0].Get("name").String() != "web_search" {
t.Errorf("expected remaining tool to be 'web_search', got '%s'", arr[0].Get("name").String())
}
})
t.Run("strips_code_execution_keeps_web_search_and_custom", func(t *testing.T) {
input := []byte(`{"tools":[{"type":"code_execution_20250825","name":"code_execution"},{"type":"custom","name":"my_tool"},{"type":"web_search_20260209","name":"web_search"},{"type":"custom","name":"other_tool"}]}`)
result, err := StripAutoInjectableTools(input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tools := providerUtils.GetJSONField(result, "tools")
arr := tools.Array()
if len(arr) != 3 {
t.Fatalf("expected 3 tools, got %d", len(arr))
}
if arr[0].Get("name").String() != "my_tool" {
t.Errorf("expected first tool to be 'my_tool', got '%s'", arr[0].Get("name").String())
}
if arr[1].Get("name").String() != "web_search" {
t.Errorf("expected second tool to be 'web_search', got '%s'", arr[1].Get("name").String())
}
if arr[2].Get("name").String() != "other_tool" {
t.Errorf("expected third tool to be 'other_tool', got '%s'", arr[2].Get("name").String())
}
})
}
func TestAnthropicToolUnmarshalJSON_MCPToolset(t *testing.T) {
t.Run("mcp_toolset is properly unmarshaled", func(t *testing.T) {
data := []byte(`{
"type": "mcp_toolset",
"mcp_server_name": "example-mcp",
"default_config": {"enabled": false},
"configs": {
"search_events": {"enabled": true},
"create_event": {"enabled": true, "defer_loading": true}
}
}`)
var tool AnthropicTool
if err := sonic.Unmarshal(data, &tool); err != nil {
t.Fatalf("unexpected unmarshal error: %v", err)
}
if tool.MCPToolset == nil {
t.Fatal("expected MCPToolset to be populated, got nil")
}
if tool.MCPToolset.Type != "mcp_toolset" {
t.Errorf("expected type 'mcp_toolset', got %q", tool.MCPToolset.Type)
}
if tool.MCPToolset.MCPServerName != "example-mcp" {
t.Errorf("expected mcp_server_name 'example-mcp', got %q", tool.MCPToolset.MCPServerName)
}
if tool.MCPToolset.DefaultConfig == nil || tool.MCPToolset.DefaultConfig.Enabled == nil || *tool.MCPToolset.DefaultConfig.Enabled != false {
t.Error("expected default_config.enabled to be false")
}
if len(tool.MCPToolset.Configs) != 2 {
t.Fatalf("expected 2 configs, got %d", len(tool.MCPToolset.Configs))
}
if tool.MCPToolset.Configs["search_events"] == nil || *tool.MCPToolset.Configs["search_events"].Enabled != true {
t.Error("expected search_events to be enabled")
}
if tool.MCPToolset.Configs["create_event"] == nil || tool.MCPToolset.Configs["create_event"].DeferLoading == nil || *tool.MCPToolset.Configs["create_event"].DeferLoading != true {
t.Error("expected create_event defer_loading to be true")
}
})
t.Run("regular tool is not affected by mcp_toolset unmarshal", func(t *testing.T) {
data := []byte(`{
"name": "get_weather",
"description": "Get weather info",
"input_schema": {"type": "object", "properties": {}}
}`)
var tool AnthropicTool
if err := sonic.Unmarshal(data, &tool); err != nil {
t.Fatalf("unexpected unmarshal error: %v", err)
}
if tool.MCPToolset != nil {
t.Error("expected MCPToolset to be nil for regular tool")
}
if tool.Name != "get_weather" {
t.Errorf("expected name 'get_weather', got %q", tool.Name)
}
})
t.Run("mcp_toolset round-trips through marshal/unmarshal", func(t *testing.T) {
original := AnthropicTool{
MCPToolset: &AnthropicMCPToolsetTool{
Type: "mcp_toolset",
MCPServerName: "test-server",
DefaultConfig: &AnthropicMCPToolsetConfig{Enabled: new(false)},
Configs: map[string]*AnthropicMCPToolsetConfig{
"tool_a": {Enabled: new(true)},
},
},
}
marshaled, err := json.Marshal(original)
if err != nil {
t.Fatalf("unexpected marshal error: %v", err)
}
var restored AnthropicTool
if err := sonic.Unmarshal(marshaled, &restored); err != nil {
t.Fatalf("unexpected unmarshal error: %v", err)
}
if restored.MCPToolset == nil {
t.Fatal("expected MCPToolset to be populated after round-trip")
}
if restored.MCPToolset.MCPServerName != "test-server" {
t.Errorf("expected mcp_server_name 'test-server', got %q", restored.MCPToolset.MCPServerName)
}
if len(restored.MCPToolset.Configs) != 1 {
t.Fatalf("expected 1 config, got %d", len(restored.MCPToolset.Configs))
}
})
t.Run("tools array with mixed regular and mcp_toolset tools", func(t *testing.T) {
data := []byte(`[
{"name": "get_weather", "description": "Get weather"},
{"type": "mcp_toolset", "mcp_server_name": "my-mcp"},
{"type": "computer_20251124", "name": "computer"}
]`)
var tools []AnthropicTool
if err := sonic.Unmarshal(data, &tools); err != nil {
t.Fatalf("unexpected unmarshal error: %v", err)
}
if len(tools) != 3 {
t.Fatalf("expected 3 tools, got %d", len(tools))
}
// First: regular tool
if tools[0].Name != "get_weather" {
t.Errorf("expected first tool name 'get_weather', got %q", tools[0].Name)
}
if tools[0].MCPToolset != nil {
t.Error("expected first tool MCPToolset to be nil")
}
// Second: mcp_toolset
if tools[1].MCPToolset == nil {
t.Fatal("expected second tool MCPToolset to be populated")
}
if tools[1].MCPToolset.MCPServerName != "my-mcp" {
t.Errorf("expected mcp_server_name 'my-mcp', got %q", tools[1].MCPToolset.MCPServerName)
}
// Third: typed tool (computer)
if tools[2].MCPToolset != nil {
t.Error("expected third tool MCPToolset to be nil")
}
})
}
func TestGetRequestBodyForResponses_RawBodyStripsFallbacks(t *testing.T) {
rawBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"fallbacks":["claude-haiku-4-5"],"temperature":0.7}`)
ctx := schemas.NewBifrostContext(nil, time.Time{})
ctx.SetValue(schemas.BifrostContextKeyUseRawRequestBody, true)
request := &schemas.BifrostResponsesRequest{
Provider: schemas.Anthropic,
Model: "claude-sonnet-4-5",
RawRequestBody: rawBody,
}
result, bifrostErr := getRequestBodyForResponses(ctx, request, false, nil)
if bifrostErr != nil {
t.Fatalf("unexpected error: %v", bifrostErr)
}
if providerUtils.GetJSONField(result, "fallbacks").Exists() {
t.Error("expected 'fallbacks' to be absent from raw-body output")
}
// Other fields must survive the round-trip
if !providerUtils.GetJSONField(result, "model").Exists() {
t.Error("expected 'model' to be present")
}
if !providerUtils.GetJSONField(result, "max_tokens").Exists() {
t.Error("expected 'max_tokens' to be present")
}
if !providerUtils.GetJSONField(result, "temperature").Exists() {
t.Error("expected 'temperature' to be present")
}
}
func TestApplyMCPToolsetConfigToBifrostTool(t *testing.T) {
t.Run("allowlist pattern merges correctly", func(t *testing.T) {
bifrostTool := &schemas.ResponsesTool{
Type: schemas.ResponsesToolTypeMCP,
ResponsesToolMCP: &schemas.ResponsesToolMCP{
ServerLabel: "test-server",
ServerURL: schemas.Ptr("https://example.com/mcp"),
},
}
toolset := &AnthropicMCPToolsetTool{
Type: "mcp_toolset",
MCPServerName: "test-server",
DefaultConfig: &AnthropicMCPToolsetConfig{Enabled: schemas.Ptr(false)},
Configs: map[string]*AnthropicMCPToolsetConfig{
"search": {Enabled: new(true)},
"create": {Enabled: schemas.Ptr(true)},
"delete": {Enabled: schemas.Ptr(false)},
},
}
applyMCPToolsetConfigToBifrostTool(bifrostTool, toolset)
if bifrostTool.ResponsesToolMCP.AllowedTools == nil {
t.Fatal("expected AllowedTools to be set")
}
allowedNames := bifrostTool.ResponsesToolMCP.AllowedTools.ToolNames
if len(allowedNames) != 2 {
t.Fatalf("expected 2 allowed tools, got %d: %v", len(allowedNames), allowedNames)
}
// Check that both "search" and "create" are present (order may vary due to map iteration)
found := map[string]bool{}
for _, name := range allowedNames {
found[name] = true
}
if !found["search"] || !found["create"] {
t.Errorf("expected allowed tools to contain 'search' and 'create', got %v", allowedNames)
}
})
t.Run("all enabled by default does not set allowlist", func(t *testing.T) {
bifrostTool := &schemas.ResponsesTool{
Type: schemas.ResponsesToolTypeMCP,
ResponsesToolMCP: &schemas.ResponsesToolMCP{
ServerLabel: "test-server",
},
}
toolset := &AnthropicMCPToolsetTool{
Type: "mcp_toolset",
MCPServerName: "test-server",
// No default_config (defaults to enabled=true)
}
applyMCPToolsetConfigToBifrostTool(bifrostTool, toolset)
if bifrostTool.ResponsesToolMCP.AllowedTools != nil {
t.Error("expected AllowedTools to be nil when all tools are enabled by default")
}
})
t.Run("nil inputs are handled safely", func(t *testing.T) {
// Should not panic
applyMCPToolsetConfigToBifrostTool(nil, nil)
applyMCPToolsetConfigToBifrostTool(&schemas.ResponsesTool{}, nil)
})
}
func TestSupportsAdaptiveThinking(t *testing.T) {
tests := []struct {
model string
expected bool
}{
{"claude-opus-4-7-20260401", true},
{"claude-opus-4.7-20260401", true},
{"claude-opus-4-6-20250514", true},
{"claude-opus-4.6-20250514", true},
{"claude-sonnet-4-6-20250514", true},
{"claude-sonnet-4.6-20250514", true},
{"claude-opus-4-5-20241022", false},
{"claude-sonnet-4-5-20241022", false},
{"claude-haiku-4-6-20250514", false}, // haiku does not support adaptive
{"claude-haiku-4-7-20260401", false}, // haiku, not opus
{"", false},
}
for _, tt := range tests {
t.Run(tt.model, func(t *testing.T) {
got := SupportsAdaptiveThinking(tt.model)
if got != tt.expected {
t.Errorf("SupportsAdaptiveThinking(%q) = %v, want %v", tt.model, got, tt.expected)
}
})
}
}
func TestAddMissingBetaHeadersToContext_TaskBudgets(t *testing.T) {
tests := []struct {
name string
provider schemas.ModelProvider
req *AnthropicMessageRequest
expectHeaders []string
unexpectHeaders []string
}{
{
name: "Anthropic gets task-budgets header when task_budget set",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{
OutputConfig: &AnthropicOutputConfig{
TaskBudget: &AnthropicTaskBudget{Type: "tokens", Total: 50000},
},
},
expectHeaders: []string{AnthropicTaskBudgetsBetaHeader},
},
{
name: "Vertex does not get task-budgets header when task_budget set",
provider: schemas.Vertex,
req: &AnthropicMessageRequest{
OutputConfig: &AnthropicOutputConfig{
TaskBudget: &AnthropicTaskBudget{Type: "tokens", Total: 50000},
},
},
unexpectHeaders: []string{AnthropicTaskBudgetsBetaHeader},
},
{
name: "no task-budgets header when task_budget is nil",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{
OutputConfig: &AnthropicOutputConfig{},
},
unexpectHeaders: []string{AnthropicTaskBudgetsBetaHeader},
},
{
name: "no task-budgets header when output_config is nil",
provider: schemas.Anthropic,
req: &AnthropicMessageRequest{},
unexpectHeaders: []string{AnthropicTaskBudgetsBetaHeader},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := schemas.NewBifrostContext(nil, time.Time{})
AddMissingBetaHeadersToContext(ctx, tt.req, tt.provider)
var headers []string
if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok {
headers = extraHeaders[AnthropicBetaHeader]
}
for _, expected := range tt.expectHeaders {
found := false
for _, h := range headers {
if h == expected {
found = true
break
}
}
if !found {
t.Errorf("expected header %q not found in %v", expected, headers)
}
}
for _, unexpected := range tt.unexpectHeaders {
for _, h := range headers {
if h == unexpected {
t.Errorf("unexpected header %q found in %v", unexpected, headers)
}
}
}
})
}
}