first commit
This commit is contained in:
160
transports/bifrost-http/websocket/pool_test.go
Normal file
160
transports/bifrost-http/websocket/pool_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
ws "github.com/fasthttp/websocket"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func startTestWSServer(t *testing.T) *httptest.Server {
|
||||
t.Helper()
|
||||
upgrader := ws.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, msg, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
conn.WriteMessage(mt, msg)
|
||||
}
|
||||
}))
|
||||
return server
|
||||
}
|
||||
|
||||
func TestPoolGetAndReturn(t *testing.T) {
|
||||
server := startTestWSServer(t)
|
||||
defer server.Close()
|
||||
|
||||
config := &schemas.WSPoolConfig{
|
||||
MaxIdlePerKey: 5,
|
||||
MaxTotalConnections: 10,
|
||||
IdleTimeoutSeconds: 300,
|
||||
MaxConnectionLifetimeSeconds: 3600,
|
||||
}
|
||||
pool := NewPool(config)
|
||||
defer pool.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
|
||||
|
||||
// Get a new connection (pool is empty, should dial)
|
||||
conn, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
assert.Equal(t, schemas.OpenAI, conn.Provider())
|
||||
assert.Equal(t, "test-key", conn.KeyID())
|
||||
assert.False(t, conn.IsClosed())
|
||||
|
||||
// Return to pool
|
||||
pool.Return(conn)
|
||||
|
||||
// Get again — should reuse the same connection
|
||||
conn2, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
assert.Same(t, conn, conn2)
|
||||
pool.Return(conn2)
|
||||
}
|
||||
|
||||
func TestPoolMaxIdlePerKey(t *testing.T) {
|
||||
server := startTestWSServer(t)
|
||||
defer server.Close()
|
||||
|
||||
config := &schemas.WSPoolConfig{
|
||||
MaxIdlePerKey: 2,
|
||||
MaxTotalConnections: 10,
|
||||
IdleTimeoutSeconds: 300,
|
||||
MaxConnectionLifetimeSeconds: 3600,
|
||||
}
|
||||
pool := NewPool(config)
|
||||
defer pool.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
|
||||
|
||||
// Get 3 connections
|
||||
var conns []*UpstreamConn
|
||||
for range 3 {
|
||||
conn, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
|
||||
// Return all 3 — only 2 should be kept (MaxIdlePerKey=2)
|
||||
for _, conn := range conns {
|
||||
pool.Return(conn)
|
||||
}
|
||||
|
||||
pool.mu.Lock()
|
||||
idleCount := len(pool.idle[key])
|
||||
pool.mu.Unlock()
|
||||
|
||||
assert.Equal(t, 2, idleCount)
|
||||
}
|
||||
|
||||
func TestPoolClose(t *testing.T) {
|
||||
server := startTestWSServer(t)
|
||||
defer server.Close()
|
||||
|
||||
config := &schemas.WSPoolConfig{}
|
||||
config.CheckAndSetDefaults()
|
||||
pool := NewPool(config)
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
|
||||
|
||||
conn, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
pool.Return(conn)
|
||||
|
||||
pool.Close()
|
||||
|
||||
// Getting from a closed pool should fail
|
||||
_, err = pool.Get(key, nil)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestPoolExpiredConnection(t *testing.T) {
|
||||
server := startTestWSServer(t)
|
||||
defer server.Close()
|
||||
|
||||
config := &schemas.WSPoolConfig{
|
||||
MaxIdlePerKey: 5,
|
||||
MaxTotalConnections: 10,
|
||||
IdleTimeoutSeconds: 1,
|
||||
MaxConnectionLifetimeSeconds: 1,
|
||||
}
|
||||
pool := NewPool(config)
|
||||
defer pool.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
key := PoolKey{Provider: schemas.OpenAI, KeyID: "test-key", Endpoint: wsURL}
|
||||
|
||||
conn, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
pool.Return(conn)
|
||||
|
||||
// Wait for connection to expire
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
// Get should dial a new connection (old one expired)
|
||||
conn2, err := pool.Get(key, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
assert.NotSame(t, conn, conn2)
|
||||
pool.Discard(conn2)
|
||||
}
|
||||
Reference in New Issue
Block a user