Files
bifrost/transports/bifrost-http/websocket/pool.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

222 lines
4.5 KiB
Go

package websocket
import (
"fmt"
"sync"
"time"
"github.com/maximhq/bifrost/core/schemas"
)
// PoolKey uniquely identifies a group of upstream connections.
type PoolKey struct {
Provider schemas.ModelProvider
KeyID string
Endpoint string
}
// Pool manages a pool of upstream WebSocket connections keyed by (provider, keyID, endpoint).
// Idle connections are cached for reuse. Connections exceeding max lifetime are discarded.
type Pool struct {
mu sync.Mutex
idle map[PoolKey][]*UpstreamConn
inFlight int
config *schemas.WSPoolConfig
closed bool
done chan struct{}
}
// NewPool creates a new upstream WebSocket connection pool.
func NewPool(config *schemas.WSPoolConfig) *Pool {
if config == nil {
config = &schemas.WSPoolConfig{}
}
config.CheckAndSetDefaults()
p := &Pool{
idle: make(map[PoolKey][]*UpstreamConn),
config: config,
done: make(chan struct{}),
}
go p.evictLoop()
return p
}
// Get retrieves an idle connection for the given key, or dials a new one.
// The returned connection is removed from the idle pool and must be returned
// via Return or discarded via Discard.
func (p *Pool) Get(key PoolKey, headers map[string]string) (*UpstreamConn, error) {
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return nil, fmt.Errorf("pool is closed")
}
conns := p.idle[key]
for len(conns) > 0 {
// Pop from the back (most recently returned)
conn := conns[len(conns)-1]
conns = conns[:len(conns)-1]
p.idle[key] = conns
p.mu.Unlock()
if conn.IsClosed() || p.isExpired(conn) {
conn.Close()
p.mu.Lock()
conns = p.idle[key]
continue
}
p.mu.Lock()
p.inFlight++
p.mu.Unlock()
return conn, nil
}
// Check total capacity (idle + in-flight) before dialing
totalIdle := 0
for _, c := range p.idle {
totalIdle += len(c)
}
if totalIdle+p.inFlight >= p.config.MaxTotalConnections {
p.mu.Unlock()
return nil, fmt.Errorf("pool capacity exhausted: %d idle + %d in-flight >= %d max", totalIdle, p.inFlight, p.config.MaxTotalConnections)
}
// Reserve a slot before unlocking to dial
p.inFlight++
p.mu.Unlock()
conn, err := p.dial(key, headers)
if err != nil {
p.mu.Lock()
p.inFlight--
p.mu.Unlock()
return nil, err
}
return conn, nil
}
// Return puts a connection back into the idle pool for reuse.
// If the connection is expired or the pool is full, it is closed instead.
func (p *Pool) Return(conn *UpstreamConn) {
if conn == nil || conn.IsClosed() {
return
}
if p.isExpired(conn) {
conn.Close()
p.mu.Lock()
p.inFlight--
p.mu.Unlock()
return
}
key := PoolKey{
Provider: conn.provider,
KeyID: conn.keyID,
Endpoint: conn.endpoint,
}
p.mu.Lock()
defer p.mu.Unlock()
p.inFlight--
if p.closed {
conn.Close()
return
}
conns := p.idle[key]
if len(conns) >= p.config.MaxIdlePerKey {
conn.Close()
return
}
p.idle[key] = append(conns, conn)
}
// Discard closes a connection without returning it to the pool.
func (p *Pool) Discard(conn *UpstreamConn) {
if conn != nil {
conn.Close()
p.mu.Lock()
p.inFlight--
p.mu.Unlock()
}
}
// Close shuts down the pool and closes all idle connections.
func (p *Pool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return
}
p.closed = true
close(p.done)
for key, conns := range p.idle {
for _, conn := range conns {
conn.Close()
}
delete(p.idle, key)
}
}
func (p *Pool) dial(key PoolKey, headers map[string]string) (*UpstreamConn, error) {
wsConn, _, err := Dial(key.Endpoint, headers)
if err != nil {
return nil, fmt.Errorf("failed to dial upstream websocket %s: %w", key.Endpoint, err)
}
return newUpstreamConn(wsConn, key.Provider, key.KeyID, key.Endpoint), nil
}
func (p *Pool) isExpired(conn *UpstreamConn) bool {
maxLifetime := time.Duration(p.config.MaxConnectionLifetimeSeconds) * time.Second
if conn.Age() >= maxLifetime {
return true
}
idleTimeout := time.Duration(p.config.IdleTimeoutSeconds) * time.Second
return time.Since(conn.LastUsed()) >= idleTimeout
}
// evictLoop periodically removes expired idle connections.
func (p *Pool) evictLoop() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-p.done:
return
case <-ticker.C:
p.evictExpired()
}
}
}
func (p *Pool) evictExpired() {
p.mu.Lock()
defer p.mu.Unlock()
for key, conns := range p.idle {
alive := conns[:0]
for _, conn := range conns {
if conn.IsClosed() || p.isExpired(conn) {
conn.Close()
} else {
alive = append(alive, conn)
}
}
if len(alive) == 0 {
delete(p.idle, key)
} else {
p.idle[key] = alive
}
}
}