first commit
This commit is contained in:
221
transports/bifrost-http/websocket/pool.go
Normal file
221
transports/bifrost-http/websocket/pool.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// PoolKey uniquely identifies a group of upstream connections.
|
||||
type PoolKey struct {
|
||||
Provider schemas.ModelProvider
|
||||
KeyID string
|
||||
Endpoint string
|
||||
}
|
||||
|
||||
// Pool manages a pool of upstream WebSocket connections keyed by (provider, keyID, endpoint).
|
||||
// Idle connections are cached for reuse. Connections exceeding max lifetime are discarded.
|
||||
type Pool struct {
|
||||
mu sync.Mutex
|
||||
idle map[PoolKey][]*UpstreamConn
|
||||
inFlight int
|
||||
|
||||
config *schemas.WSPoolConfig
|
||||
|
||||
closed bool
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewPool creates a new upstream WebSocket connection pool.
|
||||
func NewPool(config *schemas.WSPoolConfig) *Pool {
|
||||
if config == nil {
|
||||
config = &schemas.WSPoolConfig{}
|
||||
}
|
||||
config.CheckAndSetDefaults()
|
||||
p := &Pool{
|
||||
idle: make(map[PoolKey][]*UpstreamConn),
|
||||
config: config,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go p.evictLoop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Get retrieves an idle connection for the given key, or dials a new one.
|
||||
// The returned connection is removed from the idle pool and must be returned
|
||||
// via Return or discarded via Discard.
|
||||
func (p *Pool) Get(key PoolKey, headers map[string]string) (*UpstreamConn, error) {
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return nil, fmt.Errorf("pool is closed")
|
||||
}
|
||||
|
||||
conns := p.idle[key]
|
||||
for len(conns) > 0 {
|
||||
// Pop from the back (most recently returned)
|
||||
conn := conns[len(conns)-1]
|
||||
conns = conns[:len(conns)-1]
|
||||
p.idle[key] = conns
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
if conn.IsClosed() || p.isExpired(conn) {
|
||||
conn.Close()
|
||||
p.mu.Lock()
|
||||
conns = p.idle[key]
|
||||
continue
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.inFlight++
|
||||
p.mu.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Check total capacity (idle + in-flight) before dialing
|
||||
totalIdle := 0
|
||||
for _, c := range p.idle {
|
||||
totalIdle += len(c)
|
||||
}
|
||||
if totalIdle+p.inFlight >= p.config.MaxTotalConnections {
|
||||
p.mu.Unlock()
|
||||
return nil, fmt.Errorf("pool capacity exhausted: %d idle + %d in-flight >= %d max", totalIdle, p.inFlight, p.config.MaxTotalConnections)
|
||||
}
|
||||
|
||||
// Reserve a slot before unlocking to dial
|
||||
p.inFlight++
|
||||
p.mu.Unlock()
|
||||
|
||||
conn, err := p.dial(key, headers)
|
||||
if err != nil {
|
||||
p.mu.Lock()
|
||||
p.inFlight--
|
||||
p.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Return puts a connection back into the idle pool for reuse.
|
||||
// If the connection is expired or the pool is full, it is closed instead.
|
||||
func (p *Pool) Return(conn *UpstreamConn) {
|
||||
if conn == nil || conn.IsClosed() {
|
||||
return
|
||||
}
|
||||
if p.isExpired(conn) {
|
||||
conn.Close()
|
||||
p.mu.Lock()
|
||||
p.inFlight--
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
key := PoolKey{
|
||||
Provider: conn.provider,
|
||||
KeyID: conn.keyID,
|
||||
Endpoint: conn.endpoint,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.inFlight--
|
||||
|
||||
if p.closed {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
conns := p.idle[key]
|
||||
if len(conns) >= p.config.MaxIdlePerKey {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
p.idle[key] = append(conns, conn)
|
||||
}
|
||||
|
||||
// Discard closes a connection without returning it to the pool.
|
||||
func (p *Pool) Discard(conn *UpstreamConn) {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
p.mu.Lock()
|
||||
p.inFlight--
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the pool and closes all idle connections.
|
||||
func (p *Pool) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.closed = true
|
||||
close(p.done)
|
||||
|
||||
for key, conns := range p.idle {
|
||||
for _, conn := range conns {
|
||||
conn.Close()
|
||||
}
|
||||
delete(p.idle, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) dial(key PoolKey, headers map[string]string) (*UpstreamConn, error) {
|
||||
wsConn, _, err := Dial(key.Endpoint, headers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial upstream websocket %s: %w", key.Endpoint, err)
|
||||
}
|
||||
return newUpstreamConn(wsConn, key.Provider, key.KeyID, key.Endpoint), nil
|
||||
}
|
||||
|
||||
func (p *Pool) isExpired(conn *UpstreamConn) bool {
|
||||
maxLifetime := time.Duration(p.config.MaxConnectionLifetimeSeconds) * time.Second
|
||||
if conn.Age() >= maxLifetime {
|
||||
return true
|
||||
}
|
||||
idleTimeout := time.Duration(p.config.IdleTimeoutSeconds) * time.Second
|
||||
return time.Since(conn.LastUsed()) >= idleTimeout
|
||||
}
|
||||
|
||||
// evictLoop periodically removes expired idle connections.
|
||||
func (p *Pool) evictLoop() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.evictExpired()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pool) evictExpired() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for key, conns := range p.idle {
|
||||
alive := conns[:0]
|
||||
for _, conn := range conns {
|
||||
if conn.IsClosed() || p.isExpired(conn) {
|
||||
conn.Close()
|
||||
} else {
|
||||
alive = append(alive, conn)
|
||||
}
|
||||
}
|
||||
if len(alive) == 0 {
|
||||
delete(p.idle, key)
|
||||
} else {
|
||||
p.idle[key] = alive
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user