Files
bifrost/transports/bifrost-http/websocket/pool_test.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

161 lines
3.7 KiB
Go

package websocket
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
ws "github.com/fasthttp/websocket"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func startTestWSServer(t *testing.T) *httptest.Server {
t.Helper()
upgrader := ws.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, msg, err := conn.ReadMessage()
if err != nil {
break
}
conn.WriteMessage(mt, msg)
}
}))
return server
}
func TestPoolGetAndReturn(t *testing.T) {
server := startTestWSServer(t)
defer server.Close()
config := &schemas.WSPoolConfig{
MaxIdlePerKey: 5,
MaxTotalConnections: 10,
IdleTimeoutSeconds: 300,
MaxConnectionLifetimeSeconds: 3600,
}
pool := NewPool(config)
defer pool.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
// Get a new connection (pool is empty, should dial)
conn, err := pool.Get(key, nil)
require.NoError(t, err)
require.NotNil(t, conn)
assert.Equal(t, schemas.OpenAI, conn.Provider())
assert.Equal(t, "test-key", conn.KeyID())
assert.False(t, conn.IsClosed())
// Return to pool
pool.Return(conn)
// Get again — should reuse the same connection
conn2, err := pool.Get(key, nil)
require.NoError(t, err)
require.NotNil(t, conn2)
assert.Same(t, conn, conn2)
pool.Return(conn2)
}
func TestPoolMaxIdlePerKey(t *testing.T) {
server := startTestWSServer(t)
defer server.Close()
config := &schemas.WSPoolConfig{
MaxIdlePerKey: 2,
MaxTotalConnections: 10,
IdleTimeoutSeconds: 300,
MaxConnectionLifetimeSeconds: 3600,
}
pool := NewPool(config)
defer pool.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
// Get 3 connections
var conns []*UpstreamConn
for range 3 {
conn, err := pool.Get(key, nil)
require.NoError(t, err)
conns = append(conns, conn)
}
// Return all 3 — only 2 should be kept (MaxIdlePerKey=2)
for _, conn := range conns {
pool.Return(conn)
}
pool.mu.Lock()
idleCount := len(pool.idle[key])
pool.mu.Unlock()
assert.Equal(t, 2, idleCount)
}
func TestPoolClose(t *testing.T) {
server := startTestWSServer(t)
defer server.Close()
config := &schemas.WSPoolConfig{}
config.CheckAndSetDefaults()
pool := NewPool(config)
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
conn, err := pool.Get(key, nil)
require.NoError(t, err)
pool.Return(conn)
pool.Close()
// Getting from a closed pool should fail
_, err = pool.Get(key, nil)
assert.Error(t, err)
}
func TestPoolExpiredConnection(t *testing.T) {
server := startTestWSServer(t)
defer server.Close()
config := &schemas.WSPoolConfig{
MaxIdlePerKey: 5,
MaxTotalConnections: 10,
IdleTimeoutSeconds: 1,
MaxConnectionLifetimeSeconds: 1,
}
pool := NewPool(config)
defer pool.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
conn, err := pool.Get(key, nil)
require.NoError(t, err)
pool.Return(conn)
// Wait for connection to expire
time.Sleep(1500 * time.Millisecond)
// Get should dial a new connection (old one expired)
conn2, err := pool.Get(key, nil)
require.NoError(t, err)
require.NotNil(t, conn2)
assert.NotSame(t, conn, conn2)
pool.Discard(conn2)
}