first commit
This commit is contained in:
417
core/network/http.go
Normal file
417
core/network/http.go
Normal file
@@ -0,0 +1,417 @@
|
||||
// Package network provides centralized HTTP client management with proxy support.
|
||||
// It allows runtime proxy configuration updates that propagate to all HTTP clients.
|
||||
package network
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
"github.com/valyala/fasthttp/fasthttpproxy"
|
||||
)
|
||||
|
||||
// ClientPurpose defines the intended use of an HTTP client for proxy filtering
|
||||
type ClientPurpose string
|
||||
|
||||
const (
|
||||
// ClientPurposeSCIM is used for SCIM/OAuth provider requests
|
||||
ClientPurposeSCIM ClientPurpose = "scim"
|
||||
// ClientPurposeInference is used for LLM inference requests
|
||||
ClientPurposeInference ClientPurpose = "inference"
|
||||
// ClientPurposeAPI is used for general API requests (guardrails, etc.)
|
||||
ClientPurposeAPI ClientPurpose = "api"
|
||||
)
|
||||
|
||||
// DefaultClientConfig holds default timeout values for HTTP clients
|
||||
var DefaultClientConfig = struct {
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
MaxIdleConnDuration time.Duration
|
||||
MaxConnDuration time.Duration
|
||||
MaxConnsPerHost int
|
||||
}{
|
||||
ReadTimeout: 60 * time.Second,
|
||||
WriteTimeout: 60 * time.Second,
|
||||
MaxIdleConnDuration: 30 * time.Second,
|
||||
MaxConnDuration: 300 * time.Second,
|
||||
MaxConnsPerHost: 200,
|
||||
}
|
||||
|
||||
// GlobalProxyType represents the type of global proxy
|
||||
type GlobalProxyType string
|
||||
|
||||
const (
|
||||
GlobalProxyTypeHTTP GlobalProxyType = "http"
|
||||
GlobalProxyTypeSOCKS5 GlobalProxyType = "socks5"
|
||||
GlobalProxyTypeTCP GlobalProxyType = "tcp"
|
||||
)
|
||||
|
||||
// GlobalProxyConfig represents the global proxy configuration
|
||||
type GlobalProxyConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Type GlobalProxyType `json:"type"` // "http", "socks5", "tcp"
|
||||
URL string `json:"url"` // Proxy URL (e.g., http://proxy.example.com:8080)
|
||||
Username string `json:"username,omitempty"` // Optional authentication username
|
||||
Password string `json:"password,omitempty"` // Optional authentication password
|
||||
NoProxy string `json:"no_proxy,omitempty"` // Comma-separated list of hosts to bypass proxy
|
||||
Timeout int `json:"timeout,omitempty"` // Connection timeout in seconds
|
||||
SkipTLSVerify bool `json:"skip_tls_verify,omitempty"` // Skip TLS certificate verification
|
||||
// Entity enablement flags
|
||||
EnableForSCIM bool `json:"enable_for_scim"` // Enable proxy for SCIM requests (enterprise only)
|
||||
EnableForInference bool `json:"enable_for_inference"` // Enable proxy for inference requests
|
||||
EnableForAPI bool `json:"enable_for_api"` // Enable proxy for API requests
|
||||
}
|
||||
|
||||
// HTTPClientFactory manages HTTP clients with centralized proxy configuration.
|
||||
// It supports both fasthttp and standard net/http clients with purpose-based
|
||||
// proxy enablement (SCIM, Inference, API).
|
||||
type HTTPClientFactory struct {
|
||||
mu sync.RWMutex
|
||||
proxyConfig *GlobalProxyConfig
|
||||
|
||||
// Cached clients per purpose - lazily initialized
|
||||
fasthttpClients map[ClientPurpose]*fasthttp.Client
|
||||
httpClients map[ClientPurpose]*http.Client
|
||||
|
||||
logger schemas.Logger
|
||||
}
|
||||
|
||||
// NewHTTPClientFactory creates a new HTTP client factory with the given proxy configuration.
|
||||
// Pass nil for proxyConfig if proxy is not yet configured.
|
||||
func NewHTTPClientFactory(proxyConfig *GlobalProxyConfig, logger schemas.Logger) *HTTPClientFactory {
|
||||
return &HTTPClientFactory{
|
||||
proxyConfig: proxyConfig,
|
||||
fasthttpClients: make(map[ClientPurpose]*fasthttp.Client, 3),
|
||||
httpClients: make(map[ClientPurpose]*http.Client, 3),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateProxyConfig updates the proxy configuration and recreates all cached clients.
|
||||
// This is thread-safe and can be called at runtime.
|
||||
func (f *HTTPClientFactory) UpdateProxyConfig(config *GlobalProxyConfig) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
f.proxyConfig = config
|
||||
|
||||
// Clear cached clients - they will be recreated on next request
|
||||
f.fasthttpClients = make(map[ClientPurpose]*fasthttp.Client, 3)
|
||||
f.httpClients = make(map[ClientPurpose]*http.Client, 3)
|
||||
}
|
||||
|
||||
// GetProxyConfig returns the current proxy configuration (thread-safe read)
|
||||
func (f *HTTPClientFactory) GetProxyConfig() *GlobalProxyConfig {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
return f.proxyConfig
|
||||
}
|
||||
|
||||
// isProxyEnabledForPurpose checks if proxy should be used for the given purpose
|
||||
func (f *HTTPClientFactory) isProxyEnabledForPurpose(purpose ClientPurpose) bool {
|
||||
if f.proxyConfig == nil || !f.proxyConfig.Enabled {
|
||||
return false
|
||||
}
|
||||
|
||||
switch purpose {
|
||||
case ClientPurposeSCIM:
|
||||
return f.proxyConfig.EnableForSCIM
|
||||
case ClientPurposeInference:
|
||||
return f.proxyConfig.EnableForInference
|
||||
case ClientPurposeAPI:
|
||||
return f.proxyConfig.EnableForAPI
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// shouldBypassProxy checks if a host matches a noProxy pattern
|
||||
// Supported patterns:
|
||||
// - "*" matches all hosts
|
||||
// - ".example.com" matches example.com and all subdomains
|
||||
// - "*.example.com" matches subdomains of example.com only
|
||||
// - exact host match
|
||||
func shouldBypassProxy(host, pattern string) bool {
|
||||
host = strings.ToLower(strings.TrimSpace(host))
|
||||
pattern = strings.ToLower(strings.TrimSpace(pattern))
|
||||
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
if pattern == host {
|
||||
return true
|
||||
}
|
||||
// .example.com matches example.com and *.example.com
|
||||
if strings.HasPrefix(pattern, ".") {
|
||||
suffix := pattern[1:] // remove leading dot
|
||||
return host == suffix || strings.HasSuffix(host, pattern)
|
||||
}
|
||||
// *.example.com matches subdomains only
|
||||
if strings.HasPrefix(pattern, "*.") {
|
||||
suffix := pattern[1:] // keep the dot, e.g., ".example.com"
|
||||
return strings.HasSuffix(host, suffix)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetFasthttpClient returns a fasthttp client configured for the given purpose.
|
||||
// If proxy is enabled for this purpose, the client will be configured with proxy settings.
|
||||
// Clients are cached and reused until proxy config changes.
|
||||
func (f *HTTPClientFactory) GetFasthttpClient(purpose ClientPurpose) *fasthttp.Client {
|
||||
f.mu.RLock()
|
||||
if client, ok := f.fasthttpClients[purpose]; ok {
|
||||
f.mu.RUnlock()
|
||||
return client
|
||||
}
|
||||
f.mu.RUnlock()
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if client, ok := f.fasthttpClients[purpose]; ok {
|
||||
return client
|
||||
}
|
||||
|
||||
client := f.createFasthttpClient(purpose)
|
||||
f.fasthttpClients[purpose] = client
|
||||
return client
|
||||
}
|
||||
|
||||
// GetHTTPClient returns a standard net/http client configured for the given purpose.
|
||||
// If proxy is enabled for this purpose, the client will be configured with proxy settings.
|
||||
// Clients are cached and reused until proxy config changes.
|
||||
func (f *HTTPClientFactory) GetHTTPClient(purpose ClientPurpose) *http.Client {
|
||||
f.mu.RLock()
|
||||
if client, ok := f.httpClients[purpose]; ok {
|
||||
f.mu.RUnlock()
|
||||
return client
|
||||
}
|
||||
f.mu.RUnlock()
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if client, ok := f.httpClients[purpose]; ok {
|
||||
return client
|
||||
}
|
||||
|
||||
client := f.createHTTPClient(purpose)
|
||||
f.httpClients[purpose] = client
|
||||
return client
|
||||
}
|
||||
|
||||
// createFasthttpClient creates a new fasthttp client with appropriate proxy settings
|
||||
func (f *HTTPClientFactory) createFasthttpClient(purpose ClientPurpose) *fasthttp.Client {
|
||||
client := &fasthttp.Client{
|
||||
ReadTimeout: DefaultClientConfig.ReadTimeout,
|
||||
WriteTimeout: DefaultClientConfig.WriteTimeout,
|
||||
MaxIdleConnDuration: DefaultClientConfig.MaxIdleConnDuration,
|
||||
MaxConnDuration: DefaultClientConfig.MaxConnDuration,
|
||||
MaxConnsPerHost: DefaultClientConfig.MaxConnsPerHost,
|
||||
MaxConnWaitTimeout: DefaultClientConfig.ReadTimeout,
|
||||
ConnPoolStrategy: fasthttp.FIFO,
|
||||
RetryIfErr: StaleConnectionRetryIfErr,
|
||||
}
|
||||
|
||||
// Configure proxy if enabled for this purpose
|
||||
if f.isProxyEnabledForPurpose(purpose) {
|
||||
f.configureFasthttpProxy(client)
|
||||
}
|
||||
|
||||
// Configure TLS if skip verification is set
|
||||
if f.proxyConfig != nil {
|
||||
if f.proxyConfig.SkipTLSVerify {
|
||||
f.logger.Warn("skipping TLS verification for fasthttp client because skip TLS verify is set to true. It's not recommended to use this in production.")
|
||||
}
|
||||
client.TLSConfig = &tls.Config{
|
||||
InsecureSkipVerify: f.proxyConfig.SkipTLSVerify,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// StaleConnectionRetryIfErr is a RetryIfErr callback that retries requests when the failure
|
||||
// is due to a stale/dead connection being reused from the pool. This addresses intermittent
|
||||
// "cannot find whitespace in the first line of response" errors caused by connection reuse
|
||||
// with leftover chunked transfer encoding data (see: https://github.com/valyala/fasthttp/issues/1743).
|
||||
//
|
||||
// By default fasthttp only retries idempotent requests (GET/HEAD/PUT). LLM inference requests
|
||||
// use POST, so without this they fail immediately on stale connections. Retrying is safe here
|
||||
// because the error occurs during response header parsing — before the server processes the
|
||||
// new request, or on a connection the server has already closed.
|
||||
func StaleConnectionRetryIfErr(_ *fasthttp.Request, attempts int, err error) (resetTimeout bool, retry bool) {
|
||||
if attempts > 1 {
|
||||
return false, false
|
||||
}
|
||||
if err == nil {
|
||||
return false, false
|
||||
}
|
||||
errStr := err.Error()
|
||||
// io.EOF — server closed the connection (fasthttp converts this to
|
||||
// ErrConnectionClosed AFTER the retry loop, so RetryIfErr sees raw EOF)
|
||||
// "cannot find whitespace in the first line of response" — stale chunked data in buffer
|
||||
// "connection reset by peer" — server RST'd the idle connection (read-side)
|
||||
// "broken pipe" — server closed the idle connection (write-side EPIPE)
|
||||
if err == io.EOF ||
|
||||
strings.Contains(errStr, "cannot find whitespace") ||
|
||||
strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") {
|
||||
return true, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
// buildProxyURLWithAuth adds authentication to a proxy URL if credentials are provided
|
||||
func (f *HTTPClientFactory) buildProxyURLWithAuth() string {
|
||||
proxyURL := f.proxyConfig.URL
|
||||
if f.proxyConfig.Username != "" && f.proxyConfig.Password != "" {
|
||||
parsedURL, err := url.Parse(f.proxyConfig.URL)
|
||||
if err == nil {
|
||||
parsedURL.User = url.UserPassword(f.proxyConfig.Username, f.proxyConfig.Password)
|
||||
proxyURL = parsedURL.String()
|
||||
}
|
||||
}
|
||||
return proxyURL
|
||||
}
|
||||
|
||||
// configureFasthttpProxy configures proxy for a fasthttp client
|
||||
func (f *HTTPClientFactory) configureFasthttpProxy(client *fasthttp.Client) {
|
||||
if f.proxyConfig == nil || f.proxyConfig.URL == "" {
|
||||
return
|
||||
}
|
||||
|
||||
proxyURL := f.buildProxyURLWithAuth()
|
||||
var dialFunc fasthttp.DialFunc
|
||||
|
||||
switch f.proxyConfig.Type {
|
||||
case GlobalProxyTypeHTTP:
|
||||
dialFunc = fasthttpproxy.FasthttpHTTPDialer(proxyURL)
|
||||
case GlobalProxyTypeSOCKS5:
|
||||
dialFunc = fasthttpproxy.FasthttpSocksDialer(proxyURL)
|
||||
}
|
||||
|
||||
proxyCfg := f.proxyConfig
|
||||
if dialFunc != nil {
|
||||
client.Dial = func(addr string) (net.Conn, error) {
|
||||
if proxyCfg.NoProxy != "" {
|
||||
host := strings.Split(addr, ":")[0]
|
||||
if host == "" {
|
||||
host = addr
|
||||
}
|
||||
if shouldBypassProxy(host, proxyCfg.NoProxy) {
|
||||
return net.Dial("tcp", addr)
|
||||
}
|
||||
}
|
||||
return dialFunc(addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createHTTPClient creates a new standard net/http client with appropriate proxy settings
|
||||
func (f *HTTPClientFactory) createHTTPClient(purpose ClientPurpose) *http.Client {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: DefaultClientConfig.MaxConnsPerHost,
|
||||
MaxIdleConnsPerHost: DefaultClientConfig.MaxConnsPerHost,
|
||||
IdleConnTimeout: DefaultClientConfig.MaxIdleConnDuration,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: DefaultClientConfig.ReadTimeout,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
DisableCompression: false,
|
||||
DisableKeepAlives: false,
|
||||
// Disable HTTP/2 — these clients are used for auxiliary purposes (proxy/SCIM/API)
|
||||
// where HTTP/1.1 is sufficient. Without this, Go's http2 package auto-registers
|
||||
// h2 via TLSNextProto in init(), causing unintended HTTP/2 connections.
|
||||
TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
|
||||
}
|
||||
|
||||
// Configure proxy if enabled for this purpose
|
||||
if f.isProxyEnabledForPurpose(purpose) {
|
||||
f.configureHTTPProxy(transport)
|
||||
}
|
||||
|
||||
// Configure TLS if skip verification is set
|
||||
if f.proxyConfig != nil {
|
||||
if f.proxyConfig.SkipTLSVerify {
|
||||
f.logger.Warn("skipping TLS verification for fasthttp client because skip TLS verify is set to true. It's not recommended to use this in production.")
|
||||
}
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: f.proxyConfig.SkipTLSVerify,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
||||
timeout := DefaultClientConfig.ReadTimeout
|
||||
if f.proxyConfig != nil && f.proxyConfig.Timeout > 0 {
|
||||
timeout = time.Duration(f.proxyConfig.Timeout) * time.Second
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// configureHTTPProxy configures proxy for a standard net/http transport
|
||||
func (f *HTTPClientFactory) configureHTTPProxy(transport *http.Transport) {
|
||||
if f.proxyConfig == nil || f.proxyConfig.URL == "" {
|
||||
return
|
||||
}
|
||||
|
||||
proxyURL, err := url.Parse(f.proxyConfig.URL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Add authentication if provided
|
||||
if f.proxyConfig.Username != "" && f.proxyConfig.Password != "" {
|
||||
proxyURL.User = url.UserPassword(f.proxyConfig.Username, f.proxyConfig.Password)
|
||||
|
||||
// For HTTPS requests through HTTP proxy, the CONNECT method is used to establish a tunnel.
|
||||
// Proxy authentication must be sent via ProxyConnectHeader for the CONNECT request.
|
||||
// Without this, the proxy will reject/reset the connection before the TLS handshake.
|
||||
basicAuth := "Basic " + base64.StdEncoding.EncodeToString(
|
||||
[]byte(f.proxyConfig.Username+":"+f.proxyConfig.Password),
|
||||
)
|
||||
transport.ProxyConnectHeader = http.Header{
|
||||
"Proxy-Authorization": {basicAuth},
|
||||
}
|
||||
}
|
||||
|
||||
// Capture noProxy patterns at creation time to avoid data race with UpdateProxyConfig.
|
||||
// The closure below is called for each request and would otherwise read f.proxyConfig
|
||||
// concurrently with writes from UpdateProxyConfig.
|
||||
var noProxyPatterns []string
|
||||
if f.proxyConfig.NoProxy != "" {
|
||||
noProxyPatterns = strings.Split(f.proxyConfig.NoProxy, ",")
|
||||
}
|
||||
|
||||
proxyCfg := f.proxyConfig
|
||||
transport.Proxy = func(req *http.Request) (*url.URL, error) {
|
||||
// Use Hostname() to get the host without port for matching.
|
||||
// req.URL.Host is "host:port" but no_proxy patterns are host-only.
|
||||
if proxyCfg.NoProxy != "" {
|
||||
host := req.URL.Hostname()
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
for _, np := range noProxyPatterns {
|
||||
if shouldBypassProxy(host, np) {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return proxyURL, nil
|
||||
}
|
||||
}
|
||||
522
core/network/http_test.go
Normal file
522
core/network/http_test.go
Normal file
@@ -0,0 +1,522 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// TestStaleConnectionRetryIfErr validates the error-matching logic of
|
||||
// StaleConnectionRetryIfErr for different error types and attempt counts.
|
||||
func TestStaleConnectionRetryIfErr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
attempts int
|
||||
wantReset bool
|
||||
wantRetry bool
|
||||
}{
|
||||
{
|
||||
name: "retries on whitespace error (first attempt)",
|
||||
err: fmt.Errorf(`error when reading response headers: cannot find whitespace in the first line of response "217\r\ndata: ..."`),
|
||||
attempts: 1,
|
||||
wantReset: true,
|
||||
wantRetry: true,
|
||||
},
|
||||
{
|
||||
name: "retries on connection reset by peer",
|
||||
err: fmt.Errorf("read tcp 10.0.0.1:54321->10.0.0.2:443: read: connection reset by peer"),
|
||||
attempts: 1,
|
||||
wantReset: true,
|
||||
wantRetry: true,
|
||||
},
|
||||
{
|
||||
name: "retries on io.EOF (server closed connection)",
|
||||
err: io.EOF,
|
||||
attempts: 1,
|
||||
wantReset: true,
|
||||
wantRetry: true,
|
||||
},
|
||||
{
|
||||
name: "retries on broken pipe (write to closed connection)",
|
||||
err: fmt.Errorf("write tcp 10.0.0.1:53374->10.0.0.2:30000: write: broken pipe"),
|
||||
attempts: 1,
|
||||
wantReset: true,
|
||||
wantRetry: true,
|
||||
},
|
||||
{
|
||||
name: "does not retry on second attempt",
|
||||
err: io.EOF,
|
||||
attempts: 2,
|
||||
wantReset: false,
|
||||
wantRetry: false,
|
||||
},
|
||||
{
|
||||
name: "does not retry on nil error",
|
||||
err: nil,
|
||||
attempts: 1,
|
||||
wantReset: false,
|
||||
wantRetry: false,
|
||||
},
|
||||
{
|
||||
name: "does not retry on unrelated error",
|
||||
err: fmt.Errorf("dial tcp: lookup api.example.com: no such host"),
|
||||
attempts: 1,
|
||||
wantReset: false,
|
||||
wantRetry: false,
|
||||
},
|
||||
{
|
||||
name: "does not retry on timeout",
|
||||
err: fasthttp.ErrTimeout,
|
||||
attempts: 1,
|
||||
wantReset: false,
|
||||
wantRetry: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resetTimeout, retry := StaleConnectionRetryIfErr(nil, tt.attempts, tt.err)
|
||||
if resetTimeout != tt.wantReset {
|
||||
t.Errorf("resetTimeout = %v, want %v", resetTimeout, tt.wantReset)
|
||||
}
|
||||
if retry != tt.wantRetry {
|
||||
t.Errorf("retry = %v, want %v", retry, tt.wantRetry)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaleConnectionRetryWithTTLMismatch simulates the scenario from issue #1613:
|
||||
//
|
||||
// - Server idle timeout: 10 seconds (server closes keep-alive connections after 10s idle)
|
||||
// - Client MaxIdleConnDuration: 15 seconds (client holds connections for 15s)
|
||||
//
|
||||
// Between 10-15 seconds of idle time, the client still considers the connection
|
||||
// valid, but the server has already closed it. The next request on the stale
|
||||
// connection should be retried automatically via StaleConnectionRetryIfErr.
|
||||
//
|
||||
// Without the retry, POST requests fail because fasthttp's default isIdempotent
|
||||
// only retries GET/HEAD/PUT.
|
||||
func TestStaleConnectionRetryWithTTLMismatch(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping TTL mismatch test in short mode (requires 11s wait)")
|
||||
}
|
||||
|
||||
const (
|
||||
serverIdleTimeout = 10 * time.Second
|
||||
clientIdleTimeout = 15 * time.Second
|
||||
waitBetween = 11 * time.Second // > server TTL, < client TTL
|
||||
)
|
||||
|
||||
var requestCount atomic.Int32
|
||||
|
||||
// Start a test server with a 10-second idle timeout.
|
||||
// After 10s of idle time on a keep-alive connection, the server closes it.
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount.Add(1)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, "data: {\"message\": \"ok\", \"request\": %d}\n\n", requestCount.Load())
|
||||
}))
|
||||
server.Config.IdleTimeout = serverIdleTimeout
|
||||
server.Start()
|
||||
defer server.Close()
|
||||
|
||||
t.Run("with_retry_policy_POST_succeeds", func(t *testing.T) {
|
||||
client := &fasthttp.Client{
|
||||
MaxIdleConnDuration: clientIdleTimeout,
|
||||
MaxConnsPerHost: 10,
|
||||
RetryIfErr: StaleConnectionRetryIfErr,
|
||||
}
|
||||
|
||||
// --- First request: fresh connection, must succeed ---
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.Header.SetContentType("application/json")
|
||||
req.SetBodyString(`{"prompt": "hello"}`)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("First POST request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode() != 200 {
|
||||
t.Fatalf("First POST request: expected 200, got %d", resp.StatusCode())
|
||||
}
|
||||
|
||||
// Read body to ensure connection is returned to pool
|
||||
_ = resp.Body()
|
||||
t.Logf("First POST request succeeded (status=%d)", resp.StatusCode())
|
||||
|
||||
// --- Wait for server's idle timeout to expire ---
|
||||
// The server will close the connection after 10s, but the client
|
||||
// still holds it in its pool (MaxIdleConnDuration=15s).
|
||||
t.Logf("Waiting %v for server idle timeout (%v) to expire...", waitBetween, serverIdleTimeout)
|
||||
time.Sleep(waitBetween)
|
||||
|
||||
// --- Second request: stale connection, should retry and succeed ---
|
||||
req2 := fasthttp.AcquireRequest()
|
||||
resp2 := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req2)
|
||||
defer fasthttp.ReleaseResponse(resp2)
|
||||
|
||||
req2.SetRequestURI(server.URL)
|
||||
req2.Header.SetMethod(http.MethodPost)
|
||||
req2.Header.SetContentType("application/json")
|
||||
req2.SetBodyString(`{"prompt": "world"}`)
|
||||
|
||||
if err := client.Do(req2, resp2); err != nil {
|
||||
t.Fatalf("Second POST request failed (StaleConnectionRetryIfErr should have retried): %v", err)
|
||||
}
|
||||
if resp2.StatusCode() != 200 {
|
||||
t.Fatalf("Second POST request: expected 200, got %d", resp2.StatusCode())
|
||||
}
|
||||
t.Logf("Second POST request succeeded after TTL mismatch (status=%d)", resp2.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("without_retry_policy_POST_fails", func(t *testing.T) {
|
||||
// Reset request count
|
||||
requestCount.Store(0)
|
||||
|
||||
client := &fasthttp.Client{
|
||||
MaxIdleConnDuration: clientIdleTimeout,
|
||||
MaxConnsPerHost: 10,
|
||||
// No RetryIfErr — uses default isIdempotent (POST not retried)
|
||||
}
|
||||
|
||||
// --- First request: fresh connection, must succeed ---
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.Header.SetContentType("application/json")
|
||||
req.SetBodyString(`{"prompt": "hello"}`)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("First POST request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode() != 200 {
|
||||
t.Fatalf("First POST request: expected 200, got %d", resp.StatusCode())
|
||||
}
|
||||
_ = resp.Body()
|
||||
t.Logf("First POST request succeeded (status=%d)", resp.StatusCode())
|
||||
|
||||
// --- Wait for server's idle timeout to expire ---
|
||||
t.Logf("Waiting %v for server idle timeout (%v) to expire...", waitBetween, serverIdleTimeout)
|
||||
time.Sleep(waitBetween)
|
||||
|
||||
// --- Second request: stale connection, POST NOT retried by default ---
|
||||
req2 := fasthttp.AcquireRequest()
|
||||
resp2 := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req2)
|
||||
defer fasthttp.ReleaseResponse(resp2)
|
||||
|
||||
req2.SetRequestURI(server.URL)
|
||||
req2.Header.SetMethod(http.MethodPost)
|
||||
req2.Header.SetContentType("application/json")
|
||||
req2.SetBodyString(`{"prompt": "world"}`)
|
||||
|
||||
err := client.Do(req2, resp2)
|
||||
if err != nil {
|
||||
// Expected: POST request fails on stale connection without retry
|
||||
t.Logf("Second POST request failed as expected without retry policy: %v", err)
|
||||
} else {
|
||||
// The OS may have already delivered the FIN and fasthttp detected it,
|
||||
// creating a new connection transparently. This is acceptable — the
|
||||
// retry policy provides defense-in-depth for cases where FIN delivery
|
||||
// is delayed (common with TLS, proxies, and load balancers in K8s).
|
||||
t.Logf("Second POST request succeeded (OS delivered FIN before reuse) — retry policy still provides defense-in-depth")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMaxConnDurationForcesReconnection verifies that MaxConnDuration causes
|
||||
// fasthttp to close and replace connections after the configured lifetime,
|
||||
// preventing stale long-lived connections from accumulating during sustained
|
||||
// back-to-back request traffic.
|
||||
//
|
||||
// Uses the server's ConnState callback to reliably count new TCP connections
|
||||
// (r.RemoteAddr is unreliable because the OS can reuse ephemeral ports).
|
||||
func TestMaxConnDurationForcesReconnection(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping MaxConnDuration test in short mode (requires ~4s wait)")
|
||||
}
|
||||
|
||||
const maxConnDuration = 2 * time.Second
|
||||
|
||||
// Track new connections via ConnState (fires once per new TCP accept)
|
||||
var newConnCount atomic.Int32
|
||||
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "ok")
|
||||
}))
|
||||
server.Config.ConnState = func(_ net.Conn, state http.ConnState) {
|
||||
if state == http.StateNew {
|
||||
newConnCount.Add(1)
|
||||
}
|
||||
}
|
||||
server.Start()
|
||||
defer server.Close()
|
||||
|
||||
t.Run("with_MaxConnDuration_connection_is_recycled", func(t *testing.T) {
|
||||
newConnCount.Store(0)
|
||||
|
||||
client := &fasthttp.Client{
|
||||
MaxConnsPerHost: 1,
|
||||
MaxConnDuration: maxConnDuration,
|
||||
}
|
||||
|
||||
// First request: establishes connection A
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.SetBodyString(`{"test": 1}`)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("First request failed: %v", err)
|
||||
}
|
||||
_ = resp.Body()
|
||||
|
||||
connsAfterFirst := newConnCount.Load()
|
||||
t.Logf("After first request: %d new connections", connsAfterFirst)
|
||||
|
||||
// Wait for MaxConnDuration to expire
|
||||
t.Logf("Waiting %v for MaxConnDuration to expire...", maxConnDuration+500*time.Millisecond)
|
||||
time.Sleep(maxConnDuration + 500*time.Millisecond)
|
||||
|
||||
// Second request: reuses connection A but sends Connection: close
|
||||
// (fasthttp's MaxConnDuration sets Connection: close on expired conns,
|
||||
// telling the server to close the connection after the response)
|
||||
req2 := fasthttp.AcquireRequest()
|
||||
resp2 := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req2)
|
||||
defer fasthttp.ReleaseResponse(resp2)
|
||||
|
||||
req2.SetRequestURI(server.URL)
|
||||
req2.Header.SetMethod(http.MethodPost)
|
||||
req2.SetBodyString(`{"test": 2}`)
|
||||
|
||||
if err := client.Do(req2, resp2); err != nil {
|
||||
t.Fatalf("Second request failed: %v", err)
|
||||
}
|
||||
_ = resp2.Body()
|
||||
|
||||
// Third request: connection A is now closed by server → must create connection B
|
||||
req3 := fasthttp.AcquireRequest()
|
||||
resp3 := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req3)
|
||||
defer fasthttp.ReleaseResponse(resp3)
|
||||
|
||||
req3.SetRequestURI(server.URL)
|
||||
req3.Header.SetMethod(http.MethodPost)
|
||||
req3.SetBodyString(`{"test": 3}`)
|
||||
|
||||
if err := client.Do(req3, resp3); err != nil {
|
||||
t.Fatalf("Third request failed: %v", err)
|
||||
}
|
||||
|
||||
connsAfterThird := newConnCount.Load()
|
||||
if connsAfterThird < 2 {
|
||||
t.Errorf("expected at least 2 new connections after MaxConnDuration recycling, got %d", connsAfterThird)
|
||||
} else {
|
||||
t.Logf("Connection recycled: %d total new connections", connsAfterThird)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("without_MaxConnDuration_connection_is_reused", func(t *testing.T) {
|
||||
newConnCount.Store(0)
|
||||
|
||||
client := &fasthttp.Client{
|
||||
MaxConnsPerHost: 1,
|
||||
// No MaxConnDuration — connections live forever
|
||||
}
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.SetBodyString(`{"test": 1}`)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("First request failed: %v", err)
|
||||
}
|
||||
_ = resp.Body()
|
||||
|
||||
// Wait same duration as above
|
||||
time.Sleep(maxConnDuration + 500*time.Millisecond)
|
||||
|
||||
req2 := fasthttp.AcquireRequest()
|
||||
resp2 := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req2)
|
||||
defer fasthttp.ReleaseResponse(resp2)
|
||||
|
||||
req2.SetRequestURI(server.URL)
|
||||
req2.Header.SetMethod(http.MethodPost)
|
||||
req2.SetBodyString(`{"test": 2}`)
|
||||
|
||||
if err := client.Do(req2, resp2); err != nil {
|
||||
t.Fatalf("Second request failed: %v", err)
|
||||
}
|
||||
|
||||
totalConns := newConnCount.Load()
|
||||
// Without MaxConnDuration, the same connection should be reused
|
||||
if totalConns == 1 {
|
||||
t.Logf("Connection reused as expected: only 1 new connection total")
|
||||
} else {
|
||||
// OS/server may have closed it — that's acceptable
|
||||
t.Logf("Saw %d new connections (OS/server may have recycled)", totalConns)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMaxConnWaitTimeoutAlignedWithReadTimeout verifies that when the connection
|
||||
// pool is exhausted, requests wait for MaxConnWaitTimeout (aligned with ReadTimeout)
|
||||
// before failing, not the old hardcoded 10s.
|
||||
func TestMaxConnWaitTimeoutAlignedWithReadTimeout(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping pool exhaustion test in short mode (requires ~4s wait)")
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Hold the connection for 3 seconds to simulate a slow provider
|
||||
time.Sleep(3 * time.Second)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "ok")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &fasthttp.Client{
|
||||
MaxConnsPerHost: 1, // Only 1 connection allowed — second request must wait
|
||||
MaxConnWaitTimeout: 2 * time.Second, // Wait up to 2s for a free connection slot
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
// Fire first request (occupies the only connection slot for 3s)
|
||||
var wg sync.WaitGroup
|
||||
firstReqErr := make(chan error, 1)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.SetBodyString(`{"slot": "occupied"}`)
|
||||
|
||||
firstReqErr <- client.Do(req, resp)
|
||||
}()
|
||||
|
||||
// Brief pause to ensure first request is in-flight
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Second request: pool is full, should timeout after ~2s (MaxConnWaitTimeout)
|
||||
start := time.Now()
|
||||
req2 := fasthttp.AcquireRequest()
|
||||
resp2 := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req2)
|
||||
defer fasthttp.ReleaseResponse(resp2)
|
||||
|
||||
req2.SetRequestURI(server.URL)
|
||||
req2.Header.SetMethod(http.MethodPost)
|
||||
req2.SetBodyString(`{"waiting": true}`)
|
||||
|
||||
err := client.Do(req2, resp2)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if firstErr := <-firstReqErr; firstErr != nil {
|
||||
t.Fatalf("first request failed; pool-exhaustion scenario was not exercised: %v", firstErr)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
// The first request may have finished before MaxConnWaitTimeout expired,
|
||||
// allowing the second request to succeed. This is acceptable.
|
||||
t.Logf("Second request succeeded (first request completed in time, elapsed=%v)", elapsed)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the wait time is close to MaxConnWaitTimeout (2s), not 0s or 5s
|
||||
if elapsed < 1500*time.Millisecond || elapsed > 3500*time.Millisecond {
|
||||
t.Errorf("expected pool wait ~2s, but elapsed=%v (err=%v)", elapsed, err)
|
||||
} else {
|
||||
t.Logf("Pool exhaustion timeout at %v as expected (err=%v)", elapsed, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultClientConfigValues verifies that DefaultClientConfig contains
|
||||
// the expected values for connection pool settings.
|
||||
func TestDefaultClientConfigValues(t *testing.T) {
|
||||
if DefaultClientConfig.ReadTimeout != 60*time.Second {
|
||||
t.Errorf("ReadTimeout = %v, want 60s", DefaultClientConfig.ReadTimeout)
|
||||
}
|
||||
if DefaultClientConfig.WriteTimeout != 60*time.Second {
|
||||
t.Errorf("WriteTimeout = %v, want 60s", DefaultClientConfig.WriteTimeout)
|
||||
}
|
||||
if DefaultClientConfig.MaxIdleConnDuration != 30*time.Second {
|
||||
t.Errorf("MaxIdleConnDuration = %v, want 30s", DefaultClientConfig.MaxIdleConnDuration)
|
||||
}
|
||||
if DefaultClientConfig.MaxConnDuration != 300*time.Second {
|
||||
t.Errorf("MaxConnDuration = %v, want 300s", DefaultClientConfig.MaxConnDuration)
|
||||
}
|
||||
if DefaultClientConfig.MaxConnsPerHost != 200 {
|
||||
t.Errorf("MaxConnsPerHost = %d, want 200", DefaultClientConfig.MaxConnsPerHost)
|
||||
}
|
||||
// Verify the provider-level constant matches
|
||||
if schemas.DefaultMaxConnDurationInSeconds != 300 {
|
||||
t.Errorf("DefaultMaxConnDurationInSeconds = %d, want 300", schemas.DefaultMaxConnDurationInSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateFasthttpClientPoolSettings verifies that the HTTPClientFactory
|
||||
// creates fasthttp clients with the correct pool settings including
|
||||
// MaxConnDuration, MaxConnWaitTimeout, and FIFO ConnPoolStrategy.
|
||||
func TestCreateFasthttpClientPoolSettings(t *testing.T) {
|
||||
factory := NewHTTPClientFactory(nil, nil)
|
||||
client := factory.GetFasthttpClient(ClientPurposeInference)
|
||||
|
||||
if client.MaxConnDuration != DefaultClientConfig.MaxConnDuration {
|
||||
t.Errorf("MaxConnDuration = %v, want %v", client.MaxConnDuration, DefaultClientConfig.MaxConnDuration)
|
||||
}
|
||||
if client.MaxConnWaitTimeout != DefaultClientConfig.ReadTimeout {
|
||||
t.Errorf("MaxConnWaitTimeout = %v, want %v (aligned with ReadTimeout)", client.MaxConnWaitTimeout, DefaultClientConfig.ReadTimeout)
|
||||
}
|
||||
if client.ConnPoolStrategy != fasthttp.FIFO {
|
||||
t.Errorf("ConnPoolStrategy = %v, want FIFO (%v)", client.ConnPoolStrategy, fasthttp.FIFO)
|
||||
}
|
||||
if client.MaxIdleConnDuration != DefaultClientConfig.MaxIdleConnDuration {
|
||||
t.Errorf("MaxIdleConnDuration = %v, want %v", client.MaxIdleConnDuration, DefaultClientConfig.MaxIdleConnDuration)
|
||||
}
|
||||
if client.MaxConnsPerHost != DefaultClientConfig.MaxConnsPerHost {
|
||||
t.Errorf("MaxConnsPerHost = %d, want %d", client.MaxConnsPerHost, DefaultClientConfig.MaxConnsPerHost)
|
||||
}
|
||||
}
|
||||
166
core/network/multipart.go
Normal file
166
core/network/multipart.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ParseMultipartFormFields extracts text form fields from a multipart/form-data body,
|
||||
// skipping file parts to avoid loading binary data into memory.
|
||||
func ParseMultipartFormFields(contentType string, body []byte) (map[string]any, error) {
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return nil, fmt.Errorf("no boundary in content-type")
|
||||
}
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
payload := make(map[string]any)
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if part.FileName() != "" {
|
||||
_ = part.Close()
|
||||
continue
|
||||
}
|
||||
name := part.FormName()
|
||||
if name != "" {
|
||||
val, readErr := io.ReadAll(part)
|
||||
if readErr != nil {
|
||||
_ = part.Close()
|
||||
return nil, readErr
|
||||
}
|
||||
payload[name] = string(val)
|
||||
}
|
||||
_ = part.Close()
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// ReconstructMultipartBody rebuilds a multipart/form-data body from the original,
|
||||
// replacing text field values with those from payload (e.g. updated "model") and
|
||||
// copying file parts byte-for-byte.
|
||||
func ReconstructMultipartBody(origContentType string, origBody []byte, payload map[string]any) ([]byte, string, error) {
|
||||
_, params, err := mime.ParseMediaType(origContentType)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return nil, "", fmt.Errorf("no boundary in content-type")
|
||||
}
|
||||
reader := multipart.NewReader(bytes.NewReader(origBody), boundary)
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
writtenFields := make(map[string]bool)
|
||||
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
name := part.FormName()
|
||||
if part.FileName() != "" {
|
||||
fw, createErr := writer.CreatePart(part.Header)
|
||||
if createErr != nil {
|
||||
_ = part.Close()
|
||||
return nil, "", createErr
|
||||
}
|
||||
if _, copyErr := io.Copy(fw, part); copyErr != nil {
|
||||
_ = part.Close()
|
||||
return nil, "", copyErr
|
||||
}
|
||||
} else if name != "" {
|
||||
if val, ok := payload[name]; ok {
|
||||
if err := WriteMultipartField(writer, name, val); err != nil {
|
||||
_ = part.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
} else {
|
||||
origVal, readErr := io.ReadAll(part)
|
||||
if readErr != nil {
|
||||
_ = part.Close()
|
||||
return nil, "", readErr
|
||||
}
|
||||
if err := writer.WriteField(name, string(origVal)); err != nil {
|
||||
_ = part.Close()
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
writtenFields[name] = true
|
||||
}
|
||||
_ = part.Close()
|
||||
}
|
||||
|
||||
for key, val := range payload {
|
||||
if writtenFields[key] {
|
||||
continue
|
||||
}
|
||||
if err := WriteMultipartField(writer, key, val); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return buf.Bytes(), writer.FormDataContentType(), nil
|
||||
}
|
||||
|
||||
// WriteMultipartField writes a single form field to the multipart writer,
|
||||
// handling string, []string, and other value types.
|
||||
func WriteMultipartField(writer *multipart.Writer, name string, val any) error {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
return writer.WriteField(name, v)
|
||||
case []string:
|
||||
encoded, err := schemas.MarshalSorted(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writer.WriteField(name, string(encoded))
|
||||
default:
|
||||
return writer.WriteField(name, fmt.Sprintf("%v", val))
|
||||
}
|
||||
}
|
||||
|
||||
// SerializePayloadToRequest writes the modified payload back to req.Body,
|
||||
// using multipart reconstruction for multipart/form-data or JSON for everything else.
|
||||
func SerializePayloadToRequest(req *schemas.HTTPRequest, payload map[string]any, isMultipart bool, origContentType string) error {
|
||||
if isMultipart {
|
||||
newBody, newCT, err := ReconstructMultipartBody(origContentType, req.Body, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Body = newBody
|
||||
for k := range req.Headers {
|
||||
if strings.EqualFold(k, "content-type") {
|
||||
delete(req.Headers, k)
|
||||
}
|
||||
}
|
||||
req.Headers["Content-Type"] = newCT
|
||||
return nil
|
||||
}
|
||||
body, err := schemas.MarshalSorted(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Body = body
|
||||
return nil
|
||||
}
|
||||
332
core/network/multipart_test.go
Normal file
332
core/network/multipart_test.go
Normal file
@@ -0,0 +1,332 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// buildMultipartBody is a test helper that creates a multipart/form-data body
|
||||
// with the given text fields and optional file parts.
|
||||
func buildMultipartBody(t *testing.T, fields map[string]string, files map[string][]byte) ([]byte, string) {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
for name, val := range fields {
|
||||
if err := writer.WriteField(name, val); err != nil {
|
||||
t.Fatalf("WriteField(%q): %v", name, err)
|
||||
}
|
||||
}
|
||||
for name, data := range files {
|
||||
fw, err := writer.CreateFormFile(name, name+".bin")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateFormFile(%q): %v", name, err)
|
||||
}
|
||||
if _, err := fw.Write(data); err != nil {
|
||||
t.Fatalf("Write file data: %v", err)
|
||||
}
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("writer.Close: %v", err)
|
||||
}
|
||||
return buf.Bytes(), writer.FormDataContentType()
|
||||
}
|
||||
|
||||
func partOrderFromMultipartBody(t *testing.T, contentType string, body []byte) []string {
|
||||
t.Helper()
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseMediaType(%q): %v", contentType, err)
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
t.Fatalf("no boundary in content-type %q", contentType)
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
||||
var order []string
|
||||
for {
|
||||
part, err := reader.NextPart()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NextPart(): %v", err)
|
||||
}
|
||||
order = append(order, part.FormName())
|
||||
_, _ = io.Copy(io.Discard, part)
|
||||
_ = part.Close()
|
||||
}
|
||||
return order
|
||||
}
|
||||
|
||||
func TestParseMultipartFormFields(t *testing.T) {
|
||||
t.Run("extracts text fields and skips files", func(t *testing.T) {
|
||||
body, ct := buildMultipartBody(t,
|
||||
map[string]string{"model": "gpt-4", "prompt": "hello"},
|
||||
map[string][]byte{"image": {0xFF, 0xD8, 0xFF}},
|
||||
)
|
||||
result, err := ParseMultipartFormFields(ct, body)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result["model"] != "gpt-4" {
|
||||
t.Errorf("model = %v, want gpt-4", result["model"])
|
||||
}
|
||||
if result["prompt"] != "hello" {
|
||||
t.Errorf("prompt = %v, want hello", result["prompt"])
|
||||
}
|
||||
if _, exists := result["image"]; exists {
|
||||
t.Error("file part 'image' should have been skipped")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error on missing boundary", func(t *testing.T) {
|
||||
_, err := ParseMultipartFormFields("multipart/form-data", []byte("irrelevant"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing boundary")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error on invalid content-type", func(t *testing.T) {
|
||||
_, err := ParseMultipartFormFields(";;;invalid", []byte("irrelevant"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid content-type")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns empty map for empty body", func(t *testing.T) {
|
||||
// A valid multipart body with no parts (just the closing boundary).
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
ct := writer.FormDataContentType()
|
||||
_ = writer.Close()
|
||||
|
||||
result, err := ParseMultipartFormFields(ct, buf.Bytes())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected empty map, got %v", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReconstructMultipartBody(t *testing.T) {
|
||||
t.Run("replaces text field value", func(t *testing.T) {
|
||||
body, ct := buildMultipartBody(t,
|
||||
map[string]string{"model": "gpt-3.5", "prompt": "hi"},
|
||||
nil,
|
||||
)
|
||||
payload := map[string]any{"model": "gpt-4"}
|
||||
newBody, newCT, err := ReconstructMultipartBody(ct, body, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(newCT, "multipart/form-data") {
|
||||
t.Errorf("content-type = %v, want multipart/form-data prefix", newCT)
|
||||
}
|
||||
// Parse the reconstructed body and verify the value was replaced.
|
||||
parsed, err := ParseMultipartFormFields(newCT, newBody)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse reconstructed body: %v", err)
|
||||
}
|
||||
if parsed["model"] != "gpt-4" {
|
||||
t.Errorf("model = %v, want gpt-4", parsed["model"])
|
||||
}
|
||||
if parsed["prompt"] != "hi" {
|
||||
t.Errorf("prompt = %v, want hi (should be preserved)", parsed["prompt"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("adds new fields from payload", func(t *testing.T) {
|
||||
body, ct := buildMultipartBody(t,
|
||||
map[string]string{"model": "gpt-4"},
|
||||
nil,
|
||||
)
|
||||
payload := map[string]any{"model": "gpt-4", "temperature": "0.7"}
|
||||
newBody, newCT, err := ReconstructMultipartBody(ct, body, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
parsed, err := ParseMultipartFormFields(newCT, newBody)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse: %v", err)
|
||||
}
|
||||
if parsed["temperature"] != "0.7" {
|
||||
t.Errorf("temperature = %v, want 0.7", parsed["temperature"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves file parts while updating text fields", func(t *testing.T) {
|
||||
body, ct := buildMultipartBody(t,
|
||||
map[string]string{"prompt": "hi"},
|
||||
map[string][]byte{"file": []byte("audio-bytes")},
|
||||
)
|
||||
payload := map[string]any{"model": "whisper-1", "prompt": "updated"}
|
||||
newBody, newCT, err := ReconstructMultipartBody(ct, body, payload)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
order := partOrderFromMultipartBody(t, newCT, newBody)
|
||||
if len(order) != 3 {
|
||||
t.Fatalf("unexpected part count %d: %v", len(order), order)
|
||||
}
|
||||
if order[0] != "prompt" || order[1] != "file" || order[2] != "model" {
|
||||
t.Fatalf("unexpected multipart order %v, want [prompt file model]", order)
|
||||
}
|
||||
|
||||
parsed, err := ParseMultipartFormFields(newCT, newBody)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse reconstructed body: %v", err)
|
||||
}
|
||||
if parsed["prompt"] != "updated" {
|
||||
t.Fatalf("prompt = %v, want updated", parsed["prompt"])
|
||||
}
|
||||
if parsed["model"] != "whisper-1" {
|
||||
t.Fatalf("model = %v, want whisper-1", parsed["model"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error on missing boundary", func(t *testing.T) {
|
||||
_, _, err := ReconstructMultipartBody("multipart/form-data", []byte("data"), map[string]any{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing boundary")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWriteMultipartField(t *testing.T) {
|
||||
t.Run("writes string value", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
if err := WriteMultipartField(writer, "key", "value"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
_ = writer.Close()
|
||||
parsed, err := ParseMultipartFormFields(writer.FormDataContentType(), buf.Bytes())
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
if parsed["key"] != "value" {
|
||||
t.Errorf("key = %v, want value", parsed["key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("writes []string as JSON array", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
if err := WriteMultipartField(writer, "tags", []string{"a", "b"}); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
_ = writer.Close()
|
||||
parsed, err := ParseMultipartFormFields(writer.FormDataContentType(), buf.Bytes())
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
val, ok := parsed["tags"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("tags not a string, got %T", parsed["tags"])
|
||||
}
|
||||
var arr []string
|
||||
if err := sonic.UnmarshalString(val, &arr); err != nil {
|
||||
t.Fatalf("failed to unmarshal tags JSON: %v", err)
|
||||
}
|
||||
if len(arr) != 2 || arr[0] != "a" || arr[1] != "b" {
|
||||
t.Errorf("tags = %v, want [a b]", arr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("writes non-string with Sprintf", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
writer := multipart.NewWriter(&buf)
|
||||
if err := WriteMultipartField(writer, "count", 42); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
_ = writer.Close()
|
||||
parsed, err := ParseMultipartFormFields(writer.FormDataContentType(), buf.Bytes())
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
if parsed["count"] != "42" {
|
||||
t.Errorf("count = %v, want 42", parsed["count"])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSerializePayloadToRequest(t *testing.T) {
|
||||
t.Run("JSON path", func(t *testing.T) {
|
||||
req := &schemas.HTTPRequest{
|
||||
Headers: map[string]string{"Content-Type": "application/json"},
|
||||
Body: []byte(`{"old":"data"}`),
|
||||
}
|
||||
payload := map[string]any{"model": "gpt-4", "prompt": "test"}
|
||||
if err := SerializePayloadToRequest(req, payload, false, "application/json"); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
var result map[string]any
|
||||
if err := sonic.Unmarshal(req.Body, &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal result: %v", err)
|
||||
}
|
||||
if result["model"] != "gpt-4" {
|
||||
t.Errorf("model = %v, want gpt-4", result["model"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multipart path", func(t *testing.T) {
|
||||
body, ct := buildMultipartBody(t,
|
||||
map[string]string{"model": "gpt-3.5"},
|
||||
nil,
|
||||
)
|
||||
req := &schemas.HTTPRequest{
|
||||
Headers: map[string]string{"Content-Type": ct},
|
||||
Body: body,
|
||||
}
|
||||
payload := map[string]any{"model": "gpt-4"}
|
||||
if err := SerializePayloadToRequest(req, payload, true, ct); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Verify content-type was updated
|
||||
newCT := req.Headers["Content-Type"]
|
||||
if !strings.HasPrefix(newCT, "multipart/form-data") {
|
||||
t.Errorf("content-type = %v, want multipart/form-data prefix", newCT)
|
||||
}
|
||||
// Verify the body contains the updated model
|
||||
parsed, err := ParseMultipartFormFields(newCT, req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
if parsed["model"] != "gpt-4" {
|
||||
t.Errorf("model = %v, want gpt-4", parsed["model"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multipart path removes old content-type header case-insensitively", func(t *testing.T) {
|
||||
body, ct := buildMultipartBody(t,
|
||||
map[string]string{"field": "val"},
|
||||
nil,
|
||||
)
|
||||
req := &schemas.HTTPRequest{
|
||||
Headers: map[string]string{"content-type": ct},
|
||||
Body: body,
|
||||
}
|
||||
payload := map[string]any{"field": "val"}
|
||||
if err := SerializePayloadToRequest(req, payload, true, ct); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// The old lowercase "content-type" should be gone, replaced by "Content-Type".
|
||||
if _, exists := req.Headers["content-type"]; exists {
|
||||
t.Error("old lowercase content-type header should have been removed")
|
||||
}
|
||||
if _, exists := req.Headers["Content-Type"]; !exists {
|
||||
t.Error("new Content-Type header should be set")
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user