1442 lines
46 KiB
Go
1442 lines
46 KiB
Go
package utils
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
func TestRewriteJSONModelValue(t *testing.T) {
|
|
in := []byte(`{"model":"openai/gpt-5","messages":[{"role":"user","content":"x"}]}`)
|
|
out, changed := rewriteJSONModelValue(in, "openai/gpt-5", "gpt-5")
|
|
if !changed {
|
|
t.Fatal("expected model rewrite to occur")
|
|
}
|
|
if strings.Contains(string(out), `"model":"openai/gpt-5"`) {
|
|
t.Fatalf("expected prefixed model to be removed, got: %s", string(out))
|
|
}
|
|
if !strings.Contains(string(out), `"model":"gpt-5"`) {
|
|
t.Fatalf("expected rewritten model, got: %s", string(out))
|
|
}
|
|
}
|
|
|
|
func TestApplyLargePayloadRequestBodyWithModelNormalization(t *testing.T) {
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
payload := `{"model":"openai/gpt-5","messages":[{"role":"user","content":"hello"}]}`
|
|
ctx.SetValue(schemas.BifrostContextKeyLargePayloadMode, true)
|
|
ctx.SetValue(
|
|
schemas.BifrostContextKeyLargePayloadReader,
|
|
strings.NewReader(payload),
|
|
)
|
|
ctx.SetValue(schemas.BifrostContextKeyLargePayloadContentLength, len(payload))
|
|
ctx.SetValue(schemas.BifrostContextKeyLargePayloadContentType, "application/json")
|
|
ctx.SetValue(schemas.BifrostContextKeyLargePayloadMetadata, &schemas.LargePayloadMetadata{
|
|
Model: "openai/gpt-5",
|
|
})
|
|
|
|
req := &fasthttp.Request{}
|
|
if !ApplyLargePayloadRequestBodyWithModelNormalization(ctx, req, schemas.OpenAI) {
|
|
t.Fatal("expected large payload body to be applied")
|
|
}
|
|
|
|
body := string(req.Body())
|
|
if strings.Contains(body, "openai/gpt-5") {
|
|
t.Fatalf("expected rewritten model in body, got: %s", body)
|
|
}
|
|
if !strings.Contains(body, `"model":"gpt-5"`) {
|
|
t.Fatalf("expected normalized model in body, got: %s", body)
|
|
}
|
|
}
|
|
|
|
// TestHandleProviderAPIError_RawResponseIncluded verifies that HandleProviderAPIError
|
|
// always includes the raw response body in BifrostError.ExtraFields.RawResponse
|
|
func TestHandleProviderAPIError_RawResponseIncluded(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
statusCode int
|
|
body []byte
|
|
contentType string
|
|
description string
|
|
}{
|
|
{
|
|
name: "Decode failure",
|
|
statusCode: 500,
|
|
body: []byte{0xFF, 0xFE}, // Invalid gzip-compressed data
|
|
contentType: "application/json",
|
|
description: "Should include raw response when decode fails",
|
|
},
|
|
{
|
|
name: "Empty response",
|
|
statusCode: 502,
|
|
body: []byte(""),
|
|
contentType: "application/json",
|
|
description: "Should include empty raw response",
|
|
},
|
|
{
|
|
name: "Valid JSON error",
|
|
statusCode: 400,
|
|
body: []byte(`{"error": {"message": "Invalid API key"}}`),
|
|
contentType: "application/json",
|
|
description: "Should include raw response for valid JSON",
|
|
},
|
|
{
|
|
name: "HTML error response",
|
|
statusCode: 503,
|
|
body: []byte(`<html><body><h1>Service Unavailable</h1></body></html>`),
|
|
contentType: "text/html",
|
|
description: "Should include raw response for HTML errors",
|
|
},
|
|
{
|
|
name: "Unparseable non-HTML response",
|
|
statusCode: 400,
|
|
body: []byte(`This is not JSON or HTML`),
|
|
contentType: "text/plain",
|
|
description: "Should include raw response for unparseable content",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
resp := &fasthttp.Response{}
|
|
resp.SetStatusCode(tt.statusCode)
|
|
resp.Header.Set("Content-Type", tt.contentType)
|
|
// Set Content-Encoding: gzip for decode failure test to trigger BodyGunzip() error
|
|
if tt.name == "Decode failure" {
|
|
resp.Header.Set("Content-Encoding", "gzip")
|
|
}
|
|
resp.SetBody(tt.body)
|
|
|
|
var errorResp map[string]interface{}
|
|
bifrostErr := HandleProviderAPIError(resp, &errorResp)
|
|
|
|
if bifrostErr == nil {
|
|
t.Fatal("HandleProviderAPIError() returned nil")
|
|
}
|
|
|
|
if bifrostErr.ExtraFields.RawResponse == nil {
|
|
t.Errorf("%s: RawResponse is nil, expected it to be set", tt.description)
|
|
}
|
|
|
|
// Verify the raw response matches the body (for non-decode-failure cases)
|
|
if tt.name != "Decode failure" {
|
|
rawResponseBytes, err := sonic.Marshal(bifrostErr.ExtraFields.RawResponse)
|
|
if err != nil {
|
|
t.Errorf("Failed to marshal RawResponse: %v", err)
|
|
}
|
|
|
|
// The RawResponse should contain the body content
|
|
if len(rawResponseBytes) == 0 {
|
|
t.Errorf("%s: RawResponse is empty", tt.description)
|
|
}
|
|
}
|
|
|
|
t.Logf("✓ %s: RawResponse is set", tt.name)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestEnrichError_PreservesExistingRawResponse verifies that EnrichError preserves
|
|
// existing RawResponse from the error's ExtraFields when responseBody parameter is nil
|
|
func TestEnrichError_PreservesExistingRawResponse(t *testing.T) {
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
|
|
existingRawResponse := map[string]interface{}{
|
|
"error": map[string]interface{}{
|
|
"message": "Original error from provider",
|
|
"code": "invalid_api_key",
|
|
},
|
|
}
|
|
|
|
bifrostErr := &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
StatusCode: schemas.Ptr(401),
|
|
Error: &schemas.ErrorField{
|
|
Message: "Authentication failed",
|
|
},
|
|
ExtraFields: schemas.BifrostErrorExtraFields{
|
|
RawResponse: existingRawResponse,
|
|
},
|
|
}
|
|
|
|
requestBody := []byte(`{"model": "gpt-4", "messages": []}`)
|
|
|
|
// Call EnrichError with nil responseBody - should preserve existing RawResponse
|
|
enrichedErr := EnrichError(ctx, bifrostErr, requestBody, nil, true, true)
|
|
|
|
if enrichedErr == nil {
|
|
t.Fatal("EnrichError() returned nil")
|
|
}
|
|
|
|
if enrichedErr.ExtraFields.RawResponse == nil {
|
|
t.Error("RawResponse was cleared when it should have been preserved")
|
|
} else {
|
|
// Verify it's still the original
|
|
if rawMap, ok := enrichedErr.ExtraFields.RawResponse.(map[string]interface{}); ok {
|
|
if errorMap, ok := rawMap["error"].(map[string]interface{}); ok {
|
|
if errorMap["code"] != "invalid_api_key" {
|
|
t.Error("RawResponse was modified, expected it to be preserved")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
t.Log("✓ EnrichError preserves existing RawResponse when responseBody is nil")
|
|
}
|
|
|
|
// TestEnrichError_OverwritesWithProvidedResponse verifies that EnrichError sets
|
|
// RawResponse when a responseBody is provided
|
|
func TestEnrichError_OverwritesWithProvidedResponse(t *testing.T) {
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
|
|
bifrostErr := &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
StatusCode: schemas.Ptr(400),
|
|
Error: &schemas.ErrorField{
|
|
Message: "Bad request",
|
|
},
|
|
ExtraFields: schemas.BifrostErrorExtraFields{},
|
|
}
|
|
|
|
requestBody := []byte(`{"model": "gpt-4"}`)
|
|
responseBody := []byte(`{"error": {"message": "Model not found"}}`)
|
|
|
|
enrichedErr := EnrichError(ctx, bifrostErr, requestBody, responseBody, true, true)
|
|
|
|
if enrichedErr == nil {
|
|
t.Fatal("EnrichError() returned nil")
|
|
}
|
|
|
|
if enrichedErr.ExtraFields.RawResponse == nil {
|
|
t.Error("RawResponse should be set from responseBody parameter")
|
|
}
|
|
|
|
if enrichedErr.ExtraFields.RawRequest == nil {
|
|
t.Error("RawRequest should be set from requestBody parameter")
|
|
}
|
|
|
|
t.Log("✓ EnrichError sets RawRequest and RawResponse from provided bodies")
|
|
}
|
|
|
|
// TestEnrichError_RespectsFlags verifies that EnrichError respects
|
|
// sendBackRawRequest and sendBackRawResponse flags
|
|
func TestEnrichError_RespectsFlags(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
sendBackRawRequest bool
|
|
sendBackRawResponse bool
|
|
expectRequest bool
|
|
expectResponse bool
|
|
}{
|
|
{
|
|
name: "Both enabled",
|
|
sendBackRawRequest: true,
|
|
sendBackRawResponse: true,
|
|
expectRequest: true,
|
|
expectResponse: true,
|
|
},
|
|
{
|
|
name: "Only request enabled",
|
|
sendBackRawRequest: true,
|
|
sendBackRawResponse: false,
|
|
expectRequest: true,
|
|
expectResponse: false,
|
|
},
|
|
{
|
|
name: "Only response enabled",
|
|
sendBackRawRequest: false,
|
|
sendBackRawResponse: true,
|
|
expectRequest: false,
|
|
expectResponse: true,
|
|
},
|
|
{
|
|
name: "Both disabled",
|
|
sendBackRawRequest: false,
|
|
sendBackRawResponse: false,
|
|
expectRequest: false,
|
|
expectResponse: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
|
|
bifrostErr := &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
StatusCode: schemas.Ptr(500),
|
|
Error: &schemas.ErrorField{Message: "Error"},
|
|
ExtraFields: schemas.BifrostErrorExtraFields{},
|
|
}
|
|
|
|
requestBody := []byte(`{"model": "test"}`)
|
|
responseBody := []byte(`{"error": "test error"}`)
|
|
|
|
enrichedErr := EnrichError(ctx, bifrostErr, requestBody, responseBody, tt.sendBackRawRequest, tt.sendBackRawResponse)
|
|
|
|
hasRequest := enrichedErr.ExtraFields.RawRequest != nil
|
|
hasResponse := enrichedErr.ExtraFields.RawResponse != nil
|
|
|
|
if hasRequest != tt.expectRequest {
|
|
t.Errorf("RawRequest: got %v, want %v", hasRequest, tt.expectRequest)
|
|
}
|
|
|
|
if hasResponse != tt.expectResponse {
|
|
t.Errorf("RawResponse: got %v, want %v", hasResponse, tt.expectResponse)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestProviderErrorFlow_EndToEnd simulates the full flow of a provider error
|
|
// being captured and enriched with raw request/response
|
|
func TestProviderErrorFlow_EndToEnd(t *testing.T) {
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
|
|
// Simulate provider error response
|
|
errorBody := []byte(`{"error": {"message": "Rate limit exceeded", "type": "rate_limit_error", "code": "rate_limit"}}`)
|
|
|
|
resp := &fasthttp.Response{}
|
|
resp.SetStatusCode(429)
|
|
resp.Header.Set("Content-Type", "application/json")
|
|
resp.SetBody(errorBody)
|
|
|
|
// Step 1: Parse the error (like ParseOpenAIError does)
|
|
var errorResp map[string]interface{}
|
|
bifrostErr := HandleProviderAPIError(resp, &errorResp)
|
|
|
|
if bifrostErr == nil {
|
|
t.Fatal("HandleProviderAPIError returned nil")
|
|
}
|
|
|
|
// Verify raw response is captured by HandleProviderAPIError
|
|
if bifrostErr.ExtraFields.RawResponse == nil {
|
|
t.Error("HandleProviderAPIError should have set RawResponse")
|
|
}
|
|
|
|
// Step 2: Enrich with request (like providers do)
|
|
requestBody := []byte(`{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}`)
|
|
|
|
enrichedErr := EnrichError(ctx, bifrostErr, requestBody, nil, true, true)
|
|
|
|
// Verify both raw request and raw response are present
|
|
if enrichedErr.ExtraFields.RawRequest == nil {
|
|
t.Error("EnrichError should have set RawRequest")
|
|
}
|
|
|
|
if enrichedErr.ExtraFields.RawResponse == nil {
|
|
t.Error("EnrichError should have preserved RawResponse from HandleProviderAPIError")
|
|
}
|
|
|
|
t.Log("✓ End-to-end: Raw request and error response captured successfully")
|
|
}
|
|
|
|
// TestHandleProviderAPIError_AllPathsSetRawResponse verifies that all error return
|
|
// paths in HandleProviderAPIError include RawResponse
|
|
func TestHandleProviderAPIError_AllPathsSetRawResponse(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
statusCode int
|
|
body []byte
|
|
setupResp func(*fasthttp.Response)
|
|
errorType string
|
|
}{
|
|
{
|
|
name: "Path 1: Decode error",
|
|
statusCode: 500,
|
|
body: []byte{0xFF, 0xFE, 0xFD}, // Invalid gzip-compressed data
|
|
setupResp: func(r *fasthttp.Response) {
|
|
r.Header.Set("Content-Type", "application/json")
|
|
// Set Content-Encoding: gzip to trigger BodyGunzip() error on invalid gzip data
|
|
r.Header.Set("Content-Encoding", "gzip")
|
|
},
|
|
errorType: "decode_failure",
|
|
},
|
|
{
|
|
name: "Path 2: Empty response",
|
|
statusCode: 502,
|
|
body: []byte(" "), // Only whitespace
|
|
setupResp: func(r *fasthttp.Response) {
|
|
r.Header.Set("Content-Type", "application/json")
|
|
},
|
|
errorType: "empty_response",
|
|
},
|
|
{
|
|
name: "Path 3: Valid JSON",
|
|
statusCode: 400,
|
|
body: []byte(`{"error": {"message": "Bad request"}}`),
|
|
setupResp: func(r *fasthttp.Response) {
|
|
r.Header.Set("Content-Type", "application/json")
|
|
},
|
|
errorType: "valid_json",
|
|
},
|
|
{
|
|
name: "Path 4: HTML response",
|
|
statusCode: 503,
|
|
body: []byte(`<!DOCTYPE html><html><head><title>Error</title></head><body><h1>Service Error</h1></body></html>`),
|
|
setupResp: func(r *fasthttp.Response) {
|
|
r.Header.Set("Content-Type", "text/html")
|
|
},
|
|
errorType: "html",
|
|
},
|
|
{
|
|
name: "Path 5: Unparseable non-HTML",
|
|
statusCode: 500,
|
|
body: []byte(`This is plain text that's not JSON`),
|
|
setupResp: func(r *fasthttp.Response) {
|
|
r.Header.Set("Content-Type", "text/plain")
|
|
},
|
|
errorType: "unparseable",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
resp := &fasthttp.Response{}
|
|
resp.SetStatusCode(tc.statusCode)
|
|
resp.SetBody(tc.body)
|
|
tc.setupResp(resp)
|
|
|
|
var errorResp map[string]interface{}
|
|
bifrostErr := HandleProviderAPIError(resp, &errorResp)
|
|
|
|
if bifrostErr == nil {
|
|
t.Fatalf("%s: HandleProviderAPIError returned nil", tc.name)
|
|
}
|
|
|
|
if bifrostErr.ExtraFields.RawResponse == nil {
|
|
t.Errorf("%s [%s]: RawResponse is nil - MISSING raw error body!", tc.name, tc.errorType)
|
|
} else {
|
|
t.Logf("✓ %s [%s]: RawResponse is set", tc.name, tc.errorType)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestGetRequestPath verifies GetRequestPath handles all path resolution scenarios correctly
|
|
func TestGetRequestPath(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
contextPath *string
|
|
customProviderConfig *schemas.CustomProviderConfig
|
|
defaultPath string
|
|
requestType schemas.RequestType
|
|
expectedPath string
|
|
expectedIsURL bool
|
|
}{
|
|
{
|
|
name: "Returns default path when nothing is set",
|
|
defaultPath: "/v1/chat/completions",
|
|
requestType: schemas.ChatCompletionRequest,
|
|
expectedPath: "/v1/chat/completions",
|
|
expectedIsURL: false,
|
|
},
|
|
{
|
|
name: "Returns path from context when present",
|
|
contextPath: schemas.Ptr("/custom/path"),
|
|
defaultPath: "/v1/chat/completions",
|
|
requestType: schemas.ChatCompletionRequest,
|
|
expectedPath: "/custom/path",
|
|
expectedIsURL: false,
|
|
},
|
|
{
|
|
name: "Returns full URL from config override",
|
|
customProviderConfig: &schemas.CustomProviderConfig{
|
|
RequestPathOverrides: map[schemas.RequestType]string{
|
|
schemas.ChatCompletionRequest: "https://custom.api.com/v1/completions",
|
|
},
|
|
},
|
|
defaultPath: "/v1/chat/completions",
|
|
requestType: schemas.ChatCompletionRequest,
|
|
expectedPath: "https://custom.api.com/v1/completions",
|
|
expectedIsURL: true,
|
|
},
|
|
{
|
|
name: "Returns path override with leading slash",
|
|
customProviderConfig: &schemas.CustomProviderConfig{
|
|
RequestPathOverrides: map[schemas.RequestType]string{
|
|
schemas.ChatCompletionRequest: "/custom/endpoint",
|
|
},
|
|
},
|
|
defaultPath: "/v1/chat/completions",
|
|
requestType: schemas.ChatCompletionRequest,
|
|
expectedPath: "/custom/endpoint",
|
|
expectedIsURL: false,
|
|
},
|
|
{
|
|
name: "Adds leading slash to path override without one",
|
|
customProviderConfig: &schemas.CustomProviderConfig{
|
|
RequestPathOverrides: map[schemas.RequestType]string{
|
|
schemas.ChatCompletionRequest: "custom/endpoint",
|
|
},
|
|
},
|
|
defaultPath: "/v1/chat/completions",
|
|
requestType: schemas.ChatCompletionRequest,
|
|
expectedPath: "/custom/endpoint",
|
|
expectedIsURL: false,
|
|
},
|
|
{
|
|
name: "Returns default path for empty override",
|
|
customProviderConfig: &schemas.CustomProviderConfig{
|
|
RequestPathOverrides: map[schemas.RequestType]string{
|
|
schemas.ChatCompletionRequest: " ",
|
|
},
|
|
},
|
|
defaultPath: "/v1/chat/completions",
|
|
requestType: schemas.ChatCompletionRequest,
|
|
expectedPath: "/v1/chat/completions",
|
|
expectedIsURL: false,
|
|
},
|
|
{
|
|
name: "Returns default when override exists for different request type",
|
|
customProviderConfig: &schemas.CustomProviderConfig{
|
|
RequestPathOverrides: map[schemas.RequestType]string{
|
|
schemas.EmbeddingRequest: "/custom/embeddings",
|
|
},
|
|
},
|
|
defaultPath: "/v1/chat/completions",
|
|
requestType: schemas.ChatCompletionRequest,
|
|
expectedPath: "/v1/chat/completions",
|
|
expectedIsURL: false,
|
|
},
|
|
{
|
|
name: "Handles URL with http scheme",
|
|
customProviderConfig: &schemas.CustomProviderConfig{
|
|
RequestPathOverrides: map[schemas.RequestType]string{
|
|
schemas.ChatCompletionRequest: "http://internal.api:8080/completions",
|
|
},
|
|
},
|
|
defaultPath: "/v1/chat/completions",
|
|
requestType: schemas.ChatCompletionRequest,
|
|
expectedPath: "http://internal.api:8080/completions",
|
|
expectedIsURL: true,
|
|
},
|
|
{
|
|
name: "Context path takes precedence over config override",
|
|
contextPath: schemas.Ptr("/context/path"),
|
|
customProviderConfig: &schemas.CustomProviderConfig{
|
|
RequestPathOverrides: map[schemas.RequestType]string{
|
|
schemas.ChatCompletionRequest: "/config/path",
|
|
},
|
|
},
|
|
defaultPath: "/v1/chat/completions",
|
|
requestType: schemas.ChatCompletionRequest,
|
|
expectedPath: "/context/path",
|
|
expectedIsURL: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := context.Background()
|
|
if tt.contextPath != nil {
|
|
ctx = context.WithValue(ctx, schemas.BifrostContextKeyURLPath, *tt.contextPath)
|
|
}
|
|
|
|
path, isURL := GetRequestPath(ctx, tt.defaultPath, tt.customProviderConfig, tt.requestType)
|
|
|
|
if path != tt.expectedPath {
|
|
t.Errorf("GetRequestPath() path = %q, want %q", path, tt.expectedPath)
|
|
}
|
|
|
|
if isURL != tt.expectedIsURL {
|
|
t.Errorf("GetRequestPath() isURL = %v, want %v", isURL, tt.expectedIsURL)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestMarshalSorted_Deterministic verifies that MarshalSorted produces identical
|
|
// output across multiple calls with the same map, despite Go's randomized map iteration.
|
|
func TestMarshalSorted_Deterministic(t *testing.T) {
|
|
// Build a map with enough keys to make random ordering statistically certain
|
|
m := map[string]interface{}{
|
|
"zulu": 1,
|
|
"alpha": 2,
|
|
"mike": 3,
|
|
"bravo": 4,
|
|
"yankee": 5,
|
|
"charlie": 6,
|
|
"nested": map[string]interface{}{
|
|
"zebra": "z",
|
|
"apple": "a",
|
|
"mango": "m",
|
|
"banana": "b",
|
|
"cherry": "c",
|
|
"date": "d",
|
|
"fig": "f",
|
|
"grape": "g",
|
|
"kiwi": "k",
|
|
"lemon": "l",
|
|
"orange": "o",
|
|
"papaya": "p",
|
|
"quince": "q",
|
|
"raisin": "r",
|
|
"satsuma": "s",
|
|
},
|
|
}
|
|
|
|
first, err := MarshalSorted(m)
|
|
if err != nil {
|
|
t.Fatalf("MarshalSorted() error: %v", err)
|
|
}
|
|
|
|
// Run 50 iterations to be confident about determinism
|
|
for i := 0; i < 50; i++ {
|
|
got, err := MarshalSorted(m)
|
|
if err != nil {
|
|
t.Fatalf("MarshalSorted() iteration %d error: %v", i, err)
|
|
}
|
|
if string(got) != string(first) {
|
|
t.Fatalf("MarshalSorted() produced different output on iteration %d:\nfirst: %s\ngot: %s", i, first, got)
|
|
}
|
|
}
|
|
|
|
// Also verify MarshalSortedIndent
|
|
firstIndent, err := MarshalSortedIndent(m, "", " ")
|
|
if err != nil {
|
|
t.Fatalf("MarshalSortedIndent() error: %v", err)
|
|
}
|
|
|
|
for i := 0; i < 50; i++ {
|
|
got, err := MarshalSortedIndent(m, "", " ")
|
|
if err != nil {
|
|
t.Fatalf("MarshalSortedIndent() iteration %d error: %v", i, err)
|
|
}
|
|
if string(got) != string(firstIndent) {
|
|
t.Fatalf("MarshalSortedIndent() produced different output on iteration %d:\nfirst: %s\ngot: %s", i, firstIndent, got)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestCheckAndDecodeBody_PooledGzip verifies that CheckAndDecodeBody correctly
|
|
// decompresses gzip-encoded responses using pooled gzip readers.
|
|
func TestCheckAndDecodeBody_PooledGzip(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
body []byte
|
|
contentEncoding string
|
|
wantBody string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "gzip encoded body",
|
|
body: gzipCompress([]byte(`{"message":"hello world"}`)),
|
|
contentEncoding: "gzip",
|
|
wantBody: `{"message":"hello world"}`,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "gzip with uppercase header",
|
|
body: gzipCompress([]byte(`test data`)),
|
|
contentEncoding: "GZIP",
|
|
wantBody: `test data`,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "gzip with whitespace in header",
|
|
body: gzipCompress([]byte(`trimmed`)),
|
|
contentEncoding: " gzip ",
|
|
wantBody: `trimmed`,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "no encoding - plain body",
|
|
body: []byte(`plain text`),
|
|
contentEncoding: "",
|
|
wantBody: `plain text`,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "empty gzip body",
|
|
body: []byte{},
|
|
contentEncoding: "gzip",
|
|
wantBody: "",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid gzip data",
|
|
body: []byte{0xFF, 0xFE, 0xFD},
|
|
contentEncoding: "gzip",
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
resp.SetBody(tt.body)
|
|
if tt.contentEncoding != "" {
|
|
resp.Header.Set("Content-Encoding", tt.contentEncoding)
|
|
}
|
|
|
|
got, err := CheckAndDecodeBody(resp)
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Errorf("CheckAndDecodeBody() expected error, got nil")
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Errorf("CheckAndDecodeBody() unexpected error: %v", err)
|
|
return
|
|
}
|
|
if string(got) != tt.wantBody {
|
|
t.Errorf("CheckAndDecodeBody() = %q, want %q", string(got), tt.wantBody)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestCheckAndDecodeBody_Concurrent verifies no data races with concurrent access.
|
|
func TestCheckAndDecodeBody_Concurrent(t *testing.T) {
|
|
testData := []byte(`{"concurrent":"test"}`)
|
|
compressed := gzipCompress(testData)
|
|
|
|
done := make(chan bool)
|
|
for i := 0; i < 100; i++ {
|
|
go func() {
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
resp.SetBody(compressed)
|
|
resp.Header.Set("Content-Encoding", "gzip")
|
|
|
|
got, err := CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
t.Errorf("CheckAndDecodeBody() error: %v", err)
|
|
}
|
|
if string(got) != string(testData) {
|
|
t.Errorf("CheckAndDecodeBody() = %q, want %q", string(got), string(testData))
|
|
}
|
|
done <- true
|
|
}()
|
|
}
|
|
|
|
for i := 0; i < 100; i++ {
|
|
<-done
|
|
}
|
|
}
|
|
|
|
func TestDrainNonSSEStreamResponse_SSEDoesNotDrain(t *testing.T) {
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
body := []byte("data: hello\n\n")
|
|
resp.Header.SetContentType("text/event-stream")
|
|
resp.SetBodyStream(bytes.NewReader(body), len(body))
|
|
|
|
drained := DrainNonSSEStreamResponse(resp)
|
|
if drained {
|
|
t.Fatal("expected SSE response to remain readable")
|
|
}
|
|
|
|
remaining, err := io.ReadAll(resp.BodyStream())
|
|
if err != nil {
|
|
t.Fatalf("failed to read SSE body after guard: %v", err)
|
|
}
|
|
if string(remaining) != string(body) {
|
|
t.Fatalf("expected SSE body to remain intact, got %q", string(remaining))
|
|
}
|
|
}
|
|
|
|
func TestDrainNonSSEStreamResponse_NonSSEDrains(t *testing.T) {
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
body := []byte(`{"error":"not stream"}`)
|
|
resp.Header.SetContentType("application/json")
|
|
resp.SetBodyStream(bytes.NewReader(body), len(body))
|
|
|
|
drained := DrainNonSSEStreamResponse(resp)
|
|
if !drained {
|
|
t.Fatal("expected non-SSE response to be drained")
|
|
}
|
|
|
|
remaining, err := io.ReadAll(resp.BodyStream())
|
|
if err != nil {
|
|
t.Fatalf("failed to read body after drain: %v", err)
|
|
}
|
|
if len(remaining) != 0 {
|
|
t.Fatalf("expected drained body to be empty, got %q", string(remaining))
|
|
}
|
|
}
|
|
|
|
func TestDrainNonSSEStreamResponse_GzipSSEStillReadable(t *testing.T) {
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
body := []byte("data: hello\n\ndata: [DONE]\n\n")
|
|
compressed := gzipCompress(body)
|
|
resp.Header.SetContentType("text/event-stream")
|
|
resp.Header.Set("Content-Encoding", "gzip")
|
|
resp.SetBodyStream(bytes.NewReader(compressed), len(compressed))
|
|
|
|
drained := DrainNonSSEStreamResponse(resp)
|
|
if drained {
|
|
t.Fatal("expected gzip SSE response to remain readable")
|
|
}
|
|
|
|
reader, releaseGzip := DecompressStreamBody(resp)
|
|
defer releaseGzip()
|
|
|
|
remaining, err := io.ReadAll(reader)
|
|
if err != nil {
|
|
t.Fatalf("failed to read decompressed SSE body: %v", err)
|
|
}
|
|
if string(remaining) != string(body) {
|
|
t.Fatalf("expected decompressed SSE body %q, got %q", string(body), string(remaining))
|
|
}
|
|
}
|
|
|
|
// gzipCompress compresses data using gzip for testing.
|
|
func gzipCompress(data []byte) []byte {
|
|
var buf bytes.Buffer
|
|
gz := gzip.NewWriter(&buf)
|
|
if _, err := gz.Write(data); err != nil {
|
|
panic(fmt.Errorf("gzip write: %w", err))
|
|
}
|
|
if err := gz.Close(); err != nil {
|
|
panic(fmt.Errorf("gzip close: %w", err))
|
|
}
|
|
return buf.Bytes()
|
|
}
|
|
|
|
func TestMergeExtraParamsIntoJSON_PreservesKeyOrder(t *testing.T) {
|
|
// JSON with a specific key order that must be preserved
|
|
jsonBody := []byte(`{
|
|
"model": "gpt-4",
|
|
"messages": [],
|
|
"tool_choice": {"type": "function", "function": {"name": "test"}},
|
|
"tools": []
|
|
}`)
|
|
|
|
extraParams := map[string]interface{}{
|
|
"custom_field": "value",
|
|
}
|
|
|
|
result, err := MergeExtraParamsIntoJSON(jsonBody, extraParams)
|
|
if err != nil {
|
|
t.Fatalf("MergeExtraParamsIntoJSON() error: %v", err)
|
|
}
|
|
|
|
// Verify original key order is preserved and custom_field is appended
|
|
resultStr := string(result)
|
|
modelIdx := bytes.Index(result, []byte(`"model"`))
|
|
messagesIdx := bytes.Index(result, []byte(`"messages"`))
|
|
toolChoiceIdx := bytes.Index(result, []byte(`"tool_choice"`))
|
|
toolsIdx := bytes.Index(result, []byte(`"tools"`))
|
|
customIdx := bytes.Index(result, []byte(`"custom_field"`))
|
|
|
|
if modelIdx >= messagesIdx || messagesIdx >= toolChoiceIdx || toolChoiceIdx >= toolsIdx || toolsIdx >= customIdx {
|
|
t.Fatalf("Key order not preserved. Result:\n%s", resultStr)
|
|
}
|
|
}
|
|
|
|
func TestMergeExtraParamsIntoJSON_OverwriteExistingKey(t *testing.T) {
|
|
jsonBody := []byte(`{"z_first": "original", "a_second": "original"}`)
|
|
|
|
extraParams := map[string]interface{}{
|
|
"z_first": "overwritten",
|
|
}
|
|
|
|
result, err := MergeExtraParamsIntoJSON(jsonBody, extraParams)
|
|
if err != nil {
|
|
t.Fatalf("MergeExtraParamsIntoJSON() error: %v", err)
|
|
}
|
|
|
|
// z_first should still come before a_second (preserving original position)
|
|
zIdx := bytes.Index(result, []byte(`"z_first"`))
|
|
aIdx := bytes.Index(result, []byte(`"a_second"`))
|
|
if zIdx >= aIdx {
|
|
t.Fatalf("Overwritten key should preserve its position. Result: %s", string(result))
|
|
}
|
|
|
|
// z_first should have the new value
|
|
if !bytes.Contains(result, []byte(`"overwritten"`)) {
|
|
t.Fatalf("Value should be overwritten. Result: %s", string(result))
|
|
}
|
|
}
|
|
|
|
func TestMergeExtraParamsIntoJSON_DeepMerge(t *testing.T) {
|
|
jsonBody := []byte(`{"outer": {"a": 1, "b": 2}}`)
|
|
|
|
extraParams := map[string]interface{}{
|
|
"outer": map[string]interface{}{
|
|
"c": 3,
|
|
},
|
|
}
|
|
|
|
result, err := MergeExtraParamsIntoJSON(jsonBody, extraParams)
|
|
if err != nil {
|
|
t.Fatalf("MergeExtraParamsIntoJSON() error: %v", err)
|
|
}
|
|
|
|
// Verify the merge happened
|
|
var parsed map[string]interface{}
|
|
if err := sonic.Unmarshal(result, &parsed); err != nil {
|
|
t.Fatalf("Failed to parse result: %v", err)
|
|
}
|
|
|
|
outer, ok := parsed["outer"].(map[string]interface{})
|
|
if !ok {
|
|
t.Fatal("outer should be a map")
|
|
}
|
|
if len(outer) != 3 {
|
|
t.Fatalf("outer should have 3 keys after merge, got %d: %v", len(outer), outer)
|
|
}
|
|
}
|
|
|
|
func TestMergeExtraParamsIntoJSON_EmptyExtraParams(t *testing.T) {
|
|
jsonBody := []byte(`{"a": 1, "b": 2}`)
|
|
result, err := MergeExtraParamsIntoJSON(jsonBody, map[string]interface{}{})
|
|
if err != nil {
|
|
t.Fatalf("MergeExtraParamsIntoJSON() error: %v", err)
|
|
}
|
|
|
|
// Should be valid JSON with same content
|
|
var parsed map[string]interface{}
|
|
if err := sonic.Unmarshal(result, &parsed); err != nil {
|
|
t.Fatalf("Failed to parse result: %v", err)
|
|
}
|
|
if len(parsed) != 2 {
|
|
t.Fatalf("Expected 2 keys, got %d", len(parsed))
|
|
}
|
|
}
|
|
|
|
// TestParseAndSetRawRequest_CompactsJSON verifies that indented JSON input
|
|
// (with literal newlines from MarshalIndent) is compacted to a single line.
|
|
// This is critical for SSE streaming where newlines break data-line framing.
|
|
func TestParseAndSetRawRequest_CompactsJSON(t *testing.T) {
|
|
indentedJSON := []byte(`{
|
|
"model": "gpt-4",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "Hello"
|
|
}
|
|
],
|
|
"temperature": 0.7
|
|
}`)
|
|
|
|
var extraFields schemas.BifrostResponseExtraFields
|
|
ParseAndSetRawRequest(&extraFields, indentedJSON)
|
|
|
|
if extraFields.RawRequest == nil {
|
|
t.Fatal("RawRequest should be set")
|
|
}
|
|
|
|
raw, ok := extraFields.RawRequest.(json.RawMessage)
|
|
if !ok {
|
|
t.Fatalf("RawRequest should be json.RawMessage, got %T", extraFields.RawRequest)
|
|
}
|
|
|
|
// The compacted output must not contain any literal newlines
|
|
if strings.Contains(string(raw), "\n") {
|
|
t.Errorf("Compacted RawRequest should not contain newlines, got:\n%s", string(raw))
|
|
}
|
|
|
|
// Verify it's still valid JSON with the same content
|
|
var parsed map[string]interface{}
|
|
if err := sonic.Unmarshal(raw, &parsed); err != nil {
|
|
t.Fatalf("Compacted RawRequest is not valid JSON: %v", err)
|
|
}
|
|
|
|
if parsed["model"] != "gpt-4" {
|
|
t.Errorf("Expected model=gpt-4, got %v", parsed["model"])
|
|
}
|
|
}
|
|
|
|
// TestParseAndSetRawRequest_PreservesKeyOrdering verifies that JSON key order
|
|
// is maintained after compaction. This is essential for LLM prompt caching
|
|
// where key ordering affects cache hit rates.
|
|
func TestParseAndSetRawRequest_PreservesKeyOrdering(t *testing.T) {
|
|
// Keys are intentionally not alphabetically sorted
|
|
jsonBody := []byte(`{"z_last":"z","a_first":"a","m_middle":"m"}`)
|
|
|
|
var extraFields schemas.BifrostResponseExtraFields
|
|
ParseAndSetRawRequest(&extraFields, jsonBody)
|
|
|
|
raw := extraFields.RawRequest.(json.RawMessage)
|
|
result := string(raw)
|
|
|
|
zIdx := strings.Index(result, `"z_last"`)
|
|
aIdx := strings.Index(result, `"a_first"`)
|
|
mIdx := strings.Index(result, `"m_middle"`)
|
|
|
|
if zIdx >= aIdx || aIdx >= mIdx {
|
|
t.Errorf("Key ordering not preserved. Got: %s", result)
|
|
}
|
|
}
|
|
|
|
// TestParseAndSetRawRequest_EmptyBody verifies that empty input is a no-op.
|
|
func TestParseAndSetRawRequest_EmptyBody(t *testing.T) {
|
|
var extraFields schemas.BifrostResponseExtraFields
|
|
ParseAndSetRawRequest(&extraFields, []byte{})
|
|
|
|
if extraFields.RawRequest != nil {
|
|
t.Error("RawRequest should be nil for empty body")
|
|
}
|
|
|
|
ParseAndSetRawRequest(&extraFields, nil)
|
|
|
|
if extraFields.RawRequest != nil {
|
|
t.Error("RawRequest should be nil for nil body")
|
|
}
|
|
}
|
|
|
|
// TestParseAndSetRawRequest_SSEStreamingChunks simulates the actual SSE streaming
|
|
// flow end-to-end: a response chunk with raw_request containing indented JSON is
|
|
// marshaled, framed as SSE "data: <json>\n\n", and then each SSE data line is
|
|
// parsed back. This is the exact scenario that caused issue #1905 — pretty-printed
|
|
// JSON in raw_request introduced literal newlines that broke SSE data-line framing.
|
|
func TestParseAndSetRawRequest_SSEStreamingChunks(t *testing.T) {
|
|
// Simulate indented request body (as produced by MarshalSortedIndent)
|
|
indentedRequest := []byte(`{
|
|
"model": "gpt-4",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "Hello"
|
|
}
|
|
],
|
|
"stream": true,
|
|
"temperature": 0.7
|
|
}`)
|
|
|
|
// Build a response chunk with raw_request set via ParseAndSetRawRequest.
|
|
// Uses BifrostChatResponse which is the actual type marshaled in the streaming path.
|
|
chunk := schemas.BifrostChatResponse{
|
|
ID: "chatcmpl-test",
|
|
Model: "gpt-4",
|
|
Object: "chat.completion.chunk",
|
|
Choices: []schemas.BifrostResponseChoice{
|
|
{
|
|
Index: 0,
|
|
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
|
Delta: &schemas.ChatStreamResponseChoiceDelta{
|
|
Content: schemas.Ptr("Hello"),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
ParseAndSetRawRequest(&chunk.ExtraFields, indentedRequest)
|
|
|
|
// Marshal the chunk (exactly like the transport layer does: sonic.Marshal)
|
|
chunkJSON, err := sonic.Marshal(chunk)
|
|
if err != nil {
|
|
t.Fatalf("Failed to marshal chunk: %v", err)
|
|
}
|
|
|
|
// Frame as SSE: "data: <json>\n\n" (exactly as in inference.go:1591)
|
|
sseFrame := fmt.Sprintf("data: %s\n\n", chunkJSON)
|
|
|
|
// Parse the SSE frame line-by-line as a real SSE client would.
|
|
// Split on \n and check that there is exactly one "data:" line.
|
|
lines := strings.Split(strings.TrimRight(sseFrame, "\n"), "\n")
|
|
|
|
var dataLines []string
|
|
for _, line := range lines {
|
|
if strings.HasPrefix(line, "data: ") {
|
|
dataLines = append(dataLines, line)
|
|
} else if line != "" {
|
|
// Any non-empty, non-data line means SSE framing is broken —
|
|
// this is exactly what happened in #1905
|
|
t.Errorf("Unexpected non-data line in SSE frame (broken framing): %q", line)
|
|
}
|
|
}
|
|
|
|
if len(dataLines) != 1 {
|
|
t.Fatalf("Expected exactly 1 SSE data line, got %d:\n%s", len(dataLines), sseFrame)
|
|
}
|
|
|
|
// Parse the JSON payload from the single data line
|
|
jsonPayload := strings.TrimPrefix(dataLines[0], "data: ")
|
|
var parsed schemas.BifrostChatResponse
|
|
if err := sonic.Unmarshal([]byte(jsonPayload), &parsed); err != nil {
|
|
t.Fatalf("Failed to parse SSE data line as JSON (this is the #1905 bug): %v\nPayload: %s", err, jsonPayload)
|
|
}
|
|
|
|
// Verify the parsed response has the correct content
|
|
if parsed.ID != "chatcmpl-test" {
|
|
t.Errorf("Expected ID=chatcmpl-test, got %s", parsed.ID)
|
|
}
|
|
if parsed.ExtraFields.RawRequest == nil {
|
|
t.Error("RawRequest should be present in parsed chunk")
|
|
}
|
|
|
|
// Verify raw_request round-trips correctly — the client should be able
|
|
// to parse it back into the original request structure
|
|
rawBytes, err := sonic.Marshal(parsed.ExtraFields.RawRequest)
|
|
if err != nil {
|
|
t.Fatalf("Failed to marshal raw_request: %v", err)
|
|
}
|
|
var rawParsed map[string]interface{}
|
|
if err := sonic.Unmarshal(rawBytes, &rawParsed); err != nil {
|
|
t.Fatalf("raw_request is not valid JSON after round-trip: %v", err)
|
|
}
|
|
if rawParsed["model"] != "gpt-4" {
|
|
t.Errorf("Expected raw_request.model=gpt-4, got %v", rawParsed["model"])
|
|
}
|
|
}
|
|
|
|
// TestBuildClientStreamChunk_ImageGenerationStripping verifies that
|
|
// BuildClientStreamChunk correctly handles BifrostImageGenerationStreamResponse:
|
|
// strips raw fields when in logging-only mode and never mutates the original.
|
|
func TestBuildClientStreamChunk_ImageGenerationStripping(t *testing.T) {
|
|
rawReq := json.RawMessage(`{"model":"dall-e-3"}`)
|
|
rawResp := json.RawMessage(`{"data":[{"url":"https://example.com/img.png"}]}`)
|
|
|
|
imgResp := &schemas.BifrostImageGenerationStreamResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
RawRequest: rawReq,
|
|
RawResponse: rawResp,
|
|
},
|
|
}
|
|
|
|
response := &schemas.BifrostResponse{ImageGenerationStreamResponse: imgResp}
|
|
|
|
t.Run("logging-only: raw fields stripped from image gen chunk, original preserved", func(t *testing.T) {
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
ctx.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, true)
|
|
ctx.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, true)
|
|
|
|
chunk := BuildClientStreamChunk(ctx, response, nil)
|
|
if chunk.BifrostImageGenerationStreamResponse == nil {
|
|
t.Fatal("expected BifrostImageGenerationStreamResponse in chunk")
|
|
}
|
|
if chunk.BifrostImageGenerationStreamResponse.ExtraFields.RawRequest != nil {
|
|
t.Error("expected RawRequest stripped from chunk, but it was present")
|
|
}
|
|
if chunk.BifrostImageGenerationStreamResponse.ExtraFields.RawResponse != nil {
|
|
t.Error("expected RawResponse stripped from chunk, but it was present")
|
|
}
|
|
// Original must not be mutated.
|
|
if imgResp.ExtraFields.RawRequest == nil {
|
|
t.Error("original BifrostImageGenerationStreamResponse.ExtraFields.RawRequest was mutated")
|
|
}
|
|
if imgResp.ExtraFields.RawResponse == nil {
|
|
t.Error("original BifrostImageGenerationStreamResponse.ExtraFields.RawResponse was mutated")
|
|
}
|
|
if chunk.BifrostImageGenerationStreamResponse == imgResp {
|
|
t.Error("chunk contains same pointer as original; it must be a copy")
|
|
}
|
|
})
|
|
|
|
t.Run("no logging flag: raw fields preserved in image gen chunk", func(t *testing.T) {
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
|
|
chunk := BuildClientStreamChunk(ctx, response, nil)
|
|
if chunk.BifrostImageGenerationStreamResponse == nil {
|
|
t.Fatal("expected BifrostImageGenerationStreamResponse in chunk")
|
|
}
|
|
if chunk.BifrostImageGenerationStreamResponse.ExtraFields.RawRequest == nil {
|
|
t.Error("expected RawRequest present in chunk, but it was nil")
|
|
}
|
|
if chunk.BifrostImageGenerationStreamResponse.ExtraFields.RawResponse == nil {
|
|
t.Error("expected RawResponse present in chunk, but it was nil")
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChunk verifies
|
|
// that when drop-raw context flags are set, ProcessAndSendResponse strips RawRequest and
|
|
// RawResponse from the outgoing stream chunk, while leaving other ExtraFields intact.
|
|
// It also verifies that the original BifrostResponse is not mutated
|
|
// (shared object safety for PostLLMHook goroutines).
|
|
func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromResponseChunk(t *testing.T) {
|
|
rawReq := json.RawMessage(`{"model":"gpt-4","messages":[]}`)
|
|
rawResp := json.RawMessage(`{"id":"chatcmpl-001"}`)
|
|
|
|
tests := []struct {
|
|
name string
|
|
loggingOnly bool
|
|
expectStripped bool
|
|
}{
|
|
{
|
|
name: "logging-only flag set: raw data stripped from chunk",
|
|
loggingOnly: true,
|
|
expectStripped: true,
|
|
},
|
|
{
|
|
name: "logging-only flag not set: raw data preserved in chunk",
|
|
loggingOnly: false,
|
|
expectStripped: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
if tt.loggingOnly {
|
|
ctx.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, true)
|
|
ctx.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, true)
|
|
}
|
|
|
|
response := &schemas.BifrostResponse{
|
|
ChatResponse: &schemas.BifrostChatResponse{
|
|
ID: "chatcmpl-001",
|
|
Model: "gpt-4",
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
RawRequest: rawReq,
|
|
RawResponse: rawResp,
|
|
},
|
|
},
|
|
}
|
|
|
|
passThrough := func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
|
return resp, err
|
|
}
|
|
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, 1)
|
|
ProcessAndSendResponse(ctx, passThrough, response, responseChan, nil)
|
|
|
|
chunk := <-responseChan
|
|
if chunk.BifrostChatResponse == nil {
|
|
t.Fatal("expected non-nil BifrostChatResponse in stream chunk")
|
|
}
|
|
|
|
hasRawReq := chunk.BifrostChatResponse.ExtraFields.RawRequest != nil
|
|
hasRawResp := chunk.BifrostChatResponse.ExtraFields.RawResponse != nil
|
|
|
|
if tt.expectStripped {
|
|
if hasRawReq {
|
|
t.Error("expected RawRequest to be nil (stripped) in chunk, but it was present")
|
|
}
|
|
if hasRawResp {
|
|
t.Error("expected RawResponse to be nil (stripped) in chunk, but it was present")
|
|
}
|
|
// Critical: the original shared object must NOT have been mutated.
|
|
if response.ChatResponse.ExtraFields.RawRequest == nil {
|
|
t.Error("original BifrostResponse.ChatResponse.ExtraFields.RawRequest was mutated (nil); shared object must be preserved")
|
|
}
|
|
if response.ChatResponse.ExtraFields.RawResponse == nil {
|
|
t.Error("original BifrostResponse.ChatResponse.ExtraFields.RawResponse was mutated (nil); shared object must be preserved")
|
|
}
|
|
// The chunk must be a copy, not the same pointer as the original.
|
|
if chunk.BifrostChatResponse == response.ChatResponse {
|
|
t.Error("chunk.BifrostChatResponse is the same pointer as the original; it must be a copy to avoid data races")
|
|
}
|
|
} else {
|
|
if !hasRawReq {
|
|
t.Error("expected RawRequest to be present in chunk, but it was nil")
|
|
}
|
|
if !hasRawResp {
|
|
t.Error("expected RawResponse to be present in chunk, but it was nil")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromErrorChunk verifies
|
|
// that when drop-raw context flags are set, raw data is stripped from BifrostError
|
|
// payloads embedded in stream chunks, without mutating the shared BifrostError object
|
|
// (shared object safety for PostLLMHook goroutines).
|
|
func TestProcessAndSendResponse_StoreRawLoggingOnly_StripsRawDataFromErrorChunk(t *testing.T) {
|
|
rawReq := json.RawMessage(`{"model":"gpt-4"}`)
|
|
rawResp := json.RawMessage(`{"error":"rate limit exceeded"}`)
|
|
|
|
tests := []struct {
|
|
name string
|
|
loggingOnly bool
|
|
expectStripped bool
|
|
}{
|
|
{
|
|
name: "logging-only flag set: raw data stripped from error chunk",
|
|
loggingOnly: true,
|
|
expectStripped: true,
|
|
},
|
|
{
|
|
name: "logging-only flag not set: raw data preserved in error chunk",
|
|
loggingOnly: false,
|
|
expectStripped: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
if tt.loggingOnly {
|
|
ctx.SetValue(schemas.BifrostContextKeyDropRawRequestFromClient, true)
|
|
ctx.SetValue(schemas.BifrostContextKeyDropRawResponseFromClient, true)
|
|
}
|
|
|
|
// Use a postHookRunner that converts the response to a BifrostError with raw data
|
|
bifrostErr := &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
StatusCode: schemas.Ptr(429),
|
|
Error: &schemas.ErrorField{Message: "rate limit exceeded"},
|
|
ExtraFields: schemas.BifrostErrorExtraFields{
|
|
RawRequest: rawReq,
|
|
RawResponse: rawResp,
|
|
},
|
|
}
|
|
|
|
errorRunner := func(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
responseChan := make(chan *schemas.BifrostStreamChunk, 1)
|
|
ProcessAndSendResponse(ctx, errorRunner, &schemas.BifrostResponse{
|
|
ChatResponse: &schemas.BifrostChatResponse{ID: "chatcmpl-001"},
|
|
}, responseChan, nil)
|
|
|
|
chunk := <-responseChan
|
|
if chunk.BifrostError == nil {
|
|
t.Fatal("expected non-nil BifrostError in stream chunk")
|
|
}
|
|
|
|
hasRawReq := chunk.BifrostError.ExtraFields.RawRequest != nil
|
|
hasRawResp := chunk.BifrostError.ExtraFields.RawResponse != nil
|
|
|
|
if tt.expectStripped {
|
|
if hasRawReq {
|
|
t.Error("expected RawRequest to be nil (stripped) in error chunk, but it was present")
|
|
}
|
|
if hasRawResp {
|
|
t.Error("expected RawResponse to be nil (stripped) in error chunk, but it was present")
|
|
}
|
|
// Critical: the original shared BifrostError must NOT have been mutated.
|
|
if bifrostErr.ExtraFields.RawRequest == nil {
|
|
t.Error("original BifrostError.ExtraFields.RawRequest was mutated (nil); shared object must be preserved")
|
|
}
|
|
if bifrostErr.ExtraFields.RawResponse == nil {
|
|
t.Error("original BifrostError.ExtraFields.RawResponse was mutated (nil); shared object must be preserved")
|
|
}
|
|
// The chunk must hold a copy, not the same pointer as the original.
|
|
if chunk.BifrostError == bifrostErr {
|
|
t.Error("chunk.BifrostError is the same pointer as the original; it must be a copy to avoid data races")
|
|
}
|
|
} else {
|
|
if !hasRawReq {
|
|
t.Error("expected RawRequest to be present in error chunk, but it was nil")
|
|
}
|
|
if !hasRawResp {
|
|
t.Error("expected RawResponse to be present in error chunk, but it was nil")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestShouldSendBackRawRequest verifies that ShouldSendBackRawRequest correctly resolves
|
|
// whether providers should capture the raw request body. It covers:
|
|
// - Default (no context flags): returns the provider default
|
|
// - BifrostContextKeyCaptureRawRequest=true in context: always returns true
|
|
// - Logging-only mode: requestWorker sets BifrostContextKeyCaptureRawRequest=true,
|
|
// so the function sees a single flag (no second check needed).
|
|
func TestShouldSendBackRawRequest(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
contextSendBack bool
|
|
providerDefault bool
|
|
want bool
|
|
}{
|
|
{
|
|
name: "provider default false, no context flag",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "provider default true, no context flag",
|
|
providerDefault: true,
|
|
want: true,
|
|
},
|
|
{
|
|
name: "context SendBack=true overrides provider default false",
|
|
contextSendBack: true,
|
|
want: true,
|
|
},
|
|
{
|
|
name: "context SendBack=true with provider default true",
|
|
contextSendBack: true,
|
|
providerDefault: true,
|
|
want: true,
|
|
},
|
|
{
|
|
// requestWorker sets BifrostContextKeyCaptureRawRequest=true in logging-only
|
|
// mode so a single flag covers both full send-back and logging-only cases.
|
|
name: "logging-only: context SendBack=true set by requestWorker",
|
|
contextSendBack: true,
|
|
providerDefault: false,
|
|
want: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
if tt.contextSendBack {
|
|
ctx.SetValue(schemas.BifrostContextKeyCaptureRawRequest, true)
|
|
}
|
|
|
|
got := ShouldSendBackRawRequest(ctx, tt.providerDefault)
|
|
if got != tt.want {
|
|
t.Errorf("ShouldSendBackRawRequest() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestShouldSendBackRawResponse mirrors TestShouldSendBackRawRequest for the response side.
|
|
func TestShouldSendBackRawResponse(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
contextSendBack bool
|
|
providerDefault bool
|
|
want bool
|
|
}{
|
|
{
|
|
name: "provider default false, no context flag",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "provider default true, no context flag",
|
|
providerDefault: true,
|
|
want: true,
|
|
},
|
|
{
|
|
name: "context SendBack=true overrides provider default false",
|
|
contextSendBack: true,
|
|
want: true,
|
|
},
|
|
{
|
|
name: "context SendBack=true with provider default true",
|
|
contextSendBack: true,
|
|
providerDefault: true,
|
|
want: true,
|
|
},
|
|
{
|
|
// requestWorker sets BifrostContextKeyCaptureRawResponse=true in logging-only
|
|
// mode so a single flag covers both full send-back and logging-only cases.
|
|
name: "logging-only: context SendBack=true set by requestWorker",
|
|
contextSendBack: true,
|
|
providerDefault: false,
|
|
want: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
|
if tt.contextSendBack {
|
|
ctx.SetValue(schemas.BifrostContextKeyCaptureRawResponse, true)
|
|
}
|
|
|
|
got := ShouldSendBackRawResponse(ctx, tt.providerDefault)
|
|
if got != tt.want {
|
|
t.Errorf("ShouldSendBackRawResponse() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|