package anthropic import ( "context" "encoding/json" "fmt" "reflect" "slices" "sort" "testing" "time" "github.com/bytedance/sonic" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" ) func TestExtractTypesFromValue(t *testing.T) { tests := []struct { name string input interface{} expected []string }{ { name: "string type", input: "string", expected: []string{"string"}, }, { name: "[]string array", input: []string{"string", "null"}, expected: []string{"string", "null"}, }, { name: "[]interface{} array", input: []interface{}{"string", "integer", "null"}, expected: []string{"string", "integer", "null"}, }, { name: "[]interface{} with non-string items (filtered out)", input: []interface{}{"string", 123, "null"}, expected: []string{"string", "null"}, }, { name: "unsupported type returns nil", input: 123, expected: nil, }, { name: "nil returns nil", input: nil, expected: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := extractTypesFromValue(tt.input) if !reflect.DeepEqual(result, tt.expected) { t.Errorf("extractTypesFromValue() mismatch:\ngot: %+v\nwant: %+v", result, tt.expected) } }) } } func TestNormalizeSchemaForAnthropic(t *testing.T) { tests := []struct { name string input map[string]interface{} expected map[string]interface{} }{ { name: "type array with string and null - converts to anyOf", input: map[string]interface{}{ "type": []interface{}{"string", "null"}, "description": "A nullable string field", "enum": []string{"value1", "value2", ""}, }, expected: map[string]interface{}{ "description": "A nullable string field", "anyOf": []interface{}{ map[string]interface{}{ "type": "string", "enum": []string{"value1", "value2", ""}, }, map[string]interface{}{"type": "null"}, }, }, }, { name: "type array with null and string - converts to anyOf", input: map[string]interface{}{ "type": []interface{}{"null", "string"}, "description": "A nullable string field", "enum": []string{"NODE-0", "NODE-1", ""}, }, expected: map[string]interface{}{ "description": "A nullable string field", "anyOf": []interface{}{ map[string]interface{}{ "type": "string", "enum": []string{"NODE-0", "NODE-1", ""}, }, map[string]interface{}{"type": "null"}, }, }, }, { name: "type array as []string format with null - converts to anyOf", input: map[string]interface{}{ "type": []string{"string", "null"}, "enum": []string{"option1", "option2"}, }, expected: map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{ "type": "string", "enum": []string{"option1", "option2"}, }, map[string]interface{}{"type": "null"}, }, }, }, { name: "type array with single type (no null) - keeps as simple type", input: map[string]interface{}{ "type": []string{"string"}, "enum": []string{"option1", "option2"}, }, expected: map[string]interface{}{ "type": "string", "enum": []string{"option1", "option2"}, }, }, { name: "regular string type - no change", input: map[string]interface{}{ "type": "string", "description": "A regular string field", }, expected: map[string]interface{}{ "type": "string", "description": "A regular string field", }, }, { name: "nested properties with nullable type arrays - converts to anyOf", input: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "field1": map[string]interface{}{ "type": []interface{}{"string", "null"}, "enum": []string{"a", "b"}, }, "field2": map[string]interface{}{ "type": "number", }, }, }, expected: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "field1": map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{ "type": "string", "enum": []string{"a", "b"}, }, map[string]interface{}{"type": "null"}, }, }, "field2": map[string]interface{}{ "type": "number", }, }, }, }, { name: "array items with nullable type array - converts to anyOf", input: map[string]interface{}{ "type": "array", "items": map[string]interface{}{ "type": []interface{}{"string", "null"}, "enum": []string{"x", "y", "z"}, }, }, expected: map[string]interface{}{ "type": "array", "items": map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{ "type": "string", "enum": []string{"x", "y", "z"}, }, map[string]interface{}{"type": "null"}, }, }, }, }, { name: "anyOf with type arrays - nested anyOf gets flattened conceptually", input: map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{ "type": []interface{}{"string", "null"}, }, map[string]interface{}{ "type": "number", }, }, }, expected: map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{"type": "string"}, map[string]interface{}{"type": "null"}, }, }, map[string]interface{}{ "type": "number", }, }, }, }, { name: "oneOf with nullable type arrays", input: map[string]interface{}{ "oneOf": []interface{}{ map[string]interface{}{ "type": []interface{}{"string", "null"}, }, }, }, expected: map[string]interface{}{ "oneOf": []interface{}{ map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{"type": "string"}, map[string]interface{}{"type": "null"}, }, }, }, }, }, { name: "allOf with nullable type arrays", input: map[string]interface{}{ "allOf": []interface{}{ map[string]interface{}{ "type": []interface{}{"string", "null"}, }, }, }, expected: map[string]interface{}{ "allOf": []interface{}{ map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{"type": "string"}, map[string]interface{}{"type": "null"}, }, }, }, }, }, { name: "definitions with nullable type arrays", input: map[string]interface{}{ "definitions": map[string]interface{}{ "myDef": map[string]interface{}{ "type": []interface{}{"string", "null"}, }, }, }, expected: map[string]interface{}{ "definitions": map[string]interface{}{ "myDef": map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{"type": "string"}, map[string]interface{}{"type": "null"}, }, }, }, }, }, { name: "$defs with nullable type arrays", input: map[string]interface{}{ "$defs": map[string]interface{}{ "myDef": map[string]interface{}{ "type": []interface{}{"string", "null"}, }, }, }, expected: map[string]interface{}{ "$defs": map[string]interface{}{ "myDef": map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{"type": "string"}, map[string]interface{}{"type": "null"}, }, }, }, }, }, { name: "complex nested schema - real world example with nullable enum", input: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "action": map[string]interface{}{ "type": "string", "enum": []string{"continue", "transition"}, }, "target_node_id": map[string]interface{}{ "type": []interface{}{"string", "null"}, "description": "The ID of the node to transition to. Required when action is 'transition', null when action is 'continue'", "enum": []string{"NODE-0", "NODE-1", "NODE-2", ""}, }, }, "required": []string{"action"}, }, expected: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "action": map[string]interface{}{ "type": "string", "enum": []string{"continue", "transition"}, }, "target_node_id": map[string]interface{}{ "description": "The ID of the node to transition to. Required when action is 'transition', null when action is 'continue'", "anyOf": []interface{}{ map[string]interface{}{ "type": "string", "enum": []string{"NODE-0", "NODE-1", "NODE-2", ""}, }, map[string]interface{}{"type": "null"}, }, }, }, "required": []string{"action"}, }, }, { name: "nil schema - returns nil", input: nil, expected: nil, }, { name: "empty schema - returns empty", input: map[string]interface{}{}, expected: map[string]interface{}{}, }, { name: "type array with multiple non-null types - converts to anyOf", input: map[string]interface{}{ "type": []interface{}{"string", "integer"}, "description": "A field that can be string or integer", }, expected: map[string]interface{}{ "description": "A field that can be string or integer", "anyOf": []interface{}{ map[string]interface{}{"type": "string"}, map[string]interface{}{"type": "integer"}, }, }, }, { name: "type array with multiple types including null - converts to anyOf with null", input: map[string]interface{}{ "type": []interface{}{"string", "integer", "null"}, "description": "A nullable field that can be string or integer", }, expected: map[string]interface{}{ "description": "A nullable field that can be string or integer", "anyOf": []interface{}{ map[string]interface{}{"type": "string"}, map[string]interface{}{"type": "integer"}, map[string]interface{}{"type": "null"}, }, }, }, { name: "type array with multiple types and enum - filters enum values by type in anyOf branches", input: map[string]interface{}{ "type": []interface{}{"string", "integer"}, "enum": []interface{}{"value1", 123}, }, expected: map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{ "type": "string", "enum": []interface{}{"value1"}, }, map[string]interface{}{ "type": "integer", "enum": []interface{}{123}, }, }, }, }, { name: "nested properties with multi-type arrays - all convert to anyOf", input: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "field1": map[string]interface{}{ "type": []interface{}{"string", "number"}, }, "field2": map[string]interface{}{ "type": []interface{}{"boolean", "null"}, }, }, }, expected: map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "field1": map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{"type": "string"}, map[string]interface{}{"type": "number"}, }, }, "field2": map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{"type": "boolean"}, map[string]interface{}{"type": "null"}, }, }, }, }, }, { name: "real world priority field with mixed string and integer enum - filters correctly", input: map[string]interface{}{ "type": []interface{}{"string", "integer"}, "description": "Priority level - can be a number (1-10) or a string label (low/medium/high)", "enum": []interface{}{"low", "medium", "high", 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, }, expected: map[string]interface{}{ "description": "Priority level - can be a number (1-10) or a string label (low/medium/high)", "anyOf": []interface{}{ map[string]interface{}{ "type": "string", "enum": []interface{}{"low", "medium", "high"}, }, map[string]interface{}{ "type": "integer", "enum": []interface{}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := normalizeSchemaForAnthropic(tt.input) // Compare using JSON marshaling to handle []string vs []interface{} differences // Marshal both to JSON, then unmarshal back to normalized form for comparison // This ensures we compare actual structure, not field ordering gotJSON, err1 := sonic.Marshal(result) wantJSON, err2 := sonic.Marshal(tt.expected) if err1 != nil || err2 != nil { t.Fatalf("Failed to marshal for comparison: got err=%v, want err=%v", err1, err2) } // Unmarshal both back to interface{} to normalize the comparison // This handles both field ordering and []string vs []interface{} differences var gotNormalized, wantNormalized interface{} if err := sonic.Unmarshal(gotJSON, &gotNormalized); err != nil { t.Fatalf("Failed to unmarshal got JSON: %v", err) } if err := sonic.Unmarshal(wantJSON, &wantNormalized); err != nil { t.Fatalf("Failed to unmarshal want JSON: %v", err) } // Now compare the unmarshaled structures if !reflect.DeepEqual(gotNormalized, wantNormalized) { // Pretty print for error message gotJSONPretty, _ := sonic.MarshalIndent(result, "", " ") wantJSONPretty, _ := sonic.MarshalIndent(tt.expected, "", " ") t.Errorf("normalizeSchemaForAnthropic() mismatch:\ngot: %s\nwant: %s", gotJSONPretty, wantJSONPretty) } }) } } func TestConvertChatResponseFormatToAnthropicOutputFormat(t *testing.T) { tests := []struct { name string input *interface{} expected interface{} }{ { name: "chat format with nullable enum gets normalized to anyOf", input: func() *interface{} { val := interface{}(map[string]interface{}{ "type": "json_schema", "json_schema": map[string]interface{}{ "name": "TestSchema", "schema": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "field": map[string]interface{}{ "type": []interface{}{"string", "null"}, "enum": []string{"value1", "value2"}, }, }, }, }, }) return &val }(), expected: map[string]interface{}{ "type": "json_schema", "schema": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "field": map[string]interface{}{ "anyOf": []interface{}{ map[string]interface{}{ "type": "string", "enum": []string{"value1", "value2"}, }, map[string]interface{}{"type": "null"}, }, }, }, }, }, }, { name: "nil input returns nil", input: nil, expected: nil, }, { name: "non-json_schema type returns nil", input: func() *interface{} { val := interface{}(map[string]interface{}{ "type": "json", }) return &val }(), expected: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := convertChatResponseFormatToAnthropicOutputFormat(tt.input) // Compare using JSON marshaling to handle field ordering differences resultJSON, err1 := sonic.Marshal(result) expectedJSON, err2 := sonic.Marshal(tt.expected) if err1 != nil || err2 != nil { t.Fatalf("Failed to marshal for comparison: result err=%v, expected err=%v", err1, err2) } // Unmarshal both back to interface{} to normalize the comparison var resultNormalized, expectedNormalized interface{} if err := sonic.Unmarshal(resultJSON, &resultNormalized); err != nil { t.Fatalf("Failed to unmarshal result JSON: %v", err) } if err := sonic.Unmarshal(expectedJSON, &expectedNormalized); err != nil { t.Fatalf("Failed to unmarshal expected JSON: %v", err) } if !reflect.DeepEqual(resultNormalized, expectedNormalized) { t.Errorf("convertChatResponseFormatToAnthropicOutputFormat() mismatch:\ngot: %+v\nwant: %+v", result, tt.expected) } }) } } func TestValidateToolsForProvider(t *testing.T) { tests := []struct { name string tools []schemas.ResponsesTool provider schemas.ModelProvider expectErr bool }{ { name: "Anthropic allows web_search", tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}}, provider: schemas.Anthropic, expectErr: false, }, { name: "Anthropic allows web_fetch", tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}}, provider: schemas.Anthropic, expectErr: false, }, { name: "Vertex allows web_search", tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}}, provider: schemas.Vertex, expectErr: false, }, { name: "Vertex rejects web_fetch", tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}}, provider: schemas.Vertex, expectErr: true, }, { name: "Vertex rejects code_interpreter", tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeCodeInterpreter}}, provider: schemas.Vertex, expectErr: true, }, { name: "Vertex rejects MCP", tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeMCP}}, provider: schemas.Vertex, expectErr: true, }, { name: "Bedrock rejects web_search", tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebSearch}}, provider: schemas.Bedrock, expectErr: true, }, { name: "Bedrock rejects web_fetch", tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}}, provider: schemas.Bedrock, expectErr: true, }, { name: "Bedrock allows computer_use", tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeComputerUsePreview}}, provider: schemas.Bedrock, expectErr: false, }, { name: "Azure allows everything", tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}, {Type: schemas.ResponsesToolTypeCodeInterpreter}, {Type: schemas.ResponsesToolTypeMCP}}, provider: schemas.Azure, expectErr: false, }, { name: "Unknown provider allows all", tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeWebFetch}}, provider: "custom_provider", expectErr: false, }, { name: "Function tools always allowed", tools: []schemas.ResponsesTool{{Type: schemas.ResponsesToolTypeFunction}}, provider: schemas.Bedrock, expectErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := ValidateToolsForProvider(tt.tools, tt.provider) if tt.expectErr && err == nil { t.Errorf("expected error but got nil") } if !tt.expectErr && err != nil { t.Errorf("unexpected error: %v", err) } }) } } func TestAddMissingBetaHeadersToContext_PerProvider(t *testing.T) { tests := []struct { name string provider schemas.ModelProvider req *AnthropicMessageRequest expectHeaders []string unexpectHeaders []string }{ { name: "Anthropic gets structured outputs header", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ OutputFormat: json.RawMessage(`{"type":"json_schema"}`), }, expectHeaders: []string{AnthropicStructuredOutputsBetaHeader}, }, { name: "Vertex skips structured outputs header", provider: schemas.Vertex, req: &AnthropicMessageRequest{ OutputFormat: json.RawMessage(`{"type":"json_schema"}`), }, unexpectHeaders: []string{AnthropicStructuredOutputsBetaHeader}, }, { name: "Vertex skips MCP header", provider: schemas.Vertex, req: &AnthropicMessageRequest{ MCPServers: []AnthropicMCPServerV2{{URL: "http://example.com"}}, }, unexpectHeaders: []string{AnthropicMCPClientBetaHeader}, }, { name: "Anthropic gets MCP header", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ MCPServers: []AnthropicMCPServerV2{{URL: "http://example.com"}}, }, expectHeaders: []string{AnthropicMCPClientBetaHeader}, }, { name: "Vertex gets compaction header", provider: schemas.Vertex, req: &AnthropicMessageRequest{ ContextManagement: &ContextManagement{ Edits: []ContextManagementEdit{{Type: ContextManagementEditTypeCompact}}, }, }, expectHeaders: []string{AnthropicCompactionBetaHeader}, }, { name: "Bedrock gets compaction header", provider: schemas.Bedrock, req: &AnthropicMessageRequest{ ContextManagement: &ContextManagement{ Edits: []ContextManagementEdit{{Type: ContextManagementEditTypeCompact}}, }, }, expectHeaders: []string{AnthropicCompactionBetaHeader}, }, // Interleaved thinking tests { name: "Anthropic gets interleaved thinking header for enabled", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: schemas.Ptr(2048)}, }, expectHeaders: []string{AnthropicInterleavedThinkingBetaHeader}, }, { name: "Anthropic does not get interleaved thinking header for adaptive", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ Thinking: &AnthropicThinking{Type: "adaptive"}, }, unexpectHeaders: []string{AnthropicInterleavedThinkingBetaHeader}, }, { name: "Vertex gets interleaved thinking header", provider: schemas.Vertex, req: &AnthropicMessageRequest{ Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: schemas.Ptr(2048)}, }, expectHeaders: []string{AnthropicInterleavedThinkingBetaHeader}, }, { name: "Bedrock gets interleaved thinking header", provider: schemas.Bedrock, req: &AnthropicMessageRequest{ Thinking: &AnthropicThinking{Type: "enabled", BudgetTokens: schemas.Ptr(2048)}, }, expectHeaders: []string{AnthropicInterleavedThinkingBetaHeader}, }, { name: "Disabled thinking does not get interleaved thinking header", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ Thinking: &AnthropicThinking{Type: "disabled"}, }, unexpectHeaders: []string{AnthropicInterleavedThinkingBetaHeader}, }, // Fast mode tests — fast mode is Opus 4.6 only (research preview), // so tests must set Model to exercise the path. Non-Opus-4.6 models // are model-gated out regardless of provider flag. { name: "Anthropic gets fast mode header", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ Model: "claude-opus-4-6", Speed: schemas.Ptr("fast"), }, expectHeaders: []string{AnthropicFastModeBetaHeader}, }, { name: "Anthropic skips fast mode header on non-Opus-4.6 model", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ Model: "claude-sonnet-4-6", Speed: schemas.Ptr("fast"), }, unexpectHeaders: []string{AnthropicFastModeBetaHeader}, }, { name: "Bedrock skips fast mode header", provider: schemas.Bedrock, req: &AnthropicMessageRequest{ Model: "claude-opus-4-6", // fast mode is model-gated; set a supporting model so the test actually exercises provider suppression Speed: schemas.Ptr("fast"), }, unexpectHeaders: []string{AnthropicFastModeBetaHeader}, }, { name: "Azure skips fast mode header", provider: schemas.Azure, req: &AnthropicMessageRequest{ Model: "claude-opus-4-6", // fast mode is model-gated; set a supporting model so the test actually exercises provider suppression Speed: schemas.Ptr("fast"), }, unexpectHeaders: []string{AnthropicFastModeBetaHeader}, }, // Fine-grained tool streaming (eager_input_streaming) — per Table 20: // GA on Anthropic / Bedrock / Vertex, Beta on Azure. All four should // auto-inject fine-grained-tool-streaming-2025-05-14 when a tool has // eager_input_streaming: true. { name: "Anthropic gets eager_input_streaming header", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(true)}}, }, expectHeaders: []string{AnthropicEagerInputStreamingBetaHeader}, }, { name: "Bedrock gets eager_input_streaming header", provider: schemas.Bedrock, req: &AnthropicMessageRequest{ Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(true)}}, }, expectHeaders: []string{AnthropicEagerInputStreamingBetaHeader}, }, { name: "Vertex gets eager_input_streaming header", provider: schemas.Vertex, req: &AnthropicMessageRequest{ Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(true)}}, }, expectHeaders: []string{AnthropicEagerInputStreamingBetaHeader}, }, { name: "Azure gets eager_input_streaming header", provider: schemas.Azure, req: &AnthropicMessageRequest{ Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(true)}}, }, expectHeaders: []string{AnthropicEagerInputStreamingBetaHeader}, }, { name: "eager_input_streaming header absent when flag is false", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ Tools: []AnthropicTool{{Name: "t1", EagerInputStreaming: schemas.Ptr(false)}}, }, unexpectHeaders: []string{AnthropicEagerInputStreamingBetaHeader}, }, { name: "eager_input_streaming header absent when unset", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ Tools: []AnthropicTool{{Name: "t1"}}, }, unexpectHeaders: []string{AnthropicEagerInputStreamingBetaHeader}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := schemas.NewBifrostContext(nil, time.Time{}) AddMissingBetaHeadersToContext(ctx, tt.req, tt.provider) var headers []string if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok { headers = extraHeaders[AnthropicBetaHeader] } for _, expected := range tt.expectHeaders { found := false for _, h := range headers { if h == expected { found = true break } } if !found { t.Errorf("expected header %q not found in %v", expected, headers) } } for _, unexpected := range tt.unexpectHeaders { for _, h := range headers { if h == unexpected { t.Errorf("unexpected header %q found in %v", unexpected, headers) } } } }) } } func TestAddMissingBetaHeadersToContext_PassthroughWins(t *testing.T) { // When a same-prefix header is already set from passthrough, auto-injection should NOT add a second version. t.Run("passthrough_mcp_header_prevents_auto_inject", func(t *testing.T) { ctx := schemas.NewBifrostContext(nil, time.Time{}) // Simulate passthrough setting an old MCP header ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, map[string][]string{ "anthropic-beta": {AnthropicMCPClientBetaHeaderDeprecated}, }) // Request has MCP servers, which would normally auto-inject the new header req := &AnthropicMessageRequest{ MCPServers: []AnthropicMCPServerV2{{URL: "http://example.com"}}, } AddMissingBetaHeadersToContext(ctx, req, schemas.Anthropic) extraHeaders := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) betaHeaders := extraHeaders[AnthropicBetaHeader] // Should only have the old header, not both if len(betaHeaders) != 1 { t.Errorf("expected 1 header, got %d: %v", len(betaHeaders), betaHeaders) } if betaHeaders[0] != AnthropicMCPClientBetaHeaderDeprecated { t.Errorf("expected passthrough header %q, got %q", AnthropicMCPClientBetaHeaderDeprecated, betaHeaders[0]) } }) t.Run("passthrough_computer_use_header_prevents_auto_inject", func(t *testing.T) { ctx := schemas.NewBifrostContext(nil, time.Time{}) // Simulate passthrough setting an older computer-use header ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, map[string][]string{ "anthropic-beta": {AnthropicComputerUseBetaHeader20250124}, }) req := &AnthropicMessageRequest{ Tools: []AnthropicTool{{ Type: schemas.Ptr(AnthropicToolTypeComputer20251124), Name: string(AnthropicToolNameComputer), }}, } AddMissingBetaHeadersToContext(ctx, req, schemas.Anthropic) extraHeaders := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) betaHeaders := extraHeaders[AnthropicBetaHeader] if len(betaHeaders) != 1 { t.Errorf("expected 1 header, got %d: %v", len(betaHeaders), betaHeaders) } if betaHeaders[0] != AnthropicComputerUseBetaHeader20250124 { t.Errorf("expected passthrough header %q, got %q", AnthropicComputerUseBetaHeader20250124, betaHeaders[0]) } }) t.Run("no_passthrough_allows_auto_inject", func(t *testing.T) { ctx := schemas.NewBifrostContext(nil, time.Time{}) req := &AnthropicMessageRequest{ MCPServers: []AnthropicMCPServerV2{{URL: "http://example.com"}}, } AddMissingBetaHeadersToContext(ctx, req, schemas.Anthropic) extraHeaders := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string) betaHeaders := extraHeaders[AnthropicBetaHeader] if len(betaHeaders) != 1 || betaHeaders[0] != AnthropicMCPClientBetaHeader { t.Errorf("expected [%q], got %v", AnthropicMCPClientBetaHeader, betaHeaders) } }) } func TestMergeBetaHeaders(t *testing.T) { t.Run("context_extra_headers_case_insensitive_key", func(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), time.Time{}) ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, map[string][]string{ "Anthropic-Beta": {"structured-outputs-2025-11-13"}, }) got := MergeBetaHeaders(nil, ctx) want := []string{"structured-outputs-2025-11-13"} if !slices.Equal(got, want) { t.Fatalf("got %v, want %v", got, want) } }) t.Run("provider_extra_headers_case_insensitive_key", func(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), time.Time{}) got := MergeBetaHeaders(map[string]string{ "Anthropic-Beta": "mcp-client-2025-04-04", }, ctx) want := []string{"mcp-client-2025-04-04"} if !slices.Equal(got, want) { t.Fatalf("got %v, want %v", got, want) } }) t.Run("merges_provider_then_context_deduping_tokens", func(t *testing.T) { ctx := schemas.NewBifrostContext(context.Background(), time.Time{}) ctx.SetValue(schemas.BifrostContextKeyExtraHeaders, map[string][]string{ "ANTHROPIC-BETA": {"foo,bar", "bar,baz"}, }) got := MergeBetaHeaders(map[string]string{ "anthropic-beta": "foo", }, ctx) sort.Strings(got) wantSorted := []string{"bar", "baz", "foo"} if !slices.Equal(got, wantSorted) { t.Fatalf("got %v, want %v", got, wantSorted) } }) } func TestFilterBetaHeadersForProvider(t *testing.T) { allHeaders := []string{ AnthropicComputerUseBetaHeader20251124, AnthropicStructuredOutputsBetaHeader, AnthropicMCPClientBetaHeader, AnthropicPromptCachingScopeBetaHeader, AnthropicCompactionBetaHeader, AnthropicContextManagementBetaHeader, AnthropicAdvancedToolUseBetaHeader, AnthropicFilesAPIBetaHeader, AnthropicInterleavedThinkingBetaHeader, AnthropicSkillsBetaHeader, AnthropicContext1MBetaHeader, AnthropicFastModeBetaHeader, AnthropicRedactThinkingBetaHeader, } containsHeader := func(result []string, h string) bool { for _, r := range result { if r == h { return true } } return false } t.Run("Anthropic/keeps_all_headers", func(t *testing.T) { result := FilterBetaHeadersForProvider(allHeaders, schemas.Anthropic) for _, h := range allHeaders { if !containsHeader(result, h) { t.Errorf("expected header %q to be kept for Anthropic, got %v", h, result) } } }) t.Run("Vertex/drops_unsupported_headers", func(t *testing.T) { unsupported := []string{ AnthropicStructuredOutputsBetaHeader, AnthropicMCPClientBetaHeader, AnthropicPromptCachingScopeBetaHeader, AnthropicAdvancedToolUseBetaHeader, AnthropicFilesAPIBetaHeader, AnthropicSkillsBetaHeader, AnthropicFastModeBetaHeader, AnthropicRedactThinkingBetaHeader, } for _, h := range unsupported { result := FilterBetaHeadersForProvider([]string{h}, schemas.Vertex) if len(result) != 0 { t.Errorf("expected header %q to be dropped for Vertex, got %v", h, result) } } }) t.Run("Vertex/keeps_supported_headers", func(t *testing.T) { supported := []string{ AnthropicComputerUseBetaHeader20251124, AnthropicCompactionBetaHeader, AnthropicContextManagementBetaHeader, AnthropicInterleavedThinkingBetaHeader, AnthropicContext1MBetaHeader, AnthropicEagerInputStreamingBetaHeader, } result := FilterBetaHeadersForProvider(supported, schemas.Vertex) if len(result) != len(supported) { t.Errorf("expected %d headers, got %d: %v", len(supported), len(result), result) } }) t.Run("Bedrock/drops_unsupported_headers", func(t *testing.T) { unsupported := []string{ AnthropicMCPClientBetaHeader, AnthropicPromptCachingScopeBetaHeader, AnthropicAdvancedToolUseBetaHeader, AnthropicFilesAPIBetaHeader, AnthropicSkillsBetaHeader, AnthropicFastModeBetaHeader, AnthropicRedactThinkingBetaHeader, } for _, h := range unsupported { result := FilterBetaHeadersForProvider([]string{h}, schemas.Bedrock) if len(result) != 0 { t.Errorf("expected header %q to be dropped for Bedrock, got %v", h, result) } } }) t.Run("Azure/drops_unsupported_headers", func(t *testing.T) { unsupported := []string{ AnthropicFastModeBetaHeader, } for _, h := range unsupported { result := FilterBetaHeadersForProvider([]string{h}, schemas.Azure) if len(result) != 0 { t.Errorf("expected header %q to be dropped for Azure, got %v", h, result) } } }) t.Run("Azure/keeps_supported_headers", func(t *testing.T) { supported := []string{ AnthropicComputerUseBetaHeader20251124, AnthropicStructuredOutputsBetaHeader, AnthropicMCPClientBetaHeader, AnthropicPromptCachingScopeBetaHeader, AnthropicCompactionBetaHeader, AnthropicContextManagementBetaHeader, AnthropicAdvancedToolUseBetaHeader, AnthropicFilesAPIBetaHeader, AnthropicInterleavedThinkingBetaHeader, AnthropicSkillsBetaHeader, AnthropicContext1MBetaHeader, AnthropicRedactThinkingBetaHeader, AnthropicEagerInputStreamingBetaHeader, } result := FilterBetaHeadersForProvider(supported, schemas.Azure) if len(result) != len(supported) { t.Errorf("expected %d headers, got %d: %v", len(supported), len(result), result) } }) t.Run("Bedrock/keeps_supported_headers", func(t *testing.T) { supported := []string{ AnthropicComputerUseBetaHeader20251124, AnthropicStructuredOutputsBetaHeader, AnthropicCompactionBetaHeader, AnthropicContextManagementBetaHeader, AnthropicInterleavedThinkingBetaHeader, AnthropicContext1MBetaHeader, AnthropicEagerInputStreamingBetaHeader, } result := FilterBetaHeadersForProvider(supported, schemas.Bedrock) if len(result) != len(supported) { t.Errorf("expected %d headers, got %d: %v", len(supported), len(result), result) } }) t.Run("unknown_headers_dropped_for_non_anthropic", func(t *testing.T) { result := FilterBetaHeadersForProvider([]string{"some-future-beta-2025"}, schemas.Vertex) if len(result) != 0 { t.Errorf("expected unknown header to be dropped for Vertex, got %v", result) } }) t.Run("unknown_headers_forwarded_for_anthropic", func(t *testing.T) { headers := []string{"some-future-beta-2025"} result := FilterBetaHeadersForProvider(headers, schemas.Anthropic) if len(result) != len(headers) { t.Errorf("expected unknown header to be forwarded for Anthropic, got %v", result) } }) t.Run("unknown_provider_allows_all", func(t *testing.T) { result := FilterBetaHeadersForProvider(allHeaders, schemas.ModelProvider("custom-provider")) if len(result) != len(allHeaders) { t.Errorf("expected all headers for unknown provider, got %v", result) } }) t.Run("override_enables_unsupported_header", func(t *testing.T) { // redact-thinking is not supported on Vertex by default overrides := map[string]bool{AnthropicRedactThinkingBetaHeaderPrefix: true} result := FilterBetaHeadersForProvider([]string{AnthropicRedactThinkingBetaHeader}, schemas.Vertex, overrides) if len(result) != 1 || result[0] != AnthropicRedactThinkingBetaHeader { t.Errorf("expected override to allow header, got %v", result) } }) t.Run("override_disables_supported_header", func(t *testing.T) { // compaction is supported on Vertex by default; override to false should drop it silently overrides := map[string]bool{"compact-": false} result := FilterBetaHeadersForProvider([]string{AnthropicCompactionBetaHeader}, schemas.Vertex, overrides) if len(result) != 0 { t.Errorf("expected override false to drop supported header, got %v", result) } }) t.Run("override_nil_uses_defaults", func(t *testing.T) { // Passing nil overrides should behave identically to no overrides result := FilterBetaHeadersForProvider([]string{AnthropicCompactionBetaHeader}, schemas.Vertex, nil) if len(result) != 1 { t.Errorf("expected default behavior with nil overrides, got %v", result) } }) // Custom override tests for all providers customOverrideProviders := []struct { provider schemas.ModelProvider expectForwardNoOverride bool // unknown headers forwarded without override? }{ {schemas.Anthropic, true}, {schemas.Vertex, false}, {schemas.Bedrock, false}, {schemas.Azure, false}, } for _, tc := range customOverrideProviders { tc := tc t.Run(fmt.Sprintf("%s/custom_override_enables_unknown_header", tc.provider), func(t *testing.T) { overrides := map[string]bool{"new-feature-": true} result := FilterBetaHeadersForProvider([]string{"new-feature-2026-01-01"}, tc.provider, overrides) if len(result) != 1 || result[0] != "new-feature-2026-01-01" { t.Errorf("expected custom override to allow header on %s, got %v", tc.provider, result) } }) t.Run(fmt.Sprintf("%s/custom_override_disables_unknown_header", tc.provider), func(t *testing.T) { overrides := map[string]bool{"new-feature-": false} result := FilterBetaHeadersForProvider([]string{"new-feature-2026-01-01"}, tc.provider, overrides) if len(result) != 0 { t.Errorf("expected custom override false to drop header on %s, got %v", tc.provider, result) } }) t.Run(fmt.Sprintf("%s/custom_override_no_match_still_handled_correctly", tc.provider), func(t *testing.T) { overrides := map[string]bool{"new-feature-": true} result := FilterBetaHeadersForProvider([]string{"other-thing-2026"}, tc.provider, overrides) if tc.expectForwardNoOverride { if len(result) != 1 { t.Errorf("expected unknown header forwarded to %s, got %v", tc.provider, result) } } else { if len(result) != 0 { t.Errorf("expected unknown header dropped for %s, got %v", tc.provider, result) } } }) t.Run(fmt.Sprintf("%s/custom_override_with_multiple_prefixes", tc.provider), func(t *testing.T) { overrides := map[string]bool{ "alpha-": true, "beta-": false, "gamma-": true, } result := FilterBetaHeadersForProvider([]string{"alpha-2026-01"}, tc.provider, overrides) if len(result) != 1 { t.Errorf("expected alpha- allowed on %s, got %v", tc.provider, result) } result = FilterBetaHeadersForProvider([]string{"beta-2026-01"}, tc.provider, overrides) if len(result) != 0 { t.Errorf("expected beta- dropped on %s, got %v", tc.provider, result) } result = FilterBetaHeadersForProvider([]string{"gamma-2026-01"}, tc.provider, overrides) if len(result) != 1 { t.Errorf("expected gamma- allowed on %s, got %v", tc.provider, result) } }) } } func TestStripUnsupportedFieldsFromRawBody(t *testing.T) { t.Run("bedrock_strips_new_request_level_fields", func(t *testing.T) { // Raw body with every new typed field. Targeting Bedrock: speed (no FastMode), // inference_geo (no InferenceGeo), mcp_servers (no MCP), container.skills // (no Skills), top-level cache_control.scope (no PromptCachingScope), // output_config.task_budget (no TaskBudgets). All should be stripped. input := []byte(`{ "model":"claude-opus-4-6", "speed":"fast", "inference_geo":"us-east-1", "mcp_servers":[{"type":"url","url":"https://example.com","name":"x"}], "container":{"id":"c-1","skills":[{"skill_id":"s","type":"anthropic"}]}, "cache_control":{"type":"ephemeral","ttl":"5m","scope":"user"}, "output_config":{"task_budget":{"type":"tokens","total":20000}} }`) result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6") if err != nil { t.Fatalf("unexpected error: %v", err) } for _, path := range []string{"speed", "inference_geo", "mcp_servers", "container", "cache_control.scope", "output_config.task_budget"} { if providerUtils.JSONFieldExists(result, path) { t.Errorf("expected %q to be stripped for Bedrock, got: %s", path, string(result)) } } // Confirm non-scope cache_control fields are retained. if !providerUtils.JSONFieldExists(result, "cache_control.ttl") { t.Errorf("expected cache_control.ttl to survive, got: %s", string(result)) } }) t.Run("vertex_strips_mcp_strict_and_input_examples_via_feature_check", func(t *testing.T) { // Vertex: no MCP, no InputExamples, no StructuredOutputs. // tool.strict stripped; tool.input_examples stripped; mcp_servers stripped. // tool.cache_control.scope stripped (Vertex has no PromptCachingScope). input := []byte(`{ "model":"claude-sonnet-4-6", "mcp_servers":[{"type":"url","url":"u","name":"n"}], "tools":[{"name":"t1","strict":true,"input_examples":[{"input":{"a":1}}],"cache_control":{"type":"ephemeral","scope":"user"}}] }`) result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Vertex, "claude-sonnet-4-6") if err != nil { t.Fatalf("unexpected error: %v", err) } for _, path := range []string{"mcp_servers", "tools.0.strict", "tools.0.input_examples", "tools.0.cache_control.scope"} { if providerUtils.JSONFieldExists(result, path) { t.Errorf("expected %q to be stripped for Vertex, got: %s", path, string(result)) } } if !providerUtils.JSONFieldExists(result, "tools.0.name") { t.Errorf("expected tool name to survive") } }) t.Run("bedrock_keeps_input_examples_via_standalone_flag", func(t *testing.T) { // Bedrock has InputExamples=true via tool-examples-2025-10-29 but // AdvancedToolUse=false. input_examples should be KEPT; defer_loading // and allowed_callers (bundle-only) should be STRIPPED. input := []byte(`{ "model":"claude-opus-4-6", "tools":[{"name":"t1","input_examples":[{"input":{"a":1}}],"defer_loading":true,"allowed_callers":["direct"]}] }`) result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6") if err != nil { t.Fatalf("unexpected error: %v", err) } if !providerUtils.JSONFieldExists(result, "tools.0.input_examples") { t.Errorf("expected tools[0].input_examples to survive on Bedrock, got: %s", string(result)) } for _, path := range []string{"tools.0.defer_loading", "tools.0.allowed_callers"} { if providerUtils.JSONFieldExists(result, path) { t.Errorf("expected %q to be stripped for Bedrock (AdvancedToolUse bundle unsupported), got: %s", path, string(result)) } } }) t.Run("speed_stripped_on_non_opus_46_even_on_anthropic", func(t *testing.T) { // Model gate: fast-mode is Opus 4.6 only per docs. Even on Anthropic // direct where FastMode=true, targeting a different model must strip. input := []byte(`{"model":"claude-sonnet-4-6","speed":"fast"}`) result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Anthropic, "claude-sonnet-4-6") if err != nil { t.Fatalf("unexpected error: %v", err) } if providerUtils.JSONFieldExists(result, "speed") { t.Errorf("expected speed stripped for non-Opus-4.6 model on Anthropic, got: %s", string(result)) } }) t.Run("anthropic_direct_is_noop", func(t *testing.T) { // Anthropic supports everything — body should survive untouched. input := []byte(`{"model":"claude-opus-4-6","speed":"fast","mcp_servers":[{"type":"url","url":"u","name":"n"}],"container":{"id":"c"},"tools":[{"name":"t","defer_loading":true,"input_examples":[{"input":{"a":1}}]}]}`) result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Anthropic, "claude-opus-4-6") if err != nil { t.Fatalf("unexpected error: %v", err) } for _, path := range []string{"speed", "mcp_servers", "container", "tools.0.defer_loading", "tools.0.input_examples"} { if !providerUtils.JSONFieldExists(result, path) { t.Errorf("expected %q preserved on Anthropic direct, got: %s", path, string(result)) } } }) t.Run("nested_scope_stripped_on_messages_and_system", func(t *testing.T) { // Nested scope on system blocks and message blocks must also be stripped // when the provider lacks PromptCachingScope. input := []byte(`{ "model":"claude-opus-4-6", "system":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"user"}}], "messages":[{"role":"user","content":[{"type":"text","text":"q","cache_control":{"type":"ephemeral","scope":"global"}}]}] }`) result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6") if err != nil { t.Fatalf("unexpected error: %v", err) } for _, path := range []string{"system.0.cache_control.scope", "messages.0.content.0.cache_control.scope"} { if providerUtils.JSONFieldExists(result, path) { t.Errorf("expected nested %q stripped, got: %s", path, string(result)) } } }) t.Run("unknown_provider_is_safe_noop", func(t *testing.T) { input := []byte(`{"model":"claude-opus-4-6","speed":"fast"}`) result, err := stripUnsupportedFieldsFromRawBody(input, schemas.ModelProvider("custom"), "claude-opus-4-6") if err != nil { t.Fatalf("unexpected error: %v", err) } if !providerUtils.JSONFieldExists(result, "speed") { t.Errorf("expected speed preserved for unknown provider (safe default), got: %s", string(result)) } }) t.Run("container_empty_skills_stripped_but_container_preserved", func(t *testing.T) { // Skills=false provider (Bedrock), ContainerBasic=true. // skills:[] is a caller oversight — strip the empty key, preserve container. input := []byte(`{"model":"claude-opus-4-6","container":{"id":"c-1","skills":[]}}`) result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6") if err != nil { t.Fatalf("unexpected error: %v", err) } if providerUtils.JSONFieldExists(result, "container.skills") { t.Errorf("expected empty container.skills stripped on Skills=false provider, got: %s", string(result)) } if !providerUtils.JSONFieldExists(result, "container.id") { t.Errorf("expected container.id preserved (bare form still valid), got: %s", string(result)) } }) t.Run("container_nonempty_skills_drops_whole_container", func(t *testing.T) { // Non-empty skills signals caller intent; provider doesn't support — drop container. input := []byte(`{"model":"claude-opus-4-6","container":{"id":"c-1","skills":[{"skill_id":"s","type":"anthropic"}]}}`) result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Bedrock, "claude-opus-4-6") if err != nil { t.Fatalf("unexpected error: %v", err) } if providerUtils.JSONFieldExists(result, "container") { t.Errorf("expected whole container dropped for non-empty skills on Skills=false, got: %s", string(result)) } }) t.Run("container_empty_skills_on_skills_capable_provider_preserved", func(t *testing.T) { // On Anthropic direct (Skills=true), the empty skills array must be preserved // as-is — our strip logic only fires when !features.Skills. input := []byte(`{"model":"claude-opus-4-6","container":{"id":"c-1","skills":[]}}`) result, err := stripUnsupportedFieldsFromRawBody(input, schemas.Anthropic, "claude-opus-4-6") if err != nil { t.Fatalf("unexpected error: %v", err) } if !providerUtils.JSONFieldExists(result, "container.skills") { t.Errorf("expected container.skills preserved on Skills=true provider, got: %s", string(result)) } }) } // TestStripUnsupportedAnthropicFields_ContainerSkillsGating mirrors the raw-path // tests above on the typed path — ensures the typed sanitizer treats explicit // empty skills arrays as a stripable (not drop-triggering) signal. func TestStripUnsupportedAnthropicFields_ContainerSkillsGating(t *testing.T) { t.Run("empty_skills_on_skills_false_provider_strips_skills_keeps_container", func(t *testing.T) { req := &AnthropicMessageRequest{ Model: "claude-opus-4-6", Container: &AnthropicContainer{ ContainerObject: &AnthropicContainerObject{ ID: schemas.Ptr("c-1"), Skills: []AnthropicContainerSkill{}, // explicit empty }, }, } stripUnsupportedAnthropicFields(req, schemas.Bedrock, "claude-opus-4-6") if req.Container == nil { t.Fatalf("expected container preserved (bare form valid with empty skills), got nil") } if req.Container.ContainerObject == nil || req.Container.ContainerObject.Skills != nil { t.Errorf("expected skills cleared on Skills=false, got %v", req.Container.ContainerObject) } }) t.Run("nonempty_skills_on_skills_false_provider_drops_container", func(t *testing.T) { req := &AnthropicMessageRequest{ Model: "claude-opus-4-6", Container: &AnthropicContainer{ ContainerObject: &AnthropicContainerObject{ ID: schemas.Ptr("c-1"), Skills: []AnthropicContainerSkill{{SkillID: "s", Type: "anthropic"}}, }, }, } stripUnsupportedAnthropicFields(req, schemas.Bedrock, "claude-opus-4-6") if req.Container != nil { t.Errorf("expected whole container dropped for non-empty skills on Skills=false, got %v", req.Container) } }) t.Run("empty_skills_on_skills_true_provider_preserved", func(t *testing.T) { req := &AnthropicMessageRequest{ Model: "claude-opus-4-6", Container: &AnthropicContainer{ ContainerObject: &AnthropicContainerObject{ ID: schemas.Ptr("c-1"), Skills: []AnthropicContainerSkill{}, }, }, } stripUnsupportedAnthropicFields(req, schemas.Anthropic, "claude-opus-4-6") if req.Container == nil || req.Container.ContainerObject == nil { t.Fatalf("expected container preserved on Skills=true provider, got %v", req.Container) } if req.Container.ContainerObject.Skills == nil { t.Errorf("expected empty skills preserved on Skills=true provider (not nilled)") } }) } func TestStripAutoInjectableTools(t *testing.T) { t.Run("code_execution_without_web_search_preserved", func(t *testing.T) { // code_execution alone should NOT be stripped (no web_search/web_fetch to trigger auto-injection) input := []byte(`{"model":"claude-opus-4-6","tools":[{"type":"custom","name":"my_tool"},{"type":"code_execution_20250825","name":"code_execution"}]}`) result, err := StripAutoInjectableTools(input) if err != nil { t.Fatalf("unexpected error: %v", err) } tools := providerUtils.GetJSONField(result, "tools") arr := tools.Array() if len(arr) != 2 { t.Fatalf("expected 2 tools (preserved), got %d", len(arr)) } }) t.Run("code_execution_with_web_search_stripped", func(t *testing.T) { // code_execution should be stripped when web_search is present (auto-injection conflict) input := []byte(`{"tools":[{"type":"code_execution_20250825","name":"code_execution"},{"type":"web_search_20260209","name":"web_search"},{"type":"custom","name":"my_tool"}]}`) result, err := StripAutoInjectableTools(input) if err != nil { t.Fatalf("unexpected error: %v", err) } tools := providerUtils.GetJSONField(result, "tools") arr := tools.Array() if len(arr) != 2 { t.Fatalf("expected 2 tools, got %d", len(arr)) } if arr[0].Get("name").String() != "web_search" { t.Errorf("expected first tool to be 'web_search', got '%s'", arr[0].Get("name").String()) } if arr[1].Get("name").String() != "my_tool" { t.Errorf("expected second tool to be 'my_tool', got '%s'", arr[1].Get("name").String()) } }) t.Run("code_execution_with_web_fetch_stripped", func(t *testing.T) { // code_execution should be stripped when web_fetch is present input := []byte(`{"tools":[{"type":"code_execution_20250825","name":"code_execution"},{"type":"web_fetch_20250305","name":"web_fetch"},{"type":"custom","name":"my_tool"}]}`) result, err := StripAutoInjectableTools(input) if err != nil { t.Fatalf("unexpected error: %v", err) } tools := providerUtils.GetJSONField(result, "tools") arr := tools.Array() if len(arr) != 2 { t.Fatalf("expected 2 tools, got %d", len(arr)) } if arr[0].Get("name").String() != "web_fetch" { t.Errorf("expected first tool to be 'web_fetch', got '%s'", arr[0].Get("name").String()) } if arr[1].Get("name").String() != "my_tool" { t.Errorf("expected second tool to be 'my_tool', got '%s'", arr[1].Get("name").String()) } }) t.Run("web_search_alone_preserved", func(t *testing.T) { // web_search without code_execution should be preserved entirely input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"type":"custom","name":"search"}]}`) result, err := StripAutoInjectableTools(input) if err != nil { t.Fatalf("unexpected error: %v", err) } tools := providerUtils.GetJSONField(result, "tools") arr := tools.Array() if len(arr) != 2 { t.Fatalf("expected 2 tools (preserved), got %d", len(arr)) } }) t.Run("web_fetch_alone_preserved", func(t *testing.T) { // web_fetch without code_execution should be preserved entirely input := []byte(`{"tools":[{"type":"web_fetch_20250305","name":"web_fetch"},{"type":"custom","name":"fetch"}]}`) result, err := StripAutoInjectableTools(input) if err != nil { t.Fatalf("unexpected error: %v", err) } tools := providerUtils.GetJSONField(result, "tools") arr := tools.Array() if len(arr) != 2 { t.Fatalf("expected 2 tools (preserved), got %d", len(arr)) } }) t.Run("preserves_custom_tools_only", func(t *testing.T) { input := []byte(`{"tools":[{"type":"custom","name":"tool_a"},{"type":"custom","name":"tool_b"}]}`) result, err := StripAutoInjectableTools(input) if err != nil { t.Fatalf("unexpected error: %v", err) } tools := providerUtils.GetJSONField(result, "tools") arr := tools.Array() if len(arr) != 2 { t.Fatalf("expected 2 tools, got %d", len(arr)) } }) t.Run("no_tools_key", func(t *testing.T) { input := []byte(`{"model":"claude-opus-4-6","messages":[]}`) result, err := StripAutoInjectableTools(input) if err != nil { t.Fatalf("unexpected error: %v", err) } if string(result) != string(input) { t.Errorf("expected body unchanged, got %s", string(result)) } }) t.Run("empty_tools_array", func(t *testing.T) { input := []byte(`{"tools":[]}`) result, err := StripAutoInjectableTools(input) if err != nil { t.Fatalf("unexpected error: %v", err) } if string(result) != string(input) { t.Errorf("expected body unchanged, got %s", string(result)) } }) t.Run("code_execution_and_web_search_only_strips_code_execution", func(t *testing.T) { // When only code_execution + web_search, strip code_execution, keep web_search input := []byte(`{"model":"test","tools":[{"type":"code_execution_20250825","name":"code_execution"},{"type":"web_search_20250305","name":"web_search"}]}`) result, err := StripAutoInjectableTools(input) if err != nil { t.Fatalf("unexpected error: %v", err) } tools := providerUtils.GetJSONField(result, "tools") arr := tools.Array() if len(arr) != 1 { t.Fatalf("expected 1 tool, got %d", len(arr)) } if arr[0].Get("name").String() != "web_search" { t.Errorf("expected remaining tool to be 'web_search', got '%s'", arr[0].Get("name").String()) } }) t.Run("strips_code_execution_keeps_web_search_and_custom", func(t *testing.T) { input := []byte(`{"tools":[{"type":"code_execution_20250825","name":"code_execution"},{"type":"custom","name":"my_tool"},{"type":"web_search_20260209","name":"web_search"},{"type":"custom","name":"other_tool"}]}`) result, err := StripAutoInjectableTools(input) if err != nil { t.Fatalf("unexpected error: %v", err) } tools := providerUtils.GetJSONField(result, "tools") arr := tools.Array() if len(arr) != 3 { t.Fatalf("expected 3 tools, got %d", len(arr)) } if arr[0].Get("name").String() != "my_tool" { t.Errorf("expected first tool to be 'my_tool', got '%s'", arr[0].Get("name").String()) } if arr[1].Get("name").String() != "web_search" { t.Errorf("expected second tool to be 'web_search', got '%s'", arr[1].Get("name").String()) } if arr[2].Get("name").String() != "other_tool" { t.Errorf("expected third tool to be 'other_tool', got '%s'", arr[2].Get("name").String()) } }) } func TestAnthropicToolUnmarshalJSON_MCPToolset(t *testing.T) { t.Run("mcp_toolset is properly unmarshaled", func(t *testing.T) { data := []byte(`{ "type": "mcp_toolset", "mcp_server_name": "example-mcp", "default_config": {"enabled": false}, "configs": { "search_events": {"enabled": true}, "create_event": {"enabled": true, "defer_loading": true} } }`) var tool AnthropicTool if err := sonic.Unmarshal(data, &tool); err != nil { t.Fatalf("unexpected unmarshal error: %v", err) } if tool.MCPToolset == nil { t.Fatal("expected MCPToolset to be populated, got nil") } if tool.MCPToolset.Type != "mcp_toolset" { t.Errorf("expected type 'mcp_toolset', got %q", tool.MCPToolset.Type) } if tool.MCPToolset.MCPServerName != "example-mcp" { t.Errorf("expected mcp_server_name 'example-mcp', got %q", tool.MCPToolset.MCPServerName) } if tool.MCPToolset.DefaultConfig == nil || tool.MCPToolset.DefaultConfig.Enabled == nil || *tool.MCPToolset.DefaultConfig.Enabled != false { t.Error("expected default_config.enabled to be false") } if len(tool.MCPToolset.Configs) != 2 { t.Fatalf("expected 2 configs, got %d", len(tool.MCPToolset.Configs)) } if tool.MCPToolset.Configs["search_events"] == nil || *tool.MCPToolset.Configs["search_events"].Enabled != true { t.Error("expected search_events to be enabled") } if tool.MCPToolset.Configs["create_event"] == nil || tool.MCPToolset.Configs["create_event"].DeferLoading == nil || *tool.MCPToolset.Configs["create_event"].DeferLoading != true { t.Error("expected create_event defer_loading to be true") } }) t.Run("regular tool is not affected by mcp_toolset unmarshal", func(t *testing.T) { data := []byte(`{ "name": "get_weather", "description": "Get weather info", "input_schema": {"type": "object", "properties": {}} }`) var tool AnthropicTool if err := sonic.Unmarshal(data, &tool); err != nil { t.Fatalf("unexpected unmarshal error: %v", err) } if tool.MCPToolset != nil { t.Error("expected MCPToolset to be nil for regular tool") } if tool.Name != "get_weather" { t.Errorf("expected name 'get_weather', got %q", tool.Name) } }) t.Run("mcp_toolset round-trips through marshal/unmarshal", func(t *testing.T) { original := AnthropicTool{ MCPToolset: &AnthropicMCPToolsetTool{ Type: "mcp_toolset", MCPServerName: "test-server", DefaultConfig: &AnthropicMCPToolsetConfig{Enabled: new(false)}, Configs: map[string]*AnthropicMCPToolsetConfig{ "tool_a": {Enabled: new(true)}, }, }, } marshaled, err := json.Marshal(original) if err != nil { t.Fatalf("unexpected marshal error: %v", err) } var restored AnthropicTool if err := sonic.Unmarshal(marshaled, &restored); err != nil { t.Fatalf("unexpected unmarshal error: %v", err) } if restored.MCPToolset == nil { t.Fatal("expected MCPToolset to be populated after round-trip") } if restored.MCPToolset.MCPServerName != "test-server" { t.Errorf("expected mcp_server_name 'test-server', got %q", restored.MCPToolset.MCPServerName) } if len(restored.MCPToolset.Configs) != 1 { t.Fatalf("expected 1 config, got %d", len(restored.MCPToolset.Configs)) } }) t.Run("tools array with mixed regular and mcp_toolset tools", func(t *testing.T) { data := []byte(`[ {"name": "get_weather", "description": "Get weather"}, {"type": "mcp_toolset", "mcp_server_name": "my-mcp"}, {"type": "computer_20251124", "name": "computer"} ]`) var tools []AnthropicTool if err := sonic.Unmarshal(data, &tools); err != nil { t.Fatalf("unexpected unmarshal error: %v", err) } if len(tools) != 3 { t.Fatalf("expected 3 tools, got %d", len(tools)) } // First: regular tool if tools[0].Name != "get_weather" { t.Errorf("expected first tool name 'get_weather', got %q", tools[0].Name) } if tools[0].MCPToolset != nil { t.Error("expected first tool MCPToolset to be nil") } // Second: mcp_toolset if tools[1].MCPToolset == nil { t.Fatal("expected second tool MCPToolset to be populated") } if tools[1].MCPToolset.MCPServerName != "my-mcp" { t.Errorf("expected mcp_server_name 'my-mcp', got %q", tools[1].MCPToolset.MCPServerName) } // Third: typed tool (computer) if tools[2].MCPToolset != nil { t.Error("expected third tool MCPToolset to be nil") } }) } func TestGetRequestBodyForResponses_RawBodyStripsFallbacks(t *testing.T) { rawBody := []byte(`{"model":"claude-sonnet-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"fallbacks":["claude-haiku-4-5"],"temperature":0.7}`) ctx := schemas.NewBifrostContext(nil, time.Time{}) ctx.SetValue(schemas.BifrostContextKeyUseRawRequestBody, true) request := &schemas.BifrostResponsesRequest{ Provider: schemas.Anthropic, Model: "claude-sonnet-4-5", RawRequestBody: rawBody, } result, bifrostErr := getRequestBodyForResponses(ctx, request, false, nil) if bifrostErr != nil { t.Fatalf("unexpected error: %v", bifrostErr) } if providerUtils.GetJSONField(result, "fallbacks").Exists() { t.Error("expected 'fallbacks' to be absent from raw-body output") } // Other fields must survive the round-trip if !providerUtils.GetJSONField(result, "model").Exists() { t.Error("expected 'model' to be present") } if !providerUtils.GetJSONField(result, "max_tokens").Exists() { t.Error("expected 'max_tokens' to be present") } if !providerUtils.GetJSONField(result, "temperature").Exists() { t.Error("expected 'temperature' to be present") } } func TestApplyMCPToolsetConfigToBifrostTool(t *testing.T) { t.Run("allowlist pattern merges correctly", func(t *testing.T) { bifrostTool := &schemas.ResponsesTool{ Type: schemas.ResponsesToolTypeMCP, ResponsesToolMCP: &schemas.ResponsesToolMCP{ ServerLabel: "test-server", ServerURL: schemas.Ptr("https://example.com/mcp"), }, } toolset := &AnthropicMCPToolsetTool{ Type: "mcp_toolset", MCPServerName: "test-server", DefaultConfig: &AnthropicMCPToolsetConfig{Enabled: schemas.Ptr(false)}, Configs: map[string]*AnthropicMCPToolsetConfig{ "search": {Enabled: new(true)}, "create": {Enabled: schemas.Ptr(true)}, "delete": {Enabled: schemas.Ptr(false)}, }, } applyMCPToolsetConfigToBifrostTool(bifrostTool, toolset) if bifrostTool.ResponsesToolMCP.AllowedTools == nil { t.Fatal("expected AllowedTools to be set") } allowedNames := bifrostTool.ResponsesToolMCP.AllowedTools.ToolNames if len(allowedNames) != 2 { t.Fatalf("expected 2 allowed tools, got %d: %v", len(allowedNames), allowedNames) } // Check that both "search" and "create" are present (order may vary due to map iteration) found := map[string]bool{} for _, name := range allowedNames { found[name] = true } if !found["search"] || !found["create"] { t.Errorf("expected allowed tools to contain 'search' and 'create', got %v", allowedNames) } }) t.Run("all enabled by default does not set allowlist", func(t *testing.T) { bifrostTool := &schemas.ResponsesTool{ Type: schemas.ResponsesToolTypeMCP, ResponsesToolMCP: &schemas.ResponsesToolMCP{ ServerLabel: "test-server", }, } toolset := &AnthropicMCPToolsetTool{ Type: "mcp_toolset", MCPServerName: "test-server", // No default_config (defaults to enabled=true) } applyMCPToolsetConfigToBifrostTool(bifrostTool, toolset) if bifrostTool.ResponsesToolMCP.AllowedTools != nil { t.Error("expected AllowedTools to be nil when all tools are enabled by default") } }) t.Run("nil inputs are handled safely", func(t *testing.T) { // Should not panic applyMCPToolsetConfigToBifrostTool(nil, nil) applyMCPToolsetConfigToBifrostTool(&schemas.ResponsesTool{}, nil) }) } func TestSupportsAdaptiveThinking(t *testing.T) { tests := []struct { model string expected bool }{ {"claude-opus-4-7-20260401", true}, {"claude-opus-4.7-20260401", true}, {"claude-opus-4-6-20250514", true}, {"claude-opus-4.6-20250514", true}, {"claude-sonnet-4-6-20250514", true}, {"claude-sonnet-4.6-20250514", true}, {"claude-opus-4-5-20241022", false}, {"claude-sonnet-4-5-20241022", false}, {"claude-haiku-4-6-20250514", false}, // haiku does not support adaptive {"claude-haiku-4-7-20260401", false}, // haiku, not opus {"", false}, } for _, tt := range tests { t.Run(tt.model, func(t *testing.T) { got := SupportsAdaptiveThinking(tt.model) if got != tt.expected { t.Errorf("SupportsAdaptiveThinking(%q) = %v, want %v", tt.model, got, tt.expected) } }) } } func TestAddMissingBetaHeadersToContext_TaskBudgets(t *testing.T) { tests := []struct { name string provider schemas.ModelProvider req *AnthropicMessageRequest expectHeaders []string unexpectHeaders []string }{ { name: "Anthropic gets task-budgets header when task_budget set", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ OutputConfig: &AnthropicOutputConfig{ TaskBudget: &AnthropicTaskBudget{Type: "tokens", Total: 50000}, }, }, expectHeaders: []string{AnthropicTaskBudgetsBetaHeader}, }, { name: "Vertex does not get task-budgets header when task_budget set", provider: schemas.Vertex, req: &AnthropicMessageRequest{ OutputConfig: &AnthropicOutputConfig{ TaskBudget: &AnthropicTaskBudget{Type: "tokens", Total: 50000}, }, }, unexpectHeaders: []string{AnthropicTaskBudgetsBetaHeader}, }, { name: "no task-budgets header when task_budget is nil", provider: schemas.Anthropic, req: &AnthropicMessageRequest{ OutputConfig: &AnthropicOutputConfig{}, }, unexpectHeaders: []string{AnthropicTaskBudgetsBetaHeader}, }, { name: "no task-budgets header when output_config is nil", provider: schemas.Anthropic, req: &AnthropicMessageRequest{}, unexpectHeaders: []string{AnthropicTaskBudgetsBetaHeader}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := schemas.NewBifrostContext(nil, time.Time{}) AddMissingBetaHeadersToContext(ctx, tt.req, tt.provider) var headers []string if extraHeaders, ok := ctx.Value(schemas.BifrostContextKeyExtraHeaders).(map[string][]string); ok { headers = extraHeaders[AnthropicBetaHeader] } for _, expected := range tt.expectHeaders { found := false for _, h := range headers { if h == expected { found = true break } } if !found { t.Errorf("expected header %q not found in %v", expected, headers) } } for _, unexpected := range tt.unexpectHeaders { for _, h := range headers { if h == unexpected { t.Errorf("unexpected header %q found in %v", unexpected, headers) } } } }) } }