Files
bifrost/transports/bifrost-http/handlers/middlewares_test.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

2023 lines
60 KiB
Go

package handlers
import (
"bytes"
"compress/gzip"
"compress/zlib"
cryptoRand "crypto/rand"
"encoding/json"
"io"
"strings"
"testing"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
"github.com/valyala/fasthttp"
)
// mockLogger is a mock implementation of schemas.Logger for testing
type mockLogger struct{}
func (m *mockLogger) Debug(format string, args ...any) {}
func (m *mockLogger) Info(format string, args ...any) {}
func (m *mockLogger) Warn(format string, args ...any) {}
func (m *mockLogger) Error(format string, args ...any) {}
func (m *mockLogger) Fatal(format string, args ...any) {}
func (m *mockLogger) SetLevel(level schemas.LogLevel) {}
func (m *mockLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
func (m *mockLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
// TestCorsMiddleware_LocalhostOrigins tests that localhost origins are always allowed
func TestCorsMiddleware_LocalhostOrigins(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{},
},
}
SetLogger(&mockLogger{})
localhostOrigins := []string{
"http://localhost:3000",
"https://localhost:3000",
"http://127.0.0.1:8080",
"http://0.0.0.0:5000",
"https://127.0.0.1:3000",
}
for _, origin := range localhostOrigins {
t.Run(origin, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Origin", origin)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Check CORS headers are set
if string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != origin {
t.Errorf("Expected Access-Control-Allow-Origin to be %s, got %s", origin, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
}
if string(ctx.Response.Header.Peek("Access-Control-Allow-Methods")) != "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD" {
t.Errorf("Access-Control-Allow-Methods header not set correctly")
}
if string(ctx.Response.Header.Peek("Access-Control-Allow-Headers")) != "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout, X-Api-Key, X-OpenAI-Agents-SDK" {
t.Errorf("Access-Control-Allow-Headers header not set correctly")
}
if string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials")) != "true" {
t.Errorf("Access-Control-Allow-Credentials header not set correctly")
}
if string(ctx.Response.Header.Peek("Access-Control-Max-Age")) != "86400" {
t.Errorf("Access-Control-Max-Age header not set correctly")
}
// Check next handler was called
if !nextCalled {
t.Error("Next handler was not called")
}
})
}
}
// TestCorsMiddleware_ConfiguredOrigins tests that configured allowed origins work
func TestCorsMiddleware_ConfiguredOrigins(t *testing.T) {
allowedOrigin := "https://example.com"
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{allowedOrigin},
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Origin", allowedOrigin)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Check CORS headers are set
if string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != allowedOrigin {
t.Errorf("Expected Access-Control-Allow-Origin to be %s, got %s", allowedOrigin, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
}
// Check next handler was called
if !nextCalled {
t.Error("Next handler was not called")
}
}
// TestCorsMiddleware_NonAllowedOrigins tests that non-allowed origins don't get CORS headers
func TestCorsMiddleware_NonAllowedOrigins(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{"https://allowed.com"},
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Origin", "https://malicious.com")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Check CORS headers are NOT set
if len(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != 0 {
t.Error("Access-Control-Allow-Origin header should not be set for non-allowed origin")
}
// Check next handler was still called for non-OPTIONS requests
if !nextCalled {
t.Error("Next handler was not called")
}
}
// TestCorsMiddleware_PreflightAllowedOrigin tests OPTIONS preflight requests for allowed origins
func TestCorsMiddleware_PreflightAllowedOrigin(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{"https://example.com"},
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("OPTIONS")
ctx.Request.Header.Set("Origin", "https://example.com")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Check status code is 200 OK
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Errorf("Expected status code %d for allowed origin preflight, got %d", fasthttp.StatusOK, ctx.Response.StatusCode())
}
// Check CORS headers are set
if string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != "https://example.com" {
t.Error("Access-Control-Allow-Origin header not set correctly for allowed origin preflight")
}
// Check next handler was NOT called for OPTIONS requests
if nextCalled {
t.Error("Next handler should not be called for OPTIONS preflight requests")
}
}
// TestCorsMiddleware_PreflightNonAllowedOrigin tests OPTIONS preflight requests for non-allowed origins
func TestCorsMiddleware_PreflightNonAllowedOrigin(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{"https://allowed.com"},
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("OPTIONS")
ctx.Request.Header.Set("Origin", "https://malicious.com")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Check status code is 403 Forbidden
if ctx.Response.StatusCode() != fasthttp.StatusForbidden {
t.Errorf("Expected status code %d for non-allowed origin preflight, got %d", fasthttp.StatusForbidden, ctx.Response.StatusCode())
}
// Check CORS headers are NOT set
if len(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != 0 {
t.Error("Access-Control-Allow-Origin header should not be set for non-allowed origin preflight")
}
// Check next handler was NOT called for OPTIONS requests
if nextCalled {
t.Error("Next handler should not be called for OPTIONS preflight requests")
}
}
// TestCorsMiddleware_PreflightLocalhost tests OPTIONS preflight requests for localhost
func TestCorsMiddleware_PreflightLocalhost(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{},
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("OPTIONS")
ctx.Request.Header.Set("Origin", "http://localhost:3000")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Check status code is 200 OK
if ctx.Response.StatusCode() != fasthttp.StatusOK {
t.Errorf("Expected status code %d for localhost preflight, got %d", fasthttp.StatusOK, ctx.Response.StatusCode())
}
// Check CORS headers are set
if string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != "http://localhost:3000" {
t.Error("Access-Control-Allow-Origin header not set correctly for localhost preflight")
}
// Check next handler was NOT called for OPTIONS requests
if nextCalled {
t.Error("Next handler should not be called for OPTIONS preflight requests")
}
}
// TestCorsMiddleware_NoOriginHeader tests behavior when no Origin header is present
func TestCorsMiddleware_NoOriginHeader(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{},
},
}
ctx := &fasthttp.RequestCtx{}
// No Origin header set
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Check CORS headers are NOT set when no origin is present
if len(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != 0 {
t.Error("Access-Control-Allow-Origin header should not be set when no Origin header is present")
}
// Check next handler was called
if !nextCalled {
t.Error("Next handler was not called")
}
}
// Testlib.ChainMiddlewares_NoMiddlewares tests chaining with no middlewares
func TestChainMiddlewares_NoMiddlewares(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
handlerCalled := false
handler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
}
chained := lib.ChainMiddlewares(handler)
chained(ctx)
if !handlerCalled {
t.Error("Handler was not called when no middlewares are present")
}
}
// Testlib.ChainMiddlewares_SingleMiddleware tests chaining with a single middleware
func TestChainMiddlewares_SingleMiddleware(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
middlewareCalled := false
handlerCalled := false
middleware := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
middlewareCalled = true
next(ctx)
}
})
handler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
}
chained := lib.ChainMiddlewares(handler, middleware)
chained(ctx)
if !middlewareCalled {
t.Error("Middleware was not called")
}
if !handlerCalled {
t.Error("Handler was not called")
}
}
// Testlib.ChainMiddlewares_MultipleMiddlewares tests chaining with multiple middlewares
func TestChainMiddlewares_MultipleMiddlewares(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
executionOrder := []int{}
middleware1 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
executionOrder = append(executionOrder, 1)
next(ctx)
}
})
middleware2 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
executionOrder = append(executionOrder, 2)
next(ctx)
}
})
middleware3 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
executionOrder = append(executionOrder, 3)
next(ctx)
}
})
handler := func(ctx *fasthttp.RequestCtx) {
executionOrder = append(executionOrder, 4)
}
chained := lib.ChainMiddlewares(handler, middleware1, middleware2, middleware3)
chained(ctx)
// Check execution order: middlewares should execute in order, then handler
expectedOrder := []int{1, 2, 3, 4}
if len(executionOrder) != len(expectedOrder) {
t.Errorf("Expected %d function calls, got %d", len(expectedOrder), len(executionOrder))
}
for i, expected := range expectedOrder {
if i >= len(executionOrder) || executionOrder[i] != expected {
t.Errorf("Expected execution order %v, got %v", expectedOrder, executionOrder)
break
}
}
}
// Testlib.ChainMiddlewares_MiddlewareCanModifyContext tests that middlewares can modify the context
func TestChainMiddlewares_MiddlewareCanModifyContext(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
middleware := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
ctx.SetUserValue("test-key", "test-value")
next(ctx)
}
})
handler := func(ctx *fasthttp.RequestCtx) {
value := ctx.UserValue("test-key")
if value == nil {
t.Error("Handler did not receive modified context from middleware")
} else if value.(string) != "test-value" {
t.Errorf("Expected user value to be 'test-value', got '%s'", value.(string))
}
}
chained := lib.ChainMiddlewares(handler, middleware)
chained(ctx)
}
func TestIsInferenceWSEndpoint(t *testing.T) {
paths := []string{
"/v1/responses",
"/v1/realtime",
"/responses",
"/realtime",
"/openai/v1/responses",
"/openai/responses",
"/openai/openai/responses",
"/openai/v1/realtime",
"/openai/realtime",
"/openai/openai/realtime",
}
for _, path := range paths {
if !isInferenceWSEndpoint(path) {
t.Fatalf("expected inference websocket path %s to be recognized", path)
}
}
if isInferenceWSEndpoint("/api/ws") {
t.Fatal("dashboard websocket path should not be treated as inference websocket")
}
if isInferenceWSEndpoint("/openai/chat/completions") {
t.Fatal("non-websocket OpenAI path should not be treated as inference websocket")
}
}
func TestIsRealtimeTransportEndpoint(t *testing.T) {
paths := []string{
"/v1/realtime",
"/realtime",
"/openai/realtime",
"/openai/v1/realtime",
"/openai/openai/realtime",
"/v1/realtime/calls",
"/realtime/calls",
"/openai/realtime/calls",
"/openai/v1/realtime/calls",
"/openai/openai/realtime/calls",
}
for _, path := range paths {
if !isRealtimeTransportEndpoint(path) {
t.Fatalf("expected realtime transport path %s to be recognized", path)
}
}
nonTransportPaths := []string{
"/v1/realtime/client_secrets",
"/v1/realtime/sessions",
"/openai/v1/realtime/client_secrets",
"/openai/v1/realtime/sessions",
"/v1/chat/completions",
}
for _, path := range nonTransportPaths {
if isRealtimeTransportEndpoint(path) {
t.Fatalf("did not expect non-transport path %s to be recognized", path)
}
}
}
// Testlib.ChainMiddlewares_ShortCircuit tests that when a middleware writes a response
// and does not call next, subsequent middlewares and handler do not execute.
func TestChainMiddlewares_ShortCircuit(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
executionOrder := []int{}
// First middleware - writes response and short-circuits by not calling next
middleware1 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
executionOrder = append(executionOrder, 1)
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
ctx.SetBodyString("Unauthorized")
// Not calling next(ctx) to short-circuit
}
})
// Second middleware - should NOT execute when middleware1 short-circuits
middleware2 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
executionOrder = append(executionOrder, 2)
next(ctx)
}
})
// Third middleware - should NOT execute when middleware1 short-circuits
middleware3 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
executionOrder = append(executionOrder, 3)
next(ctx)
}
})
// Handler - should NOT execute when middleware1 short-circuits
handler := func(ctx *fasthttp.RequestCtx) {
executionOrder = append(executionOrder, 4)
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetBodyString("Success")
}
chained := lib.ChainMiddlewares(handler, middleware1, middleware2, middleware3)
chained(ctx)
// Verify only middleware1 executed
expectedOrder := []int{1}
if len(executionOrder) != len(expectedOrder) {
t.Errorf("Expected %d function calls, got %d", len(expectedOrder), len(executionOrder))
}
for i, expected := range expectedOrder {
if i >= len(executionOrder) || executionOrder[i] != expected {
t.Errorf("Expected execution order %v, got %v", expectedOrder, executionOrder)
break
}
}
// The middleware's response should be preserved (not overwritten)
if ctx.Response.StatusCode() != fasthttp.StatusUnauthorized {
t.Errorf("Expected status code %d, got %d", fasthttp.StatusUnauthorized, ctx.Response.StatusCode())
}
if string(ctx.Response.Body()) != "Unauthorized" {
t.Errorf("Expected body 'Unauthorized', got '%s'", string(ctx.Response.Body()))
}
}
// Testlib.ChainMiddlewares_ShortCircuitMiddlePosition tests that middleware in the middle
// can short-circuit, preventing later middlewares and handler from executing.
func TestChainMiddlewares_ShortCircuitMiddlePosition(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
executionOrder := []int{}
// First middleware - executes and calls next
middleware1 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
executionOrder = append(executionOrder, 1)
next(ctx)
}
})
// Second middleware - writes response and short-circuits
middleware2 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
executionOrder = append(executionOrder, 2)
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
ctx.SetBodyString("Unauthorized")
// Not calling next(ctx) to short-circuit
}
})
// Third middleware - should NOT execute
middleware3 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
executionOrder = append(executionOrder, 3)
next(ctx)
}
})
// Handler - should NOT execute
handler := func(ctx *fasthttp.RequestCtx) {
executionOrder = append(executionOrder, 4)
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetBodyString("Success")
}
chained := lib.ChainMiddlewares(handler, middleware1, middleware2, middleware3)
chained(ctx)
// Verify only middleware1 and middleware2 executed
expectedOrder := []int{1, 2}
if len(executionOrder) != len(expectedOrder) {
t.Errorf("Expected %d function calls, got %d", len(expectedOrder), len(executionOrder))
}
for i, expected := range expectedOrder {
if i >= len(executionOrder) || executionOrder[i] != expected {
t.Errorf("Expected execution order %v, got %v", expectedOrder, executionOrder)
break
}
}
// The middleware2's response should be preserved
if ctx.Response.StatusCode() != fasthttp.StatusUnauthorized {
t.Errorf("Expected status code %d, got %d", fasthttp.StatusUnauthorized, ctx.Response.StatusCode())
}
if string(ctx.Response.Body()) != "Unauthorized" {
t.Errorf("Expected body 'Unauthorized', got '%s'", string(ctx.Response.Body()))
}
}
// TestAuthMiddleware_NilAuthConfig tests that auth middleware allows requests when auth config is nil
func TestAuthMiddleware_NilAuthConfig(t *testing.T) {
SetLogger(&mockLogger{})
am := &AuthMiddleware{}
// authConfig is nil by default (simulates app start with no auth config)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/api/some-endpoint")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := am.APIMiddleware()
handler := middleware(next)
handler(ctx)
// When auth config is nil, requests should be allowed through
if !nextCalled {
t.Error("Next handler should be called when auth config is nil")
}
}
// TestAuthMiddleware_DisabledAuthConfig tests that auth middleware allows requests when auth is disabled
func TestAuthMiddleware_DisabledAuthConfig(t *testing.T) {
SetLogger(&mockLogger{})
am := &AuthMiddleware{}
am.UpdateAuthConfig(&configstore.AuthConfig{
AdminUserName: schemas.NewEnvVar("admin"),
AdminPassword: schemas.NewEnvVar("password"),
IsEnabled: false,
})
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/api/some-endpoint")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := am.APIMiddleware()
handler := middleware(next)
handler(ctx)
// When auth is disabled, requests should be allowed through
if !nextCalled {
t.Error("Next handler should be called when auth is disabled")
}
}
// TestAuthMiddleware_EnabledAuthConfig_NoAuth tests that auth middleware blocks unauthenticated requests
func TestAuthMiddleware_EnabledAuthConfig_NoAuth(t *testing.T) {
SetLogger(&mockLogger{})
am := &AuthMiddleware{}
am.UpdateAuthConfig(&configstore.AuthConfig{
AdminUserName: schemas.NewEnvVar("admin"),
AdminPassword: schemas.NewEnvVar("hashedpassword"),
IsEnabled: true,
})
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/api/some-endpoint")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := am.APIMiddleware()
handler := middleware(next)
handler(ctx)
// When auth is enabled and no auth header is provided, request should be blocked
if nextCalled {
t.Error("Next handler should NOT be called when auth is enabled and no credentials provided")
}
if ctx.Response.StatusCode() != fasthttp.StatusUnauthorized {
t.Errorf("Expected status code %d, got %d", fasthttp.StatusUnauthorized, ctx.Response.StatusCode())
}
}
// TestAuthMiddleware_WhitelistedRoutes tests that whitelisted routes bypass auth
func TestAuthMiddleware_WhitelistedRoutes(t *testing.T) {
SetLogger(&mockLogger{})
am := &AuthMiddleware{}
am.UpdateAuthConfig(&configstore.AuthConfig{
AdminUserName: schemas.NewEnvVar("admin"),
AdminPassword: schemas.NewEnvVar("hashedpassword"),
IsEnabled: true,
})
whitelistedRoutes := []string{
"/api/session/is-auth-enabled",
"/api/session/login",
"/api/oauth/callback",
"/health",
}
for _, route := range whitelistedRoutes {
t.Run(route, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI(route)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := am.APIMiddleware()
handler := middleware(next)
handler(ctx)
if !nextCalled {
t.Errorf("Next handler should be called for whitelisted route %s", route)
}
})
}
}
func TestAuthMiddleware_InferenceMiddleware_RealtimeTransportBypassesAuth(t *testing.T) {
SetLogger(&mockLogger{})
am := &AuthMiddleware{}
am.UpdateAuthConfig(&configstore.AuthConfig{
AdminUserName: schemas.NewEnvVar("admin"),
AdminPassword: schemas.NewEnvVar("hashedpassword"),
IsEnabled: true,
})
routes := []string{
"/v1/realtime",
"/openai/v1/realtime",
"/v1/realtime/calls?model=gpt-realtime",
"/openai/v1/realtime/calls?model=gpt-realtime",
}
for _, route := range routes {
t.Run(route, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI(route)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
handler := am.InferenceMiddleware()(next)
handler(ctx)
if !nextCalled {
t.Fatalf("expected realtime transport route %s to bypass auth", route)
}
})
}
}
func TestAuthMiddleware_InferenceMiddleware_RealtimeMintingStillRequiresAuth(t *testing.T) {
SetLogger(&mockLogger{})
am := &AuthMiddleware{}
am.UpdateAuthConfig(&configstore.AuthConfig{
AdminUserName: schemas.NewEnvVar("admin"),
AdminPassword: schemas.NewEnvVar("hashedpassword"),
IsEnabled: true,
})
routes := []string{
"/v1/realtime/client_secrets",
"/v1/realtime/sessions",
"/openai/v1/realtime/client_secrets",
"/openai/v1/realtime/sessions",
}
for _, route := range routes {
t.Run(route, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI(route)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
handler := am.InferenceMiddleware()(next)
handler(ctx)
if nextCalled {
t.Fatalf("expected realtime minting route %s to still require auth", route)
}
if ctx.Response.StatusCode() != fasthttp.StatusUnauthorized {
t.Fatalf("expected %d for route %s, got %d", fasthttp.StatusUnauthorized, route, ctx.Response.StatusCode())
}
})
}
}
// TestAuthMiddleware_UpdateAuthConfig_NilToEnabled tests updating auth config from nil to enabled
func TestAuthMiddleware_UpdateAuthConfig_NilToEnabled(t *testing.T) {
SetLogger(&mockLogger{})
am := &AuthMiddleware{}
// Initially auth config is nil
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/api/some-endpoint")
// First request should pass (nil config)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := am.APIMiddleware()
handler := middleware(next)
handler(ctx)
if !nextCalled {
t.Error("First request should pass when auth config is nil")
}
// Now enable auth
am.UpdateAuthConfig(&configstore.AuthConfig{
AdminUserName: schemas.NewEnvVar("admin"),
AdminPassword: schemas.NewEnvVar("hashedpassword"),
IsEnabled: true,
})
// Second request should be blocked (auth enabled, no credentials)
ctx2 := &fasthttp.RequestCtx{}
ctx2.Request.SetRequestURI("/api/some-endpoint")
nextCalled = false
handler(ctx2)
if nextCalled {
t.Error("Second request should be blocked after auth is enabled")
}
if ctx2.Response.StatusCode() != fasthttp.StatusUnauthorized {
t.Errorf("Expected status code %d, got %d", fasthttp.StatusUnauthorized, ctx2.Response.StatusCode())
}
}
// TestAuthMiddleware_UpdateAuthConfig_EnabledToDisabled tests disabling auth after it was enabled
func TestAuthMiddleware_UpdateAuthConfig_EnabledToDisabled(t *testing.T) {
SetLogger(&mockLogger{})
am := &AuthMiddleware{}
// Start with auth enabled
am.UpdateAuthConfig(&configstore.AuthConfig{
AdminUserName: schemas.NewEnvVar("admin"),
AdminPassword: schemas.NewEnvVar("hashedpassword"),
IsEnabled: true,
})
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/api/some-endpoint")
// First request should be blocked (auth enabled, no credentials)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := am.APIMiddleware()
handler := middleware(next)
handler(ctx)
if nextCalled {
t.Error("First request should be blocked when auth is enabled")
}
// Now disable auth
am.UpdateAuthConfig(&configstore.AuthConfig{
AdminUserName: schemas.NewEnvVar("admin"),
AdminPassword: schemas.NewEnvVar("hashedpassword"),
IsEnabled: false,
})
// Second request should pass (auth disabled)
ctx2 := &fasthttp.RequestCtx{}
ctx2.Request.SetRequestURI("/api/some-endpoint")
nextCalled = false
handler(ctx2)
if !nextCalled {
t.Error("Second request should pass after auth is disabled")
}
}
// TestFasthttpToHTTPRequest tests the conversion from fasthttp context to HTTPRequest
func TestFasthttpToHTTPRequest(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
// Set up test data
ctx.Request.Header.SetMethod("POST")
// Query params include: integers, floats, booleans, timestamps, and strings with special chars
ctx.Request.SetRequestURI("/api/v1/test?limit=100&offset=50&min_cost=12.50&max_latency=1500.75&missing_cost_only=true&start_time=2023-01-15T10:30:00Z&content_search=test+query&special=%2B%26%3D%3F")
ctx.Request.Header.Set("Content-Type", "application/json")
ctx.Request.Header.Set("Authorization", "Bearer token123")
ctx.Request.Header.Set("X-Request-Id", "12345")
ctx.Request.Header.Set("X-Custom-Header", "value-with-dashes")
ctx.Request.SetBodyString(`{"key": "value", "number": 42, "nested": {"bool": true}}`)
// Acquire HTTPRequest from pool
req := schemas.AcquireHTTPRequest()
defer schemas.ReleaseHTTPRequest(req)
// Call the function
fasthttpToHTTPRequest(ctx, req)
// Verify Method
if req.Method != "POST" {
t.Errorf("Expected Method to be 'POST', got '%s'", req.Method)
}
// Verify Path (without query params)
if req.Path != "/api/v1/test" {
t.Errorf("Expected Path to be '/api/v1/test', got '%s'", req.Path)
}
// Verify Headers
expectedHeaders := map[string]string{
"Content-Type": "application/json",
"Authorization": "Bearer token123",
"X-Request-Id": "12345",
"X-Custom-Header": "value-with-dashes",
}
for key, expectedValue := range expectedHeaders {
if actualValue, exists := req.Headers[key]; !exists {
t.Errorf("Expected header '%s' to exist", key)
} else if actualValue != expectedValue {
t.Errorf("Expected header '%s' to be '%s', got '%s'", key, expectedValue, actualValue)
}
}
// Verify Query params
expectedQuery := map[string]string{
"limit": "100", // integer
"offset": "50", // integer
"min_cost": "12.50", // float
"max_latency": "1500.75", // float
"missing_cost_only": "true", // boolean
"start_time": "2023-01-15T10:30:00Z", // timestamp
"content_search": "test query", // string with space (decoded)
"special": "+&=?", // special characters (decoded)
}
for key, expectedValue := range expectedQuery {
if actualValue, exists := req.Query[key]; !exists {
t.Errorf("Expected query param '%s' to exist", key)
} else if actualValue != expectedValue {
t.Errorf("Expected query param '%s' to be '%s', got '%s'", key, expectedValue, actualValue)
}
}
// Verify Body (JSON with various types)
expectedBody := `{"key": "value", "number": 42, "nested": {"bool": true}}`
if string(req.Body) != expectedBody {
t.Errorf("Expected Body to be '%s', got '%s'", expectedBody, string(req.Body))
}
// Verify body is a copy, not a reference
originalBody := ctx.Request.Body()
if len(req.Body) > 0 && len(originalBody) > 0 {
// Modify the HTTPRequest body
req.Body[0] = 'X'
// Original should remain unchanged
if originalBody[0] == 'X' {
t.Error("Body should be a copy, not a reference to the original")
}
}
}
// TestCorsMiddleware_DefaultHeaders tests that default CORS headers are set
func TestCorsMiddleware_DefaultHeaders(t *testing.T) {
SetLogger(&mockLogger{})
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{"https://example.com"},
AllowedHeaders: []string{}, // No custom headers
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Origin", "https://example.com")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Check default headers are set
expectedHeaders := "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout, X-Api-Key, X-OpenAI-Agents-SDK"
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
if actualHeaders != expectedHeaders {
t.Errorf("Expected Access-Control-Allow-Headers to be %s, got %s", expectedHeaders, actualHeaders)
}
if !nextCalled {
t.Error("Next handler was not called")
}
}
// TestCorsMiddleware_WildcardHeaders_NonCredentialed tests that wildcard allowed headers
// sets Access-Control-Allow-Headers to * for non-credentialed requests (wildcard origins).
func TestCorsMiddleware_WildcardHeaders_NonCredentialed(t *testing.T) {
SetLogger(&mockLogger{})
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{"*"},
AllowedHeaders: []string{"*"},
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Origin", "https://example.com")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Non-credentialed: wildcard is valid per spec
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
if actualHeaders != "*" {
t.Errorf("Expected Access-Control-Allow-Headers to be *, got %s", actualHeaders)
}
if !nextCalled {
t.Error("Next handler was not called")
}
}
// TestCorsMiddleware_WildcardHeaders_CredentialedPreflight tests that wildcard allowed headers
// reflects Access-Control-Request-Headers for credentialed preflight requests instead of sending
// the literal *, which browsers don't treat as a wildcard when credentials are present.
func TestCorsMiddleware_WildcardHeaders_CredentialedPreflight(t *testing.T) {
SetLogger(&mockLogger{})
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{"https://example.com"},
AllowedHeaders: []string{"*"},
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("OPTIONS")
ctx.Request.Header.Set("Origin", "https://example.com")
ctx.Request.Header.Set("Access-Control-Request-Headers", "Authorization, X-Custom-Header")
next := func(ctx *fasthttp.RequestCtx) {
t.Error("Next handler should not be called for preflight")
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Credentialed preflight: should reflect requested headers, not *
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
if actualHeaders != "Authorization, X-Custom-Header" {
t.Errorf("Expected Access-Control-Allow-Headers to reflect requested headers, got %s", actualHeaders)
}
// Should also have credentials
creds := string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials"))
if creds != "true" {
t.Errorf("Expected Access-Control-Allow-Credentials to be true, got %s", creds)
}
}
// TestCorsMiddleware_WildcardHeaders_CredentialedNonPreflight tests that wildcard allowed headers
// uses defaults for credentialed non-preflight requests (no Access-Control-Request-Headers).
func TestCorsMiddleware_WildcardHeaders_CredentialedNonPreflight(t *testing.T) {
SetLogger(&mockLogger{})
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{"https://example.com"},
AllowedHeaders: []string{"*"},
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Origin", "https://example.com")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Credentialed non-preflight: should use defaults (not *)
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
defaultHeaders := []string{"Content-Type", "Authorization", "X-Requested-With", "X-Stainless-Timeout", "X-Api-Key"}
for _, header := range defaultHeaders {
if !containsHeader(actualHeaders, header) {
t.Errorf("Expected Access-Control-Allow-Headers to contain %s, got %s", header, actualHeaders)
}
}
if actualHeaders == "*" {
t.Error("Expected Access-Control-Allow-Headers to NOT be * for credentialed requests")
}
if !nextCalled {
t.Error("Next handler was not called")
}
}
// TestCorsMiddleware_CustomHeaders tests that custom allowed headers are appended to defaults
func TestCorsMiddleware_CustomHeaders(t *testing.T) {
SetLogger(&mockLogger{})
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{"https://example.com"},
AllowedHeaders: []string{"X-Custom-Header", "X-Another-Header"},
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Origin", "https://example.com")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Check that custom headers are included along with defaults
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
expectedHeaders := []string{
"Content-Type",
"Authorization",
"X-Requested-With",
"X-Stainless-Timeout",
"X-Custom-Header",
"X-Another-Header",
}
for _, header := range expectedHeaders {
if !containsHeader(actualHeaders, header) {
t.Errorf("Expected Access-Control-Allow-Headers to contain %s, got %s", header, actualHeaders)
}
}
if !nextCalled {
t.Error("Next handler was not called")
}
}
// TestCorsMiddleware_DuplicateHeaders tests that duplicate headers are not added twice
func TestCorsMiddleware_DuplicateHeaders(t *testing.T) {
SetLogger(&mockLogger{})
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{"https://example.com"},
// Include a header that's already in defaults
AllowedHeaders: []string{"Content-Type", "X-Custom-Header"},
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Origin", "https://example.com")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Check headers - Content-Type should not be duplicated
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
// Count occurrences of "Content-Type"
count := countHeaderOccurrences(actualHeaders, "Content-Type")
if count != 1 {
t.Errorf("Expected Content-Type to appear once, but appeared %d times in: %s", count, actualHeaders)
}
// Custom header should be present
if !containsHeader(actualHeaders, "X-Custom-Header") {
t.Errorf("Expected Access-Control-Allow-Headers to contain X-Custom-Header, got %s", actualHeaders)
}
if !nextCalled {
t.Error("Next handler was not called")
}
}
// TestCorsMiddleware_CustomHeadersWithLocalhost tests custom headers work with localhost origins
func TestCorsMiddleware_CustomHeadersWithLocalhost(t *testing.T) {
SetLogger(&mockLogger{})
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{},
AllowedHeaders: []string{"X-Development-Header"},
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Origin", "http://localhost:3000")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Check that custom header is included for localhost
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
if !containsHeader(actualHeaders, "X-Development-Header") {
t.Errorf("Expected Access-Control-Allow-Headers to contain X-Development-Header for localhost, got %s", actualHeaders)
}
if !nextCalled {
t.Error("Next handler was not called")
}
}
// TestCorsMiddleware_CustomHeadersNotSetForNonAllowedOrigin tests that CORS headers (including custom) are not set for non-allowed origins
func TestCorsMiddleware_CustomHeadersNotSetForNonAllowedOrigin(t *testing.T) {
SetLogger(&mockLogger{})
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
AllowedOrigins: []string{"https://allowed.com"},
AllowedHeaders: []string{"X-Custom-Header"},
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Origin", "https://malicious.com")
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
middleware := CorsMiddleware(config)
handler := middleware(next)
handler(ctx)
// Check CORS headers are NOT set (including Allow-Headers)
if len(ctx.Response.Header.Peek("Access-Control-Allow-Headers")) != 0 {
t.Error("Access-Control-Allow-Headers header should not be set for non-allowed origin")
}
// Check next handler was still called for non-OPTIONS requests
if !nextCalled {
t.Error("Next handler was not called")
}
}
// Helper function to check if a header is present in the comma-separated list
func containsHeader(headerList, header string) bool {
headers := splitHeaders(headerList)
for _, h := range headers {
if h == header {
return true
}
}
return false
}
// Helper function to split and trim headers
func splitHeaders(headerList string) []string {
// Simple split by comma and trim spaces
var headers []string
start := 0
for i := 0; i < len(headerList); i++ {
if headerList[i] == ',' {
header := headerList[start:i]
// Trim spaces
for len(header) > 0 && header[0] == ' ' {
header = header[1:]
}
for len(header) > 0 && header[len(header)-1] == ' ' {
header = header[:len(header)-1]
}
if header != "" {
headers = append(headers, header)
}
start = i + 1
}
}
// Add last header
if start < len(headerList) {
header := headerList[start:]
// Trim spaces
for len(header) > 0 && header[0] == ' ' {
header = header[1:]
}
for len(header) > 0 && header[len(header)-1] == ' ' {
header = header[:len(header)-1]
}
if header != "" {
headers = append(headers, header)
}
}
return headers
}
// Helper function to count occurrences of a header
func countHeaderOccurrences(headerList, header string) int {
headers := splitHeaders(headerList)
count := 0
for _, h := range headers {
if h == header {
count++
}
}
return count
}
// TestFasthttpToHTTPRequest_PathParams tests that path parameters are extracted correctly
func TestFasthttpToHTTPRequest_PathParams(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
// Set up test data
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/v1beta/files/file-abc123")
// Simulate what the fasthttp router does - set path params as user values
ctx.SetUserValue("file_id", "file-abc123")
ctx.SetUserValue("model", "gemini-pro")
// Set some system values that should be ignored
ctx.SetUserValue("BifrostContextKeyRequestID", "req-123")
ctx.SetUserValue("trace_id", "trace-456")
ctx.SetUserValue("span_id", "span-789")
// Acquire HTTPRequest from pool
req := schemas.AcquireHTTPRequest()
defer schemas.ReleaseHTTPRequest(req)
// Call the function
fasthttpToHTTPRequest(ctx, req)
// Verify path parameters are extracted
expectedPathParams := map[string]string{
"file_id": "file-abc123",
"model": "gemini-pro",
}
if len(req.PathParams) != len(expectedPathParams) {
t.Errorf("Expected %d path params, got %d", len(expectedPathParams), len(req.PathParams))
}
for key, expectedValue := range expectedPathParams {
if actualValue, exists := req.PathParams[key]; !exists {
t.Errorf("Expected path param '%s' to exist", key)
} else if actualValue != expectedValue {
t.Errorf("Expected path param '%s' to be '%s', got '%s'", key, expectedValue, actualValue)
}
}
// Verify system keys are NOT in path params
systemKeys := []string{"BifrostContextKeyRequestID", "trace_id", "span_id"}
for _, key := range systemKeys {
if _, exists := req.PathParams[key]; exists {
t.Errorf("System key '%s' should not be in path params", key)
}
}
// Test the helper method
if fileID := req.CaseInsensitivePathParamLookup("file_id"); fileID != "file-abc123" {
t.Errorf("CaseInsensitivePathParamLookup failed: expected 'file-abc123', got '%s'", fileID)
}
if fileID := req.CaseInsensitivePathParamLookup("FILE_ID"); fileID != "file-abc123" {
t.Errorf("CaseInsensitivePathParamLookup should be case-insensitive: expected 'file-abc123', got '%s'", fileID)
}
}
func TestRequestDecompressionMiddleware_SupportedEncodings(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 10,
},
}
plainBody := []byte(`{"model":"openai/gpt-4o-mini","messages":[{"role":"user","content":"hello"}]}`)
testCases := []struct {
name string
encoding string
encode func([]byte) ([]byte, error)
}{
{name: "gzip", encoding: "gzip", encode: gzipCompress},
{name: "deflate", encoding: "deflate", encode: deflateCompress},
{name: "brotli", encoding: "br", encode: brotliCompress},
{name: "zstd", encoding: "zstd", encode: zstdCompress},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
compressedBody, err := tc.encode(plainBody)
if err != nil {
t.Fatalf("failed to encode body: %v", err)
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetContentType("application/json")
ctx.Request.Header.Set("Content-Encoding", tc.encoding)
ctx.Request.SetBodyRaw(compressedBody)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
if string(ctx.Request.Body()) != string(plainBody) {
t.Fatalf("expected decompressed body, got %q", string(ctx.Request.Body()))
}
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
if !nextCalled {
t.Fatal("next handler was not called")
}
if got := string(ctx.Request.Header.Peek("Content-Encoding")); got != "" {
t.Fatalf("expected content-encoding to be cleared, got %q", got)
}
if got := string(ctx.Request.Header.Peek("Content-Length")); got != "" {
t.Fatalf("expected content-length to be cleared, got %q", got)
}
})
}
}
func TestRequestDecompressionMiddleware_InvalidCompressedBody(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 10,
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("Content-Encoding", "gzip")
ctx.Request.SetBodyRaw([]byte("not-a-valid-gzip-payload"))
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
if nextCalled {
t.Fatal("next handler should not be called for invalid compressed payload")
}
if ctx.Response.StatusCode() != fasthttp.StatusBadRequest {
t.Fatalf("expected status 400, got %d", ctx.Response.StatusCode())
}
var bifrostErr schemas.BifrostError
if err := json.Unmarshal(ctx.Response.Body(), &bifrostErr); err != nil {
t.Fatalf("failed to decode error response: %v", err)
}
if bifrostErr.Error == nil || !strings.Contains(bifrostErr.Error.Message, "invalid compressed request body") {
t.Fatalf("unexpected error message: %#v", bifrostErr.Error)
}
}
func TestRequestDecompressionMiddleware_UnsupportedEncoding(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 10,
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("Content-Encoding", "snappy")
ctx.Request.SetBodyRaw([]byte("whatever"))
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
if nextCalled {
t.Fatal("next handler should not be called for unsupported content-encoding")
}
if ctx.Response.StatusCode() != fasthttp.StatusBadRequest {
t.Fatalf("expected status 400, got %d", ctx.Response.StatusCode())
}
var bifrostErr schemas.BifrostError
if err := json.Unmarshal(ctx.Response.Body(), &bifrostErr); err != nil {
t.Fatalf("failed to decode error response: %v", err)
}
if bifrostErr.Error == nil || !strings.Contains(bifrostErr.Error.Message, "unsupported Content-Encoding") {
t.Fatalf("unexpected error message: %#v", bifrostErr.Error)
}
}
func TestRequestDecompressionMiddleware_DecompressedSizeLimit(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 1,
},
}
plainBody := bytes.Repeat([]byte("a"), (1024*1024)+10)
compressedBody, err := gzipCompress(plainBody)
if err != nil {
t.Fatalf("failed to gzip test payload: %v", err)
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("Content-Encoding", "gzip")
ctx.Request.SetBodyRaw(compressedBody)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
if nextCalled {
t.Fatal("next handler should not be called when decompressed body exceeds limit")
}
if ctx.Response.StatusCode() != fasthttp.StatusRequestEntityTooLarge {
t.Fatalf("expected status 413, got %d", ctx.Response.StatusCode())
}
var bifrostErr schemas.BifrostError
if err := json.Unmarshal(ctx.Response.Body(), &bifrostErr); err != nil {
t.Fatalf("failed to decode error response: %v", err)
}
if bifrostErr.Error == nil || !strings.Contains(bifrostErr.Error.Message, "decompressed request body exceeds max allowed size") {
t.Fatalf("unexpected error message: %#v", bifrostErr.Error)
}
}
func TestRequestDecompressionMiddleware_EmptyBodyWithContentEncoding(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 10,
},
}
encodings := []string{"gzip", "deflate", "br", "zstd"}
for _, enc := range encodings {
t.Run(enc, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("Content-Encoding", enc)
ctx.Request.SetBodyRaw([]byte{})
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
// Empty body with Content-Encoding should return 400 (decoders fail on empty input)
if nextCalled {
// Some decoders may produce empty output — that's acceptable too
return
}
if ctx.Response.StatusCode() != fasthttp.StatusBadRequest {
t.Fatalf("expected status 400, got %d", ctx.Response.StatusCode())
}
})
}
}
func TestRequestDecompressionMiddleware_NoContentEncoding(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 10,
},
}
originalBody := []byte(`{"model":"openai/gpt-4o-mini","messages":[{"role":"user","content":"hello"}]}`)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetContentType("application/json")
ctx.Request.SetBodyRaw(originalBody)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
if string(ctx.Request.Body()) != string(originalBody) {
t.Fatalf("expected body to be unchanged, got %q", string(ctx.Request.Body()))
}
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
if !nextCalled {
t.Fatal("next handler was not called")
}
}
func TestRequestDecompressionMiddleware_ExactSizeLimit(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 1,
},
}
plainBody := bytes.Repeat([]byte("a"), 1024*1024) // exactly 1 MB
compressedBody, err := gzipCompress(plainBody)
if err != nil {
t.Fatalf("failed to gzip test payload: %v", err)
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("Content-Encoding", "gzip")
ctx.Request.SetBodyRaw(compressedBody)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
if len(ctx.Request.Body()) != 1024*1024 {
t.Fatalf("expected body length %d, got %d", 1024*1024, len(ctx.Request.Body()))
}
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
if !nextCalled {
t.Fatal("next handler was not called — body exactly at limit should pass")
}
}
// --- Streaming decompression path tests ---
func TestShouldStreamDecompress(t *testing.T) {
defaultThreshold := int(schemas.DefaultLargePayloadRequestThresholdBytes)
tests := []struct {
name string
contentLength int
customThreshold int64 // 0 means no custom threshold (use default)
want bool
}{
{"chunked (CL=-1)", -1, 0, true},
{"empty body (CL=0)", 0, 0, false},
{"small body", 100, 0, false},
{"at default threshold", defaultThreshold, 0, false},
{"above default threshold", defaultThreshold + 1, 0, true},
// Custom enterprise threshold (1MB) — body at 2MB should stream.
{"above custom threshold", 2 * 1024 * 1024, 1 * 1024 * 1024, true},
// Custom enterprise threshold (20MB) — body at default 10MB+1 should NOT stream.
{"below custom threshold", defaultThreshold + 1, 20 * 1024 * 1024, false},
// Chunked always streams regardless of custom threshold.
{"chunked with custom threshold", -1, 50 * 1024 * 1024, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &lib.Config{}
if tt.customThreshold > 0 {
cfg.StreamingDecompressThreshold = tt.customThreshold
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Content-Encoding", "gzip")
if tt.contentLength >= 0 {
ctx.Request.Header.SetContentLength(tt.contentLength)
} else {
// Simulate chunked: set body stream with unknown size
ctx.Request.SetBodyStream(bytes.NewReader(nil), -1)
}
if got := shouldStreamDecompress(cfg, ctx); got != tt.want {
t.Errorf("shouldStreamDecompress() = %v, want %v (CL=%d, threshold=%d)", got, tt.want, tt.contentLength, tt.customThreshold)
}
})
}
}
func TestRequestDecompressionMiddleware_StreamingPath_ChunkedGzip(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 100,
},
}
plainBody := []byte(`{"model":"openai/gpt-4o-mini","messages":[{"role":"user","content":"hello"}]}`)
compressedBody, err := gzipCompress(plainBody)
if err != nil {
t.Fatalf("failed to gzip: %v", err)
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("Content-Encoding", "gzip")
// Chunked: SetBodyStream with size -1 triggers the streaming path
ctx.Request.SetBodyStream(bytes.NewReader(compressedBody), -1)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
// Content-Encoding should be cleared
if ce := string(ctx.Request.Header.Peek("Content-Encoding")); ce != "" {
t.Errorf("expected Content-Encoding to be cleared, got %q", ce)
}
// Body should be correctly decompressed
body := ctx.Request.Body()
if string(body) != string(plainBody) {
t.Errorf("decompressed body mismatch: got %d bytes, want %d bytes", len(body), len(plainBody))
}
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
if !nextCalled {
t.Fatal("next handler was not called")
}
}
func TestRequestDecompressionMiddleware_StreamingPath_AllEncodings(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 100,
},
}
plainBody := []byte(`{"model":"openai/gpt-4o-mini","messages":[{"role":"user","content":"hello"}]}`)
testCases := []struct {
name string
encoding string
encode func([]byte) ([]byte, error)
}{
{name: "gzip", encoding: "gzip", encode: gzipCompress},
{name: "deflate", encoding: "deflate", encode: deflateCompress},
{name: "brotli", encoding: "br", encode: brotliCompress},
{name: "zstd", encoding: "zstd", encode: zstdCompress},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
compressedBody, err := tc.encode(plainBody)
if err != nil {
t.Fatalf("failed to encode body: %v", err)
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("Content-Encoding", tc.encoding)
// Use chunked (-1) to trigger streaming path regardless of compressed size
ctx.Request.SetBodyStream(bytes.NewReader(compressedBody), -1)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
body := ctx.Request.Body()
if string(body) != string(plainBody) {
t.Fatalf("expected decompressed body, got %q", string(body))
}
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
if !nextCalled {
t.Fatal("next handler was not called")
}
if got := string(ctx.Request.Header.Peek("Content-Encoding")); got != "" {
t.Fatalf("expected content-encoding to be cleared, got %q", got)
}
})
}
}
func TestRequestDecompressionMiddleware_StreamingPath_InvalidBody(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 100,
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("Content-Encoding", "gzip")
// Chunked with invalid gzip data → streaming path → error
ctx.Request.SetBodyStream(bytes.NewReader([]byte("not-a-valid-gzip-payload")), -1)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
if nextCalled {
t.Fatal("next handler should not be called for invalid compressed payload")
}
if ctx.Response.StatusCode() != fasthttp.StatusBadRequest {
t.Fatalf("expected status 400, got %d", ctx.Response.StatusCode())
}
}
func TestRequestDecompressionMiddleware_StreamingPath_UnsupportedEncoding(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 100,
},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("Content-Encoding", "snappy")
ctx.Request.SetBodyStream(bytes.NewReader([]byte("whatever")), -1)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
if nextCalled {
t.Fatal("next handler should not be called for unsupported encoding")
}
if ctx.Response.StatusCode() != fasthttp.StatusBadRequest {
t.Fatalf("expected status 400, got %d", ctx.Response.StatusCode())
}
}
func TestRequestDecompressionMiddleware_BufferedPath_SmallGzip(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 10,
},
}
plainBody := []byte(`{"model":"openai/gpt-4o-mini","messages":[{"role":"user","content":"hello"}]}`)
compressedBody, err := gzipCompress(plainBody)
if err != nil {
t.Fatalf("failed to gzip: %v", err)
}
// Verify compressed body is below threshold (should use buffered path)
if int64(len(compressedBody)) > schemas.DefaultLargePayloadRequestThresholdBytes {
t.Skip("compressed body unexpectedly exceeds threshold")
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("Content-Encoding", "gzip")
// SetBodyRaw with known small Content-Length → buffered path
ctx.Request.SetBodyRaw(compressedBody)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
if string(ctx.Request.Body()) != string(plainBody) {
t.Fatalf("expected decompressed body, got %q", string(ctx.Request.Body()))
}
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
if !nextCalled {
t.Fatal("next handler was not called")
}
}
func TestRequestDecompressionMiddleware_StreamingPath_LargeGzip(t *testing.T) {
config := &lib.Config{
ClientConfig: &configstore.ClientConfig{
MaxRequestBodySizeMB: 100,
},
}
// Random bytes are incompressible — compressed size ≈ input size + gzip overhead.
bodySize := int(schemas.DefaultLargePayloadRequestThresholdBytes) + 1024*1024
plainBody := make([]byte, bodySize)
if _, err := cryptoRand.Read(plainBody); err != nil {
t.Fatalf("failed to generate random data: %v", err)
}
compressedBody, err := gzipCompress(plainBody)
if err != nil {
t.Fatalf("failed to gzip: %v", err)
}
if int64(len(compressedBody)) <= schemas.DefaultLargePayloadRequestThresholdBytes {
t.Skipf("compressed body %d bytes is below threshold %d",
len(compressedBody), schemas.DefaultLargePayloadRequestThresholdBytes)
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.Set("Content-Encoding", "gzip")
ctx.Request.Header.SetContentLength(len(compressedBody))
ctx.Request.SetBodyStream(bytes.NewReader(compressedBody), len(compressedBody))
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
if ce := string(ctx.Request.Header.Peek("Content-Encoding")); ce != "" {
t.Errorf("expected Content-Encoding to be cleared, got %q", ce)
}
body := ctx.Request.Body()
if len(body) != len(plainBody) {
t.Errorf("decompressed body length: got %d, want %d", len(body), len(plainBody))
}
if !bytes.Equal(body, plainBody) {
t.Error("decompressed body content does not match original")
}
}
handler := RequestDecompressionMiddleware(config)(next)
handler(ctx)
if !nextCalled {
t.Fatal("next handler was not called")
}
}
func gzipCompress(data []byte) ([]byte, error) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
if _, err := gz.Write(data); err != nil {
return nil, err
}
if err := gz.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// deflateCompress produces zlib-wrapped DEFLATE (RFC 1950) — the correct
// format for HTTP Content-Encoding "deflate" per RFC 9110 §8.4.1.2.
func deflateCompress(data []byte) ([]byte, error) {
var buf bytes.Buffer
w, err := zlib.NewWriterLevel(&buf, zlib.DefaultCompression)
if err != nil {
return nil, err
}
if _, err := w.Write(data); err != nil {
_ = w.Close()
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func brotliCompress(data []byte) ([]byte, error) {
var buf bytes.Buffer
w := brotli.NewWriter(&buf)
if _, err := w.Write(data); err != nil {
_ = w.Close()
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func zstdCompress(data []byte) ([]byte, error) {
var buf bytes.Buffer
enc, err := zstd.NewWriter(&buf)
if err != nil {
return nil, err
}
if _, err := io.Copy(enc, bytes.NewReader(data)); err != nil {
enc.Close()
return nil, err
}
if err := enc.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}