360 lines
11 KiB
Go
360 lines
11 KiB
Go
package utils
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/maximhq/bifrost/core/network"
|
|
"github.com/valyala/fasthttp"
|
|
)
|
|
|
|
// TestConfigureDialer_SetsRetryIfErr verifies that ConfigureDialer installs
|
|
// the StaleConnectionRetryIfErr callback on the client.
|
|
func TestConfigureDialer_SetsRetryIfErr(t *testing.T) {
|
|
client := &fasthttp.Client{}
|
|
if client.RetryIfErr != nil {
|
|
t.Fatal("precondition: RetryIfErr should be nil on a new client")
|
|
}
|
|
|
|
ConfigureDialer(client)
|
|
|
|
if client.RetryIfErr == nil {
|
|
t.Fatal("ConfigureDialer should set RetryIfErr")
|
|
}
|
|
|
|
// Verify it behaves like StaleConnectionRetryIfErr
|
|
reset, retry := client.RetryIfErr(nil, 1, fmt.Errorf("cannot find whitespace in the first line of response"))
|
|
if !reset || !retry {
|
|
t.Error("RetryIfErr should retry on whitespace error")
|
|
}
|
|
reset, retry = client.RetryIfErr(nil, 1, fmt.Errorf("dial tcp: no such host"))
|
|
if reset || retry {
|
|
t.Error("RetryIfErr should not retry on unrelated errors")
|
|
}
|
|
}
|
|
|
|
// TestConfigureDialer_SetsDial verifies that ConfigureDialer installs a custom
|
|
// Dial function on the client when no existing Dial is present.
|
|
func TestConfigureDialer_SetsDial(t *testing.T) {
|
|
client := &fasthttp.Client{}
|
|
if client.Dial != nil {
|
|
t.Fatal("precondition: Dial should be nil on a new client")
|
|
}
|
|
|
|
ConfigureDialer(client)
|
|
|
|
if client.Dial == nil {
|
|
t.Fatal("ConfigureDialer should set a Dial function")
|
|
}
|
|
}
|
|
|
|
// TestConfigureDialer_ComposesWithExistingDial verifies that when a custom Dial
|
|
// function is already set (e.g., from ConfigureProxy), ConfigureDialer wraps it
|
|
// and still enables TCP keepalive on the resulting connection.
|
|
func TestConfigureDialer_ComposesWithExistingDial(t *testing.T) {
|
|
var proxyDialCalled atomic.Bool
|
|
|
|
client := &fasthttp.Client{}
|
|
// Simulate a proxy dial function (set by ConfigureProxy)
|
|
client.Dial = func(addr string) (net.Conn, error) {
|
|
proxyDialCalled.Store(true)
|
|
return net.Dial("tcp", addr)
|
|
}
|
|
|
|
ConfigureDialer(client)
|
|
|
|
// Start a test server to connect to
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
fmt.Fprint(w, "ok")
|
|
}))
|
|
defer server.Close()
|
|
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
req.SetRequestURI(server.URL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
|
|
if err := client.Do(req, resp); err != nil {
|
|
t.Fatalf("request failed: %v", err)
|
|
}
|
|
if resp.StatusCode() != 200 {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode())
|
|
}
|
|
if !proxyDialCalled.Load() {
|
|
t.Error("ConfigureDialer should have called the existing proxy dial function")
|
|
}
|
|
}
|
|
|
|
// TestConfigureDialer_TCPKeepAliveEnabled verifies that connections created
|
|
// through ConfigureDialer have TCP keepalive enabled.
|
|
func TestConfigureDialer_TCPKeepAliveEnabled(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
fmt.Fprint(w, "ok")
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Test without existing dial (direct connection path)
|
|
t.Run("without_existing_dial", func(t *testing.T) {
|
|
client := &fasthttp.Client{}
|
|
ConfigureDialer(client)
|
|
|
|
// The Dial function should create connections with keepalive
|
|
// We can verify by making a connection and checking the TCP options
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
req.SetRequestURI(server.URL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
|
|
if err := client.Do(req, resp); err != nil {
|
|
t.Fatalf("request failed: %v", err)
|
|
}
|
|
if resp.StatusCode() != 200 {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode())
|
|
}
|
|
})
|
|
|
|
// Test with existing dial (proxy composition path)
|
|
t.Run("with_existing_dial", func(t *testing.T) {
|
|
var connFromProxy net.Conn
|
|
client := &fasthttp.Client{}
|
|
client.Dial = func(addr string) (net.Conn, error) {
|
|
conn, err := net.Dial("tcp", addr)
|
|
connFromProxy = conn
|
|
return conn, err
|
|
}
|
|
ConfigureDialer(client)
|
|
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
req.SetRequestURI(server.URL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
|
|
if err := client.Do(req, resp); err != nil {
|
|
t.Fatalf("request failed: %v", err)
|
|
}
|
|
|
|
// Verify the proxy-returned connection is a TCP connection
|
|
// (ConfigureDialer enables keepalive via SetKeepAliveConfig on it)
|
|
if connFromProxy == nil {
|
|
t.Fatal("proxy dial should have been called")
|
|
}
|
|
if _, ok := connFromProxy.(*net.TCPConn); !ok {
|
|
t.Errorf("expected *net.TCPConn, got %T", connFromProxy)
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestConfigureDialer_ReturnValue verifies that ConfigureDialer returns the
|
|
// same client pointer it received (for chaining).
|
|
func TestConfigureDialer_ReturnValue(t *testing.T) {
|
|
client := &fasthttp.Client{}
|
|
result := ConfigureDialer(client)
|
|
if result != client {
|
|
t.Error("ConfigureDialer should return the same client pointer")
|
|
}
|
|
}
|
|
|
|
// TestConfigureDialer_Idempotent verifies that calling ConfigureDialer multiple
|
|
// times doesn't break the client.
|
|
func TestConfigureDialer_Idempotent(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
fmt.Fprint(w, "ok")
|
|
}))
|
|
defer server.Close()
|
|
|
|
client := &fasthttp.Client{}
|
|
ConfigureDialer(client)
|
|
ConfigureDialer(client) // called again
|
|
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
req.SetRequestURI(server.URL)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.SetBodyString(`{"test": true}`)
|
|
|
|
if err := client.Do(req, resp); err != nil {
|
|
t.Fatalf("request failed after double ConfigureDialer: %v", err)
|
|
}
|
|
if resp.StatusCode() != 200 {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode())
|
|
}
|
|
}
|
|
|
|
// TestConfigureDialer_WithRetryOnStaleConnection is an integration test that
|
|
// verifies ConfigureDialer enables successful POST retry after TTL mismatch.
|
|
// This combines both the retry and keepalive behaviors.
|
|
func TestConfigureDialer_WithRetryOnStaleConnection(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping TTL mismatch test in short mode (requires 11s wait)")
|
|
}
|
|
|
|
const (
|
|
serverIdleTimeout = 10 * time.Second
|
|
clientIdleTimeout = 15 * time.Second
|
|
waitBetween = 11 * time.Second
|
|
)
|
|
|
|
var requestCount atomic.Int32
|
|
|
|
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
requestCount.Add(1)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
fmt.Fprintf(w, `{"ok": true, "request": %d}`, requestCount.Load())
|
|
}))
|
|
server.Config.IdleTimeout = serverIdleTimeout
|
|
server.Start()
|
|
defer server.Close()
|
|
|
|
client := &fasthttp.Client{
|
|
MaxIdleConnDuration: clientIdleTimeout,
|
|
MaxConnsPerHost: 10,
|
|
}
|
|
// Use ConfigureDialer (the function under test) instead of manually setting RetryIfErr
|
|
ConfigureDialer(client)
|
|
|
|
// First request: establish connection in pool
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
req.SetRequestURI(server.URL)
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.SetBodyString(`{"prompt": "hello"}`)
|
|
|
|
if err := client.Do(req, resp); err != nil {
|
|
t.Fatalf("First POST failed: %v", err)
|
|
}
|
|
if resp.StatusCode() != 200 {
|
|
t.Fatalf("First POST: expected 200, got %d", resp.StatusCode())
|
|
}
|
|
_ = resp.Body()
|
|
|
|
// Wait for server TTL to expire
|
|
t.Logf("Waiting %v for server idle timeout to expire...", waitBetween)
|
|
time.Sleep(waitBetween)
|
|
|
|
// Second request: stale connection should be retried by ConfigureDialer's retry policy
|
|
req2 := fasthttp.AcquireRequest()
|
|
resp2 := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req2)
|
|
defer fasthttp.ReleaseResponse(resp2)
|
|
|
|
req2.SetRequestURI(server.URL)
|
|
req2.Header.SetMethod(http.MethodPost)
|
|
req2.SetBodyString(`{"prompt": "world"}`)
|
|
|
|
if err := client.Do(req2, resp2); err != nil {
|
|
t.Fatalf("Second POST failed (ConfigureDialer retry should have saved it): %v", err)
|
|
}
|
|
if resp2.StatusCode() != 200 {
|
|
t.Fatalf("Second POST: expected 200, got %d", resp2.StatusCode())
|
|
}
|
|
t.Logf("Second POST succeeded after TTL mismatch via ConfigureDialer")
|
|
}
|
|
|
|
// TestConfigureRetry_Deprecated verifies the deprecated ConfigureRetry still works.
|
|
func TestConfigureRetry_Deprecated(t *testing.T) {
|
|
client := &fasthttp.Client{}
|
|
result := ConfigureRetry(client)
|
|
|
|
if result != client {
|
|
t.Error("ConfigureRetry should return the same client pointer")
|
|
}
|
|
if client.RetryIfErr == nil {
|
|
t.Fatal("ConfigureRetry should set RetryIfErr")
|
|
}
|
|
|
|
// Verify it uses the same StaleConnectionRetryIfErr
|
|
reset, retry := client.RetryIfErr(nil, 1, fmt.Errorf("cannot find whitespace"))
|
|
if !reset || !retry {
|
|
t.Error("ConfigureRetry should install StaleConnectionRetryIfErr")
|
|
}
|
|
}
|
|
|
|
// TestConfigureDialer_DialError verifies that dial errors from the existing
|
|
// dial function are properly propagated (not swallowed).
|
|
func TestConfigureDialer_DialError(t *testing.T) {
|
|
expectedErr := fmt.Errorf("proxy connection refused")
|
|
client := &fasthttp.Client{}
|
|
client.Dial = func(addr string) (net.Conn, error) {
|
|
return nil, expectedErr
|
|
}
|
|
|
|
ConfigureDialer(client)
|
|
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
req.SetRequestURI("http://localhost:1/test")
|
|
req.Header.SetMethod(http.MethodPost)
|
|
|
|
err := client.Do(req, resp)
|
|
if err == nil {
|
|
t.Fatal("expected error from failed proxy dial")
|
|
}
|
|
t.Logf("Got expected error: %v", err)
|
|
}
|
|
|
|
// TestStaleConnectionRetryIfErr_WrappedErrors verifies behavior with wrapped errors.
|
|
func TestStaleConnectionRetryIfErr_WrappedErrors(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
wantRetry bool
|
|
}{
|
|
{
|
|
name: "wrapped whitespace error",
|
|
err: fmt.Errorf("fasthttp: %w", fmt.Errorf("cannot find whitespace in header")),
|
|
wantRetry: true,
|
|
},
|
|
{
|
|
name: "wrapped connection reset",
|
|
err: fmt.Errorf("during POST: connection reset by peer"),
|
|
wantRetry: true,
|
|
},
|
|
{
|
|
name: "wrapped broken pipe",
|
|
err: fmt.Errorf("during POST: %w", fmt.Errorf("write tcp 10.0.0.1:53374->10.0.0.2:30000: write: broken pipe")),
|
|
wantRetry: true,
|
|
},
|
|
{
|
|
name: "ErrConnectionClosed from fasthttp",
|
|
err: fasthttp.ErrConnectionClosed,
|
|
wantRetry: false, // Not matched - this error appears AFTER the retry loop
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, retry := network.StaleConnectionRetryIfErr(nil, 1, tt.err)
|
|
if retry != tt.wantRetry {
|
|
t.Errorf("retry = %v, want %v", retry, tt.wantRetry)
|
|
}
|
|
})
|
|
}
|
|
}
|