first commit
This commit is contained in:
359
core/providers/utils/dialer_test.go
Normal file
359
core/providers/utils/dialer_test.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/network"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// TestConfigureDialer_SetsRetryIfErr verifies that ConfigureDialer installs
|
||||
// the StaleConnectionRetryIfErr callback on the client.
|
||||
func TestConfigureDialer_SetsRetryIfErr(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
if client.RetryIfErr != nil {
|
||||
t.Fatal("precondition: RetryIfErr should be nil on a new client")
|
||||
}
|
||||
|
||||
ConfigureDialer(client)
|
||||
|
||||
if client.RetryIfErr == nil {
|
||||
t.Fatal("ConfigureDialer should set RetryIfErr")
|
||||
}
|
||||
|
||||
// Verify it behaves like StaleConnectionRetryIfErr
|
||||
reset, retry := client.RetryIfErr(nil, 1, fmt.Errorf("cannot find whitespace in the first line of response"))
|
||||
if !reset || !retry {
|
||||
t.Error("RetryIfErr should retry on whitespace error")
|
||||
}
|
||||
reset, retry = client.RetryIfErr(nil, 1, fmt.Errorf("dial tcp: no such host"))
|
||||
if reset || retry {
|
||||
t.Error("RetryIfErr should not retry on unrelated errors")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigureDialer_SetsDial verifies that ConfigureDialer installs a custom
|
||||
// Dial function on the client when no existing Dial is present.
|
||||
func TestConfigureDialer_SetsDial(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
if client.Dial != nil {
|
||||
t.Fatal("precondition: Dial should be nil on a new client")
|
||||
}
|
||||
|
||||
ConfigureDialer(client)
|
||||
|
||||
if client.Dial == nil {
|
||||
t.Fatal("ConfigureDialer should set a Dial function")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigureDialer_ComposesWithExistingDial verifies that when a custom Dial
|
||||
// function is already set (e.g., from ConfigureProxy), ConfigureDialer wraps it
|
||||
// and still enables TCP keepalive on the resulting connection.
|
||||
func TestConfigureDialer_ComposesWithExistingDial(t *testing.T) {
|
||||
var proxyDialCalled atomic.Bool
|
||||
|
||||
client := &fasthttp.Client{}
|
||||
// Simulate a proxy dial function (set by ConfigureProxy)
|
||||
client.Dial = func(addr string) (net.Conn, error) {
|
||||
proxyDialCalled.Store(true)
|
||||
return net.Dial("tcp", addr)
|
||||
}
|
||||
|
||||
ConfigureDialer(client)
|
||||
|
||||
// Start a test server to connect to
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "ok")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode() != 200 {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode())
|
||||
}
|
||||
if !proxyDialCalled.Load() {
|
||||
t.Error("ConfigureDialer should have called the existing proxy dial function")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigureDialer_TCPKeepAliveEnabled verifies that connections created
|
||||
// through ConfigureDialer have TCP keepalive enabled.
|
||||
func TestConfigureDialer_TCPKeepAliveEnabled(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "ok")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Test without existing dial (direct connection path)
|
||||
t.Run("without_existing_dial", func(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
ConfigureDialer(client)
|
||||
|
||||
// The Dial function should create connections with keepalive
|
||||
// We can verify by making a connection and checking the TCP options
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode() != 200 {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode())
|
||||
}
|
||||
})
|
||||
|
||||
// Test with existing dial (proxy composition path)
|
||||
t.Run("with_existing_dial", func(t *testing.T) {
|
||||
var connFromProxy net.Conn
|
||||
client := &fasthttp.Client{}
|
||||
client.Dial = func(addr string) (net.Conn, error) {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
connFromProxy = conn
|
||||
return conn, err
|
||||
}
|
||||
ConfigureDialer(client)
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the proxy-returned connection is a TCP connection
|
||||
// (ConfigureDialer enables keepalive via SetKeepAliveConfig on it)
|
||||
if connFromProxy == nil {
|
||||
t.Fatal("proxy dial should have been called")
|
||||
}
|
||||
if _, ok := connFromProxy.(*net.TCPConn); !ok {
|
||||
t.Errorf("expected *net.TCPConn, got %T", connFromProxy)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConfigureDialer_ReturnValue verifies that ConfigureDialer returns the
|
||||
// same client pointer it received (for chaining).
|
||||
func TestConfigureDialer_ReturnValue(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
result := ConfigureDialer(client)
|
||||
if result != client {
|
||||
t.Error("ConfigureDialer should return the same client pointer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigureDialer_Idempotent verifies that calling ConfigureDialer multiple
|
||||
// times doesn't break the client.
|
||||
func TestConfigureDialer_Idempotent(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "ok")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &fasthttp.Client{}
|
||||
ConfigureDialer(client)
|
||||
ConfigureDialer(client) // called again
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.SetBodyString(`{"test": true}`)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("request failed after double ConfigureDialer: %v", err)
|
||||
}
|
||||
if resp.StatusCode() != 200 {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigureDialer_WithRetryOnStaleConnection is an integration test that
|
||||
// verifies ConfigureDialer enables successful POST retry after TTL mismatch.
|
||||
// This combines both the retry and keepalive behaviors.
|
||||
func TestConfigureDialer_WithRetryOnStaleConnection(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping TTL mismatch test in short mode (requires 11s wait)")
|
||||
}
|
||||
|
||||
const (
|
||||
serverIdleTimeout = 10 * time.Second
|
||||
clientIdleTimeout = 15 * time.Second
|
||||
waitBetween = 11 * time.Second
|
||||
)
|
||||
|
||||
var requestCount atomic.Int32
|
||||
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"ok": true, "request": %d}`, requestCount.Load())
|
||||
}))
|
||||
server.Config.IdleTimeout = serverIdleTimeout
|
||||
server.Start()
|
||||
defer server.Close()
|
||||
|
||||
client := &fasthttp.Client{
|
||||
MaxIdleConnDuration: clientIdleTimeout,
|
||||
MaxConnsPerHost: 10,
|
||||
}
|
||||
// Use ConfigureDialer (the function under test) instead of manually setting RetryIfErr
|
||||
ConfigureDialer(client)
|
||||
|
||||
// First request: establish connection in pool
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.SetBodyString(`{"prompt": "hello"}`)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("First POST failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode() != 200 {
|
||||
t.Fatalf("First POST: expected 200, got %d", resp.StatusCode())
|
||||
}
|
||||
_ = resp.Body()
|
||||
|
||||
// Wait for server TTL to expire
|
||||
t.Logf("Waiting %v for server idle timeout to expire...", waitBetween)
|
||||
time.Sleep(waitBetween)
|
||||
|
||||
// Second request: stale connection should be retried by ConfigureDialer's retry policy
|
||||
req2 := fasthttp.AcquireRequest()
|
||||
resp2 := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req2)
|
||||
defer fasthttp.ReleaseResponse(resp2)
|
||||
|
||||
req2.SetRequestURI(server.URL)
|
||||
req2.Header.SetMethod(http.MethodPost)
|
||||
req2.SetBodyString(`{"prompt": "world"}`)
|
||||
|
||||
if err := client.Do(req2, resp2); err != nil {
|
||||
t.Fatalf("Second POST failed (ConfigureDialer retry should have saved it): %v", err)
|
||||
}
|
||||
if resp2.StatusCode() != 200 {
|
||||
t.Fatalf("Second POST: expected 200, got %d", resp2.StatusCode())
|
||||
}
|
||||
t.Logf("Second POST succeeded after TTL mismatch via ConfigureDialer")
|
||||
}
|
||||
|
||||
// TestConfigureRetry_Deprecated verifies the deprecated ConfigureRetry still works.
|
||||
func TestConfigureRetry_Deprecated(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
result := ConfigureRetry(client)
|
||||
|
||||
if result != client {
|
||||
t.Error("ConfigureRetry should return the same client pointer")
|
||||
}
|
||||
if client.RetryIfErr == nil {
|
||||
t.Fatal("ConfigureRetry should set RetryIfErr")
|
||||
}
|
||||
|
||||
// Verify it uses the same StaleConnectionRetryIfErr
|
||||
reset, retry := client.RetryIfErr(nil, 1, fmt.Errorf("cannot find whitespace"))
|
||||
if !reset || !retry {
|
||||
t.Error("ConfigureRetry should install StaleConnectionRetryIfErr")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigureDialer_DialError verifies that dial errors from the existing
|
||||
// dial function are properly propagated (not swallowed).
|
||||
func TestConfigureDialer_DialError(t *testing.T) {
|
||||
expectedErr := fmt.Errorf("proxy connection refused")
|
||||
client := &fasthttp.Client{}
|
||||
client.Dial = func(addr string) (net.Conn, error) {
|
||||
return nil, expectedErr
|
||||
}
|
||||
|
||||
ConfigureDialer(client)
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI("http://localhost:1/test")
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
|
||||
err := client.Do(req, resp)
|
||||
if err == nil {
|
||||
t.Fatal("expected error from failed proxy dial")
|
||||
}
|
||||
t.Logf("Got expected error: %v", err)
|
||||
}
|
||||
|
||||
// TestStaleConnectionRetryIfErr_WrappedErrors verifies behavior with wrapped errors.
|
||||
func TestStaleConnectionRetryIfErr_WrappedErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantRetry bool
|
||||
}{
|
||||
{
|
||||
name: "wrapped whitespace error",
|
||||
err: fmt.Errorf("fasthttp: %w", fmt.Errorf("cannot find whitespace in header")),
|
||||
wantRetry: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped connection reset",
|
||||
err: fmt.Errorf("during POST: connection reset by peer"),
|
||||
wantRetry: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped broken pipe",
|
||||
err: fmt.Errorf("during POST: %w", fmt.Errorf("write tcp 10.0.0.1:53374->10.0.0.2:30000: write: broken pipe")),
|
||||
wantRetry: true,
|
||||
},
|
||||
{
|
||||
name: "ErrConnectionClosed from fasthttp",
|
||||
err: fasthttp.ErrConnectionClosed,
|
||||
wantRetry: false, // Not matched - this error appears AFTER the retry loop
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, retry := network.StaleConnectionRetryIfErr(nil, 1, tt.err)
|
||||
if retry != tt.wantRetry {
|
||||
t.Errorf("retry = %v, want %v", retry, tt.wantRetry)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user