2023 lines
60 KiB
Go
2023 lines
60 KiB
Go
package handlers
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"compress/zlib"
|
|
cryptoRand "crypto/rand"
|
|
"encoding/json"
|
|
"io"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/andybalholm/brotli"
|
|
"github.com/klauspost/compress/zstd"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
"github.com/maximhq/bifrost/framework/configstore"
|
|
"github.com/maximhq/bifrost/transports/bifrost-http/lib"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
// mockLogger is a mock implementation of schemas.Logger for testing
|
|
type mockLogger struct{}
|
|
|
|
func (m *mockLogger) Debug(format string, args ...any) {}
|
|
func (m *mockLogger) Info(format string, args ...any) {}
|
|
func (m *mockLogger) Warn(format string, args ...any) {}
|
|
func (m *mockLogger) Error(format string, args ...any) {}
|
|
func (m *mockLogger) Fatal(format string, args ...any) {}
|
|
func (m *mockLogger) SetLevel(level schemas.LogLevel) {}
|
|
func (m *mockLogger) SetOutputType(outputType schemas.LoggerOutputType) {}
|
|
func (m *mockLogger) LogHTTPRequest(level schemas.LogLevel, msg string) schemas.LogEventBuilder {
|
|
return schemas.NoopLogEvent
|
|
}
|
|
|
|
// TestCorsMiddleware_LocalhostOrigins tests that localhost origins are always allowed
|
|
func TestCorsMiddleware_LocalhostOrigins(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{},
|
|
},
|
|
}
|
|
|
|
SetLogger(&mockLogger{})
|
|
|
|
localhostOrigins := []string{
|
|
"http://localhost:3000",
|
|
"https://localhost:3000",
|
|
"http://127.0.0.1:8080",
|
|
"http://0.0.0.0:5000",
|
|
"https://127.0.0.1:3000",
|
|
}
|
|
|
|
for _, origin := range localhostOrigins {
|
|
t.Run(origin, func(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.Set("Origin", origin)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Check CORS headers are set
|
|
if string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != origin {
|
|
t.Errorf("Expected Access-Control-Allow-Origin to be %s, got %s", origin, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
|
}
|
|
if string(ctx.Response.Header.Peek("Access-Control-Allow-Methods")) != "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD" {
|
|
t.Errorf("Access-Control-Allow-Methods header not set correctly")
|
|
}
|
|
if string(ctx.Response.Header.Peek("Access-Control-Allow-Headers")) != "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout, X-Api-Key, X-OpenAI-Agents-SDK" {
|
|
t.Errorf("Access-Control-Allow-Headers header not set correctly")
|
|
}
|
|
if string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials")) != "true" {
|
|
t.Errorf("Access-Control-Allow-Credentials header not set correctly")
|
|
}
|
|
if string(ctx.Response.Header.Peek("Access-Control-Max-Age")) != "86400" {
|
|
t.Errorf("Access-Control-Max-Age header not set correctly")
|
|
}
|
|
|
|
// Check next handler was called
|
|
if !nextCalled {
|
|
t.Error("Next handler was not called")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_ConfiguredOrigins tests that configured allowed origins work
|
|
func TestCorsMiddleware_ConfiguredOrigins(t *testing.T) {
|
|
allowedOrigin := "https://example.com"
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{allowedOrigin},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.Set("Origin", allowedOrigin)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Check CORS headers are set
|
|
if string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != allowedOrigin {
|
|
t.Errorf("Expected Access-Control-Allow-Origin to be %s, got %s", allowedOrigin, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
|
}
|
|
|
|
// Check next handler was called
|
|
if !nextCalled {
|
|
t.Error("Next handler was not called")
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_NonAllowedOrigins tests that non-allowed origins don't get CORS headers
|
|
func TestCorsMiddleware_NonAllowedOrigins(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{"https://allowed.com"},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.Set("Origin", "https://malicious.com")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Check CORS headers are NOT set
|
|
if len(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != 0 {
|
|
t.Error("Access-Control-Allow-Origin header should not be set for non-allowed origin")
|
|
}
|
|
|
|
// Check next handler was still called for non-OPTIONS requests
|
|
if !nextCalled {
|
|
t.Error("Next handler was not called")
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_PreflightAllowedOrigin tests OPTIONS preflight requests for allowed origins
|
|
func TestCorsMiddleware_PreflightAllowedOrigin(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{"https://example.com"},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("OPTIONS")
|
|
ctx.Request.Header.Set("Origin", "https://example.com")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Check status code is 200 OK
|
|
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
|
t.Errorf("Expected status code %d for allowed origin preflight, got %d", fasthttp.StatusOK, ctx.Response.StatusCode())
|
|
}
|
|
|
|
// Check CORS headers are set
|
|
if string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != "https://example.com" {
|
|
t.Error("Access-Control-Allow-Origin header not set correctly for allowed origin preflight")
|
|
}
|
|
|
|
// Check next handler was NOT called for OPTIONS requests
|
|
if nextCalled {
|
|
t.Error("Next handler should not be called for OPTIONS preflight requests")
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_PreflightNonAllowedOrigin tests OPTIONS preflight requests for non-allowed origins
|
|
func TestCorsMiddleware_PreflightNonAllowedOrigin(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{"https://allowed.com"},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("OPTIONS")
|
|
ctx.Request.Header.Set("Origin", "https://malicious.com")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Check status code is 403 Forbidden
|
|
if ctx.Response.StatusCode() != fasthttp.StatusForbidden {
|
|
t.Errorf("Expected status code %d for non-allowed origin preflight, got %d", fasthttp.StatusForbidden, ctx.Response.StatusCode())
|
|
}
|
|
|
|
// Check CORS headers are NOT set
|
|
if len(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != 0 {
|
|
t.Error("Access-Control-Allow-Origin header should not be set for non-allowed origin preflight")
|
|
}
|
|
|
|
// Check next handler was NOT called for OPTIONS requests
|
|
if nextCalled {
|
|
t.Error("Next handler should not be called for OPTIONS preflight requests")
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_PreflightLocalhost tests OPTIONS preflight requests for localhost
|
|
func TestCorsMiddleware_PreflightLocalhost(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("OPTIONS")
|
|
ctx.Request.Header.Set("Origin", "http://localhost:3000")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Check status code is 200 OK
|
|
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
|
t.Errorf("Expected status code %d for localhost preflight, got %d", fasthttp.StatusOK, ctx.Response.StatusCode())
|
|
}
|
|
|
|
// Check CORS headers are set
|
|
if string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != "http://localhost:3000" {
|
|
t.Error("Access-Control-Allow-Origin header not set correctly for localhost preflight")
|
|
}
|
|
|
|
// Check next handler was NOT called for OPTIONS requests
|
|
if nextCalled {
|
|
t.Error("Next handler should not be called for OPTIONS preflight requests")
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_NoOriginHeader tests behavior when no Origin header is present
|
|
func TestCorsMiddleware_NoOriginHeader(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
// No Origin header set
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Check CORS headers are NOT set when no origin is present
|
|
if len(ctx.Response.Header.Peek("Access-Control-Allow-Origin")) != 0 {
|
|
t.Error("Access-Control-Allow-Origin header should not be set when no Origin header is present")
|
|
}
|
|
|
|
// Check next handler was called
|
|
if !nextCalled {
|
|
t.Error("Next handler was not called")
|
|
}
|
|
}
|
|
|
|
// Testlib.ChainMiddlewares_NoMiddlewares tests chaining with no middlewares
|
|
func TestChainMiddlewares_NoMiddlewares(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
handlerCalled := false
|
|
|
|
handler := func(ctx *fasthttp.RequestCtx) {
|
|
handlerCalled = true
|
|
}
|
|
|
|
chained := lib.ChainMiddlewares(handler)
|
|
chained(ctx)
|
|
|
|
if !handlerCalled {
|
|
t.Error("Handler was not called when no middlewares are present")
|
|
}
|
|
}
|
|
|
|
// Testlib.ChainMiddlewares_SingleMiddleware tests chaining with a single middleware
|
|
func TestChainMiddlewares_SingleMiddleware(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
middlewareCalled := false
|
|
handlerCalled := false
|
|
|
|
middleware := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
middlewareCalled = true
|
|
next(ctx)
|
|
}
|
|
})
|
|
|
|
handler := func(ctx *fasthttp.RequestCtx) {
|
|
handlerCalled = true
|
|
}
|
|
|
|
chained := lib.ChainMiddlewares(handler, middleware)
|
|
chained(ctx)
|
|
|
|
if !middlewareCalled {
|
|
t.Error("Middleware was not called")
|
|
}
|
|
if !handlerCalled {
|
|
t.Error("Handler was not called")
|
|
}
|
|
}
|
|
|
|
// Testlib.ChainMiddlewares_MultipleMiddlewares tests chaining with multiple middlewares
|
|
func TestChainMiddlewares_MultipleMiddlewares(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
executionOrder := []int{}
|
|
|
|
middleware1 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
executionOrder = append(executionOrder, 1)
|
|
next(ctx)
|
|
}
|
|
})
|
|
|
|
middleware2 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
executionOrder = append(executionOrder, 2)
|
|
next(ctx)
|
|
}
|
|
})
|
|
|
|
middleware3 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
executionOrder = append(executionOrder, 3)
|
|
next(ctx)
|
|
}
|
|
})
|
|
|
|
handler := func(ctx *fasthttp.RequestCtx) {
|
|
executionOrder = append(executionOrder, 4)
|
|
}
|
|
|
|
chained := lib.ChainMiddlewares(handler, middleware1, middleware2, middleware3)
|
|
chained(ctx)
|
|
|
|
// Check execution order: middlewares should execute in order, then handler
|
|
expectedOrder := []int{1, 2, 3, 4}
|
|
if len(executionOrder) != len(expectedOrder) {
|
|
t.Errorf("Expected %d function calls, got %d", len(expectedOrder), len(executionOrder))
|
|
}
|
|
|
|
for i, expected := range expectedOrder {
|
|
if i >= len(executionOrder) || executionOrder[i] != expected {
|
|
t.Errorf("Expected execution order %v, got %v", expectedOrder, executionOrder)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Testlib.ChainMiddlewares_MiddlewareCanModifyContext tests that middlewares can modify the context
|
|
func TestChainMiddlewares_MiddlewareCanModifyContext(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
middleware := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
ctx.SetUserValue("test-key", "test-value")
|
|
next(ctx)
|
|
}
|
|
})
|
|
|
|
handler := func(ctx *fasthttp.RequestCtx) {
|
|
value := ctx.UserValue("test-key")
|
|
if value == nil {
|
|
t.Error("Handler did not receive modified context from middleware")
|
|
} else if value.(string) != "test-value" {
|
|
t.Errorf("Expected user value to be 'test-value', got '%s'", value.(string))
|
|
}
|
|
}
|
|
|
|
chained := lib.ChainMiddlewares(handler, middleware)
|
|
chained(ctx)
|
|
}
|
|
|
|
func TestIsInferenceWSEndpoint(t *testing.T) {
|
|
paths := []string{
|
|
"/v1/responses",
|
|
"/v1/realtime",
|
|
"/responses",
|
|
"/realtime",
|
|
"/openai/v1/responses",
|
|
"/openai/responses",
|
|
"/openai/openai/responses",
|
|
"/openai/v1/realtime",
|
|
"/openai/realtime",
|
|
"/openai/openai/realtime",
|
|
}
|
|
|
|
for _, path := range paths {
|
|
if !isInferenceWSEndpoint(path) {
|
|
t.Fatalf("expected inference websocket path %s to be recognized", path)
|
|
}
|
|
}
|
|
|
|
if isInferenceWSEndpoint("/api/ws") {
|
|
t.Fatal("dashboard websocket path should not be treated as inference websocket")
|
|
}
|
|
if isInferenceWSEndpoint("/openai/chat/completions") {
|
|
t.Fatal("non-websocket OpenAI path should not be treated as inference websocket")
|
|
}
|
|
}
|
|
|
|
func TestIsRealtimeTransportEndpoint(t *testing.T) {
|
|
paths := []string{
|
|
"/v1/realtime",
|
|
"/realtime",
|
|
"/openai/realtime",
|
|
"/openai/v1/realtime",
|
|
"/openai/openai/realtime",
|
|
"/v1/realtime/calls",
|
|
"/realtime/calls",
|
|
"/openai/realtime/calls",
|
|
"/openai/v1/realtime/calls",
|
|
"/openai/openai/realtime/calls",
|
|
}
|
|
|
|
for _, path := range paths {
|
|
if !isRealtimeTransportEndpoint(path) {
|
|
t.Fatalf("expected realtime transport path %s to be recognized", path)
|
|
}
|
|
}
|
|
|
|
nonTransportPaths := []string{
|
|
"/v1/realtime/client_secrets",
|
|
"/v1/realtime/sessions",
|
|
"/openai/v1/realtime/client_secrets",
|
|
"/openai/v1/realtime/sessions",
|
|
"/v1/chat/completions",
|
|
}
|
|
|
|
for _, path := range nonTransportPaths {
|
|
if isRealtimeTransportEndpoint(path) {
|
|
t.Fatalf("did not expect non-transport path %s to be recognized", path)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Testlib.ChainMiddlewares_ShortCircuit tests that when a middleware writes a response
|
|
// and does not call next, subsequent middlewares and handler do not execute.
|
|
func TestChainMiddlewares_ShortCircuit(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
executionOrder := []int{}
|
|
|
|
// First middleware - writes response and short-circuits by not calling next
|
|
middleware1 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
executionOrder = append(executionOrder, 1)
|
|
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
|
|
ctx.SetBodyString("Unauthorized")
|
|
// Not calling next(ctx) to short-circuit
|
|
}
|
|
})
|
|
|
|
// Second middleware - should NOT execute when middleware1 short-circuits
|
|
middleware2 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
executionOrder = append(executionOrder, 2)
|
|
next(ctx)
|
|
}
|
|
})
|
|
|
|
// Third middleware - should NOT execute when middleware1 short-circuits
|
|
middleware3 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
executionOrder = append(executionOrder, 3)
|
|
next(ctx)
|
|
}
|
|
})
|
|
|
|
// Handler - should NOT execute when middleware1 short-circuits
|
|
handler := func(ctx *fasthttp.RequestCtx) {
|
|
executionOrder = append(executionOrder, 4)
|
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
|
ctx.SetBodyString("Success")
|
|
}
|
|
|
|
chained := lib.ChainMiddlewares(handler, middleware1, middleware2, middleware3)
|
|
chained(ctx)
|
|
|
|
// Verify only middleware1 executed
|
|
expectedOrder := []int{1}
|
|
if len(executionOrder) != len(expectedOrder) {
|
|
t.Errorf("Expected %d function calls, got %d", len(expectedOrder), len(executionOrder))
|
|
}
|
|
|
|
for i, expected := range expectedOrder {
|
|
if i >= len(executionOrder) || executionOrder[i] != expected {
|
|
t.Errorf("Expected execution order %v, got %v", expectedOrder, executionOrder)
|
|
break
|
|
}
|
|
}
|
|
|
|
// The middleware's response should be preserved (not overwritten)
|
|
if ctx.Response.StatusCode() != fasthttp.StatusUnauthorized {
|
|
t.Errorf("Expected status code %d, got %d", fasthttp.StatusUnauthorized, ctx.Response.StatusCode())
|
|
}
|
|
if string(ctx.Response.Body()) != "Unauthorized" {
|
|
t.Errorf("Expected body 'Unauthorized', got '%s'", string(ctx.Response.Body()))
|
|
}
|
|
}
|
|
|
|
// Testlib.ChainMiddlewares_ShortCircuitMiddlePosition tests that middleware in the middle
|
|
// can short-circuit, preventing later middlewares and handler from executing.
|
|
func TestChainMiddlewares_ShortCircuitMiddlePosition(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
executionOrder := []int{}
|
|
|
|
// First middleware - executes and calls next
|
|
middleware1 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
executionOrder = append(executionOrder, 1)
|
|
next(ctx)
|
|
}
|
|
})
|
|
|
|
// Second middleware - writes response and short-circuits
|
|
middleware2 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
executionOrder = append(executionOrder, 2)
|
|
ctx.SetStatusCode(fasthttp.StatusUnauthorized)
|
|
ctx.SetBodyString("Unauthorized")
|
|
// Not calling next(ctx) to short-circuit
|
|
}
|
|
})
|
|
|
|
// Third middleware - should NOT execute
|
|
middleware3 := schemas.BifrostHTTPMiddleware(func(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
|
return func(ctx *fasthttp.RequestCtx) {
|
|
executionOrder = append(executionOrder, 3)
|
|
next(ctx)
|
|
}
|
|
})
|
|
|
|
// Handler - should NOT execute
|
|
handler := func(ctx *fasthttp.RequestCtx) {
|
|
executionOrder = append(executionOrder, 4)
|
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
|
ctx.SetBodyString("Success")
|
|
}
|
|
|
|
chained := lib.ChainMiddlewares(handler, middleware1, middleware2, middleware3)
|
|
chained(ctx)
|
|
|
|
// Verify only middleware1 and middleware2 executed
|
|
expectedOrder := []int{1, 2}
|
|
if len(executionOrder) != len(expectedOrder) {
|
|
t.Errorf("Expected %d function calls, got %d", len(expectedOrder), len(executionOrder))
|
|
}
|
|
|
|
for i, expected := range expectedOrder {
|
|
if i >= len(executionOrder) || executionOrder[i] != expected {
|
|
t.Errorf("Expected execution order %v, got %v", expectedOrder, executionOrder)
|
|
break
|
|
}
|
|
}
|
|
|
|
// The middleware2's response should be preserved
|
|
if ctx.Response.StatusCode() != fasthttp.StatusUnauthorized {
|
|
t.Errorf("Expected status code %d, got %d", fasthttp.StatusUnauthorized, ctx.Response.StatusCode())
|
|
}
|
|
if string(ctx.Response.Body()) != "Unauthorized" {
|
|
t.Errorf("Expected body 'Unauthorized', got '%s'", string(ctx.Response.Body()))
|
|
}
|
|
}
|
|
|
|
// TestAuthMiddleware_NilAuthConfig tests that auth middleware allows requests when auth config is nil
|
|
func TestAuthMiddleware_NilAuthConfig(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
am := &AuthMiddleware{}
|
|
// authConfig is nil by default (simulates app start with no auth config)
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.SetRequestURI("/api/some-endpoint")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := am.APIMiddleware()
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// When auth config is nil, requests should be allowed through
|
|
if !nextCalled {
|
|
t.Error("Next handler should be called when auth config is nil")
|
|
}
|
|
}
|
|
|
|
// TestAuthMiddleware_DisabledAuthConfig tests that auth middleware allows requests when auth is disabled
|
|
func TestAuthMiddleware_DisabledAuthConfig(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
am := &AuthMiddleware{}
|
|
am.UpdateAuthConfig(&configstore.AuthConfig{
|
|
AdminUserName: schemas.NewEnvVar("admin"),
|
|
AdminPassword: schemas.NewEnvVar("password"),
|
|
IsEnabled: false,
|
|
})
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.SetRequestURI("/api/some-endpoint")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := am.APIMiddleware()
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// When auth is disabled, requests should be allowed through
|
|
if !nextCalled {
|
|
t.Error("Next handler should be called when auth is disabled")
|
|
}
|
|
}
|
|
|
|
// TestAuthMiddleware_EnabledAuthConfig_NoAuth tests that auth middleware blocks unauthenticated requests
|
|
func TestAuthMiddleware_EnabledAuthConfig_NoAuth(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
am := &AuthMiddleware{}
|
|
am.UpdateAuthConfig(&configstore.AuthConfig{
|
|
AdminUserName: schemas.NewEnvVar("admin"),
|
|
AdminPassword: schemas.NewEnvVar("hashedpassword"),
|
|
IsEnabled: true,
|
|
})
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.SetRequestURI("/api/some-endpoint")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := am.APIMiddleware()
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// When auth is enabled and no auth header is provided, request should be blocked
|
|
if nextCalled {
|
|
t.Error("Next handler should NOT be called when auth is enabled and no credentials provided")
|
|
}
|
|
if ctx.Response.StatusCode() != fasthttp.StatusUnauthorized {
|
|
t.Errorf("Expected status code %d, got %d", fasthttp.StatusUnauthorized, ctx.Response.StatusCode())
|
|
}
|
|
}
|
|
|
|
// TestAuthMiddleware_WhitelistedRoutes tests that whitelisted routes bypass auth
|
|
func TestAuthMiddleware_WhitelistedRoutes(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
am := &AuthMiddleware{}
|
|
am.UpdateAuthConfig(&configstore.AuthConfig{
|
|
AdminUserName: schemas.NewEnvVar("admin"),
|
|
AdminPassword: schemas.NewEnvVar("hashedpassword"),
|
|
IsEnabled: true,
|
|
})
|
|
|
|
whitelistedRoutes := []string{
|
|
"/api/session/is-auth-enabled",
|
|
"/api/session/login",
|
|
"/api/oauth/callback",
|
|
"/health",
|
|
}
|
|
|
|
for _, route := range whitelistedRoutes {
|
|
t.Run(route, func(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.SetRequestURI(route)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := am.APIMiddleware()
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
if !nextCalled {
|
|
t.Errorf("Next handler should be called for whitelisted route %s", route)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_InferenceMiddleware_RealtimeTransportBypassesAuth(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
am := &AuthMiddleware{}
|
|
am.UpdateAuthConfig(&configstore.AuthConfig{
|
|
AdminUserName: schemas.NewEnvVar("admin"),
|
|
AdminPassword: schemas.NewEnvVar("hashedpassword"),
|
|
IsEnabled: true,
|
|
})
|
|
|
|
routes := []string{
|
|
"/v1/realtime",
|
|
"/openai/v1/realtime",
|
|
"/v1/realtime/calls?model=gpt-realtime",
|
|
"/openai/v1/realtime/calls?model=gpt-realtime",
|
|
}
|
|
|
|
for _, route := range routes {
|
|
t.Run(route, func(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.SetRequestURI(route)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
handler := am.InferenceMiddleware()(next)
|
|
handler(ctx)
|
|
|
|
if !nextCalled {
|
|
t.Fatalf("expected realtime transport route %s to bypass auth", route)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_InferenceMiddleware_RealtimeMintingStillRequiresAuth(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
am := &AuthMiddleware{}
|
|
am.UpdateAuthConfig(&configstore.AuthConfig{
|
|
AdminUserName: schemas.NewEnvVar("admin"),
|
|
AdminPassword: schemas.NewEnvVar("hashedpassword"),
|
|
IsEnabled: true,
|
|
})
|
|
|
|
routes := []string{
|
|
"/v1/realtime/client_secrets",
|
|
"/v1/realtime/sessions",
|
|
"/openai/v1/realtime/client_secrets",
|
|
"/openai/v1/realtime/sessions",
|
|
}
|
|
|
|
for _, route := range routes {
|
|
t.Run(route, func(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.SetRequestURI(route)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
handler := am.InferenceMiddleware()(next)
|
|
handler(ctx)
|
|
|
|
if nextCalled {
|
|
t.Fatalf("expected realtime minting route %s to still require auth", route)
|
|
}
|
|
if ctx.Response.StatusCode() != fasthttp.StatusUnauthorized {
|
|
t.Fatalf("expected %d for route %s, got %d", fasthttp.StatusUnauthorized, route, ctx.Response.StatusCode())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestAuthMiddleware_UpdateAuthConfig_NilToEnabled tests updating auth config from nil to enabled
|
|
func TestAuthMiddleware_UpdateAuthConfig_NilToEnabled(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
am := &AuthMiddleware{}
|
|
// Initially auth config is nil
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.SetRequestURI("/api/some-endpoint")
|
|
|
|
// First request should pass (nil config)
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := am.APIMiddleware()
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
if !nextCalled {
|
|
t.Error("First request should pass when auth config is nil")
|
|
}
|
|
|
|
// Now enable auth
|
|
am.UpdateAuthConfig(&configstore.AuthConfig{
|
|
AdminUserName: schemas.NewEnvVar("admin"),
|
|
AdminPassword: schemas.NewEnvVar("hashedpassword"),
|
|
IsEnabled: true,
|
|
})
|
|
|
|
// Second request should be blocked (auth enabled, no credentials)
|
|
ctx2 := &fasthttp.RequestCtx{}
|
|
ctx2.Request.SetRequestURI("/api/some-endpoint")
|
|
|
|
nextCalled = false
|
|
handler(ctx2)
|
|
|
|
if nextCalled {
|
|
t.Error("Second request should be blocked after auth is enabled")
|
|
}
|
|
if ctx2.Response.StatusCode() != fasthttp.StatusUnauthorized {
|
|
t.Errorf("Expected status code %d, got %d", fasthttp.StatusUnauthorized, ctx2.Response.StatusCode())
|
|
}
|
|
}
|
|
|
|
// TestAuthMiddleware_UpdateAuthConfig_EnabledToDisabled tests disabling auth after it was enabled
|
|
func TestAuthMiddleware_UpdateAuthConfig_EnabledToDisabled(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
am := &AuthMiddleware{}
|
|
// Start with auth enabled
|
|
am.UpdateAuthConfig(&configstore.AuthConfig{
|
|
AdminUserName: schemas.NewEnvVar("admin"),
|
|
AdminPassword: schemas.NewEnvVar("hashedpassword"),
|
|
IsEnabled: true,
|
|
})
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.SetRequestURI("/api/some-endpoint")
|
|
|
|
// First request should be blocked (auth enabled, no credentials)
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := am.APIMiddleware()
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
if nextCalled {
|
|
t.Error("First request should be blocked when auth is enabled")
|
|
}
|
|
|
|
// Now disable auth
|
|
am.UpdateAuthConfig(&configstore.AuthConfig{
|
|
AdminUserName: schemas.NewEnvVar("admin"),
|
|
AdminPassword: schemas.NewEnvVar("hashedpassword"),
|
|
IsEnabled: false,
|
|
})
|
|
|
|
// Second request should pass (auth disabled)
|
|
ctx2 := &fasthttp.RequestCtx{}
|
|
ctx2.Request.SetRequestURI("/api/some-endpoint")
|
|
|
|
nextCalled = false
|
|
handler(ctx2)
|
|
|
|
if !nextCalled {
|
|
t.Error("Second request should pass after auth is disabled")
|
|
}
|
|
}
|
|
|
|
// TestFasthttpToHTTPRequest tests the conversion from fasthttp context to HTTPRequest
|
|
func TestFasthttpToHTTPRequest(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
// Set up test data
|
|
ctx.Request.Header.SetMethod("POST")
|
|
// Query params include: integers, floats, booleans, timestamps, and strings with special chars
|
|
ctx.Request.SetRequestURI("/api/v1/test?limit=100&offset=50&min_cost=12.50&max_latency=1500.75&missing_cost_only=true&start_time=2023-01-15T10:30:00Z&content_search=test+query&special=%2B%26%3D%3F")
|
|
ctx.Request.Header.Set("Content-Type", "application/json")
|
|
ctx.Request.Header.Set("Authorization", "Bearer token123")
|
|
ctx.Request.Header.Set("X-Request-Id", "12345")
|
|
ctx.Request.Header.Set("X-Custom-Header", "value-with-dashes")
|
|
ctx.Request.SetBodyString(`{"key": "value", "number": 42, "nested": {"bool": true}}`)
|
|
|
|
// Acquire HTTPRequest from pool
|
|
req := schemas.AcquireHTTPRequest()
|
|
defer schemas.ReleaseHTTPRequest(req)
|
|
|
|
// Call the function
|
|
fasthttpToHTTPRequest(ctx, req)
|
|
|
|
// Verify Method
|
|
if req.Method != "POST" {
|
|
t.Errorf("Expected Method to be 'POST', got '%s'", req.Method)
|
|
}
|
|
|
|
// Verify Path (without query params)
|
|
if req.Path != "/api/v1/test" {
|
|
t.Errorf("Expected Path to be '/api/v1/test', got '%s'", req.Path)
|
|
}
|
|
|
|
// Verify Headers
|
|
expectedHeaders := map[string]string{
|
|
"Content-Type": "application/json",
|
|
"Authorization": "Bearer token123",
|
|
"X-Request-Id": "12345",
|
|
"X-Custom-Header": "value-with-dashes",
|
|
}
|
|
for key, expectedValue := range expectedHeaders {
|
|
if actualValue, exists := req.Headers[key]; !exists {
|
|
t.Errorf("Expected header '%s' to exist", key)
|
|
} else if actualValue != expectedValue {
|
|
t.Errorf("Expected header '%s' to be '%s', got '%s'", key, expectedValue, actualValue)
|
|
}
|
|
}
|
|
|
|
// Verify Query params
|
|
expectedQuery := map[string]string{
|
|
"limit": "100", // integer
|
|
"offset": "50", // integer
|
|
"min_cost": "12.50", // float
|
|
"max_latency": "1500.75", // float
|
|
"missing_cost_only": "true", // boolean
|
|
"start_time": "2023-01-15T10:30:00Z", // timestamp
|
|
"content_search": "test query", // string with space (decoded)
|
|
"special": "+&=?", // special characters (decoded)
|
|
}
|
|
for key, expectedValue := range expectedQuery {
|
|
if actualValue, exists := req.Query[key]; !exists {
|
|
t.Errorf("Expected query param '%s' to exist", key)
|
|
} else if actualValue != expectedValue {
|
|
t.Errorf("Expected query param '%s' to be '%s', got '%s'", key, expectedValue, actualValue)
|
|
}
|
|
}
|
|
|
|
// Verify Body (JSON with various types)
|
|
expectedBody := `{"key": "value", "number": 42, "nested": {"bool": true}}`
|
|
if string(req.Body) != expectedBody {
|
|
t.Errorf("Expected Body to be '%s', got '%s'", expectedBody, string(req.Body))
|
|
}
|
|
|
|
// Verify body is a copy, not a reference
|
|
originalBody := ctx.Request.Body()
|
|
if len(req.Body) > 0 && len(originalBody) > 0 {
|
|
// Modify the HTTPRequest body
|
|
req.Body[0] = 'X'
|
|
// Original should remain unchanged
|
|
if originalBody[0] == 'X' {
|
|
t.Error("Body should be a copy, not a reference to the original")
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_DefaultHeaders tests that default CORS headers are set
|
|
func TestCorsMiddleware_DefaultHeaders(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{"https://example.com"},
|
|
AllowedHeaders: []string{}, // No custom headers
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.Set("Origin", "https://example.com")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Check default headers are set
|
|
expectedHeaders := "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout, X-Api-Key, X-OpenAI-Agents-SDK"
|
|
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
|
|
if actualHeaders != expectedHeaders {
|
|
t.Errorf("Expected Access-Control-Allow-Headers to be %s, got %s", expectedHeaders, actualHeaders)
|
|
}
|
|
|
|
if !nextCalled {
|
|
t.Error("Next handler was not called")
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_WildcardHeaders_NonCredentialed tests that wildcard allowed headers
|
|
// sets Access-Control-Allow-Headers to * for non-credentialed requests (wildcard origins).
|
|
func TestCorsMiddleware_WildcardHeaders_NonCredentialed(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{"*"},
|
|
AllowedHeaders: []string{"*"},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.Set("Origin", "https://example.com")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Non-credentialed: wildcard is valid per spec
|
|
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
|
|
if actualHeaders != "*" {
|
|
t.Errorf("Expected Access-Control-Allow-Headers to be *, got %s", actualHeaders)
|
|
}
|
|
|
|
if !nextCalled {
|
|
t.Error("Next handler was not called")
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_WildcardHeaders_CredentialedPreflight tests that wildcard allowed headers
|
|
// reflects Access-Control-Request-Headers for credentialed preflight requests instead of sending
|
|
// the literal *, which browsers don't treat as a wildcard when credentials are present.
|
|
func TestCorsMiddleware_WildcardHeaders_CredentialedPreflight(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{"https://example.com"},
|
|
AllowedHeaders: []string{"*"},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("OPTIONS")
|
|
ctx.Request.Header.Set("Origin", "https://example.com")
|
|
ctx.Request.Header.Set("Access-Control-Request-Headers", "Authorization, X-Custom-Header")
|
|
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
t.Error("Next handler should not be called for preflight")
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Credentialed preflight: should reflect requested headers, not *
|
|
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
|
|
if actualHeaders != "Authorization, X-Custom-Header" {
|
|
t.Errorf("Expected Access-Control-Allow-Headers to reflect requested headers, got %s", actualHeaders)
|
|
}
|
|
|
|
// Should also have credentials
|
|
creds := string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials"))
|
|
if creds != "true" {
|
|
t.Errorf("Expected Access-Control-Allow-Credentials to be true, got %s", creds)
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_WildcardHeaders_CredentialedNonPreflight tests that wildcard allowed headers
|
|
// uses defaults for credentialed non-preflight requests (no Access-Control-Request-Headers).
|
|
func TestCorsMiddleware_WildcardHeaders_CredentialedNonPreflight(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{"https://example.com"},
|
|
AllowedHeaders: []string{"*"},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.Set("Origin", "https://example.com")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Credentialed non-preflight: should use defaults (not *)
|
|
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
|
|
defaultHeaders := []string{"Content-Type", "Authorization", "X-Requested-With", "X-Stainless-Timeout", "X-Api-Key"}
|
|
for _, header := range defaultHeaders {
|
|
if !containsHeader(actualHeaders, header) {
|
|
t.Errorf("Expected Access-Control-Allow-Headers to contain %s, got %s", header, actualHeaders)
|
|
}
|
|
}
|
|
if actualHeaders == "*" {
|
|
t.Error("Expected Access-Control-Allow-Headers to NOT be * for credentialed requests")
|
|
}
|
|
|
|
if !nextCalled {
|
|
t.Error("Next handler was not called")
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_CustomHeaders tests that custom allowed headers are appended to defaults
|
|
func TestCorsMiddleware_CustomHeaders(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{"https://example.com"},
|
|
AllowedHeaders: []string{"X-Custom-Header", "X-Another-Header"},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.Set("Origin", "https://example.com")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Check that custom headers are included along with defaults
|
|
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
|
|
expectedHeaders := []string{
|
|
"Content-Type",
|
|
"Authorization",
|
|
"X-Requested-With",
|
|
"X-Stainless-Timeout",
|
|
"X-Custom-Header",
|
|
"X-Another-Header",
|
|
}
|
|
|
|
for _, header := range expectedHeaders {
|
|
if !containsHeader(actualHeaders, header) {
|
|
t.Errorf("Expected Access-Control-Allow-Headers to contain %s, got %s", header, actualHeaders)
|
|
}
|
|
}
|
|
|
|
if !nextCalled {
|
|
t.Error("Next handler was not called")
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_DuplicateHeaders tests that duplicate headers are not added twice
|
|
func TestCorsMiddleware_DuplicateHeaders(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{"https://example.com"},
|
|
// Include a header that's already in defaults
|
|
AllowedHeaders: []string{"Content-Type", "X-Custom-Header"},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.Set("Origin", "https://example.com")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Check headers - Content-Type should not be duplicated
|
|
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
|
|
|
|
// Count occurrences of "Content-Type"
|
|
count := countHeaderOccurrences(actualHeaders, "Content-Type")
|
|
if count != 1 {
|
|
t.Errorf("Expected Content-Type to appear once, but appeared %d times in: %s", count, actualHeaders)
|
|
}
|
|
|
|
// Custom header should be present
|
|
if !containsHeader(actualHeaders, "X-Custom-Header") {
|
|
t.Errorf("Expected Access-Control-Allow-Headers to contain X-Custom-Header, got %s", actualHeaders)
|
|
}
|
|
|
|
if !nextCalled {
|
|
t.Error("Next handler was not called")
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_CustomHeadersWithLocalhost tests custom headers work with localhost origins
|
|
func TestCorsMiddleware_CustomHeadersWithLocalhost(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{},
|
|
AllowedHeaders: []string{"X-Development-Header"},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.Set("Origin", "http://localhost:3000")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Check that custom header is included for localhost
|
|
actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))
|
|
if !containsHeader(actualHeaders, "X-Development-Header") {
|
|
t.Errorf("Expected Access-Control-Allow-Headers to contain X-Development-Header for localhost, got %s", actualHeaders)
|
|
}
|
|
|
|
if !nextCalled {
|
|
t.Error("Next handler was not called")
|
|
}
|
|
}
|
|
|
|
// TestCorsMiddleware_CustomHeadersNotSetForNonAllowedOrigin tests that CORS headers (including custom) are not set for non-allowed origins
|
|
func TestCorsMiddleware_CustomHeadersNotSetForNonAllowedOrigin(t *testing.T) {
|
|
SetLogger(&mockLogger{})
|
|
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
AllowedOrigins: []string{"https://allowed.com"},
|
|
AllowedHeaders: []string{"X-Custom-Header"},
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.Set("Origin", "https://malicious.com")
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
middleware := CorsMiddleware(config)
|
|
handler := middleware(next)
|
|
handler(ctx)
|
|
|
|
// Check CORS headers are NOT set (including Allow-Headers)
|
|
if len(ctx.Response.Header.Peek("Access-Control-Allow-Headers")) != 0 {
|
|
t.Error("Access-Control-Allow-Headers header should not be set for non-allowed origin")
|
|
}
|
|
|
|
// Check next handler was still called for non-OPTIONS requests
|
|
if !nextCalled {
|
|
t.Error("Next handler was not called")
|
|
}
|
|
}
|
|
|
|
// Helper function to check if a header is present in the comma-separated list
|
|
func containsHeader(headerList, header string) bool {
|
|
headers := splitHeaders(headerList)
|
|
for _, h := range headers {
|
|
if h == header {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Helper function to split and trim headers
|
|
func splitHeaders(headerList string) []string {
|
|
// Simple split by comma and trim spaces
|
|
var headers []string
|
|
start := 0
|
|
for i := 0; i < len(headerList); i++ {
|
|
if headerList[i] == ',' {
|
|
header := headerList[start:i]
|
|
// Trim spaces
|
|
for len(header) > 0 && header[0] == ' ' {
|
|
header = header[1:]
|
|
}
|
|
for len(header) > 0 && header[len(header)-1] == ' ' {
|
|
header = header[:len(header)-1]
|
|
}
|
|
if header != "" {
|
|
headers = append(headers, header)
|
|
}
|
|
start = i + 1
|
|
}
|
|
}
|
|
// Add last header
|
|
if start < len(headerList) {
|
|
header := headerList[start:]
|
|
// Trim spaces
|
|
for len(header) > 0 && header[0] == ' ' {
|
|
header = header[1:]
|
|
}
|
|
for len(header) > 0 && header[len(header)-1] == ' ' {
|
|
header = header[:len(header)-1]
|
|
}
|
|
if header != "" {
|
|
headers = append(headers, header)
|
|
}
|
|
}
|
|
return headers
|
|
}
|
|
|
|
// Helper function to count occurrences of a header
|
|
func countHeaderOccurrences(headerList, header string) int {
|
|
headers := splitHeaders(headerList)
|
|
count := 0
|
|
for _, h := range headers {
|
|
if h == header {
|
|
count++
|
|
}
|
|
}
|
|
return count
|
|
}
|
|
|
|
// TestFasthttpToHTTPRequest_PathParams tests that path parameters are extracted correctly
|
|
func TestFasthttpToHTTPRequest_PathParams(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
|
|
// Set up test data
|
|
ctx.Request.Header.SetMethod("GET")
|
|
ctx.Request.SetRequestURI("/v1beta/files/file-abc123")
|
|
|
|
// Simulate what the fasthttp router does - set path params as user values
|
|
ctx.SetUserValue("file_id", "file-abc123")
|
|
ctx.SetUserValue("model", "gemini-pro")
|
|
|
|
// Set some system values that should be ignored
|
|
ctx.SetUserValue("BifrostContextKeyRequestID", "req-123")
|
|
ctx.SetUserValue("trace_id", "trace-456")
|
|
ctx.SetUserValue("span_id", "span-789")
|
|
|
|
// Acquire HTTPRequest from pool
|
|
req := schemas.AcquireHTTPRequest()
|
|
defer schemas.ReleaseHTTPRequest(req)
|
|
|
|
// Call the function
|
|
fasthttpToHTTPRequest(ctx, req)
|
|
|
|
// Verify path parameters are extracted
|
|
expectedPathParams := map[string]string{
|
|
"file_id": "file-abc123",
|
|
"model": "gemini-pro",
|
|
}
|
|
|
|
if len(req.PathParams) != len(expectedPathParams) {
|
|
t.Errorf("Expected %d path params, got %d", len(expectedPathParams), len(req.PathParams))
|
|
}
|
|
|
|
for key, expectedValue := range expectedPathParams {
|
|
if actualValue, exists := req.PathParams[key]; !exists {
|
|
t.Errorf("Expected path param '%s' to exist", key)
|
|
} else if actualValue != expectedValue {
|
|
t.Errorf("Expected path param '%s' to be '%s', got '%s'", key, expectedValue, actualValue)
|
|
}
|
|
}
|
|
|
|
// Verify system keys are NOT in path params
|
|
systemKeys := []string{"BifrostContextKeyRequestID", "trace_id", "span_id"}
|
|
for _, key := range systemKeys {
|
|
if _, exists := req.PathParams[key]; exists {
|
|
t.Errorf("System key '%s' should not be in path params", key)
|
|
}
|
|
}
|
|
|
|
// Test the helper method
|
|
if fileID := req.CaseInsensitivePathParamLookup("file_id"); fileID != "file-abc123" {
|
|
t.Errorf("CaseInsensitivePathParamLookup failed: expected 'file-abc123', got '%s'", fileID)
|
|
}
|
|
if fileID := req.CaseInsensitivePathParamLookup("FILE_ID"); fileID != "file-abc123" {
|
|
t.Errorf("CaseInsensitivePathParamLookup should be case-insensitive: expected 'file-abc123', got '%s'", fileID)
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_SupportedEncodings(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 10,
|
|
},
|
|
}
|
|
|
|
plainBody := []byte(`{"model":"openai/gpt-4o-mini","messages":[{"role":"user","content":"hello"}]}`)
|
|
testCases := []struct {
|
|
name string
|
|
encoding string
|
|
encode func([]byte) ([]byte, error)
|
|
}{
|
|
{name: "gzip", encoding: "gzip", encode: gzipCompress},
|
|
{name: "deflate", encoding: "deflate", encode: deflateCompress},
|
|
{name: "brotli", encoding: "br", encode: brotliCompress},
|
|
{name: "zstd", encoding: "zstd", encode: zstdCompress},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
compressedBody, err := tc.encode(plainBody)
|
|
if err != nil {
|
|
t.Fatalf("failed to encode body: %v", err)
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.SetContentType("application/json")
|
|
ctx.Request.Header.Set("Content-Encoding", tc.encoding)
|
|
ctx.Request.SetBodyRaw(compressedBody)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
if string(ctx.Request.Body()) != string(plainBody) {
|
|
t.Fatalf("expected decompressed body, got %q", string(ctx.Request.Body()))
|
|
}
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
if !nextCalled {
|
|
t.Fatal("next handler was not called")
|
|
}
|
|
if got := string(ctx.Request.Header.Peek("Content-Encoding")); got != "" {
|
|
t.Fatalf("expected content-encoding to be cleared, got %q", got)
|
|
}
|
|
if got := string(ctx.Request.Header.Peek("Content-Length")); got != "" {
|
|
t.Fatalf("expected content-length to be cleared, got %q", got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_InvalidCompressedBody(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 10,
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.Set("Content-Encoding", "gzip")
|
|
ctx.Request.SetBodyRaw([]byte("not-a-valid-gzip-payload"))
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
if nextCalled {
|
|
t.Fatal("next handler should not be called for invalid compressed payload")
|
|
}
|
|
if ctx.Response.StatusCode() != fasthttp.StatusBadRequest {
|
|
t.Fatalf("expected status 400, got %d", ctx.Response.StatusCode())
|
|
}
|
|
|
|
var bifrostErr schemas.BifrostError
|
|
if err := json.Unmarshal(ctx.Response.Body(), &bifrostErr); err != nil {
|
|
t.Fatalf("failed to decode error response: %v", err)
|
|
}
|
|
if bifrostErr.Error == nil || !strings.Contains(bifrostErr.Error.Message, "invalid compressed request body") {
|
|
t.Fatalf("unexpected error message: %#v", bifrostErr.Error)
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_UnsupportedEncoding(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 10,
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.Set("Content-Encoding", "snappy")
|
|
ctx.Request.SetBodyRaw([]byte("whatever"))
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
if nextCalled {
|
|
t.Fatal("next handler should not be called for unsupported content-encoding")
|
|
}
|
|
if ctx.Response.StatusCode() != fasthttp.StatusBadRequest {
|
|
t.Fatalf("expected status 400, got %d", ctx.Response.StatusCode())
|
|
}
|
|
|
|
var bifrostErr schemas.BifrostError
|
|
if err := json.Unmarshal(ctx.Response.Body(), &bifrostErr); err != nil {
|
|
t.Fatalf("failed to decode error response: %v", err)
|
|
}
|
|
if bifrostErr.Error == nil || !strings.Contains(bifrostErr.Error.Message, "unsupported Content-Encoding") {
|
|
t.Fatalf("unexpected error message: %#v", bifrostErr.Error)
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_DecompressedSizeLimit(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 1,
|
|
},
|
|
}
|
|
|
|
plainBody := bytes.Repeat([]byte("a"), (1024*1024)+10)
|
|
compressedBody, err := gzipCompress(plainBody)
|
|
if err != nil {
|
|
t.Fatalf("failed to gzip test payload: %v", err)
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.Set("Content-Encoding", "gzip")
|
|
ctx.Request.SetBodyRaw(compressedBody)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
if nextCalled {
|
|
t.Fatal("next handler should not be called when decompressed body exceeds limit")
|
|
}
|
|
if ctx.Response.StatusCode() != fasthttp.StatusRequestEntityTooLarge {
|
|
t.Fatalf("expected status 413, got %d", ctx.Response.StatusCode())
|
|
}
|
|
|
|
var bifrostErr schemas.BifrostError
|
|
if err := json.Unmarshal(ctx.Response.Body(), &bifrostErr); err != nil {
|
|
t.Fatalf("failed to decode error response: %v", err)
|
|
}
|
|
if bifrostErr.Error == nil || !strings.Contains(bifrostErr.Error.Message, "decompressed request body exceeds max allowed size") {
|
|
t.Fatalf("unexpected error message: %#v", bifrostErr.Error)
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_EmptyBodyWithContentEncoding(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 10,
|
|
},
|
|
}
|
|
|
|
encodings := []string{"gzip", "deflate", "br", "zstd"}
|
|
for _, enc := range encodings {
|
|
t.Run(enc, func(t *testing.T) {
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.Set("Content-Encoding", enc)
|
|
ctx.Request.SetBodyRaw([]byte{})
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
// Empty body with Content-Encoding should return 400 (decoders fail on empty input)
|
|
if nextCalled {
|
|
// Some decoders may produce empty output — that's acceptable too
|
|
return
|
|
}
|
|
if ctx.Response.StatusCode() != fasthttp.StatusBadRequest {
|
|
t.Fatalf("expected status 400, got %d", ctx.Response.StatusCode())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_NoContentEncoding(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 10,
|
|
},
|
|
}
|
|
|
|
originalBody := []byte(`{"model":"openai/gpt-4o-mini","messages":[{"role":"user","content":"hello"}]}`)
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.SetContentType("application/json")
|
|
ctx.Request.SetBodyRaw(originalBody)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
if string(ctx.Request.Body()) != string(originalBody) {
|
|
t.Fatalf("expected body to be unchanged, got %q", string(ctx.Request.Body()))
|
|
}
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
if !nextCalled {
|
|
t.Fatal("next handler was not called")
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_ExactSizeLimit(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 1,
|
|
},
|
|
}
|
|
|
|
plainBody := bytes.Repeat([]byte("a"), 1024*1024) // exactly 1 MB
|
|
compressedBody, err := gzipCompress(plainBody)
|
|
if err != nil {
|
|
t.Fatalf("failed to gzip test payload: %v", err)
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.Set("Content-Encoding", "gzip")
|
|
ctx.Request.SetBodyRaw(compressedBody)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
if len(ctx.Request.Body()) != 1024*1024 {
|
|
t.Fatalf("expected body length %d, got %d", 1024*1024, len(ctx.Request.Body()))
|
|
}
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
if !nextCalled {
|
|
t.Fatal("next handler was not called — body exactly at limit should pass")
|
|
}
|
|
}
|
|
|
|
// --- Streaming decompression path tests ---
|
|
|
|
func TestShouldStreamDecompress(t *testing.T) {
|
|
defaultThreshold := int(schemas.DefaultLargePayloadRequestThresholdBytes)
|
|
tests := []struct {
|
|
name string
|
|
contentLength int
|
|
customThreshold int64 // 0 means no custom threshold (use default)
|
|
want bool
|
|
}{
|
|
{"chunked (CL=-1)", -1, 0, true},
|
|
{"empty body (CL=0)", 0, 0, false},
|
|
{"small body", 100, 0, false},
|
|
{"at default threshold", defaultThreshold, 0, false},
|
|
{"above default threshold", defaultThreshold + 1, 0, true},
|
|
// Custom enterprise threshold (1MB) — body at 2MB should stream.
|
|
{"above custom threshold", 2 * 1024 * 1024, 1 * 1024 * 1024, true},
|
|
// Custom enterprise threshold (20MB) — body at default 10MB+1 should NOT stream.
|
|
{"below custom threshold", defaultThreshold + 1, 20 * 1024 * 1024, false},
|
|
// Chunked always streams regardless of custom threshold.
|
|
{"chunked with custom threshold", -1, 50 * 1024 * 1024, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &lib.Config{}
|
|
if tt.customThreshold > 0 {
|
|
cfg.StreamingDecompressThreshold = tt.customThreshold
|
|
}
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.Set("Content-Encoding", "gzip")
|
|
if tt.contentLength >= 0 {
|
|
ctx.Request.Header.SetContentLength(tt.contentLength)
|
|
} else {
|
|
// Simulate chunked: set body stream with unknown size
|
|
ctx.Request.SetBodyStream(bytes.NewReader(nil), -1)
|
|
}
|
|
if got := shouldStreamDecompress(cfg, ctx); got != tt.want {
|
|
t.Errorf("shouldStreamDecompress() = %v, want %v (CL=%d, threshold=%d)", got, tt.want, tt.contentLength, tt.customThreshold)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_StreamingPath_ChunkedGzip(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 100,
|
|
},
|
|
}
|
|
|
|
plainBody := []byte(`{"model":"openai/gpt-4o-mini","messages":[{"role":"user","content":"hello"}]}`)
|
|
compressedBody, err := gzipCompress(plainBody)
|
|
if err != nil {
|
|
t.Fatalf("failed to gzip: %v", err)
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.Set("Content-Encoding", "gzip")
|
|
// Chunked: SetBodyStream with size -1 triggers the streaming path
|
|
ctx.Request.SetBodyStream(bytes.NewReader(compressedBody), -1)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
// Content-Encoding should be cleared
|
|
if ce := string(ctx.Request.Header.Peek("Content-Encoding")); ce != "" {
|
|
t.Errorf("expected Content-Encoding to be cleared, got %q", ce)
|
|
}
|
|
// Body should be correctly decompressed
|
|
body := ctx.Request.Body()
|
|
if string(body) != string(plainBody) {
|
|
t.Errorf("decompressed body mismatch: got %d bytes, want %d bytes", len(body), len(plainBody))
|
|
}
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
if !nextCalled {
|
|
t.Fatal("next handler was not called")
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_StreamingPath_AllEncodings(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 100,
|
|
},
|
|
}
|
|
|
|
plainBody := []byte(`{"model":"openai/gpt-4o-mini","messages":[{"role":"user","content":"hello"}]}`)
|
|
testCases := []struct {
|
|
name string
|
|
encoding string
|
|
encode func([]byte) ([]byte, error)
|
|
}{
|
|
{name: "gzip", encoding: "gzip", encode: gzipCompress},
|
|
{name: "deflate", encoding: "deflate", encode: deflateCompress},
|
|
{name: "brotli", encoding: "br", encode: brotliCompress},
|
|
{name: "zstd", encoding: "zstd", encode: zstdCompress},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
compressedBody, err := tc.encode(plainBody)
|
|
if err != nil {
|
|
t.Fatalf("failed to encode body: %v", err)
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.Set("Content-Encoding", tc.encoding)
|
|
// Use chunked (-1) to trigger streaming path regardless of compressed size
|
|
ctx.Request.SetBodyStream(bytes.NewReader(compressedBody), -1)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
body := ctx.Request.Body()
|
|
if string(body) != string(plainBody) {
|
|
t.Fatalf("expected decompressed body, got %q", string(body))
|
|
}
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
if !nextCalled {
|
|
t.Fatal("next handler was not called")
|
|
}
|
|
if got := string(ctx.Request.Header.Peek("Content-Encoding")); got != "" {
|
|
t.Fatalf("expected content-encoding to be cleared, got %q", got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_StreamingPath_InvalidBody(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 100,
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.Set("Content-Encoding", "gzip")
|
|
// Chunked with invalid gzip data → streaming path → error
|
|
ctx.Request.SetBodyStream(bytes.NewReader([]byte("not-a-valid-gzip-payload")), -1)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
if nextCalled {
|
|
t.Fatal("next handler should not be called for invalid compressed payload")
|
|
}
|
|
if ctx.Response.StatusCode() != fasthttp.StatusBadRequest {
|
|
t.Fatalf("expected status 400, got %d", ctx.Response.StatusCode())
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_StreamingPath_UnsupportedEncoding(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 100,
|
|
},
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.Set("Content-Encoding", "snappy")
|
|
ctx.Request.SetBodyStream(bytes.NewReader([]byte("whatever")), -1)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
if nextCalled {
|
|
t.Fatal("next handler should not be called for unsupported encoding")
|
|
}
|
|
if ctx.Response.StatusCode() != fasthttp.StatusBadRequest {
|
|
t.Fatalf("expected status 400, got %d", ctx.Response.StatusCode())
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_BufferedPath_SmallGzip(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 10,
|
|
},
|
|
}
|
|
|
|
plainBody := []byte(`{"model":"openai/gpt-4o-mini","messages":[{"role":"user","content":"hello"}]}`)
|
|
compressedBody, err := gzipCompress(plainBody)
|
|
if err != nil {
|
|
t.Fatalf("failed to gzip: %v", err)
|
|
}
|
|
|
|
// Verify compressed body is below threshold (should use buffered path)
|
|
if int64(len(compressedBody)) > schemas.DefaultLargePayloadRequestThresholdBytes {
|
|
t.Skip("compressed body unexpectedly exceeds threshold")
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.Set("Content-Encoding", "gzip")
|
|
// SetBodyRaw with known small Content-Length → buffered path
|
|
ctx.Request.SetBodyRaw(compressedBody)
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
if string(ctx.Request.Body()) != string(plainBody) {
|
|
t.Fatalf("expected decompressed body, got %q", string(ctx.Request.Body()))
|
|
}
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
if !nextCalled {
|
|
t.Fatal("next handler was not called")
|
|
}
|
|
}
|
|
|
|
func TestRequestDecompressionMiddleware_StreamingPath_LargeGzip(t *testing.T) {
|
|
config := &lib.Config{
|
|
ClientConfig: &configstore.ClientConfig{
|
|
MaxRequestBodySizeMB: 100,
|
|
},
|
|
}
|
|
|
|
// Random bytes are incompressible — compressed size ≈ input size + gzip overhead.
|
|
bodySize := int(schemas.DefaultLargePayloadRequestThresholdBytes) + 1024*1024
|
|
plainBody := make([]byte, bodySize)
|
|
if _, err := cryptoRand.Read(plainBody); err != nil {
|
|
t.Fatalf("failed to generate random data: %v", err)
|
|
}
|
|
compressedBody, err := gzipCompress(plainBody)
|
|
if err != nil {
|
|
t.Fatalf("failed to gzip: %v", err)
|
|
}
|
|
|
|
if int64(len(compressedBody)) <= schemas.DefaultLargePayloadRequestThresholdBytes {
|
|
t.Skipf("compressed body %d bytes is below threshold %d",
|
|
len(compressedBody), schemas.DefaultLargePayloadRequestThresholdBytes)
|
|
}
|
|
|
|
ctx := &fasthttp.RequestCtx{}
|
|
ctx.Request.Header.SetMethod("POST")
|
|
ctx.Request.Header.Set("Content-Encoding", "gzip")
|
|
ctx.Request.Header.SetContentLength(len(compressedBody))
|
|
ctx.Request.SetBodyStream(bytes.NewReader(compressedBody), len(compressedBody))
|
|
|
|
nextCalled := false
|
|
next := func(ctx *fasthttp.RequestCtx) {
|
|
nextCalled = true
|
|
if ce := string(ctx.Request.Header.Peek("Content-Encoding")); ce != "" {
|
|
t.Errorf("expected Content-Encoding to be cleared, got %q", ce)
|
|
}
|
|
body := ctx.Request.Body()
|
|
if len(body) != len(plainBody) {
|
|
t.Errorf("decompressed body length: got %d, want %d", len(body), len(plainBody))
|
|
}
|
|
if !bytes.Equal(body, plainBody) {
|
|
t.Error("decompressed body content does not match original")
|
|
}
|
|
}
|
|
|
|
handler := RequestDecompressionMiddleware(config)(next)
|
|
handler(ctx)
|
|
|
|
if !nextCalled {
|
|
t.Fatal("next handler was not called")
|
|
}
|
|
}
|
|
|
|
func gzipCompress(data []byte) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
gz := gzip.NewWriter(&buf)
|
|
if _, err := gz.Write(data); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := gz.Close(); err != nil {
|
|
return nil, err
|
|
}
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
// deflateCompress produces zlib-wrapped DEFLATE (RFC 1950) — the correct
|
|
// format for HTTP Content-Encoding "deflate" per RFC 9110 §8.4.1.2.
|
|
func deflateCompress(data []byte) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
w, err := zlib.NewWriterLevel(&buf, zlib.DefaultCompression)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := w.Write(data); err != nil {
|
|
_ = w.Close()
|
|
return nil, err
|
|
}
|
|
if err := w.Close(); err != nil {
|
|
return nil, err
|
|
}
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
func brotliCompress(data []byte) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
w := brotli.NewWriter(&buf)
|
|
if _, err := w.Write(data); err != nil {
|
|
_ = w.Close()
|
|
return nil, err
|
|
}
|
|
if err := w.Close(); err != nil {
|
|
return nil, err
|
|
}
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
func zstdCompress(data []byte) ([]byte, error) {
|
|
var buf bytes.Buffer
|
|
enc, err := zstd.NewWriter(&buf)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := io.Copy(enc, bytes.NewReader(data)); err != nil {
|
|
enc.Close()
|
|
return nil, err
|
|
}
|
|
if err := enc.Close(); err != nil {
|
|
return nil, err
|
|
}
|
|
return buf.Bytes(), nil
|
|
}
|