222 lines
4.5 KiB
Go
222 lines
4.5 KiB
Go
package websocket
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// PoolKey uniquely identifies a group of upstream connections.
|
|
type PoolKey struct {
|
|
Provider schemas.ModelProvider
|
|
KeyID string
|
|
Endpoint string
|
|
}
|
|
|
|
// Pool manages a pool of upstream WebSocket connections keyed by (provider, keyID, endpoint).
|
|
// Idle connections are cached for reuse. Connections exceeding max lifetime are discarded.
|
|
type Pool struct {
|
|
mu sync.Mutex
|
|
idle map[PoolKey][]*UpstreamConn
|
|
inFlight int
|
|
|
|
config *schemas.WSPoolConfig
|
|
|
|
closed bool
|
|
done chan struct{}
|
|
}
|
|
|
|
// NewPool creates a new upstream WebSocket connection pool.
|
|
func NewPool(config *schemas.WSPoolConfig) *Pool {
|
|
if config == nil {
|
|
config = &schemas.WSPoolConfig{}
|
|
}
|
|
config.CheckAndSetDefaults()
|
|
p := &Pool{
|
|
idle: make(map[PoolKey][]*UpstreamConn),
|
|
config: config,
|
|
done: make(chan struct{}),
|
|
}
|
|
go p.evictLoop()
|
|
return p
|
|
}
|
|
|
|
// Get retrieves an idle connection for the given key, or dials a new one.
|
|
// The returned connection is removed from the idle pool and must be returned
|
|
// via Return or discarded via Discard.
|
|
func (p *Pool) Get(key PoolKey, headers map[string]string) (*UpstreamConn, error) {
|
|
p.mu.Lock()
|
|
if p.closed {
|
|
p.mu.Unlock()
|
|
return nil, fmt.Errorf("pool is closed")
|
|
}
|
|
|
|
conns := p.idle[key]
|
|
for len(conns) > 0 {
|
|
// Pop from the back (most recently returned)
|
|
conn := conns[len(conns)-1]
|
|
conns = conns[:len(conns)-1]
|
|
p.idle[key] = conns
|
|
|
|
p.mu.Unlock()
|
|
|
|
if conn.IsClosed() || p.isExpired(conn) {
|
|
conn.Close()
|
|
p.mu.Lock()
|
|
conns = p.idle[key]
|
|
continue
|
|
}
|
|
|
|
p.mu.Lock()
|
|
p.inFlight++
|
|
p.mu.Unlock()
|
|
return conn, nil
|
|
}
|
|
|
|
// Check total capacity (idle + in-flight) before dialing
|
|
totalIdle := 0
|
|
for _, c := range p.idle {
|
|
totalIdle += len(c)
|
|
}
|
|
if totalIdle+p.inFlight >= p.config.MaxTotalConnections {
|
|
p.mu.Unlock()
|
|
return nil, fmt.Errorf("pool capacity exhausted: %d idle + %d in-flight >= %d max", totalIdle, p.inFlight, p.config.MaxTotalConnections)
|
|
}
|
|
|
|
// Reserve a slot before unlocking to dial
|
|
p.inFlight++
|
|
p.mu.Unlock()
|
|
|
|
conn, err := p.dial(key, headers)
|
|
if err != nil {
|
|
p.mu.Lock()
|
|
p.inFlight--
|
|
p.mu.Unlock()
|
|
return nil, err
|
|
}
|
|
return conn, nil
|
|
}
|
|
|
|
// Return puts a connection back into the idle pool for reuse.
|
|
// If the connection is expired or the pool is full, it is closed instead.
|
|
func (p *Pool) Return(conn *UpstreamConn) {
|
|
if conn == nil || conn.IsClosed() {
|
|
return
|
|
}
|
|
if p.isExpired(conn) {
|
|
conn.Close()
|
|
p.mu.Lock()
|
|
p.inFlight--
|
|
p.mu.Unlock()
|
|
return
|
|
}
|
|
|
|
key := PoolKey{
|
|
Provider: conn.provider,
|
|
KeyID: conn.keyID,
|
|
Endpoint: conn.endpoint,
|
|
}
|
|
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
p.inFlight--
|
|
|
|
if p.closed {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
conns := p.idle[key]
|
|
if len(conns) >= p.config.MaxIdlePerKey {
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
p.idle[key] = append(conns, conn)
|
|
}
|
|
|
|
// Discard closes a connection without returning it to the pool.
|
|
func (p *Pool) Discard(conn *UpstreamConn) {
|
|
if conn != nil {
|
|
conn.Close()
|
|
p.mu.Lock()
|
|
p.inFlight--
|
|
p.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// Close shuts down the pool and closes all idle connections.
|
|
func (p *Pool) Close() {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
if p.closed {
|
|
return
|
|
}
|
|
p.closed = true
|
|
close(p.done)
|
|
|
|
for key, conns := range p.idle {
|
|
for _, conn := range conns {
|
|
conn.Close()
|
|
}
|
|
delete(p.idle, key)
|
|
}
|
|
}
|
|
|
|
func (p *Pool) dial(key PoolKey, headers map[string]string) (*UpstreamConn, error) {
|
|
wsConn, _, err := Dial(key.Endpoint, headers)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to dial upstream websocket %s: %w", key.Endpoint, err)
|
|
}
|
|
return newUpstreamConn(wsConn, key.Provider, key.KeyID, key.Endpoint), nil
|
|
}
|
|
|
|
func (p *Pool) isExpired(conn *UpstreamConn) bool {
|
|
maxLifetime := time.Duration(p.config.MaxConnectionLifetimeSeconds) * time.Second
|
|
if conn.Age() >= maxLifetime {
|
|
return true
|
|
}
|
|
idleTimeout := time.Duration(p.config.IdleTimeoutSeconds) * time.Second
|
|
return time.Since(conn.LastUsed()) >= idleTimeout
|
|
}
|
|
|
|
// evictLoop periodically removes expired idle connections.
|
|
func (p *Pool) evictLoop() {
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-p.done:
|
|
return
|
|
case <-ticker.C:
|
|
p.evictExpired()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Pool) evictExpired() {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
for key, conns := range p.idle {
|
|
alive := conns[:0]
|
|
for _, conn := range conns {
|
|
if conn.IsClosed() || p.isExpired(conn) {
|
|
conn.Close()
|
|
} else {
|
|
alive = append(alive, conn)
|
|
}
|
|
}
|
|
if len(alive) == 0 {
|
|
delete(p.idle, key)
|
|
} else {
|
|
p.idle[key] = alive
|
|
}
|
|
}
|
|
}
|