package utils
import (
"bytes"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"strings"
"testing"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
func TestRewriteJSONModelValue(t *testing.T) {
in := []byte(`{"model":"openai/gpt-5","messages":[{"role":"user","content":"x"}]}`)
out, changed := rewriteJSONModelValue(in, "openai/gpt-5", "gpt-5")
if !changed {
t.Fatal("expected model rewrite to occur")
}
if strings.Contains(string(out), `"model":"openai/gpt-5"`) {
t.Fatalf("expected prefixed model to be removed, got: %s", string(out))
}
if !strings.Contains(string(out), `"model":"gpt-5"`) {
t.Fatalf("expected rewritten model, got: %s", string(out))
}
}
func TestApplyLargePayloadRequestBodyWithModelNormalization(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
payload := `{"model":"openai/gpt-5","messages":[{"role":"user","content":"hello"}]}`
ctx.SetValue(schemas.BifrostContextKeyLargePayloadMode, true)
ctx.SetValue(
schemas.BifrostContextKeyLargePayloadReader,
strings.NewReader(payload),
)
ctx.SetValue(schemas.BifrostContextKeyLargePayloadContentLength, len(payload))
ctx.SetValue(schemas.BifrostContextKeyLargePayloadContentType, "application/json")
ctx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, &schemas.LargePayloadMetadata{
Model: "openai/gpt-5",
})
req := &fasthttp.Request{}
if !ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, schemas.OpenAI) {
t.Fatal("expected large payload body to be applied")
}
body := string(req.Body())
if strings.Contains(body, "openai/gpt-5") {
t.Fatalf("expected rewritten model in body, got: %s", body)
}
if !strings.Contains(body, `"model":"gpt-5"`) {
t.Fatalf("expected normalized model in body, got: %s", body)
}
}
// TestHandleProviderAPIError_RawResponseIncluded verifies that HandleProviderAPIError
// always includes the raw response body in BifrostError.ExtraFields.RawResponse
func TestHandleProviderAPIError_RawResponseIncluded(t *testing.T) {
tests := []struct {
name string
statusCode int
body []byte
contentType string
description string
}{
{
name: "Decode failure",
statusCode: 500,
body: []byte{0xFF, 0xFE}, // Invalid gzip-compressed data
contentType: "application/json",
description: "Should include raw response when decode fails",
},
{
name: "Empty response",
statusCode: 502,
body: []byte(""),
contentType: "application/json",
description: "Should include empty raw response",
},
{
name: "Valid JSON error",
statusCode: 400,
body: []byte(`{"error": {"message": "Invalid API key"}}`),
contentType: "application/json",
description: "Should include raw response for valid JSON",
},
{
name: "HTML error response",
statusCode: 503,
body: []byte(`
Service Unavailable
`),
contentType: "text/html",
description: "Should include raw response for HTML errors",
},
{
name: "Unparseable non-HTML response",
statusCode: 400,
body: []byte(`This is not JSON or HTML`),
contentType: "text/plain",
description: "Should include raw response for unparseable content",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := &fasthttp.Response{}
resp.SetStatusCode(tt.statusCode)
resp.Header.Set("Content-Type", tt.contentType)
// Set Content-Encoding: gzip for decode failure test to trigger BodyGunzip() error
if tt.name == "Decode failure" {
resp.Header.Set("Content-Encoding", "gzip")
}
resp.SetBody(tt.body)
var errorResp map[string]interface{}
bifrostErr := HandleProviderAPIError(resp, &errorResp)
if bifrostErr == nil {
t.Fatal("HandleProviderAPIError() returned nil")
}
if bifrostErr.ExtraFields.RawResponse == nil {
t.Errorf("%s: RawResponse is nil, expected it to be set", tt.description)
}
// Verify the raw response matches the body (for non-decode-failure cases)
if tt.name != "Decode failure" {
rawResponseBytes, err := sonic.Marshal(bifrostErr.ExtraFields.RawResponse)
if err != nil {
t.Errorf("Failed to marshal RawResponse: %v", err)
}
// The RawResponse should contain the body content
if len(rawResponseBytes) == 0 {
t.Errorf("%s: RawResponse is empty", tt.description)
}
}
t.Logf("✓ %s: RawResponse is set", tt.name)
})
}
}
// TestEnrichError_PreservesExistingRawResponse verifies that EnrichError preserves
// existing RawResponse from the error's ExtraFields when responseBody parameter is nil
func TestEnrichError_PreservesExistingRawResponse(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
existingRawResponse := map[string]interface{}{
"error": map[string]interface{}{
"message": "Original error from provider",
"code": "invalid_api_key",
},
}
bifrostErr := &schemas.BifrostError{
IsBifrostError: false,
StatusCode: schemas.Ptr(401),
Error: &schemas.ErrorField{
Message: "Authentication failed",
},
ExtraFields: schemas.BifrostErrorExtraFields{
RawResponse: existingRawResponse,
},
}
requestBody := []byte(`{"model": "gpt-4", "messages": []}`)
// Call EnrichError with nil responseBody - should preserve existing RawResponse
enrichedErr := EnrichError(ctx, bifrostErr, requestBody, nil, true, true)
if enrichedErr == nil {
t.Fatal("EnrichError() returned nil")
}
if enrichedErr.ExtraFields.RawResponse == nil {
t.Error("RawResponse was cleared when it should have been preserved")
} else {
// Verify it's still the original
if rawMap, ok := enrichedErr.ExtraFields.RawResponse.(map[string]interface{}); ok {
if errorMap, ok := rawMap["error"].(map[string]interface{}); ok {
if errorMap["code"] != "invalid_api_key" {
t.Error("RawResponse was modified, expected it to be preserved")
}
}
}
}
t.Log("✓ EnrichError preserves existing RawResponse when responseBody is nil")
}
// TestEnrichError_OverwritesWithProvidedResponse verifies that EnrichError sets
// RawResponse when a responseBody is provided
func TestEnrichError_OverwritesWithProvidedResponse(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bifrostErr := &schemas.BifrostError{
IsBifrostError: false,
StatusCode: schemas.Ptr(400),
Error: &schemas.ErrorField{
Message: "Bad request",
},
ExtraFields: schemas.BifrostErrorExtraFields{},
}
requestBody := []byte(`{"model": "gpt-4"}`)
responseBody := []byte(`{"error": {"message": "Model not found"}}`)
enrichedErr := EnrichError(ctx, bifrostErr, requestBody, responseBody, true, true)
if enrichedErr == nil {
t.Fatal("EnrichError() returned nil")
}
if enrichedErr.ExtraFields.RawResponse == nil {
t.Error("RawResponse should be set from responseBody parameter")
}
if enrichedErr.ExtraFields.RawRequest == nil {
t.Error("RawRequest should be set from requestBody parameter")
}
t.Log("✓ EnrichError sets RawRequest and RawResponse from provided bodies")
}
// TestEnrichError_RespectsFlags verifies that EnrichError respects
// sendBackRawRequest and sendBackRawResponse flags
func TestEnrichError_RespectsFlags(t *testing.T) {
tests := []struct {
name string
sendBackRawRequest bool
sendBackRawResponse bool
expectRequest bool
expectResponse bool
}{
{
name: "Both enabled",
sendBackRawRequest: true,
sendBackRawResponse: true,
expectRequest: true,
expectResponse: true,
},
{
name: "Only request enabled",
sendBackRawRequest: true,
sendBackRawResponse: false,
expectRequest: true,
expectResponse: false,
},
{
name: "Only response enabled",
sendBackRawRequest: false,
sendBackRawResponse: true,
expectRequest: false,
expectResponse: true,
},
{
name: "Both disabled",
sendBackRawRequest: false,
sendBackRawResponse: false,
expectRequest: false,
expectResponse: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
bifrostErr := &schemas.BifrostError{
IsBifrostError: false,
StatusCode: schemas.Ptr(500),
Error: &schemas.ErrorField{Message: "Error"},
ExtraFields: schemas.BifrostErrorExtraFields{},
}
requestBody := []byte(`{"model": "test"}`)
responseBody := []byte(`{"error": "test error"}`)
enrichedErr := EnrichError(ctx, bifrostErr, requestBody, responseBody, tt.sendBackRawRequest, tt.sendBackRawResponse)
hasRequest := enrichedErr.ExtraFields.RawRequest != nil
hasResponse := enrichedErr.ExtraFields.RawResponse != nil
if hasRequest != tt.expectRequest {
t.Errorf("RawRequest: got %v, want %v", hasRequest, tt.expectRequest)
}
if hasResponse != tt.expectResponse {
t.Errorf("RawResponse: got %v, want %v", hasResponse, tt.expectResponse)
}
})
}
}
// TestProviderErrorFlow_EndToEnd simulates the full flow of a provider error
// being captured and enriched with raw request/response
func TestProviderErrorFlow_EndToEnd(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
// Simulate provider error response
errorBody := []byte(`{"error": {"message": "Rate limit exceeded", "type": "rate_limit_error", "code": "rate_limit"}}`)
resp := &fasthttp.Response{}
resp.SetStatusCode(429)
resp.Header.Set("Content-Type", "application/json")
resp.SetBody(errorBody)
// Step 1: Parse the error (like ParseOpenAIError does)
var errorResp map[string]interface{}
bifrostErr := HandleProviderAPIError(resp, &errorResp)
if bifrostErr == nil {
t.Fatal("HandleProviderAPIError returned nil")
}
// Verify raw response is captured by HandleProviderAPIError
if bifrostErr.ExtraFields.RawResponse == nil {
t.Error("HandleProviderAPIError should have set RawResponse")
}
// Step 2: Enrich with request (like providers do)
requestBody := []byte(`{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}`)
enrichedErr := EnrichError(ctx, bifrostErr, requestBody, nil, true, true)
// Verify both raw request and raw response are present
if enrichedErr.ExtraFields.RawRequest == nil {
t.Error("EnrichError should have set RawRequest")
}
if enrichedErr.ExtraFields.RawResponse == nil {
t.Error("EnrichError should have preserved RawResponse from HandleProviderAPIError")
}
t.Log("✓ End-to-end: Raw request and error response captured successfully")
}
// TestHandleProviderAPIError_AllPathsSetRawResponse verifies that all error return
// paths in HandleProviderAPIError include RawResponse
func TestHandleProviderAPIError_AllPathsSetRawResponse(t *testing.T) {
testCases := []struct {
name string
statusCode int
body []byte
setupResp func(*fasthttp.Response)
errorType string
}{
{
name: "Path 1: Decode error",
statusCode: 500,
body: []byte{0xFF, 0xFE, 0xFD}, // Invalid gzip-compressed data
setupResp: func(r *fasthttp.Response) {
r.Header.Set("Content-Type", "application/json")
// Set Content-Encoding: gzip to trigger BodyGunzip() error on invalid gzip data
r.Header.Set("Content-Encoding", "gzip")
},
errorType: "decode_failure",
},
{
name: "Path 2: Empty response",
statusCode: 502,
body: []byte(" "), // Only whitespace
setupResp: func(r *fasthttp.Response) {
r.Header.Set("Content-Type", "application/json")
},
errorType: "empty_response",
},
{
name: "Path 3: Valid JSON",
statusCode: 400,
body: []byte(`{"error": {"message": "Bad request"}}`),
setupResp: func(r *fasthttp.Response) {
r.Header.Set("Content-Type", "application/json")
},
errorType: "valid_json",
},
{
name: "Path 4: HTML response",
statusCode: 503,
body: []byte(`ErrorService Error
`),
setupResp: func(r *fasthttp.Response) {
r.Header.Set("Content-Type", "text/html")
},
errorType: "html",
},
{
name: "Path 5: Unparseable non-HTML",
statusCode: 500,
body: []byte(`This is plain text that's not JSON`),
setupResp: func(r *fasthttp.Response) {
r.Header.Set("Content-Type", "text/plain")
},
errorType: "unparseable",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resp := &fasthttp.Response{}
resp.SetStatusCode(tc.statusCode)
resp.SetBody(tc.body)
tc.setupResp(resp)
var errorResp map[string]interface{}
bifrostErr := HandleProviderAPIError(resp, &errorResp)
if bifrostErr == nil {
t.Fatalf("%s: HandleProviderAPIError returned nil", tc.name)
}
if bifrostErr.ExtraFields.RawResponse == nil {
t.Errorf("%s [%s]: RawResponse is nil - MISSING raw error body!", tc.name, tc.errorType)
} else {
t.Logf("✓ %s [%s]: RawResponse is set", tc.name, tc.errorType)
}
})
}
}
// TestGetRequestPath verifies GetRequestPath handles all path resolution scenarios correctly
func TestGetRequestPath(t *testing.T) {
tests := []struct {
name string
contextPath *string
customProviderConfig *schemas.CustomProviderConfig
defaultPath string
requestType schemas.RequestType
expectedPath string
expectedIsURL bool
}{
{
name: "Returns default path when nothing is set",
defaultPath: "/v1/chat/completions",
requestType: schemas.ChatCompletionRequest,
expectedPath: "/v1/chat/completions",
expectedIsURL: false,
},
{
name: "Returns path from context when present",
contextPath: schemas.Ptr("/custom/path"),
defaultPath: "/v1/chat/completions",
requestType: schemas.ChatCompletionRequest,
expectedPath: "/custom/path",
expectedIsURL: false,
},
{
name: "Returns full URL from config override",
customProviderConfig: &schemas.CustomProviderConfig{
RequestPathOverrides: map[schemas.RequestType]string{
schemas.ChatCompletionRequest: "https://custom.api.com/v1/completions",
},
},
defaultPath: "/v1/chat/completions",
requestType: schemas.ChatCompletionRequest,
expectedPath: "https://custom.api.com/v1/completions",
expectedIsURL: true,
},
{
name: "Returns path override with leading slash",
customProviderConfig: &schemas.CustomProviderConfig{
RequestPathOverrides: map[schemas.RequestType]string{
schemas.ChatCompletionRequest: "/custom/endpoint",
},
},
defaultPath: "/v1/chat/completions",
requestType: schemas.ChatCompletionRequest,
expectedPath: "/custom/endpoint",
expectedIsURL: false,
},
{
name: "Adds leading slash to path override without one",
customProviderConfig: &schemas.CustomProviderConfig{
RequestPathOverrides: map[schemas.RequestType]string{
schemas.ChatCompletionRequest: "custom/endpoint",
},
},
defaultPath: "/v1/chat/completions",
requestType: schemas.ChatCompletionRequest,
expectedPath: "/custom/endpoint",
expectedIsURL: false,
},
{
name: "Returns default path for empty override",
customProviderConfig: &schemas.CustomProviderConfig{
RequestPathOverrides: map[schemas.RequestType]string{
schemas.ChatCompletionRequest: " ",
},
},
defaultPath: "/v1/chat/completions",
requestType: schemas.ChatCompletionRequest,
expectedPath: "/v1/chat/completions",
expectedIsURL: false,
},
{
name: "Returns default when override exists for different request type",
customProviderConfig: &schemas.CustomProviderConfig{
RequestPathOverrides: map[schemas.RequestType]string{
schemas.EmbeddingRequest: "/custom/embeddings",
},
},
defaultPath: "/v1/chat/completions",
requestType: schemas.ChatCompletionRequest,
expectedPath: "/v1/chat/completions",
expectedIsURL: false,
},
{
name: "Handles URL with http scheme",
customProviderConfig: &schemas.CustomProviderConfig{
RequestPathOverrides: map[schemas.RequestType]string{
schemas.ChatCompletionRequest: "http://internal.api:8080/completions",
},
},
defaultPath: "/v1/chat/completions",
requestType: schemas.ChatCompletionRequest,
expectedPath: "http://internal.api:8080/completions",
expectedIsURL: true,
},
{
name: "Context path takes precedence over config override",
contextPath: schemas.Ptr("/context/path"),
customProviderConfig: &schemas.CustomProviderConfig{
RequestPathOverrides: map[schemas.RequestType]string{
schemas.ChatCompletionRequest: "/config/path",
},
},
defaultPath: "/v1/chat/completions",
requestType: schemas.ChatCompletionRequest,
expectedPath: "/context/path",
expectedIsURL: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
if tt.contextPath != nil {
ctx = context.WithValue(ctx, schemas.BifrostContextKeyURLPath, *tt.contextPath)
}
path, isURL := GetRequestPath(ctx, tt.defaultPath, tt.customProviderConfig, tt.requestType)
if path != tt.expectedPath {
t.Errorf("GetRequestPath() path = %q, want %q", path, tt.expectedPath)
}
if isURL != tt.expectedIsURL {
t.Errorf("GetRequestPath() isURL = %v, want %v", isURL, tt.expectedIsURL)
}
})
}
}
// TestMarshalSorted_Deterministic verifies that MarshalSorted produces identical
// output across multiple calls with the same map, despite Go's randomized map iteration.
func TestMarshalSorted_Deterministic(t *testing.T) {
// Build a map with enough keys to make random ordering statistically certain
m := map[string]interface{}{
"zulu": 1,
"alpha": 2,
"mike": 3,
"bravo": 4,
"yankee": 5,
"charlie": 6,
"nested": map[string]interface{}{
"zebra": "z",
"apple": "a",
"mango": "m",
"banana": "b",
"cherry": "c",
"date": "d",
"fig": "f",
"grape": "g",
"kiwi": "k",
"lemon": "l",
"orange": "o",
"papaya": "p",
"quince": "q",
"raisin": "r",
"satsuma": "s",
},
}
first, err := MarshalSorted(m)
if err != nil {
t.Fatalf("MarshalSorted() error: %v", err)
}
// Run 50 iterations to be confident about determinism
for i := 0; i < 50; i++ {
got, err := MarshalSorted(m)
if err != nil {
t.Fatalf("MarshalSorted() iteration %d error: %v", i, err)
}
if string(got) != string(first) {
t.Fatalf("MarshalSorted() produced different output on iteration %d:\nfirst: %s\ngot: %s", i, first, got)
}
}
// Also verify MarshalSortedIndent
firstIndent, err := MarshalSortedIndent(m, "", " ")
if err != nil {
t.Fatalf("MarshalSortedIndent() error: %v", err)
}
for i := 0; i < 50; i++ {
got, err := MarshalSortedIndent(m, "", " ")
if err != nil {
t.Fatalf("MarshalSortedIndent() iteration %d error: %v", i, err)
}
if string(got) != string(firstIndent) {
t.Fatalf("MarshalSortedIndent() produced different output on iteration %d:\nfirst: %s\ngot: %s", i, firstIndent, got)
}
}
}
// TestCheckAndDecodeBody_PooledGzip verifies that CheckAndDecodeBody correctly
// decompresses gzip-encoded responses using pooled gzip readers.
func TestCheckAndDecodeBody_PooledGzip(t *testing.T) {
tests := []struct {
name string
body []byte
contentEncoding string
wantBody string
wantErr bool
}{
{
name: "gzip encoded body",
body: gzipCompress([]byte(`{"message":"hello world"}`)),
contentEncoding: "gzip",
wantBody: `{"message":"hello world"}`,
wantErr: false,
},
{
name: "gzip with uppercase header",
body: gzipCompress([]byte(`test data`)),
contentEncoding: "GZIP",
wantBody: `test data`,
wantErr: false,
},
{
name: "gzip with whitespace in header",
body: gzipCompress([]byte(`trimmed`)),
contentEncoding: " gzip ",
wantBody: `trimmed`,
wantErr: false,
},
{
name: "no encoding - plain body",
body: []byte(`plain text`),
contentEncoding: "",
wantBody: `plain text`,
wantErr: false,
},
{
name: "empty gzip body",
body: []byte{},
contentEncoding: "gzip",
wantBody: "",
wantErr: false,
},
{
name: "invalid gzip data",
body: []byte{0xFF, 0xFE, 0xFD},
contentEncoding: "gzip",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
resp.SetBody(tt.body)
if tt.contentEncoding != "" {
resp.Header.Set("Content-Encoding", tt.contentEncoding)
}
got, err := CheckAndDecodeBody(resp)
if tt.wantErr {
if err == nil {
t.Errorf("CheckAndDecodeBody() expected error, got nil")
}
return
}
if err != nil {
t.Errorf("CheckAndDecodeBody() unexpected error: %v", err)
return
}
if string(got) != tt.wantBody {
t.Errorf("CheckAndDecodeBody() = %q, want %q", string(got), tt.wantBody)
}
})
}
}
// TestCheckAndDecodeBody_Concurrent verifies no data races with concurrent access.
func TestCheckAndDecodeBody_Concurrent(t *testing.T) {
testData := []byte(`{"concurrent":"test"}`)
compressed := gzipCompress(testData)
done := make(chan bool)
for i := 0; i < 100; i++ {
go func() {
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
resp.SetBody(compressed)
resp.Header.Set("Content-Encoding", "gzip")
got, err := CheckAndDecodeBody(resp)
if err != nil {
t.Errorf("CheckAndDecodeBody() error: %v", err)
}
if string(got) != string(testData) {
t.Errorf("CheckAndDecodeBody() = %q, want %q", string(got), string(testData))
}
done <- true
}()
}
for i := 0; i < 100; i++ {
<-done
}
}
func TestDrainNonSSEStreamResponse_SSEDoesNotDrain(t *testing.T) {
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
body := []byte("data: hello\n\n")
resp.Header.SetContentType("text/event-stream")
resp.SetBodyStream(bytes.NewReader(body), len(body))
drained := DrainNonSSEStreamResponse(resp)
if drained {
t.Fatal("expected SSE response to remain readable")
}
remaining, err := io.ReadAll(resp.BodyStream())
if err != nil {
t.Fatalf("failed to read SSE body after guard: %v", err)
}
if string(remaining) != string(body) {
t.Fatalf("expected SSE body to remain intact, got %q", string(remaining))
}
}
func TestDrainNonSSEStreamResponse_NonSSEDrains(t *testing.T) {
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
body := []byte(`{"error":"not stream"}`)
resp.Header.SetContentType("application/json")
resp.SetBodyStream(bytes.NewReader(body), len(body))
drained := DrainNonSSEStreamResponse(resp)
if !drained {
t.Fatal("expected non-SSE response to be drained")
}
remaining, err := io.ReadAll(resp.BodyStream())
if err != nil {
t.Fatalf("failed to read body after drain: %v", err)
}
if len(remaining) != 0 {
t.Fatalf("expected drained body to be empty, got %q", string(remaining))
}
}
func TestDrainNonSSEStreamResponse_GzipSSEStillReadable(t *testing.T) {
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
body := []byte("data: hello\n\ndata: [DONE]\n\n")
compressed := gzipCompress(body)
resp.Header.SetContentType("text/event-stream")
resp.Header.Set("Content-Encoding", "gzip")
resp.SetBodyStream(bytes.NewReader(compressed), len(compressed))
drained := DrainNonSSEStreamResponse(resp)
if drained {
t.Fatal("expected gzip SSE response to remain readable")
}
reader, releaseGzip := DecompressStreamBody(resp)
defer releaseGzip()
remaining, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("failed to read decompressed SSE body: %v", err)
}
if string(remaining) != string(body) {
t.Fatalf("expected decompressed SSE body %q, got %q", string(body), string(remaining))
}
}
// gzipCompress compresses data using gzip for testing.
func gzipCompress(data []byte) []byte {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
if _, err := gz.Write(data); err != nil {
panic(fmt.Errorf("gzip write: %w", err))
}
if err := gz.Close(); err != nil {
panic(fmt.Errorf("gzip close: %w", err))
}
return buf.Bytes()
}
func TestMergeExtraParamsIntoJSON_PreservesKeyOrder(t *testing.T) {
// JSON with a specific key order that must be preserved
jsonBody := []byte(`{
"model": "gpt-4",
"messages": [],
"tool_choice": {"type": "function", "function": {"name": "test"}},
"tools": []
}`)
extraParams := map[string]interface{}{
"custom_field": "value",
}
result, err := MergeExtraParamsIntoJSON(jsonBody, extraParams)
if err != nil {
t.Fatalf("MergeExtraParamsIntoJSON() error: %v", err)
}
// Verify original key order is preserved and custom_field is appended
resultStr := string(result)
modelIdx := bytes.Index(result, []byte(`"model"`))
messagesIdx := bytes.Index(result, []byte(`"messages"`))
toolChoiceIdx := bytes.Index(result, []byte(`"tool_choice"`))
toolsIdx := bytes.Index(result, []byte(`"tools"`))
customIdx := bytes.Index(result, []byte(`"custom_field"`))
if modelIdx >= messagesIdx || messagesIdx >= toolChoiceIdx || toolChoiceIdx >= toolsIdx || toolsIdx >= customIdx {
t.Fatalf("Key order not preserved. Result:\n%s", resultStr)
}
}
func TestMergeExtraParamsIntoJSON_OverwriteExistingKey(t *testing.T) {
jsonBody := []byte(`{"z_first": "original", "a_second": "original"}`)
extraParams := map[string]interface{}{
"z_first": "overwritten",
}
result, err := MergeExtraParamsIntoJSON(jsonBody, extraParams)
if err != nil {
t.Fatalf("MergeExtraParamsIntoJSON() error: %v", err)
}
// z_first should still come before a_second (preserving original position)
zIdx := bytes.Index(result, []byte(`"z_first"`))
aIdx := bytes.Index(result, []byte(`"a_second"`))
if zIdx >= aIdx {
t.Fatalf("Overwritten key should preserve its position. Result: %s", string(result))
}
// z_first should have the new value
if !bytes.Contains(result, []byte(`"overwritten"`)) {
t.Fatalf("Value should be overwritten. Result: %s", string(result))
}
}
func TestMergeExtraParamsIntoJSON_DeepMerge(t *testing.T) {
jsonBody := []byte(`{"outer": {"a": 1, "b": 2}}`)
extraParams := map[string]interface{}{
"outer": map[string]interface{}{
"c": 3,
},
}
result, err := MergeExtraParamsIntoJSON(jsonBody, extraParams)
if err != nil {
t.Fatalf("MergeExtraParamsIntoJSON() error: %v", err)
}
// Verify the merge happened
var parsed map[string]interface{}
if err := sonic.Unmarshal(result, &parsed); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
outer, ok := parsed["outer"].(map[string]interface{})
if !ok {
t.Fatal("outer should be a map")
}
if len(outer) != 3 {
t.Fatalf("outer should have 3 keys after merge, got %d: %v", len(outer), outer)
}
}
func TestMergeExtraParamsIntoJSON_EmptyExtraParams(t *testing.T) {
jsonBody := []byte(`{"a": 1, "b": 2}`)
result, err := MergeExtraParamsIntoJSON(jsonBody, map[string]interface{}{})
if err != nil {
t.Fatalf("MergeExtraParamsIntoJSON() error: %v", err)
}
// Should be valid JSON with same content
var parsed map[string]interface{}
if err := sonic.Unmarshal(result, &parsed); err != nil {
t.Fatalf("Failed to parse result: %v", err)
}
if len(parsed) != 2 {
t.Fatalf("Expected 2 keys, got %d", len(parsed))
}
}
// TestParseAndSetRawRequest_CompactsJSON verifies that indented JSON input
// (with literal newlines from MarshalIndent) is compacted to a single line.
// This is critical for SSE streaming where newlines break data-line framing.
func TestParseAndSetRawRequest_CompactsJSON(t *testing.T) {
indentedJSON := []byte(`{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello"
}
],
"temperature": 0.7
}`)
var extraFields schemas.BifrostResponseExtraFields
ParseAndSetRawRequest(&extraFields, indentedJSON)
if extraFields.RawRequest == nil {
t.Fatal("RawRequest should be set")
}
raw, ok := extraFields.RawRequest.(json.RawMessage)
if !ok {
t.Fatalf("RawRequest should be json.RawMessage, got %T", extraFields.RawRequest)
}
// The compacted output must not contain any literal newlines
if strings.Contains(string(raw), "\n") {
t.Errorf("Compacted RawRequest should not contain newlines, got:\n%s", string(raw))
}
// Verify it's still valid JSON with the same content
var parsed map[string]interface{}
if err := sonic.Unmarshal(raw, &parsed); err != nil {
t.Fatalf("Compacted RawRequest is not valid JSON: %v", err)
}
if parsed["model"] != "gpt-4" {
t.Errorf("Expected model=gpt-4, got %v", parsed["model"])
}
}
// TestParseAndSetRawRequest_PreservesKeyOrdering verifies that JSON key order
// is maintained after compaction. This is essential for LLM prompt caching
// where key ordering affects cache hit rates.
func TestParseAndSetRawRequest_PreservesKeyOrdering(t *testing.T) {
// Keys are intentionally not alphabetically sorted
jsonBody := []byte(`{"z_last":"z","a_first":"a","m_middle":"m"}`)
var extraFields schemas.BifrostResponseExtraFields
ParseAndSetRawRequest(&extraFields, jsonBody)
raw := extraFields.RawRequest.(json.RawMessage)
result := string(raw)
zIdx := strings.Index(result, `"z_last"`)
aIdx := strings.Index(result, `"a_first"`)
mIdx := strings.Index(result, `"m_middle"`)
if zIdx >= aIdx || aIdx >= mIdx {
t.Errorf("Key ordering not preserved. Got: %s", result)
}
}
// TestParseAndSetRawRequest_EmptyBody verifies that empty input is a no-op.
func TestParseAndSetRawRequest_EmptyBody(t *testing.T) {
var extraFields schemas.BifrostResponseExtraFields
ParseAndSetRawRequest(&extraFields, []byte{})
if extraFields.RawRequest != nil {
t.Error("RawRequest should be nil for empty body")
}
ParseAndSetRawRequest(&extraFields, nil)
if extraFields.RawRequest != nil {
t.Error("RawRequest should be nil for nil body")
}
}
// TestParseAndSetRawRequest_SSEStreamingChunks simulates the actual SSE streaming
// flow end-to-end: a response chunk with raw_request containing indented JSON is
// marshaled, framed as SSE "data: \n\n", and then each SSE data line is
// parsed back. This is the exact scenario that caused issue #1905 — pretty-printed
// JSON in raw_request introduced literal newlines that broke SSE data-line framing.
func TestParseAndSetRawRequest_SSEStreamingChunks(t *testing.T) {
// Simulate indented request body (as produced by MarshalSortedIndent)
indentedRequest := []byte(`{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello"
}
],
"stream": true,
"temperature": 0.7
}`)
// Build a response chunk with raw_request set via ParseAndSetRawRequest.
// Uses BifrostChatResponse which is the actual type marshaled in the streaming path.
chunk := schemas.BifrostChatResponse{
ID: "chatcmpl-test",
Model: "gpt-4",
Object: "chat.completion.chunk",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
Delta: &schemas.ChatStreamResponseChoiceDelta{
Content: schemas.Ptr("Hello"),
},
},
},
},
}
ParseAndSetRawRequest(&chunk.ExtraFields, indentedRequest)
// Marshal the chunk (exactly like the transport layer does: sonic.Marshal)
chunkJSON, err := sonic.Marshal(chunk)
if err != nil {
t.Fatalf("Failed to marshal chunk: %v", err)
}
// Frame as SSE: "data: \n\n" (exactly as in inference.go:1591)
sseFrame := fmt.Sprintf("data: %s\n\n", chunkJSON)
// Parse the SSE frame line-by-line as a real SSE client would.
// Split on \n and check that there is exactly one "data:" line.
lines := strings.Split(strings.TrimRight(sseFrame, "\n"), "\n")
var dataLines []string
for _, line := range lines {
if strings.HasPrefix(line, "data: ") {
dataLines = append(dataLines, line)
} else if line != "" {
// Any non-empty, non-data line means SSE framing is broken —
// this is exactly what happened in #1905
t.Errorf("Unexpected non-data line in SSE frame (broken framing): %q", line)
}
}
if len(dataLines) != 1 {
t.Fatalf("Expected exactly 1 SSE data line, got %d:\n%s", len(dataLines), sseFrame)
}
// Parse the JSON payload from the single data line
jsonPayload := strings.TrimPrefix(dataLines[0], "data: ")
var parsed schemas.BifrostChatResponse
if err := sonic.Unmarshal([]byte(jsonPayload), &parsed); err != nil {
t.Fatalf("Failed to parse SSE data line as JSON (this is the #1905 bug): %v\nPayload: %s", err, jsonPayload)
}
// Verify the parsed response has the correct content
if parsed.ID != "chatcmpl-test" {
t.Errorf("Expected ID=chatcmpl-test, got %s", parsed.ID)
}
if parsed.ExtraFields.RawRequest == nil {
t.Error("RawRequest should be present in parsed chunk")
}
// Verify raw_request round-trips correctly — the client should be able
// to parse it back into the original request structure
rawBytes, err := sonic.Marshal(parsed.ExtraFields.RawRequest)
if err != nil {
t.Fatalf("Failed to marshal raw_request: %v", err)
}
var rawParsed map[string]interface{}
if err := sonic.Unmarshal(rawBytes, &rawParsed); err != nil {
t.Fatalf("raw_request is not valid JSON after round-trip: %v", err)
}
if rawParsed["model"] != "gpt-4" {
t.Errorf("Expected raw_request.model=gpt-4, got %v", rawParsed["model"])
}
}
// TestBuildClientStreamChunk_ImageGenerationStripping verifies that
// BuildClientStreamChunk correctly handles BifrostImageGenerationStreamResponse:
// strips raw fields when in logging-only mode and never mutates the original.
func TestBuildClientStreamChunk_ImageGenerationStripping(t *testing.T) {
rawReq := json.RawMessage(`{"model":"dall-e-3"}`)
rawResp := json.RawMessage(`{"data":[{"url":"https://example.com/img.png"}]}`)
imgResp := &schemas.BifrostImageGenerationStreamResponse{
ExtraFields: schemas.BifrostResponseExtraFields{
RawRequest: rawReq,
RawResponse: rawResp,
},
}
response := &schemas.BifrostResponse{ImageGenerationStreamResponse: imgResp}
t.Run("logging-only: raw fields stripped from image gen chunk, original preserved", func(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
ctx.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, true)
ctx.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, true)
chunk := BuildClientStreamChunk(ctx, response, nil)
if chunk.BifrostImageGenerationStreamResponse == nil {
t.Fatal("expected BifrostImageGenerationStreamResponse in chunk")
}
if chunk.BifrostImageGenerationStreamResponse.ExtraFields.RawRequest != nil {
t.Error("expected RawRequest stripped from chunk, but it was present")
}
if chunk.BifrostImageGenerationStreamResponse.ExtraFields.RawResponse != nil {
t.Error("expected RawResponse stripped from chunk, but it was present")
}
// Original must not be mutated.
if imgResp.ExtraFields.RawRequest == nil {
t.Error("original BifrostImageGenerationStreamResponse.ExtraFields.RawRequest was mutated")
}
if imgResp.ExtraFields.RawResponse == nil {
t.Error("original BifrostImageGenerationStreamResponse.ExtraFields.RawResponse was mutated")
}
if chunk.BifrostImageGenerationStreamResponse == imgResp {
t.Error("chunk contains same pointer as original; it must be a copy")
}
})
t.Run("no logging flag: raw fields preserved in image gen chunk", func(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
chunk := BuildClientStreamChunk(ctx, response, nil)
if chunk.BifrostImageGenerationStreamResponse == nil {
t.Fatal("expected BifrostImageGenerationStreamResponse in chunk")
}
if chunk.BifrostImageGenerationStreamResponse.ExtraFields.RawRequest == nil {
t.Error("expected RawRequest present in chunk, but it was nil")
}
if chunk.BifrostImageGenerationStreamResponse.ExtraFields.RawResponse == nil {
t.Error("expected RawResponse present in chunk, but it was nil")
}
})
}
// TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChunk verifies
// that when drop-raw context flags are set, ProcessAndSendResponse strips RawRequest and
// RawResponse from the outgoing stream chunk, while leaving other ExtraFields intact.
// It also verifies that the original BifrostResponse is not mutated
// (shared object safety for PostLLMHook goroutines).
func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChunk(t *testing.T) {
rawReq := json.RawMessage(`{"model":"gpt-4","messages":[]}`)
rawResp := json.RawMessage(`{"id":"chatcmpl-001"}`)
tests := []struct {
name string
loggingOnly bool
expectStripped bool
}{
{
name: "logging-only flag set: raw data stripped from chunk",
loggingOnly: true,
expectStripped: true,
},
{
name: "logging-only flag not set: raw data preserved in chunk",
loggingOnly: false,
expectStripped: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
if tt.loggingOnly {
ctx.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, true)
ctx.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, true)
}
response := &schemas.BifrostResponse{
ChatResponse: &schemas.BifrostChatResponse{
ID: "chatcmpl-001",
Model: "gpt-4",
ExtraFields: schemas.BifrostResponseExtraFields{
RawRequest: rawReq,
RawResponse: rawResp,
},
},
}
passThrough := func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return resp, err
}
responseChan := make(chan *schemas.BifrostStreamChunk, 1)
ProcessAndSendResponse(ctx, passThrough, response, responseChan, nil)
chunk := <-responseChan
if chunk.BifrostChatResponse == nil {
t.Fatal("expected non-nil BifrostChatResponse in stream chunk")
}
hasRawReq := chunk.BifrostChatResponse.ExtraFields.RawRequest != nil
hasRawResp := chunk.BifrostChatResponse.ExtraFields.RawResponse != nil
if tt.expectStripped {
if hasRawReq {
t.Error("expected RawRequest to be nil (stripped) in chunk, but it was present")
}
if hasRawResp {
t.Error("expected RawResponse to be nil (stripped) in chunk, but it was present")
}
// Critical: the original shared object must NOT have been mutated.
if response.ChatResponse.ExtraFields.RawRequest == nil {
t.Error("original BifrostResponse.ChatResponse.ExtraFields.RawRequest was mutated (nil); shared object must be preserved")
}
if response.ChatResponse.ExtraFields.RawResponse == nil {
t.Error("original BifrostResponse.ChatResponse.ExtraFields.RawResponse was mutated (nil); shared object must be preserved")
}
// The chunk must be a copy, not the same pointer as the original.
if chunk.BifrostChatResponse == response.ChatResponse {
t.Error("chunk.BifrostChatResponse is the same pointer as the original; it must be a copy to avoid data races")
}
} else {
if !hasRawReq {
t.Error("expected RawRequest to be present in chunk, but it was nil")
}
if !hasRawResp {
t.Error("expected RawResponse to be present in chunk, but it was nil")
}
}
})
}
}
// TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromErrorChunk verifies
// that when drop-raw context flags are set, raw data is stripped from BifrostError
// payloads embedded in stream chunks, without mutating the shared BifrostError object
// (shared object safety for PostLLMHook goroutines).
func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromErrorChunk(t *testing.T) {
rawReq := json.RawMessage(`{"model":"gpt-4"}`)
rawResp := json.RawMessage(`{"error":"rate limit exceeded"}`)
tests := []struct {
name string
loggingOnly bool
expectStripped bool
}{
{
name: "logging-only flag set: raw data stripped from error chunk",
loggingOnly: true,
expectStripped: true,
},
{
name: "logging-only flag not set: raw data preserved in error chunk",
loggingOnly: false,
expectStripped: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
if tt.loggingOnly {
ctx.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, true)
ctx.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, true)
}
// Use a postHookRunner that converts the response to a BifrostError with raw data
bifrostErr := &schemas.BifrostError{
IsBifrostError: false,
StatusCode: schemas.Ptr(429),
Error: &schemas.ErrorField{Message: "rate limit exceeded"},
ExtraFields: schemas.BifrostErrorExtraFields{
RawRequest: rawReq,
RawResponse: rawResp,
},
}
errorRunner := func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return nil, bifrostErr
}
responseChan := make(chan *schemas.BifrostStreamChunk, 1)
ProcessAndSendResponse(ctx, errorRunner, &schemas.BifrostResponse{
ChatResponse: &schemas.BifrostChatResponse{ID: "chatcmpl-001"},
}, responseChan, nil)
chunk := <-responseChan
if chunk.BifrostError == nil {
t.Fatal("expected non-nil BifrostError in stream chunk")
}
hasRawReq := chunk.BifrostError.ExtraFields.RawRequest != nil
hasRawResp := chunk.BifrostError.ExtraFields.RawResponse != nil
if tt.expectStripped {
if hasRawReq {
t.Error("expected RawRequest to be nil (stripped) in error chunk, but it was present")
}
if hasRawResp {
t.Error("expected RawResponse to be nil (stripped) in error chunk, but it was present")
}
// Critical: the original shared BifrostError must NOT have been mutated.
if bifrostErr.ExtraFields.RawRequest == nil {
t.Error("original BifrostError.ExtraFields.RawRequest was mutated (nil); shared object must be preserved")
}
if bifrostErr.ExtraFields.RawResponse == nil {
t.Error("original BifrostError.ExtraFields.RawResponse was mutated (nil); shared object must be preserved")
}
// The chunk must hold a copy, not the same pointer as the original.
if chunk.BifrostError == bifrostErr {
t.Error("chunk.BifrostError is the same pointer as the original; it must be a copy to avoid data races")
}
} else {
if !hasRawReq {
t.Error("expected RawRequest to be present in error chunk, but it was nil")
}
if !hasRawResp {
t.Error("expected RawResponse to be present in error chunk, but it was nil")
}
}
})
}
}
// TestShouldSendBackRawRequest verifies that ShouldSendBackRawRequest correctly resolves
// whether providers should capture the raw request body. It covers:
// - Default (no context flags): returns the provider default
// - BifrostContextKeyCaptureRawRequest=true in context: always returns true
// - Logging-only mode: requestWorker sets BifrostContextKeyCaptureRawRequest=true,
// so the function sees a single flag (no second check needed).
func TestShouldSendBackRawRequest(t *testing.T) {
tests := []struct {
name string
contextSendBack bool
providerDefault bool
want bool
}{
{
name: "provider default false, no context flag",
want: false,
},
{
name: "provider default true, no context flag",
providerDefault: true,
want: true,
},
{
name: "context SendBack=true overrides provider default false",
contextSendBack: true,
want: true,
},
{
name: "context SendBack=true with provider default true",
contextSendBack: true,
providerDefault: true,
want: true,
},
{
// requestWorker sets BifrostContextKeyCaptureRawRequest=true in logging-only
// mode so a single flag covers both full send-back and logging-only cases.
name: "logging-only: context SendBack=true set by requestWorker",
contextSendBack: true,
providerDefault: false,
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
if tt.contextSendBack {
ctx.SetValue(schemas.BifrostContextKeyCaptureRawRequest, true)
}
got := ShouldSendBackRawRequest(ctx, tt.providerDefault)
if got != tt.want {
t.Errorf("ShouldSendBackRawRequest() = %v, want %v", got, tt.want)
}
})
}
}
// TestShouldSendBackRawResponse mirrors TestShouldSendBackRawRequest for the response side.
func TestShouldSendBackRawResponse(t *testing.T) {
tests := []struct {
name string
contextSendBack bool
providerDefault bool
want bool
}{
{
name: "provider default false, no context flag",
want: false,
},
{
name: "provider default true, no context flag",
providerDefault: true,
want: true,
},
{
name: "context SendBack=true overrides provider default false",
contextSendBack: true,
want: true,
},
{
name: "context SendBack=true with provider default true",
contextSendBack: true,
providerDefault: true,
want: true,
},
{
// requestWorker sets BifrostContextKeyCaptureRawResponse=true in logging-only
// mode so a single flag covers both full send-back and logging-only cases.
name: "logging-only: context SendBack=true set by requestWorker",
contextSendBack: true,
providerDefault: false,
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
if tt.contextSendBack {
ctx.SetValue(schemas.BifrostContextKeyCaptureRawResponse, true)
}
got := ShouldSendBackRawResponse(ctx, tt.providerDefault)
if got != tt.want {
t.Errorf("ShouldSendBackRawResponse() = %v, want %v", got, tt.want)
}
})
}
}