first commit
This commit is contained in:
136
core/providers/utils/audio.go
Normal file
136
core/providers/utils/audio.go
Normal file
@@ -0,0 +1,136 @@
|
||||
// Package utils provides common utility functions used across different provider implementations.
|
||||
// This file contains audio-related utility functions for format conversion.
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// PCMConfig holds the configuration for PCM audio data
|
||||
type PCMConfig struct {
|
||||
SampleRate int // Sample rate in Hz (e.g., 24000)
|
||||
NumChannels int // Number of audio channels (1 = mono, 2 = stereo)
|
||||
BitsPerSample int // Bits per sample (e.g., 16)
|
||||
}
|
||||
|
||||
// DefaultGeminiPCMConfig returns the default PCM configuration for Gemini TTS
|
||||
// Gemini TTS returns audio in PCM format with the following specs:
|
||||
// - Format: signed 16-bit little-endian (s16le)
|
||||
// - Sample rate: 24000 Hz
|
||||
// - Channels: 1 (mono)
|
||||
func DefaultGeminiPCMConfig() PCMConfig {
|
||||
return PCMConfig{
|
||||
SampleRate: 24000,
|
||||
NumChannels: 1,
|
||||
BitsPerSample: 16,
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertPCMToWAV converts raw PCM audio data to WAV format
|
||||
// The PCM data is expected to be in signed little-endian format (s16le for 16-bit)
|
||||
func ConvertPCMToWAV(pcmData []byte, config PCMConfig) ([]byte, error) {
|
||||
byteRate := config.SampleRate * config.NumChannels * config.BitsPerSample / 8
|
||||
blockAlign := config.NumChannels * config.BitsPerSample / 8
|
||||
|
||||
dataSize := uint32(len(pcmData))
|
||||
fileSize := 36 + dataSize // 36 bytes for header + data
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
// RIFF header
|
||||
buf.WriteString("RIFF")
|
||||
binary.Write(&buf, binary.LittleEndian, fileSize)
|
||||
buf.WriteString("WAVE")
|
||||
|
||||
// fmt subchunk
|
||||
buf.WriteString("fmt ")
|
||||
binary.Write(&buf, binary.LittleEndian, uint32(16)) // Subchunk1Size (16 for PCM)
|
||||
binary.Write(&buf, binary.LittleEndian, uint16(1)) // AudioFormat (1 = PCM)
|
||||
binary.Write(&buf, binary.LittleEndian, uint16(config.NumChannels)) // NumChannels
|
||||
binary.Write(&buf, binary.LittleEndian, uint32(config.SampleRate)) // SampleRate
|
||||
binary.Write(&buf, binary.LittleEndian, uint32(byteRate)) // ByteRate
|
||||
binary.Write(&buf, binary.LittleEndian, uint16(blockAlign)) // BlockAlign
|
||||
binary.Write(&buf, binary.LittleEndian, uint16(config.BitsPerSample)) // BitsPerSample
|
||||
|
||||
// data subchunk
|
||||
buf.WriteString("data")
|
||||
binary.Write(&buf, binary.LittleEndian, dataSize)
|
||||
buf.Write(pcmData)
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
var (
|
||||
riff = []byte("RIFF")
|
||||
wave = []byte("WAVE")
|
||||
id3 = []byte("ID3")
|
||||
form = []byte("FORM")
|
||||
aiff = []byte("AIFF")
|
||||
aifc = []byte("AIFC")
|
||||
flac = []byte("fLaC")
|
||||
oggs = []byte("OggS")
|
||||
adif = []byte("ADIF")
|
||||
)
|
||||
|
||||
// DetectAudioMimeType attempts to detect the MIME type from audio file headers.
|
||||
// Supports detection of: WAV, MP3, AIFF, AAC, OGG Vorbis, and FLAC formats.
|
||||
func DetectAudioMimeType(audioData []byte) string {
|
||||
if len(audioData) < 4 {
|
||||
return "audio/mp3"
|
||||
}
|
||||
// WAV (RIFF/WAVE)
|
||||
if len(audioData) >= 12 &&
|
||||
bytes.Equal(audioData[:4], riff) &&
|
||||
bytes.Equal(audioData[8:12], wave) {
|
||||
return "audio/wav"
|
||||
}
|
||||
// MP3: ID3v2 tag (keep this check for MP3)
|
||||
if len(audioData) >= 3 && bytes.Equal(audioData[:3], id3) {
|
||||
return "audio/mp3"
|
||||
}
|
||||
// AAC: ADIF or ADTS (0xFFF sync) - check before MP3 frame sync to avoid misclassification
|
||||
if bytes.HasPrefix(audioData, adif) {
|
||||
return "audio/aac"
|
||||
}
|
||||
if len(audioData) >= 2 && audioData[0] == 0xFF && (audioData[1]&0xF6) == 0xF0 {
|
||||
return "audio/aac"
|
||||
}
|
||||
// AIFF / AIFC (map both to audio/aiff)
|
||||
if len(audioData) >= 12 && bytes.Equal(audioData[:4], form) &&
|
||||
(bytes.Equal(audioData[8:12], aiff) || bytes.Equal(audioData[8:12], aifc)) {
|
||||
return "audio/aiff"
|
||||
}
|
||||
// FLAC
|
||||
if bytes.HasPrefix(audioData, flac) {
|
||||
return "audio/flac"
|
||||
}
|
||||
// OGG container
|
||||
if bytes.HasPrefix(audioData, oggs) {
|
||||
return "audio/ogg"
|
||||
}
|
||||
// MP3: MPEG frame sync (cover common variants) - check after AAC to avoid misclassification
|
||||
if len(audioData) >= 2 && audioData[0] == 0xFF &&
|
||||
(audioData[1] == 0xFB || audioData[1] == 0xF3 || audioData[1] == 0xF2 || audioData[1] == 0xFA) {
|
||||
return "audio/mp3"
|
||||
}
|
||||
// Fallback within supported set
|
||||
return "audio/mp3"
|
||||
}
|
||||
|
||||
// AudioFilenameFromBytes returns a filename with the correct extension for the given audio data.
|
||||
// Falls back to "audio.mp3" if the format cannot be detected.
|
||||
func AudioFilenameFromBytes(audioData []byte) string {
|
||||
mimeToExt := map[string]string{
|
||||
"audio/wav": "audio.wav",
|
||||
"audio/aac": "audio.aac",
|
||||
"audio/aiff": "audio.aiff",
|
||||
"audio/flac": "audio.flac",
|
||||
"audio/ogg": "audio.ogg",
|
||||
}
|
||||
mime := DetectAudioMimeType(audioData)
|
||||
if ext, ok := mimeToExt[mime]; ok {
|
||||
return ext
|
||||
}
|
||||
return "audio.mp3"
|
||||
}
|
||||
192
core/providers/utils/decompression.go
Normal file
192
core/providers/utils/decompression.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"compress/zlib"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Pooled decompression readers
|
||||
//
|
||||
// Each encoding gets a sync.Pool with Acquire/Release helpers that follow the
|
||||
// same contract:
|
||||
// - Acquire(r io.Reader) returns a ready-to-read decompressor, reusing a
|
||||
// pooled instance when possible, falling back to a fresh allocation.
|
||||
// - Release returns the decompressor to the pool for future reuse.
|
||||
// Callers MUST call Release exactly once after the reader is fully consumed.
|
||||
//
|
||||
// All pool operations are panic-safe: type assertions use the comma-ok form,
|
||||
// Reset calls are wrapped in recover, and nil checks guard every dereference.
|
||||
// A corrupt or wrong-typed pooled instance is silently discarded (GC reclaims
|
||||
// it) and a fresh allocation takes its place.
|
||||
//
|
||||
// Gzip, deflate, and brotli readers are stateless between streams — Close (if
|
||||
// applicable) then Reset is safe. Zstd decoders run background goroutines so
|
||||
// Close is terminal; pooled decoders are reset without closing.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// safeReset calls resetFn and recovers from any panic. Returns true on success.
|
||||
// A corrupt pooled reader may panic inside Reset; this prevents that from
|
||||
// bringing down the server.
|
||||
func safeReset(resetFn func() error) (ok bool) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
ok = false
|
||||
}
|
||||
}()
|
||||
return resetFn() == nil
|
||||
}
|
||||
|
||||
// ---- gzip ----
|
||||
|
||||
var gzipReaderPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &gzip.Reader{}
|
||||
},
|
||||
}
|
||||
|
||||
// AcquireGzipReader gets a gzip.Reader from the pool and resets it to read from r,
|
||||
// or creates a new one if the pool is empty or reset fails.
|
||||
func AcquireGzipReader(r io.Reader) (*gzip.Reader, error) {
|
||||
if v := gzipReaderPool.Get(); v != nil {
|
||||
if gz, ok := v.(*gzip.Reader); ok {
|
||||
if safeReset(func() error { return gz.Reset(r) }) {
|
||||
return gz, nil
|
||||
}
|
||||
}
|
||||
// Wrong type or reset failed/panicked — discard, let GC reclaim.
|
||||
}
|
||||
return gzip.NewReader(r)
|
||||
}
|
||||
|
||||
// ReleaseGzipReader closes and returns a gzip.Reader to the pool.
|
||||
func ReleaseGzipReader(gz *gzip.Reader) {
|
||||
if gz == nil {
|
||||
return
|
||||
}
|
||||
_ = gz.Close()
|
||||
gzipReaderPool.Put(gz)
|
||||
}
|
||||
|
||||
// ---- deflate ----
|
||||
//
|
||||
// HTTP Content-Encoding "deflate" is zlib-wrapped DEFLATE (RFC 1950), NOT raw
|
||||
// DEFLATE (RFC 1951). This matches fasthttp's implementation and the HTTP spec
|
||||
// (RFC 9110 §8.4.1.2). We use compress/zlib, not compress/flate.
|
||||
|
||||
// deflateReader is the interface that zlib readers support for pooling.
|
||||
// The concrete type from zlib.NewReader is unexported, but implements Resetter.
|
||||
type deflateReader interface {
|
||||
io.ReadCloser
|
||||
Reset(r io.Reader, dict []byte) error
|
||||
}
|
||||
|
||||
// No New func: zlib.NewReader validates the header eagerly, so it cannot be
|
||||
// created with a nil reader. Pooled readers are populated via Release.
|
||||
var deflateReaderPool = sync.Pool{}
|
||||
|
||||
// AcquireFlateReader gets a zlib (HTTP "deflate") reader from the pool and
|
||||
// resets it to read from r, or creates a new one if the pool is empty or
|
||||
// reset fails.
|
||||
func AcquireFlateReader(r io.Reader) (io.ReadCloser, error) {
|
||||
if v := deflateReaderPool.Get(); v != nil {
|
||||
if dr, ok := v.(deflateReader); ok {
|
||||
if safeReset(func() error { return dr.Reset(r, nil) }) {
|
||||
return dr, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return zlib.NewReader(r)
|
||||
}
|
||||
|
||||
// ReleaseFlateReader closes and returns a deflate reader to the pool.
|
||||
func ReleaseFlateReader(fr io.ReadCloser) {
|
||||
if fr == nil {
|
||||
return
|
||||
}
|
||||
_ = fr.Close()
|
||||
deflateReaderPool.Put(fr)
|
||||
}
|
||||
|
||||
// ---- brotli ----
|
||||
|
||||
var brotliReaderPool = sync.Pool{
|
||||
New: func() any {
|
||||
return brotli.NewReader(nil)
|
||||
},
|
||||
}
|
||||
|
||||
// AcquireBrotliReader gets a brotli.Reader from the pool and resets it to read
|
||||
// from r, or creates a new one if the pool is empty or reset panics.
|
||||
func AcquireBrotliReader(r io.Reader) *brotli.Reader {
|
||||
if v := brotliReaderPool.Get(); v != nil {
|
||||
if br, ok := v.(*brotli.Reader); ok {
|
||||
// brotli.Reset is void (no error), but wrap in safeReset for
|
||||
// consistency: a corrupt pooled reader could panic on Reset.
|
||||
if safeReset(func() error { br.Reset(r); return nil }) {
|
||||
return br
|
||||
}
|
||||
}
|
||||
// Wrong type or reset panicked — discard, let GC reclaim.
|
||||
}
|
||||
return brotli.NewReader(r)
|
||||
}
|
||||
|
||||
// ReleaseBrotliReader returns a brotli.Reader to the pool.
|
||||
// Brotli readers have no Close method; Reset(nil) is sufficient to drop the
|
||||
// reference to the underlying reader.
|
||||
func ReleaseBrotliReader(br *brotli.Reader) {
|
||||
if br == nil {
|
||||
return
|
||||
}
|
||||
br.Reset(nil)
|
||||
brotliReaderPool.Put(br)
|
||||
}
|
||||
|
||||
// ---- zstd ----
|
||||
|
||||
var zstdDecoderPool = sync.Pool{
|
||||
New: func() any {
|
||||
dec, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(1))
|
||||
if err != nil {
|
||||
// NewReader(nil) failing is unexpected; return nil so Acquire
|
||||
// falls through to a fresh allocation with the real reader.
|
||||
return nil
|
||||
}
|
||||
return dec
|
||||
},
|
||||
}
|
||||
|
||||
// AcquireZstdDecoder gets a zstd.Decoder from the pool and resets it to read
|
||||
// from r, or creates a new one if the pool is empty or reset fails.
|
||||
// Decoders are created with concurrency=1 to minimise goroutine overhead.
|
||||
func AcquireZstdDecoder(r io.Reader) (*zstd.Decoder, error) {
|
||||
if v := zstdDecoderPool.Get(); v != nil {
|
||||
if dec, ok := v.(*zstd.Decoder); ok && dec != nil {
|
||||
if safeReset(func() error { return dec.Reset(r) }) {
|
||||
return dec, nil
|
||||
}
|
||||
// Reset failed/panicked — release references before discarding.
|
||||
// Don't call Close (terminal); Reset(nil) is safe per pool contract.
|
||||
_ = dec.Reset(nil)
|
||||
}
|
||||
}
|
||||
return zstd.NewReader(r, zstd.WithDecoderConcurrency(1))
|
||||
}
|
||||
|
||||
// ReleaseZstdDecoder returns a zstd.Decoder to the pool.
|
||||
// Unlike other decoders, zstd.Close() is terminal (stops background goroutines
|
||||
// permanently). We only call Reset(nil) to release the source reference, then
|
||||
// re-pool. Close is never called on pooled decoders.
|
||||
func ReleaseZstdDecoder(dec *zstd.Decoder) {
|
||||
if dec == nil {
|
||||
return
|
||||
}
|
||||
_ = dec.Reset(nil)
|
||||
zstdDecoderPool.Put(dec)
|
||||
}
|
||||
516
core/providers/utils/decompression_test.go
Normal file
516
core/providers/utils/decompression_test.go
Normal file
@@ -0,0 +1,516 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"compress/zlib"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
const poolTestIterations = 10
|
||||
|
||||
var testPayload = []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello world"}]}`)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// helpers — one compressor per encoding
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func compressGzip(data []byte) []byte {
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
if _, err := gz.Write(data); err != nil {
|
||||
panic(fmt.Errorf("gzip write: %w", err))
|
||||
}
|
||||
if err := gz.Close(); err != nil {
|
||||
panic(fmt.Errorf("gzip close: %w", err))
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// compressFlate produces zlib-wrapped DEFLATE (RFC 1950) — the correct format
|
||||
// for HTTP Content-Encoding "deflate" per RFC 9110 §8.4.1.2.
|
||||
func compressFlate(data []byte) []byte {
|
||||
var buf bytes.Buffer
|
||||
w, err := zlib.NewWriterLevel(&buf, zlib.DefaultCompression)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("zlib new writer: %w", err))
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
panic(fmt.Errorf("zlib write: %w", err))
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
panic(fmt.Errorf("zlib close: %w", err))
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func compressBrotli(data []byte) []byte {
|
||||
var buf bytes.Buffer
|
||||
w := brotli.NewWriter(&buf)
|
||||
if _, err := w.Write(data); err != nil {
|
||||
panic(fmt.Errorf("brotli write: %w", err))
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
panic(fmt.Errorf("brotli close: %w", err))
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func compressZstd(data []byte) []byte {
|
||||
var buf bytes.Buffer
|
||||
enc, err := zstd.NewWriter(&buf)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("zstd new writer: %w", err))
|
||||
}
|
||||
if _, err := enc.Write(data); err != nil {
|
||||
panic(fmt.Errorf("zstd write: %w", err))
|
||||
}
|
||||
if err := enc.Close(); err != nil {
|
||||
panic(fmt.Errorf("zstd close: %w", err))
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Pool cycle tests — each runs Acquire → ReadAll → Release N times to verify
|
||||
// that pooled instances are reused correctly across iterations.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAcquireReleaseGzipReader(t *testing.T) {
|
||||
compressed := compressGzip(testPayload)
|
||||
for i := 0; i < poolTestIterations; i++ {
|
||||
gz, err := AcquireGzipReader(bytes.NewReader(compressed))
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: AcquireGzipReader error: %v", i, err)
|
||||
}
|
||||
got, err := io.ReadAll(gz)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: ReadAll error: %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("iteration %d: got %q, want %q", i, got, testPayload)
|
||||
}
|
||||
ReleaseGzipReader(gz)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcquireReleaseFlateReader(t *testing.T) {
|
||||
compressed := compressFlate(testPayload)
|
||||
for i := 0; i < poolTestIterations; i++ {
|
||||
fr, err := AcquireFlateReader(bytes.NewReader(compressed))
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: AcquireFlateReader error: %v", i, err)
|
||||
}
|
||||
got, err := io.ReadAll(fr)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: ReadAll error: %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("iteration %d: got %q, want %q", i, got, testPayload)
|
||||
}
|
||||
ReleaseFlateReader(fr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcquireReleaseBrotliReader(t *testing.T) {
|
||||
compressed := compressBrotli(testPayload)
|
||||
for i := 0; i < poolTestIterations; i++ {
|
||||
br := AcquireBrotliReader(bytes.NewReader(compressed))
|
||||
got, err := io.ReadAll(br)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: ReadAll error: %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("iteration %d: got %q, want %q", i, got, testPayload)
|
||||
}
|
||||
ReleaseBrotliReader(br)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcquireReleaseZstdDecoder(t *testing.T) {
|
||||
compressed := compressZstd(testPayload)
|
||||
for i := 0; i < poolTestIterations; i++ {
|
||||
dec, err := AcquireZstdDecoder(bytes.NewReader(compressed))
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: AcquireZstdDecoder error: %v", i, err)
|
||||
}
|
||||
got, err := io.ReadAll(dec)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: ReadAll error: %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("iteration %d: got %q, want %q", i, got, testPayload)
|
||||
}
|
||||
ReleaseZstdDecoder(dec)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Panic-safety tests — verify that corrupt or wrong-typed pooled instances
|
||||
// are handled gracefully (no panics, fallback to fresh allocation).
|
||||
//
|
||||
// Each poison test drains the target pool first so stale values from earlier
|
||||
// tests cannot interfere. sync.Pool.Get offers no ordering guarantee, so
|
||||
// draining ensures the poisoned value is the only item available.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// drainPool removes all previously pooled values so the next Put/Get pair
|
||||
// exercises the intended fallback path. Temporarily sets New to nil to ensure
|
||||
// Get returns nil when the pool is empty.
|
||||
func drainPool(p *sync.Pool) {
|
||||
origNew := p.New
|
||||
p.New = nil
|
||||
for p.Get() != nil {
|
||||
}
|
||||
p.New = origNew
|
||||
}
|
||||
|
||||
// TestPool_WrongType_NoPanic poisons each pool with a wrong-typed value,
|
||||
// then verifies Acquire still succeeds by falling back to a fresh allocation.
|
||||
func TestPool_WrongType_NoPanic(t *testing.T) {
|
||||
wrongValue := "not a reader"
|
||||
compressed := compressGzip(testPayload)
|
||||
|
||||
t.Run("gzip", func(t *testing.T) {
|
||||
drainPool(&gzipReaderPool)
|
||||
gzipReaderPool.Put(wrongValue)
|
||||
gz, err := AcquireGzipReader(bytes.NewReader(compressed))
|
||||
if err != nil {
|
||||
t.Fatalf("AcquireGzipReader should fall back, got error: %v", err)
|
||||
}
|
||||
got, _ := io.ReadAll(gz)
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("got %q, want %q", got, testPayload)
|
||||
}
|
||||
ReleaseGzipReader(gz)
|
||||
})
|
||||
|
||||
t.Run("deflate", func(t *testing.T) {
|
||||
drainPool(&deflateReaderPool)
|
||||
deflateReaderPool.Put(wrongValue)
|
||||
fr, err := AcquireFlateReader(bytes.NewReader(compressFlate(testPayload)))
|
||||
if err != nil {
|
||||
t.Fatalf("AcquireFlateReader should fall back, got error: %v", err)
|
||||
}
|
||||
got, _ := io.ReadAll(fr)
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("got %q, want %q", got, testPayload)
|
||||
}
|
||||
ReleaseFlateReader(fr)
|
||||
})
|
||||
|
||||
t.Run("brotli", func(t *testing.T) {
|
||||
drainPool(&brotliReaderPool)
|
||||
brotliReaderPool.Put(wrongValue)
|
||||
br := AcquireBrotliReader(bytes.NewReader(compressBrotli(testPayload)))
|
||||
got, err := io.ReadAll(br)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("got %q, want %q", got, testPayload)
|
||||
}
|
||||
ReleaseBrotliReader(br)
|
||||
})
|
||||
|
||||
t.Run("zstd", func(t *testing.T) {
|
||||
drainPool(&zstdDecoderPool)
|
||||
zstdDecoderPool.Put(wrongValue)
|
||||
dec, err := AcquireZstdDecoder(bytes.NewReader(compressZstd(testPayload)))
|
||||
if err != nil {
|
||||
t.Fatalf("AcquireZstdDecoder should fall back, got error: %v", err)
|
||||
}
|
||||
got, _ := io.ReadAll(dec)
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("got %q, want %q", got, testPayload)
|
||||
}
|
||||
ReleaseZstdDecoder(dec)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPool_NilInPool_NoPanic puts an explicit nil into each pool, then
|
||||
// verifies Acquire still succeeds by falling back to a fresh allocation.
|
||||
func TestPool_NilInPool_NoPanic(t *testing.T) {
|
||||
t.Run("gzip", func(t *testing.T) {
|
||||
drainPool(&gzipReaderPool)
|
||||
gzipReaderPool.Put((*gzip.Reader)(nil))
|
||||
gz, err := AcquireGzipReader(bytes.NewReader(compressGzip(testPayload)))
|
||||
if err != nil {
|
||||
t.Fatalf("should fall back, got error: %v", err)
|
||||
}
|
||||
got, _ := io.ReadAll(gz)
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("got %q, want %q", got, testPayload)
|
||||
}
|
||||
ReleaseGzipReader(gz)
|
||||
})
|
||||
|
||||
t.Run("zstd", func(t *testing.T) {
|
||||
drainPool(&zstdDecoderPool)
|
||||
zstdDecoderPool.Put((*zstd.Decoder)(nil))
|
||||
dec, err := AcquireZstdDecoder(bytes.NewReader(compressZstd(testPayload)))
|
||||
if err != nil {
|
||||
t.Fatalf("should fall back, got error: %v", err)
|
||||
}
|
||||
got, _ := io.ReadAll(dec)
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("got %q, want %q", got, testPayload)
|
||||
}
|
||||
ReleaseZstdDecoder(dec)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPool_InvalidData_NoPanic verifies that Acquire handles corrupt/invalid
|
||||
// compressed data without panicking. The error should surface as a return
|
||||
// value or a read error, never a panic.
|
||||
func TestPool_InvalidData_NoPanic(t *testing.T) {
|
||||
garbage := []byte("this is not compressed data at all")
|
||||
|
||||
t.Run("gzip", func(t *testing.T) {
|
||||
// gzip.NewReader validates the header immediately, so Acquire returns error.
|
||||
_, err := AcquireGzipReader(bytes.NewReader(garbage))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid gzip data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deflate", func(t *testing.T) {
|
||||
// zlib.NewReader validates the header eagerly, so Acquire returns error.
|
||||
_, err := AcquireFlateReader(bytes.NewReader(garbage))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid deflate (zlib) data")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("brotli", func(t *testing.T) {
|
||||
// brotli doesn't validate eagerly; error surfaces on Read.
|
||||
br := AcquireBrotliReader(bytes.NewReader(garbage))
|
||||
_, readErr := io.ReadAll(br)
|
||||
if readErr == nil {
|
||||
t.Fatal("expected read error for invalid brotli data")
|
||||
}
|
||||
ReleaseBrotliReader(br)
|
||||
})
|
||||
|
||||
t.Run("zstd", func(t *testing.T) {
|
||||
// zstd.Decoder doesn't validate eagerly; error surfaces on Read.
|
||||
dec, err := AcquireZstdDecoder(bytes.NewReader(garbage))
|
||||
if err != nil {
|
||||
t.Fatalf("AcquireZstdDecoder should not error eagerly: %v", err)
|
||||
}
|
||||
_, readErr := io.ReadAll(dec)
|
||||
if readErr == nil {
|
||||
t.Fatal("expected read error for invalid zstd data")
|
||||
}
|
||||
ReleaseZstdDecoder(dec)
|
||||
})
|
||||
}
|
||||
|
||||
// TestPool_CorruptPooledInstance_NoPanic simulates a corrupt pooled reader
|
||||
// whose Reset panics. Verifies safeReset catches the panic and Acquire
|
||||
// falls through to a fresh allocation.
|
||||
func TestPool_CorruptPooledInstance_NoPanic(t *testing.T) {
|
||||
t.Run("safeReset_catches_panic", func(t *testing.T) {
|
||||
ok := safeReset(func() error {
|
||||
panic("simulated corrupt reader")
|
||||
})
|
||||
if ok {
|
||||
t.Fatal("safeReset should return false when resetFn panics")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("gzip_after_corrupt", func(t *testing.T) {
|
||||
// Poison pool with a gzip.Reader that has corrupt internal state:
|
||||
// a zero-value reader that has been closed without ever being used.
|
||||
drainPool(&gzipReaderPool)
|
||||
corrupt := &gzip.Reader{}
|
||||
gzipReaderPool.Put(corrupt)
|
||||
|
||||
// Acquire should recover, discard the corrupt reader, and create fresh.
|
||||
compressed := compressGzip(testPayload)
|
||||
gz, err := AcquireGzipReader(bytes.NewReader(compressed))
|
||||
if err != nil {
|
||||
t.Fatalf("should recover from corrupt pooled reader, got: %v", err)
|
||||
}
|
||||
got, _ := io.ReadAll(gz)
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("got %q, want %q", got, testPayload)
|
||||
}
|
||||
ReleaseGzipReader(gz)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRelease_Nil_NoPanic verifies Release functions are safe to call with nil.
|
||||
func TestRelease_Nil_NoPanic(t *testing.T) {
|
||||
ReleaseGzipReader(nil)
|
||||
ReleaseFlateReader(nil)
|
||||
ReleaseBrotliReader(nil)
|
||||
ReleaseZstdDecoder(nil)
|
||||
}
|
||||
|
||||
// TestPool_RecoveryAndReuse verifies that after a corrupt instance is discarded,
|
||||
// the pool recovers and subsequent cycles work normally.
|
||||
func TestPool_RecoveryAndReuse(t *testing.T) {
|
||||
// Drain then poison all pools with wrong types.
|
||||
drainPool(&gzipReaderPool)
|
||||
drainPool(&deflateReaderPool)
|
||||
drainPool(&brotliReaderPool)
|
||||
drainPool(&zstdDecoderPool)
|
||||
gzipReaderPool.Put("wrong")
|
||||
deflateReaderPool.Put(42)
|
||||
brotliReaderPool.Put(struct{}{})
|
||||
zstdDecoderPool.Put(false)
|
||||
|
||||
// Run a normal Acquire → ReadAll → Release cycle for each.
|
||||
// This verifies the pool recovers: the wrong-typed value is discarded,
|
||||
// a fresh instance is created and released back, making the pool healthy.
|
||||
t.Run("gzip", func(t *testing.T) {
|
||||
compressed := compressGzip(testPayload)
|
||||
for i := 0; i < 3; i++ {
|
||||
gz, err := AcquireGzipReader(bytes.NewReader(compressed))
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: %v", i, err)
|
||||
}
|
||||
got, _ := io.ReadAll(gz)
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("iteration %d: mismatch", i)
|
||||
}
|
||||
ReleaseGzipReader(gz)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deflate", func(t *testing.T) {
|
||||
compressed := compressFlate(testPayload)
|
||||
for i := 0; i < 3; i++ {
|
||||
fr, err := AcquireFlateReader(bytes.NewReader(compressed))
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: %v", i, err)
|
||||
}
|
||||
got, _ := io.ReadAll(fr)
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("iteration %d: mismatch", i)
|
||||
}
|
||||
ReleaseFlateReader(fr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("brotli", func(t *testing.T) {
|
||||
compressed := compressBrotli(testPayload)
|
||||
for i := 0; i < 3; i++ {
|
||||
br := AcquireBrotliReader(bytes.NewReader(compressed))
|
||||
got, err := io.ReadAll(br)
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("iteration %d: mismatch", i)
|
||||
}
|
||||
ReleaseBrotliReader(br)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("zstd", func(t *testing.T) {
|
||||
compressed := compressZstd(testPayload)
|
||||
for i := 0; i < 3; i++ {
|
||||
dec, err := AcquireZstdDecoder(bytes.NewReader(compressed))
|
||||
if err != nil {
|
||||
t.Fatalf("iteration %d: %v", i, err)
|
||||
}
|
||||
got, _ := io.ReadAll(dec)
|
||||
if !bytes.Equal(got, testPayload) {
|
||||
t.Errorf("iteration %d: mismatch", i)
|
||||
}
|
||||
ReleaseZstdDecoder(dec)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPool_EmptyReader_NoPanic verifies Acquire handles an empty reader
|
||||
// (zero bytes) without panicking. Gzip/zstd should return an error (no header),
|
||||
// deflate/brotli should return empty or error on Read.
|
||||
func TestPool_EmptyReader_NoPanic(t *testing.T) {
|
||||
empty := bytes.NewReader(nil)
|
||||
|
||||
t.Run("gzip", func(t *testing.T) {
|
||||
_, err := AcquireGzipReader(empty)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty gzip input")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("deflate", func(t *testing.T) {
|
||||
// zlib.NewReader validates the header eagerly; empty input has no header.
|
||||
_, err := AcquireFlateReader(bytes.NewReader(nil))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty deflate (zlib) input")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("brotli", func(t *testing.T) {
|
||||
br := AcquireBrotliReader(bytes.NewReader(nil))
|
||||
data, _ := io.ReadAll(br)
|
||||
if len(data) != 0 {
|
||||
t.Errorf("expected empty output, got %d bytes", len(data))
|
||||
}
|
||||
ReleaseBrotliReader(br)
|
||||
})
|
||||
|
||||
t.Run("zstd", func(t *testing.T) {
|
||||
dec, err := AcquireZstdDecoder(bytes.NewReader(nil))
|
||||
if err != nil {
|
||||
// Some versions error eagerly, that's fine.
|
||||
return
|
||||
}
|
||||
data, _ := io.ReadAll(dec)
|
||||
// Empty input with no zstd frame yields empty output or read error.
|
||||
_ = data
|
||||
ReleaseZstdDecoder(dec)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSafeReset verifies safeReset correctly handles panics and errors.
|
||||
func TestSafeReset(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
ok := safeReset(func() error { return nil })
|
||||
if !ok {
|
||||
t.Fatal("expected true for successful reset")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error", func(t *testing.T) {
|
||||
ok := safeReset(func() error { return io.ErrUnexpectedEOF })
|
||||
if ok {
|
||||
t.Fatal("expected false for failed reset")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("panic_string", func(t *testing.T) {
|
||||
ok := safeReset(func() error { panic("boom") })
|
||||
if ok {
|
||||
t.Fatal("expected false for panicking reset")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("panic_nonnil", func(t *testing.T) {
|
||||
ok := safeReset(func() error { panic("") })
|
||||
if ok {
|
||||
t.Fatal("expected false for empty-string panic")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("panic_error", func(t *testing.T) {
|
||||
ok := safeReset(func() error {
|
||||
panic(fmt.Errorf("internal corruption: %s", strings.Repeat("x", 100)))
|
||||
})
|
||||
if ok {
|
||||
t.Fatal("expected false for error panic")
|
||||
}
|
||||
})
|
||||
}
|
||||
359
core/providers/utils/dialer_test.go
Normal file
359
core/providers/utils/dialer_test.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/network"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// TestConfigureDialer_SetsRetryIfErr verifies that ConfigureDialer installs
|
||||
// the StaleConnectionRetryIfErr callback on the client.
|
||||
func TestConfigureDialer_SetsRetryIfErr(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
if client.RetryIfErr != nil {
|
||||
t.Fatal("precondition: RetryIfErr should be nil on a new client")
|
||||
}
|
||||
|
||||
ConfigureDialer(client)
|
||||
|
||||
if client.RetryIfErr == nil {
|
||||
t.Fatal("ConfigureDialer should set RetryIfErr")
|
||||
}
|
||||
|
||||
// Verify it behaves like StaleConnectionRetryIfErr
|
||||
reset, retry := client.RetryIfErr(nil, 1, fmt.Errorf("cannot find whitespace in the first line of response"))
|
||||
if !reset || !retry {
|
||||
t.Error("RetryIfErr should retry on whitespace error")
|
||||
}
|
||||
reset, retry = client.RetryIfErr(nil, 1, fmt.Errorf("dial tcp: no such host"))
|
||||
if reset || retry {
|
||||
t.Error("RetryIfErr should not retry on unrelated errors")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigureDialer_SetsDial verifies that ConfigureDialer installs a custom
|
||||
// Dial function on the client when no existing Dial is present.
|
||||
func TestConfigureDialer_SetsDial(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
if client.Dial != nil {
|
||||
t.Fatal("precondition: Dial should be nil on a new client")
|
||||
}
|
||||
|
||||
ConfigureDialer(client)
|
||||
|
||||
if client.Dial == nil {
|
||||
t.Fatal("ConfigureDialer should set a Dial function")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigureDialer_ComposesWithExistingDial verifies that when a custom Dial
|
||||
// function is already set (e.g., from ConfigureProxy), ConfigureDialer wraps it
|
||||
// and still enables TCP keepalive on the resulting connection.
|
||||
func TestConfigureDialer_ComposesWithExistingDial(t *testing.T) {
|
||||
var proxyDialCalled atomic.Bool
|
||||
|
||||
client := &fasthttp.Client{}
|
||||
// Simulate a proxy dial function (set by ConfigureProxy)
|
||||
client.Dial = func(addr string) (net.Conn, error) {
|
||||
proxyDialCalled.Store(true)
|
||||
return net.Dial("tcp", addr)
|
||||
}
|
||||
|
||||
ConfigureDialer(client)
|
||||
|
||||
// Start a test server to connect to
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "ok")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode() != 200 {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode())
|
||||
}
|
||||
if !proxyDialCalled.Load() {
|
||||
t.Error("ConfigureDialer should have called the existing proxy dial function")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigureDialer_TCPKeepAliveEnabled verifies that connections created
|
||||
// through ConfigureDialer have TCP keepalive enabled.
|
||||
func TestConfigureDialer_TCPKeepAliveEnabled(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "ok")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Test without existing dial (direct connection path)
|
||||
t.Run("without_existing_dial", func(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
ConfigureDialer(client)
|
||||
|
||||
// The Dial function should create connections with keepalive
|
||||
// We can verify by making a connection and checking the TCP options
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode() != 200 {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode())
|
||||
}
|
||||
})
|
||||
|
||||
// Test with existing dial (proxy composition path)
|
||||
t.Run("with_existing_dial", func(t *testing.T) {
|
||||
var connFromProxy net.Conn
|
||||
client := &fasthttp.Client{}
|
||||
client.Dial = func(addr string) (net.Conn, error) {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
connFromProxy = conn
|
||||
return conn, err
|
||||
}
|
||||
ConfigureDialer(client)
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the proxy-returned connection is a TCP connection
|
||||
// (ConfigureDialer enables keepalive via SetKeepAliveConfig on it)
|
||||
if connFromProxy == nil {
|
||||
t.Fatal("proxy dial should have been called")
|
||||
}
|
||||
if _, ok := connFromProxy.(*net.TCPConn); !ok {
|
||||
t.Errorf("expected *net.TCPConn, got %T", connFromProxy)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConfigureDialer_ReturnValue verifies that ConfigureDialer returns the
|
||||
// same client pointer it received (for chaining).
|
||||
func TestConfigureDialer_ReturnValue(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
result := ConfigureDialer(client)
|
||||
if result != client {
|
||||
t.Error("ConfigureDialer should return the same client pointer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigureDialer_Idempotent verifies that calling ConfigureDialer multiple
|
||||
// times doesn't break the client.
|
||||
func TestConfigureDialer_Idempotent(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(w, "ok")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &fasthttp.Client{}
|
||||
ConfigureDialer(client)
|
||||
ConfigureDialer(client) // called again
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.SetBodyString(`{"test": true}`)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("request failed after double ConfigureDialer: %v", err)
|
||||
}
|
||||
if resp.StatusCode() != 200 {
|
||||
t.Fatalf("expected 200, got %d", resp.StatusCode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigureDialer_WithRetryOnStaleConnection is an integration test that
|
||||
// verifies ConfigureDialer enables successful POST retry after TTL mismatch.
|
||||
// This combines both the retry and keepalive behaviors.
|
||||
func TestConfigureDialer_WithRetryOnStaleConnection(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping TTL mismatch test in short mode (requires 11s wait)")
|
||||
}
|
||||
|
||||
const (
|
||||
serverIdleTimeout = 10 * time.Second
|
||||
clientIdleTimeout = 15 * time.Second
|
||||
waitBetween = 11 * time.Second
|
||||
)
|
||||
|
||||
var requestCount atomic.Int32
|
||||
|
||||
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount.Add(1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, `{"ok": true, "request": %d}`, requestCount.Load())
|
||||
}))
|
||||
server.Config.IdleTimeout = serverIdleTimeout
|
||||
server.Start()
|
||||
defer server.Close()
|
||||
|
||||
client := &fasthttp.Client{
|
||||
MaxIdleConnDuration: clientIdleTimeout,
|
||||
MaxConnsPerHost: 10,
|
||||
}
|
||||
// Use ConfigureDialer (the function under test) instead of manually setting RetryIfErr
|
||||
ConfigureDialer(client)
|
||||
|
||||
// First request: establish connection in pool
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(server.URL)
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
req.SetBodyString(`{"prompt": "hello"}`)
|
||||
|
||||
if err := client.Do(req, resp); err != nil {
|
||||
t.Fatalf("First POST failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode() != 200 {
|
||||
t.Fatalf("First POST: expected 200, got %d", resp.StatusCode())
|
||||
}
|
||||
_ = resp.Body()
|
||||
|
||||
// Wait for server TTL to expire
|
||||
t.Logf("Waiting %v for server idle timeout to expire...", waitBetween)
|
||||
time.Sleep(waitBetween)
|
||||
|
||||
// Second request: stale connection should be retried by ConfigureDialer's retry policy
|
||||
req2 := fasthttp.AcquireRequest()
|
||||
resp2 := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req2)
|
||||
defer fasthttp.ReleaseResponse(resp2)
|
||||
|
||||
req2.SetRequestURI(server.URL)
|
||||
req2.Header.SetMethod(http.MethodPost)
|
||||
req2.SetBodyString(`{"prompt": "world"}`)
|
||||
|
||||
if err := client.Do(req2, resp2); err != nil {
|
||||
t.Fatalf("Second POST failed (ConfigureDialer retry should have saved it): %v", err)
|
||||
}
|
||||
if resp2.StatusCode() != 200 {
|
||||
t.Fatalf("Second POST: expected 200, got %d", resp2.StatusCode())
|
||||
}
|
||||
t.Logf("Second POST succeeded after TTL mismatch via ConfigureDialer")
|
||||
}
|
||||
|
||||
// TestConfigureRetry_Deprecated verifies the deprecated ConfigureRetry still works.
|
||||
func TestConfigureRetry_Deprecated(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
result := ConfigureRetry(client)
|
||||
|
||||
if result != client {
|
||||
t.Error("ConfigureRetry should return the same client pointer")
|
||||
}
|
||||
if client.RetryIfErr == nil {
|
||||
t.Fatal("ConfigureRetry should set RetryIfErr")
|
||||
}
|
||||
|
||||
// Verify it uses the same StaleConnectionRetryIfErr
|
||||
reset, retry := client.RetryIfErr(nil, 1, fmt.Errorf("cannot find whitespace"))
|
||||
if !reset || !retry {
|
||||
t.Error("ConfigureRetry should install StaleConnectionRetryIfErr")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigureDialer_DialError verifies that dial errors from the existing
|
||||
// dial function are properly propagated (not swallowed).
|
||||
func TestConfigureDialer_DialError(t *testing.T) {
|
||||
expectedErr := fmt.Errorf("proxy connection refused")
|
||||
client := &fasthttp.Client{}
|
||||
client.Dial = func(addr string) (net.Conn, error) {
|
||||
return nil, expectedErr
|
||||
}
|
||||
|
||||
ConfigureDialer(client)
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI("http://localhost:1/test")
|
||||
req.Header.SetMethod(http.MethodPost)
|
||||
|
||||
err := client.Do(req, resp)
|
||||
if err == nil {
|
||||
t.Fatal("expected error from failed proxy dial")
|
||||
}
|
||||
t.Logf("Got expected error: %v", err)
|
||||
}
|
||||
|
||||
// TestStaleConnectionRetryIfErr_WrappedErrors verifies behavior with wrapped errors.
|
||||
func TestStaleConnectionRetryIfErr_WrappedErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantRetry bool
|
||||
}{
|
||||
{
|
||||
name: "wrapped whitespace error",
|
||||
err: fmt.Errorf("fasthttp: %w", fmt.Errorf("cannot find whitespace in header")),
|
||||
wantRetry: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped connection reset",
|
||||
err: fmt.Errorf("during POST: connection reset by peer"),
|
||||
wantRetry: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped broken pipe",
|
||||
err: fmt.Errorf("during POST: %w", fmt.Errorf("write tcp 10.0.0.1:53374->10.0.0.2:30000: write: broken pipe")),
|
||||
wantRetry: true,
|
||||
},
|
||||
{
|
||||
name: "ErrConnectionClosed from fasthttp",
|
||||
err: fasthttp.ErrConnectionClosed,
|
||||
wantRetry: false, // Not matched - this error appears AFTER the retry loop
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, retry := network.StaleConnectionRetryIfErr(nil, 1, tt.err)
|
||||
if retry != tt.wantRetry {
|
||||
t.Errorf("retry = %v, want %v", retry, tt.wantRetry)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
14
core/providers/utils/file.go
Normal file
14
core/providers/utils/file.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// FileBytesToBase64DataURL converts raw file bytes to base64 data URL format
|
||||
func FileBytesToBase64DataURL(fileBytes []byte) string {
|
||||
mimeType := http.DetectContentType(fileBytes)
|
||||
b64Data := base64.StdEncoding.EncodeToString(fileBytes)
|
||||
return fmt.Sprintf("data:%s;base64,%s", mimeType, b64Data)
|
||||
}
|
||||
302
core/providers/utils/html_response_handler_test.go
Normal file
302
core/providers/utils/html_response_handler_test.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestIsHTMLResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
contentType string
|
||||
body []byte
|
||||
expectedIsHTML bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "HTML with Content-Type header",
|
||||
contentType: "text/html; charset=utf-8",
|
||||
body: []byte("<html><body>Error</body></html>"),
|
||||
expectedIsHTML: true,
|
||||
description: "Should detect HTML from Content-Type header",
|
||||
},
|
||||
{
|
||||
name: "HTML without Content-Type",
|
||||
contentType: "application/octet-stream",
|
||||
body: []byte("<!DOCTYPE html><html><head><title>Error 500</title></head></html>"),
|
||||
expectedIsHTML: true,
|
||||
description: "Should detect HTML from DOCTYPE",
|
||||
},
|
||||
{
|
||||
name: "HTML with h1 tag",
|
||||
contentType: "application/octet-stream",
|
||||
body: []byte("<h1>Service Unavailable</h1>"),
|
||||
expectedIsHTML: true,
|
||||
description: "Should detect HTML from h1 tag",
|
||||
},
|
||||
{
|
||||
name: "JSON response",
|
||||
contentType: "application/json",
|
||||
body: []byte(`{"error": "invalid request"}`),
|
||||
expectedIsHTML: false,
|
||||
description: "Should not detect JSON as HTML",
|
||||
},
|
||||
{
|
||||
name: "Plain text response",
|
||||
contentType: "text/plain",
|
||||
body: []byte("Invalid request"),
|
||||
expectedIsHTML: false,
|
||||
description: "Should not detect plain text as HTML",
|
||||
},
|
||||
{
|
||||
name: "Empty body",
|
||||
contentType: "text/html",
|
||||
body: []byte(""),
|
||||
expectedIsHTML: true,
|
||||
description: "Should detect HTML from Content-Type even with empty body",
|
||||
},
|
||||
{
|
||||
name: "Very short body",
|
||||
contentType: "application/json",
|
||||
body: []byte("abc"),
|
||||
expectedIsHTML: false,
|
||||
description: "Should not detect very short body as HTML",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp := &fasthttp.Response{}
|
||||
resp.Header.Set("Content-Type", tt.contentType)
|
||||
|
||||
result := IsHTMLResponse(resp, tt.body)
|
||||
if result != tt.expectedIsHTML {
|
||||
t.Errorf("isHTMLResponse() = %v, want %v. %s", result, tt.expectedIsHTML, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractHTMLErrorMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
htmlBody []byte
|
||||
expectMsg string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Extract from title tag",
|
||||
htmlBody: []byte(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>404 Not Found</title></head>
|
||||
<body><p>The page was not found</p></body>
|
||||
</html>
|
||||
`),
|
||||
expectMsg: "404 Not Found",
|
||||
description: "Should extract title from title tag",
|
||||
},
|
||||
{
|
||||
name: "Extract from h1 tag",
|
||||
htmlBody: []byte(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Service Unavailable</h1>
|
||||
<p>The service is currently unavailable</p>
|
||||
</body>
|
||||
</html>
|
||||
`),
|
||||
expectMsg: "Service Unavailable",
|
||||
description: "Should extract from h1 tag when title is missing",
|
||||
},
|
||||
{
|
||||
name: "Extract from h2 tag",
|
||||
htmlBody: []byte(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<h2 class="error-header">Authentication Failed</h2>
|
||||
<p>Please check your credentials</p>
|
||||
</body>
|
||||
</html>
|
||||
`),
|
||||
expectMsg: "Authentication Failed",
|
||||
description: "Should extract from h2 tag with attributes",
|
||||
},
|
||||
{
|
||||
name: "Extract visible text when no headers",
|
||||
htmlBody: []byte(`
|
||||
<html>
|
||||
<body>
|
||||
<div>There was an error processing your request. Please try again later.</div>
|
||||
</body>
|
||||
</html>
|
||||
`),
|
||||
expectMsg: "There was an error processing your request. Please try again later.",
|
||||
description: "Should extract visible text from div when no headers found",
|
||||
},
|
||||
{
|
||||
name: "Ignore script and style tags",
|
||||
htmlBody: []byte(`
|
||||
<html>
|
||||
<head><title>Error</title></head>
|
||||
<body>
|
||||
<script>var x = 'ignore me';</script>
|
||||
<style>.error { color: red; }</style>
|
||||
<h1>Actual Error Message</h1>
|
||||
</body>
|
||||
</html>
|
||||
`),
|
||||
expectMsg: "Actual Error Message",
|
||||
description: "Should ignore script and style content",
|
||||
},
|
||||
{
|
||||
name: "Extract from first valid h1",
|
||||
htmlBody: []byte(`
|
||||
<html>
|
||||
<body>
|
||||
<h1></h1>
|
||||
<h1>Second header with actual content</h1>
|
||||
</body>
|
||||
</html>
|
||||
`),
|
||||
expectMsg: "Second header with actual content",
|
||||
description: "Should extract from first non-empty header",
|
||||
},
|
||||
{
|
||||
name: "Handle meta description",
|
||||
htmlBody: []byte(`
|
||||
<html>
|
||||
<head>
|
||||
<meta name="description" content="Rate limit exceeded. Please wait 60 seconds.">
|
||||
</head>
|
||||
<body></body>
|
||||
</html>
|
||||
`),
|
||||
expectMsg: "Rate limit exceeded. Please wait 60 seconds.",
|
||||
description: "Should extract from meta description",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractHTMLErrorMessage(tt.htmlBody)
|
||||
if result != tt.expectMsg {
|
||||
t.Errorf("extractHTMLErrorMessage() = %q, want %q. %s", result, tt.expectMsg, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleProviderAPIErrorWithHTML(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
contentType string
|
||||
body []byte
|
||||
description string
|
||||
expectedInMessage string
|
||||
}{
|
||||
{
|
||||
name: "HTML 500 error - lazy detection",
|
||||
statusCode: 500,
|
||||
contentType: "text/html; charset=utf-8",
|
||||
body: []byte(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Internal Server Error</title></head>
|
||||
<body><h1>Something went wrong</h1></body>
|
||||
</html>
|
||||
`),
|
||||
description: "Should detect and handle HTML only after JSON parse fails",
|
||||
expectedInMessage: "HTML response received from provider",
|
||||
},
|
||||
{
|
||||
name: "HTML 403 error - lazy detection",
|
||||
statusCode: 403,
|
||||
contentType: "text/html",
|
||||
body: []byte(`
|
||||
<html>
|
||||
<body>
|
||||
<h1>Forbidden</h1>
|
||||
<p>Access denied</p>
|
||||
</body>
|
||||
</html>
|
||||
`),
|
||||
description: "Should detect HTML on parse failure",
|
||||
expectedInMessage: "HTML response received from provider",
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON with HTML fallback",
|
||||
statusCode: 400,
|
||||
contentType: "application/json",
|
||||
body: []byte(`not valid json`),
|
||||
description: "Should fall back to raw string when not HTML",
|
||||
expectedInMessage: "provider API error",
|
||||
},
|
||||
{
|
||||
name: "Valid JSON error response",
|
||||
statusCode: 400,
|
||||
contentType: "application/json",
|
||||
body: []byte(`{"error": {"message": "Invalid request"}, "code": "invalid_request"}`),
|
||||
description: "Should handle valid JSON without HTML detection",
|
||||
expectedInMessage: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp := &fasthttp.Response{}
|
||||
resp.SetStatusCode(tt.statusCode)
|
||||
resp.Header.Set("Content-Type", tt.contentType)
|
||||
resp.SetBody(tt.body)
|
||||
|
||||
var errorResp map[string]interface{}
|
||||
bifrostErr := HandleProviderAPIError(resp, &errorResp)
|
||||
|
||||
if bifrostErr == nil {
|
||||
t.Errorf("HandleProviderAPIError() returned nil error")
|
||||
return
|
||||
}
|
||||
|
||||
if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != tt.statusCode {
|
||||
t.Errorf("HandleProviderAPIError() status code = %v, want %v", bifrostErr.StatusCode, tt.statusCode)
|
||||
}
|
||||
|
||||
if bifrostErr.Error == nil {
|
||||
t.Errorf("HandleProviderAPIError() error field is nil")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if expected message is in the response
|
||||
if tt.expectedInMessage != "" && !strings.Contains(bifrostErr.Error.Message, tt.expectedInMessage) {
|
||||
t.Errorf("Expected message to contain %q, got %q", tt.expectedInMessage, bifrostErr.Error.Message)
|
||||
}
|
||||
|
||||
t.Logf("Handled %s: status=%d, message=%q", tt.name, *bifrostErr.StatusCode, bifrostErr.Error.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsHTMLResponse(b *testing.B) {
|
||||
resp := &fasthttp.Response{}
|
||||
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
|
||||
body := []byte(`<!DOCTYPE html><html><head><title>Error</title></head><body><h1>Test Error</h1></body></html>`)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
IsHTMLResponse(resp, body)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExtractHTMLErrorMessage(b *testing.B) {
|
||||
body := []byte(`<!DOCTYPE html><html><head><title>Internal Server Error</title></head><body><h1>Something went wrong</h1><p>This is a detailed error message</p></body></html>`)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ExtractHTMLErrorMessage(body)
|
||||
}
|
||||
}
|
||||
284
core/providers/utils/idle_timeout_reader_test.go
Normal file
284
core/providers/utils/idle_timeout_reader_test.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// readCloserSpy implements io.ReadCloser and records how many times Close() was called.
|
||||
type readCloserSpy struct {
|
||||
mu sync.Mutex
|
||||
closed int
|
||||
}
|
||||
|
||||
func (c *readCloserSpy) Read([]byte) (int, error) { return 0, io.EOF }
|
||||
|
||||
func (c *readCloserSpy) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.closed++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *readCloserSpy) closeCount() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.closed
|
||||
}
|
||||
|
||||
// zeroThenBlockReader returns (0, nil) on the first read, then blocks forever.
|
||||
type zeroThenBlockReader struct {
|
||||
first atomic.Bool
|
||||
pipeRd *io.PipeReader
|
||||
}
|
||||
|
||||
func (r *zeroThenBlockReader) Read(p []byte) (int, error) {
|
||||
if r.first.CompareAndSwap(false, true) {
|
||||
return 0, nil // zero-byte read
|
||||
}
|
||||
// block until pipe is closed
|
||||
return r.pipeRd.Read(p)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIdleTimeoutReader_NormalRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
|
||||
// Use pr as bodyStream — closing pr unblocks reads.
|
||||
wrapped, cleanup := NewIdleTimeoutReader(pr, pr, 500*time.Millisecond)
|
||||
defer cleanup()
|
||||
|
||||
// Writer sends 5 chunks quickly.
|
||||
go func() {
|
||||
for i := 0; i < 5; i++ {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
pw.Write([]byte("chunk"))
|
||||
}
|
||||
pw.Close()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 64)
|
||||
var totalBytes int
|
||||
for {
|
||||
n, err := wrapped.Read(buf)
|
||||
totalBytes += n
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if totalBytes != 5*len("chunk") {
|
||||
t.Fatalf("expected %d bytes, got %d", 5*len("chunk"), totalBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdleTimeoutReader_TimeoutClosesStream(t *testing.T) {
|
||||
t.Parallel()
|
||||
pr, pw := io.Pipe()
|
||||
defer pw.Close()
|
||||
|
||||
// 100ms timeout, write nothing — should timeout and close the pipe reader.
|
||||
wrapped, cleanup := NewIdleTimeoutReader(pr, pr, 100*time.Millisecond)
|
||||
defer cleanup()
|
||||
|
||||
start := time.Now()
|
||||
buf := make([]byte, 64)
|
||||
_, err := wrapped.Read(buf)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected an error from timed-out read, got nil")
|
||||
}
|
||||
|
||||
// Should complete within ~200ms (100ms timeout + margin), not hang.
|
||||
if elapsed > 500*time.Millisecond {
|
||||
t.Fatalf("read took %v, expected ~100ms timeout", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdleTimeoutReader_TimeoutAfterPartialData(t *testing.T) {
|
||||
t.Parallel()
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
// 200ms idle timeout.
|
||||
wrapped, cleanup := NewIdleTimeoutReader(pr, pr, 200*time.Millisecond)
|
||||
defer cleanup()
|
||||
|
||||
// Writer sends 3 chunks then stops.
|
||||
go func() {
|
||||
for i := 0; i < 3; i++ {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
pw.Write([]byte("data"))
|
||||
}
|
||||
// stop writing — idle timeout should fire after 200ms and close pr
|
||||
}()
|
||||
|
||||
buf := make([]byte, 64)
|
||||
chunksRead := 0
|
||||
for {
|
||||
n, err := wrapped.Read(buf)
|
||||
if n > 0 {
|
||||
chunksRead++
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if chunksRead != 3 {
|
||||
t.Fatalf("expected 3 chunks before timeout, got %d", chunksRead)
|
||||
}
|
||||
|
||||
pw.Close()
|
||||
}
|
||||
|
||||
func TestIdleTimeoutReader_ResetOnData(t *testing.T) {
|
||||
t.Parallel()
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
// 200ms timeout, but data arrives every 150ms — should never timeout.
|
||||
wrapped, cleanup := NewIdleTimeoutReader(pr, pr, 200*time.Millisecond)
|
||||
defer cleanup()
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 5; i++ {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
pw.Write([]byte("ok"))
|
||||
}
|
||||
pw.Close()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 64)
|
||||
chunksRead := 0
|
||||
for {
|
||||
n, err := wrapped.Read(buf)
|
||||
if n > 0 {
|
||||
chunksRead++
|
||||
}
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF after all chunks, got: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if chunksRead != 5 {
|
||||
t.Fatalf("expected 5 chunks (timer should reset), got %d", chunksRead)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdleTimeoutReader_CleanupStopsTimer(t *testing.T) {
|
||||
t.Parallel()
|
||||
pr, pw := io.Pipe()
|
||||
defer pr.Close()
|
||||
defer pw.Close()
|
||||
|
||||
spy := &readCloserSpy{}
|
||||
|
||||
_, cleanup := NewIdleTimeoutReader(pr, spy, 100*time.Millisecond)
|
||||
// Call cleanup immediately — timer should be stopped.
|
||||
cleanup()
|
||||
|
||||
// Wait well past the timeout duration.
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
if spy.closeCount() != 0 {
|
||||
t.Fatalf("expected closer to NOT be called after cleanup, but was called %d times", spy.closeCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdleTimeoutReader_DoubleCloseIsSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
spy := &readCloserSpy{}
|
||||
|
||||
br := &zeroThenBlockReader{first: atomic.Bool{}, pipeRd: nil}
|
||||
// Use spy as bodyStream — it implements both io.Reader and io.Closer.
|
||||
_, cleanup := NewIdleTimeoutReader(br, spy, 50*time.Millisecond)
|
||||
defer cleanup()
|
||||
|
||||
// Let the timer fire (closes spy via sync.Once).
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Manually close again — should not panic.
|
||||
spy.Close()
|
||||
|
||||
// sync.Once ensures the idle timer's close ran exactly once.
|
||||
// The manual close above adds another, so total should be 2
|
||||
// (the once.Do protects the timer path, not external callers).
|
||||
// The key guarantee: no panic.
|
||||
if spy.closeCount() < 1 {
|
||||
t.Fatal("expected at least one close call")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdleTimeoutReader_ZeroBytesDoNotResetTimer(t *testing.T) {
|
||||
t.Parallel()
|
||||
pr, pw := io.Pipe()
|
||||
defer pw.Close()
|
||||
|
||||
// Use pr as bodyStream — when idle timeout fires, it closes pr,
|
||||
// which causes reads on pr to return io.ErrClosedPipe.
|
||||
zr := &zeroThenBlockReader{pipeRd: pr}
|
||||
wrapped, cleanup := NewIdleTimeoutReader(zr, pr, 100*time.Millisecond)
|
||||
defer cleanup()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 64)
|
||||
// First read returns (0, nil), second read blocks until pipe is closed.
|
||||
for {
|
||||
_, err := wrapped.Read(buf)
|
||||
if err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Timer fired and closed the pipe — Read() returned an error. Good.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("expected idle timeout to fire, but read is still blocking")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdleTimeoutReader_ErrorFromClosedPipe(t *testing.T) {
|
||||
t.Parallel()
|
||||
pr, pw := io.Pipe()
|
||||
defer pw.Close()
|
||||
|
||||
// Use pr as bodyStream — when idle timeout fires, it closes pr,
|
||||
// which makes Read return io.ErrClosedPipe.
|
||||
wrapped, cleanup := NewIdleTimeoutReader(pr, pr, 50*time.Millisecond)
|
||||
defer cleanup()
|
||||
|
||||
buf := make([]byte, 64)
|
||||
_, err := wrapped.Read(buf)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error from closed pipe")
|
||||
}
|
||||
// The error should indicate the pipe was closed.
|
||||
if !errors.Is(err, io.ErrClosedPipe) && !errors.Is(err, io.EOF) {
|
||||
// Some implementations return io.ErrClosedPipe, others EOF.
|
||||
t.Logf("got error: %v (acceptable)", err)
|
||||
}
|
||||
}
|
||||
50
core/providers/utils/images.go
Normal file
50
core/providers/utils/images.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ConvertSizeToAspectRatioAndResolution converts a standard size string (e.g., "1024x1024")
|
||||
// to an aspect ratio and image size tier.
|
||||
// aspectRatio is one of "1:1", "3:4", "4:3", "9:16", "16:9" (empty if unrecognised).
|
||||
// imageSize is one of "1K", "2K", "4K" (empty if out of range).
|
||||
func ConvertSizeToAspectRatioAndResolution(size string) (aspectRatio, imageSize string) {
|
||||
parts := strings.Split(size, "x")
|
||||
if len(parts) != 2 {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
width, err1 := strconv.Atoi(parts[0])
|
||||
height, err2 := strconv.Atoi(parts[1])
|
||||
if err1 != nil || err2 != nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
if width <= 0 || height <= 0 {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
if width <= 1024 && height <= 1024 {
|
||||
imageSize = "1K"
|
||||
} else if width <= 2048 && height <= 2048 {
|
||||
imageSize = "2K"
|
||||
} else if width <= 4096 && height <= 4096 {
|
||||
imageSize = "4K"
|
||||
}
|
||||
|
||||
ratio := float64(width) / float64(height)
|
||||
if ratio >= 0.99 && ratio <= 1.01 {
|
||||
aspectRatio = "1:1"
|
||||
} else if ratio >= 0.74 && ratio <= 0.76 {
|
||||
aspectRatio = "3:4"
|
||||
} else if ratio >= 1.32 && ratio <= 1.34 {
|
||||
aspectRatio = "4:3"
|
||||
} else if ratio >= 0.56 && ratio <= 0.57 {
|
||||
aspectRatio = "9:16"
|
||||
} else if ratio >= 1.77 && ratio <= 1.78 {
|
||||
aspectRatio = "16:9"
|
||||
}
|
||||
|
||||
return aspectRatio, imageSize
|
||||
}
|
||||
339
core/providers/utils/large_response.go
Normal file
339
core/providers/utils/large_response.go
Normal file
@@ -0,0 +1,339 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// LargeResponseReader wraps an io.Reader and releases the fasthttp response on Close.
|
||||
// Used by providers to keep the response alive while the transport streams it to the client.
|
||||
type LargeResponseReader struct {
|
||||
io.Reader
|
||||
Resp *fasthttp.Response
|
||||
cleanup func()
|
||||
consumed bool // true after Read returns io.EOF — body fully consumed through Reader chain
|
||||
}
|
||||
|
||||
// Read delegates to the wrapped Reader and tracks EOF so Close() can skip
|
||||
// a redundant (and potentially blocking) drain of the body stream.
|
||||
func (r *LargeResponseReader) Read(p []byte) (int, error) {
|
||||
n, err := r.Reader.Read(p)
|
||||
if err == io.EOF {
|
||||
r.consumed = true
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close drains any unconsumed body stream and releases the underlying fasthttp
|
||||
// response back to the pool. Draining prevents "whitespace in header" errors on
|
||||
// connection reuse when the client disconnects before the full response is consumed
|
||||
// (see: fasthttp#1743).
|
||||
//
|
||||
// When the body was already fully consumed through the Reader chain (consumed == true),
|
||||
// the drain is skipped. For identity-encoded responses (no Content-Length), the body
|
||||
// stream is a fasthttp closeReader that blocks until the TCP connection closes — which
|
||||
// can take minutes if the upstream server keeps the connection alive.
|
||||
func (r *LargeResponseReader) Close() error {
|
||||
if r == nil || r.Resp == nil {
|
||||
return nil
|
||||
}
|
||||
if !r.consumed {
|
||||
if bodyStream := r.Resp.BodyStream(); bodyStream != nil {
|
||||
_, _ = io.Copy(io.Discard, bodyStream)
|
||||
if closer, ok := bodyStream.(io.Closer); ok {
|
||||
_ = closer.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
if r.cleanup != nil {
|
||||
r.cleanup()
|
||||
r.cleanup = nil
|
||||
}
|
||||
fasthttp.ReleaseResponse(r.Resp)
|
||||
r.Resp = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildLargeResponseClient creates a streaming-enabled fasthttp client for large response detection.
|
||||
// The client caps buffering at the threshold and enables response body streaming.
|
||||
//
|
||||
// ReadTimeout/WriteTimeout/MaxConnDuration are zeroed: large-response bodies may take arbitrarily
|
||||
// long to download, and fasthttp's ReadTimeout bounds *full* body read — not idle. Idle detection
|
||||
// on stalled streams is handled separately (see NewIdleTimeoutReader / SetupStreamingPassthrough).
|
||||
func BuildLargeResponseClient(base *fasthttp.Client, responseThreshold int64) *fasthttp.Client {
|
||||
client := CloneFastHTTPClientConfig(base)
|
||||
if responseThreshold > 0 && responseThreshold <= int64(math.MaxInt) {
|
||||
client.MaxResponseBodySize = int(responseThreshold)
|
||||
}
|
||||
client.StreamResponseBody = true
|
||||
client.ReadTimeout = 0
|
||||
client.WriteTimeout = 0
|
||||
client.MaxConnDuration = 0
|
||||
return client
|
||||
}
|
||||
|
||||
// PrepareResponseStreaming configures response body streaming when a large response
|
||||
// threshold is set in context. Returns the client to use for MakeRequestWithContext.
|
||||
// When threshold > 0: sets resp.StreamBody = true and returns a streaming-enabled client.
|
||||
// When threshold <= 0: returns the original client unchanged (no-op for feature-off path).
|
||||
func PrepareResponseStreaming(ctx *schemas.BifrostContext, client *fasthttp.Client, resp *fasthttp.Response) *fasthttp.Client {
|
||||
responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64)
|
||||
if responseThreshold <= 0 {
|
||||
return client
|
||||
}
|
||||
resp.StreamBody = true
|
||||
return BuildLargeResponseClient(client, responseThreshold)
|
||||
}
|
||||
|
||||
// MaterializeStreamErrorBody reads a streamed error body into resp so that resp.Body()
|
||||
// returns the error payload for parsing. No-op when response streaming is not active.
|
||||
func MaterializeStreamErrorBody(ctx *schemas.BifrostContext, resp *fasthttp.Response) {
|
||||
responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64)
|
||||
if responseThreshold <= 0 {
|
||||
return
|
||||
}
|
||||
if bodyStream := resp.BodyStream(); bodyStream != nil {
|
||||
gz, reader, wasGzip := decompressBodyStreamIfGzip(resp, bodyStream)
|
||||
if wasGzip {
|
||||
defer ReleaseGzipReader(gz)
|
||||
}
|
||||
bodyBytes, readErr := io.ReadAll(io.LimitReader(reader, 512*1024)) // 512KB cap for error bodies
|
||||
if readErr != nil {
|
||||
return
|
||||
}
|
||||
resp.SetBody(bodyBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// FinalizeResponseWithLargeDetection processes the response body with optional large response
|
||||
// detection. Takes ownership semantics: when isLargeResponse is true, the caller must NOT
|
||||
// release resp (it's wrapped in a reader stored in context). When false, resp is unchanged
|
||||
// and the caller should release as normal.
|
||||
//
|
||||
// Returns:
|
||||
// - (body, false, nil) — normal path; body ready for parsing; resp NOT released.
|
||||
// - (nil, true, nil) — large response detected; context keys set for streaming;
|
||||
// caller must set respOwned = false.
|
||||
// - (nil, false, err) — error; resp NOT released.
|
||||
func FinalizeResponseWithLargeDetection(
|
||||
ctx *schemas.BifrostContext,
|
||||
resp *fasthttp.Response,
|
||||
logger schemas.Logger,
|
||||
) ([]byte, bool, *schemas.BifrostError) {
|
||||
responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64)
|
||||
|
||||
// No threshold — normal buffered read (feature-off path)
|
||||
if responseThreshold <= 0 {
|
||||
body, err := CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
}
|
||||
// Copy body before caller releases resp
|
||||
return append([]byte(nil), body...), false, nil
|
||||
}
|
||||
|
||||
contentLength := resp.Header.ContentLength()
|
||||
|
||||
// Known small response — read from stream, return body for normal parsing
|
||||
if contentLength > 0 && int64(contentLength) <= responseThreshold {
|
||||
if bodyStream := resp.BodyStream(); bodyStream != nil {
|
||||
gz, reader, wasGzip := decompressBodyStreamIfGzip(resp, bodyStream)
|
||||
if wasGzip {
|
||||
defer ReleaseGzipReader(gz)
|
||||
}
|
||||
bodyBytes, readErr := io.ReadAll(reader)
|
||||
if readErr != nil {
|
||||
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr)
|
||||
}
|
||||
return bodyBytes, false, nil
|
||||
}
|
||||
// No stream — buffered fallback
|
||||
body, err := CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
}
|
||||
return append([]byte(nil), body...), false, nil
|
||||
}
|
||||
|
||||
// Unknown Content-Length (chunked transfer encoding) — buffer up to responseThreshold
|
||||
// to determine if response is truly large. Responses within threshold are returned
|
||||
// buffered for normal parsing/logging; only responses exceeding threshold are streamed.
|
||||
if contentLength <= 0 {
|
||||
if bodyStream := resp.BodyStream(); bodyStream != nil {
|
||||
gz, reader, wasGzip := decompressBodyStreamIfGzip(resp, bodyStream)
|
||||
releaseGzip := func() {}
|
||||
if wasGzip {
|
||||
releaseGzip = func() {
|
||||
ReleaseGzipReader(gz)
|
||||
}
|
||||
}
|
||||
bodyBytes, readErr := io.ReadAll(io.LimitReader(reader, responseThreshold+1))
|
||||
if readErr != nil {
|
||||
releaseGzip()
|
||||
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr)
|
||||
}
|
||||
if int64(len(bodyBytes)) <= responseThreshold {
|
||||
releaseGzip()
|
||||
return bodyBytes, false, nil
|
||||
}
|
||||
// Exceeds threshold without Content-Length — set up large response streaming.
|
||||
combinedReader := io.MultiReader(bytes.NewReader(bodyBytes), reader)
|
||||
closableReader := &LargeResponseReader{
|
||||
Reader: combinedReader,
|
||||
Resp: resp,
|
||||
cleanup: releaseGzip,
|
||||
}
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargeResponseMode, true)
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargeResponseReader, closableReader)
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargeResponseContentLength, contentLength)
|
||||
if ct := string(resp.Header.ContentType()); ct != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargeResponseContentType, ct)
|
||||
}
|
||||
previewLen := min(len(bodyBytes), 1048576)
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargePayloadResponsePreview, string(bodyBytes[:previewLen]))
|
||||
return nil, true, nil
|
||||
}
|
||||
// No stream — buffered fallback
|
||||
body, err := CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
}
|
||||
return append([]byte(nil), body...), false, nil
|
||||
}
|
||||
|
||||
// Known large response (Content-Length > threshold) — prefetch first 64KB for
|
||||
// metadata extraction, then stream the rest without full materialization.
|
||||
bodyStream := resp.BodyStream()
|
||||
if bodyStream == nil {
|
||||
// No stream available — fall back to buffered read
|
||||
if logger != nil {
|
||||
logger.Warn("large-response fallback to buffered path: content_length=%d threshold=%d body_stream_nil=true", contentLength, responseThreshold)
|
||||
}
|
||||
body, err := CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
}
|
||||
return append([]byte(nil), body...), false, nil
|
||||
}
|
||||
|
||||
// Decompress on-the-fly if provider returned gzip-encoded response.
|
||||
// Clears Content-Encoding so the transport doesn't re-add it to the client response.
|
||||
gz, decompressedStream, wasGzip := decompressBodyStreamIfGzip(resp, bodyStream)
|
||||
if wasGzip {
|
||||
contentLength = -1 // decompressed size unknown; transport will use chunked encoding
|
||||
}
|
||||
|
||||
prefetchSize := 64 * 1024 // default
|
||||
if ps, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadPrefetchSize).(int); ok && ps > 0 {
|
||||
prefetchSize = ps
|
||||
}
|
||||
prefetchBuf := make([]byte, prefetchSize)
|
||||
n, readErr := io.ReadFull(decompressedStream, prefetchBuf)
|
||||
if readErr != nil && readErr != io.EOF && readErr != io.ErrUnexpectedEOF {
|
||||
if wasGzip {
|
||||
ReleaseGzipReader(gz)
|
||||
}
|
||||
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr)
|
||||
}
|
||||
prefetchBuf = prefetchBuf[:n]
|
||||
|
||||
combinedReader := io.MultiReader(bytes.NewReader(prefetchBuf), decompressedStream)
|
||||
closableReader := &LargeResponseReader{
|
||||
Reader: combinedReader,
|
||||
Resp: resp,
|
||||
cleanup: func() {
|
||||
if wasGzip {
|
||||
ReleaseGzipReader(gz)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargeResponseMode, true)
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargeResponseReader, closableReader)
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargeResponseContentLength, contentLength)
|
||||
if ct := string(resp.Header.ContentType()); ct != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargeResponseContentType, ct)
|
||||
}
|
||||
previewLen := min(n, 1048576)
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargePayloadResponsePreview, string(prefetchBuf[:previewLen]))
|
||||
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
// ParseOpenAIUsageFromBytes parses OpenAI-format usage from raw JSON bytes into BifrostLLMUsage.
|
||||
// Handles both Chat Completions (prompt_tokens/completion_tokens) and Responses API
|
||||
// (input_tokens/output_tokens) field names. Expects the "usage" object bytes directly,
|
||||
// not the full response body.
|
||||
func ParseOpenAIUsageFromBytes(data []byte) *schemas.BifrostLLMUsage {
|
||||
var usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
// Responses API uses different field names
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
if err := sonic.Unmarshal(data, &usage); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := &schemas.BifrostLLMUsage{}
|
||||
if usage.PromptTokens > 0 {
|
||||
result.PromptTokens = usage.PromptTokens
|
||||
} else if usage.InputTokens > 0 {
|
||||
result.PromptTokens = usage.InputTokens
|
||||
}
|
||||
if usage.CompletionTokens > 0 {
|
||||
result.CompletionTokens = usage.CompletionTokens
|
||||
} else if usage.OutputTokens > 0 {
|
||||
result.CompletionTokens = usage.OutputTokens
|
||||
}
|
||||
if usage.TotalTokens > 0 {
|
||||
result.TotalTokens = usage.TotalTokens
|
||||
} else {
|
||||
result.TotalTokens = result.PromptTokens + result.CompletionTokens
|
||||
}
|
||||
|
||||
if result.TotalTokens == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SetupStreamingPassthrough configures large response passthrough for streaming
|
||||
// responses when large payload mode is active. Wraps the response body stream
|
||||
// in a LargeResponseReader and sets context keys for the transport layer.
|
||||
// Returns true if passthrough was set up. When true, the caller should return
|
||||
// a closed channel and must NOT release resp — it's owned by the reader in context.
|
||||
func SetupStreamingPassthrough(ctx *schemas.BifrostContext, resp *fasthttp.Response) bool {
|
||||
isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool)
|
||||
if !isLargePayload {
|
||||
return false
|
||||
}
|
||||
|
||||
reader, releaseGzip := DecompressStreamBody(resp)
|
||||
|
||||
// Wrap reader with idle timeout to detect stalled streams.
|
||||
reader, stopIdleTimeout := NewIdleTimeoutReader(reader, resp.BodyStream(), GetStreamIdleTimeout(ctx))
|
||||
|
||||
closableReader := &LargeResponseReader{
|
||||
Reader: reader,
|
||||
Resp: resp,
|
||||
cleanup: func() {
|
||||
stopIdleTimeout()
|
||||
releaseGzip()
|
||||
},
|
||||
}
|
||||
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargeResponseMode, true)
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargeResponseReader, closableReader)
|
||||
if ct := string(resp.Header.ContentType()); ct != "" {
|
||||
ctx.SetValue(schemas.BifrostContextKeyLargeResponseContentType, ct)
|
||||
}
|
||||
return true
|
||||
}
|
||||
406
core/providers/utils/make_request_test.go
Normal file
406
core/providers/utils/make_request_test.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
"github.com/valyala/fasthttp/fasthttputil"
|
||||
)
|
||||
|
||||
// newTestServer creates an in-memory fasthttp server that responds after the given delay.
|
||||
// Returns a client configured to talk to it and a cleanup function.
|
||||
func newTestServer(t *testing.T, delay time.Duration, statusCode int) (*fasthttp.Client, func()) {
|
||||
t.Helper()
|
||||
ln := fasthttputil.NewInmemoryListener()
|
||||
|
||||
server := &fasthttp.Server{
|
||||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||||
if delay > 0 {
|
||||
time.Sleep(delay)
|
||||
}
|
||||
ctx.SetStatusCode(statusCode)
|
||||
ctx.SetBody([]byte(`{"ok":true}`))
|
||||
},
|
||||
}
|
||||
|
||||
go server.Serve(ln) //nolint:errcheck
|
||||
|
||||
client := &fasthttp.Client{
|
||||
Dial: func(addr string) (net.Conn, error) {
|
||||
return ln.Dial()
|
||||
},
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
ln.Close()
|
||||
}
|
||||
|
||||
return client, cleanup
|
||||
}
|
||||
|
||||
func TestMakeRequestWithContext_SuccessReturnsNoopWait(t *testing.T) {
|
||||
client, cleanup := newTestServer(t, 0, 200)
|
||||
defer cleanup()
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI("http://test/")
|
||||
|
||||
latency, bifrostErr, wait := MakeRequestWithContext(context.Background(), client, req, resp)
|
||||
defer wait()
|
||||
|
||||
if bifrostErr != nil {
|
||||
t.Fatalf("expected no error, got: %v", bifrostErr.Error.Message)
|
||||
}
|
||||
if latency <= 0 {
|
||||
t.Fatal("expected positive latency")
|
||||
}
|
||||
if resp.StatusCode() != 200 {
|
||||
t.Fatalf("expected status 200, got %d", resp.StatusCode())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeRequestWithContext_DeadlineExceededReturnsTimeoutError(t *testing.T) {
|
||||
// Server takes 500ms to respond
|
||||
client, cleanup := newTestServer(t, 500*time.Millisecond, 200)
|
||||
defer cleanup()
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
req.SetRequestURI("http://test/")
|
||||
|
||||
// Deadline exceeded almost immediately
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, bifrostErr, wait := MakeRequestWithContext(ctx, client, req, resp)
|
||||
|
||||
// Should get a timeout error with 504 status
|
||||
if bifrostErr == nil {
|
||||
t.Fatal("expected timeout error")
|
||||
}
|
||||
if bifrostErr.Error.Type == nil || *bifrostErr.Error.Type != schemas.RequestTimedOut {
|
||||
t.Fatalf("expected RequestTimedOut error type, got: %v", bifrostErr.Error.Type)
|
||||
}
|
||||
if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != 504 {
|
||||
t.Fatalf("expected status 504, got: %v", bifrostErr.StatusCode)
|
||||
}
|
||||
|
||||
// wait() should block until the goroutine finishes, then we can safely release
|
||||
start := time.Now()
|
||||
wait()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// The wait should have taken roughly the remaining server delay (~490ms)
|
||||
if elapsed < 200*time.Millisecond {
|
||||
t.Fatalf("wait() returned too quickly (%v), expected it to block until goroutine finishes", elapsed)
|
||||
}
|
||||
|
||||
// Now safe to release
|
||||
fasthttp.ReleaseRequest(req)
|
||||
fasthttp.ReleaseResponse(resp)
|
||||
}
|
||||
|
||||
func TestMakeRequestWithContext_ContextCancelReturnsCancelledError(t *testing.T) {
|
||||
// Server takes 500ms to respond
|
||||
client, cleanup := newTestServer(t, 500*time.Millisecond, 200)
|
||||
defer cleanup()
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
req.SetRequestURI("http://test/")
|
||||
|
||||
// Cancel context explicitly (not deadline)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
_, bifrostErr, wait := MakeRequestWithContext(ctx, client, req, resp)
|
||||
|
||||
// Should get a cancellation error with 499 status
|
||||
if bifrostErr == nil {
|
||||
t.Fatal("expected cancellation error")
|
||||
}
|
||||
if bifrostErr.Error.Type == nil || *bifrostErr.Error.Type != schemas.RequestCancelled {
|
||||
t.Fatalf("expected RequestCancelled error type, got: %v", bifrostErr.Error.Type)
|
||||
}
|
||||
if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != 499 {
|
||||
t.Fatalf("expected status 499, got: %v", bifrostErr.StatusCode)
|
||||
}
|
||||
|
||||
// wait() should block until the goroutine finishes
|
||||
start := time.Now()
|
||||
wait()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if elapsed < 200*time.Millisecond {
|
||||
t.Fatalf("wait() returned too quickly (%v), expected it to block until goroutine finishes", elapsed)
|
||||
}
|
||||
|
||||
fasthttp.ReleaseRequest(req)
|
||||
fasthttp.ReleaseResponse(resp)
|
||||
}
|
||||
|
||||
func TestMakeRequestWithContext_WaitPreventsDataRace(t *testing.T) {
|
||||
// This test verifies the fix for the data race. Under -race, accessing resp
|
||||
// while client.Do is still writing to it would be flagged. The wait function
|
||||
// ensures we don't release until the goroutine is done.
|
||||
//
|
||||
// Run with: go test -race -run TestMakeRequestWithContext_WaitPreventsDataRace
|
||||
|
||||
// Server responds after 200ms
|
||||
client, cleanup := newTestServer(t, 200*time.Millisecond, 200)
|
||||
defer cleanup()
|
||||
|
||||
for range 10 {
|
||||
func() {
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
req.SetRequestURI("http://test/")
|
||||
|
||||
// Cancel context after 5ms — well before server responds
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, _, wait := MakeRequestWithContext(ctx, client, req, resp)
|
||||
|
||||
// Simulate the real caller pattern: defer wait() before defer Release.
|
||||
// Go defers are LIFO, so wait() runs first, then Release.
|
||||
// This is the pattern that prevents the data race.
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
defer wait()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeRequestWithContext_WaitIsIdempotent(t *testing.T) {
|
||||
client, cleanup := newTestServer(t, 50*time.Millisecond, 200)
|
||||
defer cleanup()
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI("http://test/")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, _, wait := MakeRequestWithContext(ctx, client, req, resp)
|
||||
|
||||
// First call should block
|
||||
wait()
|
||||
// Second call should not deadlock (channel already drained)
|
||||
// Note: this will deadlock if the implementation is wrong, so the test
|
||||
// would time out rather than fail gracefully.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Second wait() completed — but note this actually WILL deadlock with
|
||||
// the current implementation since <-errChan can only be read once.
|
||||
// This documents the behavior: wait() should only be called once.
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Expected: second wait() blocks forever because errChan is already drained.
|
||||
// This is fine — callers should only call wait() once (via a single defer).
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeRequestWithContext_SuccessWaitDoesNotBlock(t *testing.T) {
|
||||
client, cleanup := newTestServer(t, 0, 200)
|
||||
defer cleanup()
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI("http://test/")
|
||||
|
||||
_, _, wait := MakeRequestWithContext(context.Background(), client, req, resp)
|
||||
|
||||
// On the success path, wait should be a noop that returns immediately
|
||||
start := time.Now()
|
||||
wait()
|
||||
if time.Since(start) > 10*time.Millisecond {
|
||||
t.Fatal("wait() on success path should be a noop and return immediately")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeRequestWithContext_ConcurrentRequestsWithCancellation(t *testing.T) {
|
||||
// Simulate the production scenario: multiple concurrent requests where some
|
||||
// contexts cancel while the HTTP call is in-flight. Under -race, this would
|
||||
// detect the original bug where deferred Release races with client.Do.
|
||||
client, cleanup := newTestServer(t, 100*time.Millisecond, 200)
|
||||
defer cleanup()
|
||||
|
||||
const numRequests = 20
|
||||
var completed atomic.Int32
|
||||
|
||||
done := make(chan struct{})
|
||||
for range numRequests {
|
||||
go func() {
|
||||
defer func() {
|
||||
if completed.Add(1) == numRequests {
|
||||
close(done)
|
||||
}
|
||||
}()
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
req.SetRequestURI("http://test/")
|
||||
|
||||
// Half the requests cancel early, half complete normally
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
if completed.Load()%2 == 0 {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
} else {
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
|
||||
}
|
||||
|
||||
_, _, wait := MakeRequestWithContext(ctx, client, req, resp)
|
||||
// Correct pattern: wait before release
|
||||
wait()
|
||||
cancel()
|
||||
fasthttp.ReleaseRequest(req)
|
||||
fasthttp.ReleaseResponse(resp)
|
||||
}()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// All requests completed
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("timed out waiting for requests, only %d/%d completed", completed.Load(), numRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBifrostTimeoutError(t *testing.T) {
|
||||
err := NewBifrostTimeoutError("test timeout", context.DeadlineExceeded)
|
||||
|
||||
if !err.IsBifrostError {
|
||||
t.Fatal("expected IsBifrostError to be true")
|
||||
}
|
||||
if err.StatusCode == nil || *err.StatusCode != 504 {
|
||||
t.Fatalf("expected StatusCode 504, got %v", err.StatusCode)
|
||||
}
|
||||
if err.Error.Type == nil || *err.Error.Type != schemas.RequestTimedOut {
|
||||
t.Fatalf("expected RequestTimedOut type, got %v", err.Error.Type)
|
||||
}
|
||||
if err.Error.Message != "test timeout" {
|
||||
t.Fatalf("expected 'test timeout', got %s", err.Error.Message)
|
||||
}
|
||||
// Note: ExtraFields.Provider is populated by bifrost.go's dispatcher via
|
||||
// PopulateExtraFields, not by NewBifrostTimeoutError — the constructor has
|
||||
// no provider context.
|
||||
}
|
||||
|
||||
func TestMakeRequestWithContext_ClientError(t *testing.T) {
|
||||
// Test that client errors still return noop wait function
|
||||
client := &fasthttp.Client{
|
||||
Dial: func(addr string) (net.Conn, error) {
|
||||
return nil, &net.OpError{Op: "dial", Net: "tcp", Err: &net.DNSError{Err: "no such host", Name: "nonexistent.invalid"}}
|
||||
},
|
||||
ReadTimeout: 1 * time.Second,
|
||||
WriteTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI("http://nonexistent.invalid/")
|
||||
|
||||
_, bifrostErr, wait := MakeRequestWithContext(context.Background(), client, req, resp)
|
||||
defer wait()
|
||||
|
||||
if bifrostErr == nil {
|
||||
t.Fatal("expected error for nonexistent host")
|
||||
}
|
||||
// wait should be noop since the goroutine completed (with error)
|
||||
start := time.Now()
|
||||
wait()
|
||||
if time.Since(start) > 10*time.Millisecond {
|
||||
t.Fatal("wait() should be noop on error path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeRequestWithContext_DeferOrderingPattern(t *testing.T) {
|
||||
// Verify the exact defer pattern used by callers works correctly under -race.
|
||||
// This mirrors the real provider code pattern.
|
||||
client, cleanup := newTestServer(t, 150*time.Millisecond, 200)
|
||||
defer cleanup()
|
||||
|
||||
// Track the order of operations
|
||||
var order []string
|
||||
var orderDone = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(orderDone)
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
req.SetRequestURI("http://test/")
|
||||
|
||||
// Mimic the real provider pattern with defer ordering:
|
||||
// These defers run in reverse order (LIFO)
|
||||
defer func() {
|
||||
fasthttp.ReleaseRequest(req)
|
||||
order = append(order, "release-req")
|
||||
}()
|
||||
defer func() {
|
||||
fasthttp.ReleaseResponse(resp)
|
||||
order = append(order, "release-resp")
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, _, wait := MakeRequestWithContext(ctx, client, req, resp)
|
||||
// This defer runs FIRST (last declared = first to run)
|
||||
defer func() {
|
||||
wait()
|
||||
order = append(order, "wait-done")
|
||||
}()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-orderDone:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out")
|
||||
}
|
||||
|
||||
// Verify order: wait must complete before any release
|
||||
if len(order) != 3 {
|
||||
t.Fatalf("expected 3 operations, got %d: %v", len(order), order)
|
||||
}
|
||||
if order[0] != "wait-done" {
|
||||
t.Fatalf("expected wait-done first, got: %v", order)
|
||||
}
|
||||
if order[1] != "release-resp" {
|
||||
t.Fatalf("expected release-resp second, got: %v", order)
|
||||
}
|
||||
if order[2] != "release-req" {
|
||||
t.Fatalf("expected release-req third, got: %v", order)
|
||||
}
|
||||
}
|
||||
265
core/providers/utils/modelparamscache.go
Normal file
265
core/providers/utils/modelparamscache.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const DefaultModelParamsCacheSize = 2048
|
||||
|
||||
// ModelParams holds cached parameters for a model.
|
||||
// Add new fields here as more model-level parameters need caching.
|
||||
type ModelParams struct {
|
||||
MaxOutputTokens *int
|
||||
}
|
||||
|
||||
type modelParamsCacheEntry struct {
|
||||
model string
|
||||
params ModelParams
|
||||
}
|
||||
|
||||
// inflightCall represents an in-progress cache miss handler invocation.
|
||||
// Multiple goroutines waiting for the same model share one call.
|
||||
type inflightCall struct {
|
||||
done chan struct{}
|
||||
result *ModelParams
|
||||
}
|
||||
|
||||
type modelParamsCache struct {
|
||||
mu sync.RWMutex
|
||||
capacity int
|
||||
items map[string]*list.Element
|
||||
order *list.List // front = most recently inserted/updated
|
||||
cacheMissHandler func(model string) *ModelParams
|
||||
|
||||
inflightMu sync.Mutex
|
||||
inflight map[string]*inflightCall
|
||||
}
|
||||
|
||||
var (
|
||||
globalModelParamsCache *modelParamsCache
|
||||
cacheOnce sync.Once
|
||||
)
|
||||
|
||||
// knownAnthropicMaxOutputTokens provides static fallback defaults for Claude models
|
||||
// when both cache and DB miss handler return nothing. Only Anthropic requires max_tokens.
|
||||
var knownAnthropicMaxOutputTokens = map[string]int{
|
||||
"claude-opus-4-6": 128000,
|
||||
"claude-sonnet-4-6": 64000,
|
||||
"claude-haiku-4-5": 64000,
|
||||
"claude-sonnet-4-5": 64000,
|
||||
"claude-opus-4-5": 64000,
|
||||
"claude-opus-4-1": 32000,
|
||||
"claude-sonnet-4": 64000,
|
||||
"claude-opus-4": 32000,
|
||||
"claude-sonnet-4-0": 64000,
|
||||
"claude-opus-4-0": 32000,
|
||||
"claude-3-5-sonnet": 8192,
|
||||
"claude-3-5-haiku": 8192,
|
||||
"claude-3-7-sonnet": 8192,
|
||||
"claude-3-opus": 4096,
|
||||
"claude-3-sonnet": 4096,
|
||||
"claude-3-haiku": 4096,
|
||||
}
|
||||
|
||||
func newModelParamsCache(capacity int) *modelParamsCache {
|
||||
return &modelParamsCache{
|
||||
capacity: capacity,
|
||||
items: make(map[string]*list.Element, capacity),
|
||||
order: list.New(),
|
||||
inflight: make(map[string]*inflightCall),
|
||||
}
|
||||
}
|
||||
|
||||
func getModelParamsCache() *modelParamsCache {
|
||||
cacheOnce.Do(func() {
|
||||
globalModelParamsCache = newModelParamsCache(DefaultModelParamsCacheSize)
|
||||
})
|
||||
return globalModelParamsCache
|
||||
}
|
||||
|
||||
func (c *modelParamsCache) Get(model string) (ModelParams, bool) {
|
||||
c.mu.Lock()
|
||||
elem, ok := c.items[model]
|
||||
if ok {
|
||||
c.order.MoveToFront(elem)
|
||||
params := elem.Value.(*modelParamsCacheEntry).params
|
||||
c.mu.Unlock()
|
||||
return params, true
|
||||
}
|
||||
handler := c.cacheMissHandler
|
||||
c.mu.Unlock()
|
||||
|
||||
if handler == nil {
|
||||
return ModelParams{}, false
|
||||
}
|
||||
|
||||
// Deduplicate concurrent miss handler calls for the same model.
|
||||
c.inflightMu.Lock()
|
||||
if call, ok := c.inflight[model]; ok {
|
||||
c.inflightMu.Unlock()
|
||||
<-call.done
|
||||
if call.result == nil {
|
||||
return ModelParams{}, false
|
||||
}
|
||||
return *call.result, true
|
||||
}
|
||||
call := &inflightCall{done: make(chan struct{})}
|
||||
c.inflight[model] = call
|
||||
c.inflightMu.Unlock()
|
||||
|
||||
result := handler(model)
|
||||
call.result = result
|
||||
close(call.done)
|
||||
|
||||
c.inflightMu.Lock()
|
||||
delete(c.inflight, model)
|
||||
c.inflightMu.Unlock()
|
||||
|
||||
if result == nil {
|
||||
return ModelParams{}, false
|
||||
}
|
||||
c.Set(model, *result)
|
||||
return *result, true
|
||||
}
|
||||
|
||||
func (c *modelParamsCache) Set(model string, params ModelParams) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if elem, ok := c.items[model]; ok {
|
||||
elem.Value.(*modelParamsCacheEntry).params = params
|
||||
c.order.MoveToFront(elem)
|
||||
return
|
||||
}
|
||||
|
||||
if c.order.Len() >= c.capacity {
|
||||
c.evict()
|
||||
}
|
||||
|
||||
entry := &modelParamsCacheEntry{model: model, params: params}
|
||||
elem := c.order.PushFront(entry)
|
||||
c.items[model] = elem
|
||||
}
|
||||
|
||||
func (c *modelParamsCache) BulkSet(entries map[string]ModelParams) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for model, params := range entries {
|
||||
if elem, ok := c.items[model]; ok {
|
||||
elem.Value.(*modelParamsCacheEntry).params = params
|
||||
c.order.MoveToFront(elem)
|
||||
continue
|
||||
}
|
||||
|
||||
if c.order.Len() >= c.capacity {
|
||||
c.evict()
|
||||
}
|
||||
|
||||
entry := &modelParamsCacheEntry{model: model, params: params}
|
||||
elem := c.order.PushFront(entry)
|
||||
c.items[model] = elem
|
||||
}
|
||||
}
|
||||
|
||||
func (c *modelParamsCache) evict() {
|
||||
tail := c.order.Back()
|
||||
if tail == nil {
|
||||
return
|
||||
}
|
||||
c.order.Remove(tail)
|
||||
delete(c.items, tail.Value.(*modelParamsCacheEntry).model)
|
||||
}
|
||||
|
||||
// GetModelParams returns the cached parameters for a model.
|
||||
// On cache miss, calls the registered miss handler (if any) to load from DB.
|
||||
func GetModelParams(model string) (ModelParams, bool) {
|
||||
return getModelParamsCache().Get(model)
|
||||
}
|
||||
|
||||
// SetModelParams sets the parameters for a model in the cache.
|
||||
func SetModelParams(model string, params ModelParams) {
|
||||
getModelParamsCache().Set(model, params)
|
||||
}
|
||||
|
||||
// BulkSetModelParams sets parameters for multiple models at once.
|
||||
func BulkSetModelParams(entries map[string]ModelParams) {
|
||||
getModelParamsCache().BulkSet(entries)
|
||||
}
|
||||
|
||||
// SetCacheMissHandler registers a callback invoked on cache miss.
|
||||
// The handler should query the DB for the model's parameters and return them,
|
||||
// or nil if not found. The result is automatically cached.
|
||||
func SetCacheMissHandler(fn func(model string) *ModelParams) {
|
||||
c := getModelParamsCache()
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.cacheMissHandler = fn
|
||||
}
|
||||
|
||||
// GetMaxOutputTokens returns the cached max_output_tokens for a model.
|
||||
// Returns 0, false on cache miss or if max_output_tokens is not set.
|
||||
func GetMaxOutputTokens(model string) (int, bool) {
|
||||
params, ok := GetModelParams(model)
|
||||
if !ok || params.MaxOutputTokens == nil {
|
||||
return 0, false
|
||||
}
|
||||
return *params.MaxOutputTokens, true
|
||||
}
|
||||
|
||||
// GetMaxOutputTokensOrDefault returns the cached max_output_tokens for a model,
|
||||
// or the provided default value on cache miss. For Claude models, falls back to
|
||||
// known static defaults before using the caller's default.
|
||||
func GetMaxOutputTokensOrDefault(model string, defaultValue int) int {
|
||||
if m, ok := GetMaxOutputTokens(model); ok {
|
||||
return m
|
||||
}
|
||||
if strings.Contains(model, "claude") {
|
||||
base := normalizeClaudeModelName(model)
|
||||
if base != model {
|
||||
if m, ok := GetMaxOutputTokens(base); ok {
|
||||
return m
|
||||
}
|
||||
}
|
||||
if m, ok := knownAnthropicMaxOutputTokens[base]; ok {
|
||||
return m
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// normalizeClaudeModelName extracts the base Claude model name from
|
||||
// provider-specific model ID formats.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// "claude-sonnet-4-20250514" → "claude-sonnet-4"
|
||||
// "anthropic.claude-sonnet-4-20250514-v1:0" → "claude-sonnet-4"
|
||||
// "us.anthropic.claude-sonnet-4-20250514-v1:0" → "claude-sonnet-4"
|
||||
// "claude-3-5-sonnet-20241022" → "claude-3-5-sonnet"
|
||||
func normalizeClaudeModelName(model string) string {
|
||||
// Strip region + provider prefixes (us.anthropic., anthropic., etc.)
|
||||
if idx := strings.LastIndex(model, "."); idx >= 0 {
|
||||
model = model[idx+1:]
|
||||
}
|
||||
// Strip Bedrock version suffix (":0", ":1", etc.) and the preceding "-v1"/"-v2"
|
||||
if idx := strings.Index(model, ":"); idx >= 0 {
|
||||
model = model[:idx]
|
||||
if len(model) >= 3 {
|
||||
suffix := model[len(model)-3:]
|
||||
if suffix == "-v1" || suffix == "-v2" {
|
||||
model = model[:len(model)-3]
|
||||
}
|
||||
}
|
||||
}
|
||||
// Strip "-v1", "-v2" even without colon (e.g., "anthropic.claude-opus-4-6-v1")
|
||||
if strings.HasSuffix(model, "-v1") || strings.HasSuffix(model, "-v2") {
|
||||
model = model[:len(model)-3]
|
||||
}
|
||||
// Strip date version suffix using schemas.BaseModelName
|
||||
return schemas.BaseModelName(model)
|
||||
}
|
||||
337
core/providers/utils/modelparamscache_test.go
Normal file
337
core/providers/utils/modelparamscache_test.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func intPtr(v int) *int { return &v }
|
||||
|
||||
func TestModelParamsCacheGetSet(t *testing.T) {
|
||||
cache := newModelParamsCache(10)
|
||||
|
||||
cache.Set("claude-sonnet-4-20250514", ModelParams{MaxOutputTokens: intPtr(8192)})
|
||||
val, ok := cache.Get("claude-sonnet-4-20250514")
|
||||
if !ok || val.MaxOutputTokens == nil || *val.MaxOutputTokens != 8192 {
|
||||
t.Errorf("expected 8192, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheMiss(t *testing.T) {
|
||||
cache := newModelParamsCache(10)
|
||||
|
||||
val, ok := cache.Get("nonexistent-model")
|
||||
if ok || val.MaxOutputTokens != nil {
|
||||
t.Errorf("expected miss, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheUpdate(t *testing.T) {
|
||||
cache := newModelParamsCache(10)
|
||||
|
||||
cache.Set("claude-sonnet-4", ModelParams{MaxOutputTokens: intPtr(8192)})
|
||||
cache.Set("claude-sonnet-4", ModelParams{MaxOutputTokens: intPtr(16384)})
|
||||
|
||||
val, ok := cache.Get("claude-sonnet-4")
|
||||
if !ok || val.MaxOutputTokens == nil || *val.MaxOutputTokens != 16384 {
|
||||
t.Errorf("expected 16384 after update, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheEviction(t *testing.T) {
|
||||
cache := newModelParamsCache(3)
|
||||
|
||||
cache.Set("model-a", ModelParams{MaxOutputTokens: intPtr(1000)})
|
||||
cache.Set("model-b", ModelParams{MaxOutputTokens: intPtr(2000)})
|
||||
cache.Set("model-c", ModelParams{MaxOutputTokens: intPtr(3000)})
|
||||
// This should evict model-a (oldest insertion)
|
||||
cache.Set("model-d", ModelParams{MaxOutputTokens: intPtr(4000)})
|
||||
|
||||
if _, ok := cache.Get("model-a"); ok {
|
||||
t.Error("model-a should have been evicted")
|
||||
}
|
||||
if val, ok := cache.Get("model-b"); !ok || *val.MaxOutputTokens != 2000 {
|
||||
t.Errorf("model-b should still exist, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
if val, ok := cache.Get("model-d"); !ok || *val.MaxOutputTokens != 4000 {
|
||||
t.Errorf("model-d should exist, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheBulkSet(t *testing.T) {
|
||||
cache := newModelParamsCache(100)
|
||||
|
||||
entries := map[string]ModelParams{
|
||||
"claude-sonnet-4": {MaxOutputTokens: intPtr(8192)},
|
||||
"claude-opus-4": {MaxOutputTokens: intPtr(4096)},
|
||||
"gpt-4o": {MaxOutputTokens: intPtr(16384)},
|
||||
"gemini-2.0-flash": {MaxOutputTokens: intPtr(8192)},
|
||||
}
|
||||
cache.BulkSet(entries)
|
||||
|
||||
for model, expected := range entries {
|
||||
val, ok := cache.Get(model)
|
||||
if !ok || *val.MaxOutputTokens != *expected.MaxOutputTokens {
|
||||
t.Errorf("BulkSet: model %s expected %d, got %+v (ok=%v)", model, *expected.MaxOutputTokens, val, ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheBulkSetOverflow(t *testing.T) {
|
||||
cache := newModelParamsCache(3)
|
||||
|
||||
entries := map[string]ModelParams{
|
||||
"model-1": {MaxOutputTokens: intPtr(1000)},
|
||||
"model-2": {MaxOutputTokens: intPtr(2000)},
|
||||
"model-3": {MaxOutputTokens: intPtr(3000)},
|
||||
"model-4": {MaxOutputTokens: intPtr(4000)},
|
||||
"model-5": {MaxOutputTokens: intPtr(5000)},
|
||||
}
|
||||
cache.BulkSet(entries)
|
||||
|
||||
if cache.order.Len() != 3 {
|
||||
t.Errorf("expected 3 entries after overflow BulkSet, got %d", cache.order.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheBulkSetUpdate(t *testing.T) {
|
||||
cache := newModelParamsCache(10)
|
||||
|
||||
cache.Set("claude-sonnet-4", ModelParams{MaxOutputTokens: intPtr(4096)})
|
||||
cache.BulkSet(map[string]ModelParams{
|
||||
"claude-sonnet-4": {MaxOutputTokens: intPtr(8192)},
|
||||
})
|
||||
|
||||
val, ok := cache.Get("claude-sonnet-4")
|
||||
if !ok || *val.MaxOutputTokens != 8192 {
|
||||
t.Errorf("BulkSet should update existing entry, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelParamsCacheConcurrency(t *testing.T) {
|
||||
cache := newModelParamsCache(100)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
model := fmt.Sprintf("model-%d", i)
|
||||
cache.Set(model, ModelParams{MaxOutputTokens: intPtr(i * 1000)})
|
||||
cache.Get(model)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if cache.order.Len() > 100 {
|
||||
t.Errorf("cache exceeded capacity: %d", cache.order.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMaxOutputTokens(t *testing.T) {
|
||||
cache := getModelParamsCache()
|
||||
cache.Set("test-max-output", ModelParams{MaxOutputTokens: intPtr(16384)})
|
||||
|
||||
val, ok := GetMaxOutputTokens("test-max-output")
|
||||
if !ok || val != 16384 {
|
||||
t.Errorf("expected 16384, got %d (ok=%v)", val, ok)
|
||||
}
|
||||
|
||||
val, ok = GetMaxOutputTokens("missing-model-get")
|
||||
if ok || val != 0 {
|
||||
t.Errorf("expected miss, got %d (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMaxOutputTokensNilField(t *testing.T) {
|
||||
cache := getModelParamsCache()
|
||||
cache.Set("test-nil-field", ModelParams{})
|
||||
|
||||
val, ok := GetMaxOutputTokens("test-nil-field")
|
||||
if ok || val != 0 {
|
||||
t.Errorf("expected miss for nil MaxOutputTokens, got %d (ok=%v)", val, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMaxOutputTokensOrDefault(t *testing.T) {
|
||||
cache := getModelParamsCache()
|
||||
cache.Set("test-or-default", ModelParams{MaxOutputTokens: intPtr(16384)})
|
||||
|
||||
val := GetMaxOutputTokensOrDefault("test-or-default", 4096)
|
||||
if val != 16384 {
|
||||
t.Errorf("expected cached value 16384, got %d", val)
|
||||
}
|
||||
|
||||
val = GetMaxOutputTokensOrDefault("missing-model-default", 4096)
|
||||
if val != 4096 {
|
||||
t.Errorf("expected default 4096 for missing non-claude model, got %d", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMissHandler(t *testing.T) {
|
||||
cache := newModelParamsCache(10)
|
||||
called := false
|
||||
cache.cacheMissHandler = func(model string) *ModelParams {
|
||||
called = true
|
||||
if model == "db-model" {
|
||||
return &ModelParams{MaxOutputTokens: intPtr(32000)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Miss handler returns a value → should be cached
|
||||
val, ok := cache.Get("db-model")
|
||||
if !ok || val.MaxOutputTokens == nil || *val.MaxOutputTokens != 32000 {
|
||||
t.Errorf("expected 32000 from miss handler, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
if !called {
|
||||
t.Error("miss handler was not called")
|
||||
}
|
||||
|
||||
// Verify it was cached (handler should not be called again)
|
||||
called = false
|
||||
val, ok = cache.Get("db-model")
|
||||
if !ok || *val.MaxOutputTokens != 32000 {
|
||||
t.Errorf("expected cached 32000, got %+v (ok=%v)", val, ok)
|
||||
}
|
||||
if called {
|
||||
t.Error("miss handler should not be called for cached entry")
|
||||
}
|
||||
|
||||
// Miss handler returns nil → should return false
|
||||
val, ok = cache.Get("unknown-model")
|
||||
if ok {
|
||||
t.Errorf("expected miss for unknown model, got %+v", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMissHandlerNil(t *testing.T) {
|
||||
cache := newModelParamsCache(10)
|
||||
// No handler registered
|
||||
val, ok := cache.Get("any-model")
|
||||
if ok {
|
||||
t.Errorf("expected miss with nil handler, got %+v", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeModelName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
desc string
|
||||
}{
|
||||
// Anthropic direct (bare model names)
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5", "Anthropic: no version suffix"},
|
||||
{"claude-sonnet-4-20250514", "claude-sonnet-4", "Anthropic: date suffix"},
|
||||
{"claude-opus-4-5", "claude-opus-4-5", "Anthropic: no version suffix"},
|
||||
{"claude-opus-4-6-20250514", "claude-opus-4-6", "Anthropic: date suffix"},
|
||||
{"claude-sonnet-4-6", "claude-sonnet-4-6", "Anthropic: no version suffix"},
|
||||
{"claude-3-5-sonnet-20241022", "claude-3-5-sonnet", "Anthropic: legacy date suffix"},
|
||||
{"claude-3-7-sonnet-20250219", "claude-3-7-sonnet", "Anthropic: legacy date suffix"},
|
||||
|
||||
// Bedrock (anthropic. prefix + -v1:0 suffix)
|
||||
{"anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet", "Bedrock: prefix + v1:0"},
|
||||
{"anthropic.claude-opus-4-6-v1", "claude-opus-4-6", "Bedrock: prefix + v1 no colon"},
|
||||
{"anthropic.claude-3-7-sonnet-v1", "claude-3-7-sonnet", "Bedrock: prefix + v1 no colon"},
|
||||
{"anthropic.claude-sonnet-4-20250514-v1:0", "claude-sonnet-4", "Bedrock: prefix + date + v1:0"},
|
||||
{"anthropic.claude-3-5-sonnet-20241022-v1:0", "claude-3-5-sonnet", "Bedrock: prefix + legacy date + v1:0"},
|
||||
|
||||
// Bedrock with region prefix
|
||||
{"us.anthropic.claude-sonnet-4-6", "claude-sonnet-4-6", "Bedrock regional: us prefix"},
|
||||
{"us.anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet", "Bedrock regional: us + v1:0"},
|
||||
{"global.anthropic.claude-opus-4-6-20260301-v1:0", "claude-opus-4-6", "Bedrock regional: global + date + v1:0"},
|
||||
{"eu.anthropic.claude-sonnet-4-5-20250929-v1:0", "claude-sonnet-4-5", "Bedrock regional: eu + date + v1:0"},
|
||||
|
||||
// Vertex (same as Anthropic direct — deployment is bare model name)
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5", "Vertex: bare model"},
|
||||
{"claude-sonnet-4-20250514", "claude-sonnet-4", "Vertex: date suffix"},
|
||||
|
||||
// Azure (deployment names — typically bare model names)
|
||||
{"claude-opus-4-5", "claude-opus-4-5", "Azure: deployment name"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
got := normalizeClaudeModelName(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("normalizeClaudeModelName(%q) = %q, want %q", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMaxOutputTokensOrDefaultStaticFallback(t *testing.T) {
|
||||
// Use a fresh cache with no entries to test static fallback only
|
||||
// We test via the normalizeClaudeModelName + map lookup directly
|
||||
// since the global cache may have entries from other tests
|
||||
tests := []struct {
|
||||
model string
|
||||
expected int
|
||||
desc string
|
||||
}{
|
||||
// Anthropic direct
|
||||
{"claude-sonnet-4-20250514", 64000, "Anthropic: claude-sonnet-4"},
|
||||
{"claude-opus-4-6-20250514", 128000, "Anthropic: claude-opus-4-6"},
|
||||
{"claude-3-5-sonnet-20241022", 8192, "Anthropic: claude-3-5-sonnet"},
|
||||
|
||||
// Bedrock
|
||||
{"anthropic.claude-sonnet-4-20250514-v1:0", 64000, "Bedrock: claude-sonnet-4"},
|
||||
{"anthropic.claude-opus-4-6-v1", 128000, "Bedrock: claude-opus-4-6"},
|
||||
{"anthropic.claude-3-5-sonnet-20241022-v1:0", 8192, "Bedrock: claude-3-5-sonnet"},
|
||||
|
||||
// Bedrock with region prefix
|
||||
{"us.anthropic.claude-opus-4-6-v1:0", 128000, "Bedrock regional: claude-opus-4-6"},
|
||||
{"global.anthropic.claude-sonnet-4-5-20250929-v1:0", 64000, "Bedrock regional: claude-sonnet-4-5"},
|
||||
{"eu.anthropic.claude-3-haiku-20240307-v1:0", 4096, "Bedrock regional: claude-3-haiku"},
|
||||
|
||||
// Vertex
|
||||
{"claude-opus-4-5", 64000, "Vertex: claude-opus-4-5"},
|
||||
{"claude-haiku-4-5", 64000, "Vertex: claude-haiku-4-5"},
|
||||
|
||||
// Azure
|
||||
{"claude-3-5-sonnet-20241022", 8192, "Azure: claude-3-5-sonnet"},
|
||||
{"claude-sonnet-4-6", 64000, "Azure: claude-sonnet-4-6"},
|
||||
|
||||
// Non-Claude models should return the default
|
||||
{"gpt-4o", 4096, "Non-Claude: gpt-4o"},
|
||||
{"gemini-2.0-flash", 4096, "Non-Claude: gemini-2.0-flash"},
|
||||
{"command-r-plus", 4096, "Non-Claude: command-r-plus"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
// Test the static fallback logic directly
|
||||
got := staticAnthropicFallback(tt.model, 4096)
|
||||
if got != tt.expected {
|
||||
t.Errorf("staticAnthropicFallback(%q, 4096) = %d, want %d", tt.model, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// staticAnthropicFallback is a test helper that mimics the fallback logic
|
||||
// in GetMaxOutputTokensOrDefault without going through the global cache.
|
||||
func staticAnthropicFallback(model string, defaultValue int) int {
|
||||
if !contains(model, "claude") {
|
||||
return defaultValue
|
||||
}
|
||||
base := normalizeClaudeModelName(model)
|
||||
if m, ok := knownAnthropicMaxOutputTokens[base]; ok {
|
||||
return m
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(substr) == 0 || indexSubstring(s, substr) >= 0)
|
||||
}
|
||||
|
||||
func indexSubstring(s, substr string) int {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
356
core/providers/utils/models.go
Normal file
356
core/providers/utils/models.go
Normal file
@@ -0,0 +1,356 @@
|
||||
// Package utils — list_models.go
|
||||
// Centralised pipeline for filtering and backfilling models in ListModels responses.
|
||||
//
|
||||
// Every provider's ToBifrostListModelsResponse follows the same logical steps:
|
||||
// 1. Resolve each API model's name (alias lookup → alias key; else raw model ID)
|
||||
// 2. Filter (allowlist + blacklist check on the resolved name)
|
||||
// 3. Backfill entries that were not returned by the API but should appear in output
|
||||
//
|
||||
// Providers plug in custom MatchFns to extend the default matching behaviour.
|
||||
// Example: Bedrock adds region-prefix-aware matching on top of DefaultMatchFns.
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
// ToDisplayName converts a raw model ID or alias key into a human-readable display name.
|
||||
// Splits on "-" or "_", title-cases each word, and joins with spaces.
|
||||
//
|
||||
// "gemini-pro" → "Gemini Pro"
|
||||
// "claude_3_opus" → "Claude 3 Opus"
|
||||
// "gpt-4-turbo" → "Gpt 4 Turbo"
|
||||
func ToDisplayName(id string) string {
|
||||
caser := cases.Title(language.English)
|
||||
parts := strings.FieldsFunc(id, func(r rune) bool {
|
||||
return r == '-' || r == '_'
|
||||
})
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
for i, part := range parts {
|
||||
if part != "" {
|
||||
parts[i] = caser.String(strings.ToLower(part))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// MatchFn reports whether two model ID strings should be treated as equivalent.
|
||||
// Functions are applied in order during every comparison — the first one that
|
||||
// returns true short-circuits the rest.
|
||||
//
|
||||
// Example built-in fns (see DefaultMatchFns):
|
||||
//
|
||||
// exactMatch("gpt-4", "gpt-4") → true
|
||||
// sameBaseModel("claude-3-5-sonnet-20241022", "claude-3-5") → true
|
||||
type MatchFn func(a, b string) bool
|
||||
|
||||
// DefaultMatchFns returns the standard matching functions used by most providers.
|
||||
// Currently only performs case-insensitive exact matching.
|
||||
//
|
||||
// SameBaseModel (strips version suffixes, e.g. "claude-3-5-sonnet-20241022" ≈ "claude-3-5-sonnet")
|
||||
// is intentionally excluded — users should use aliases for explicit version-to-base-name mapping.
|
||||
// It can be appended here if fuzzy base-model matching is ever needed globally.
|
||||
func DefaultMatchFns() []MatchFn {
|
||||
return []MatchFn{
|
||||
func(a, b string) bool { return strings.EqualFold(a, b) },
|
||||
}
|
||||
}
|
||||
|
||||
// matches reports whether a and b are considered equal by any of the provided fns.
|
||||
// Returns true on the first fn that returns true.
|
||||
func matches(a, b string, fns []MatchFn) bool {
|
||||
for _, fn := range fns {
|
||||
if fn(a, b) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FilterResult is the outcome of running Pipeline.FilterModel for a single model
|
||||
// from the provider's API response. Each returned result represents one alias
|
||||
// entry (or the raw model ID when no alias matched) that passed all filters.
|
||||
type FilterResult struct {
|
||||
// ResolvedID is the user-facing model name to use as the ID suffix.
|
||||
// If the model matched an alias VALUE, this is the alias KEY.
|
||||
// Otherwise this is the original model ID from the API response.
|
||||
//
|
||||
// Example: API returns "gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo"}
|
||||
// → ResolvedID = "my-gpt4"
|
||||
// Example: API returns "gpt-3.5-turbo", no alias match
|
||||
// → ResolvedID = "gpt-3.5-turbo"
|
||||
ResolvedID string
|
||||
|
||||
// AliasValue is the provider-specific model ID when the model was matched
|
||||
// via an alias. Set as the model.Alias field so callers know the underlying ID.
|
||||
// Empty when the model was matched directly (no alias involved).
|
||||
//
|
||||
// Example: API returns "gpt-4-turbo", alias key "my-gpt4" matched
|
||||
// → AliasValue = "gpt-4-turbo"
|
||||
AliasValue string
|
||||
}
|
||||
|
||||
// Pipeline holds all the context needed to filter and backfill models in a
|
||||
// single ListModels response. Construct one per ToBifrostListModelsResponse call
|
||||
// and use its methods instead of passing params + matchFns to every function.
|
||||
//
|
||||
// pipeline := &providerUtils.ListModelsPipeline{
|
||||
// AllowedModels: key.Models,
|
||||
// BlacklistedModels: key.BlacklistedModels,
|
||||
// Aliases: key.Aliases,
|
||||
// Unfiltered: request.Unfiltered,
|
||||
// ProviderKey: schemas.OpenAI,
|
||||
// MatchFns: providerUtils.DefaultMatchFns(),
|
||||
// }
|
||||
// if pipeline.ShouldEarlyExit() { return empty }
|
||||
// result := pipeline.FilterModel(model.ID)
|
||||
// pipeline.BackfillModels(included)
|
||||
type ListModelsPipeline struct {
|
||||
AllowedModels schemas.WhiteList
|
||||
BlacklistedModels schemas.BlackList
|
||||
// Aliases maps user-facing alias keys to provider-specific model IDs.
|
||||
// e.g. {"my-gpt4": "gpt-4-turbo-2024-04-09"}
|
||||
Aliases map[string]string
|
||||
Unfiltered bool
|
||||
ProviderKey schemas.ModelProvider
|
||||
// MatchFns is the ordered list of equivalence functions used for every
|
||||
// model ID comparison. Use DefaultMatchFns() for standard behaviour;
|
||||
// providers may append additional fns (e.g. Bedrock's region-prefix remover).
|
||||
MatchFns []MatchFn
|
||||
}
|
||||
|
||||
// ShouldEarlyExit reports whether ToBifrostListModelsResponse should immediately
|
||||
// return an empty response without processing any models.
|
||||
//
|
||||
// Returns true when:
|
||||
// - not unfiltered AND allowlist is empty AND no aliases configured
|
||||
// (there is nothing to match against — all models would be filtered out anyway)
|
||||
// - not unfiltered AND blacklist blocks everything
|
||||
//
|
||||
// Note: allowlist empty + aliases present → do NOT early exit.
|
||||
// The aliases drive backfill in the wildcard-allowlist case (Case B of BackfillModels).
|
||||
func (p *ListModelsPipeline) ShouldEarlyExit() bool {
|
||||
if p.Unfiltered {
|
||||
return false
|
||||
}
|
||||
if p.BlacklistedModels.IsBlockAll() {
|
||||
return true
|
||||
}
|
||||
if p.AllowedModels.IsEmpty() && len(p.Aliases) == 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// aliasMatch holds a single alias key/value pair returned by resolveModelID.
|
||||
type aliasMatch struct {
|
||||
key string
|
||||
value string
|
||||
}
|
||||
|
||||
// resolveModelID returns all alias entries whose VALUE matches modelID using the pipeline's MatchFns,
|
||||
// plus the raw model ID itself as an additional entry so both the alias key and the original model
|
||||
// name appear in the list-models output.
|
||||
// Results are sorted by alias key (case-insensitive) for deterministic ordering.
|
||||
//
|
||||
// If one or more aliases match → returns one aliasMatch per matching alias key, plus the raw ID.
|
||||
//
|
||||
// Example: modelID="gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"}
|
||||
// → [{key:"gpt-4-turbo", value:""}, {key:"gpt4-alias", value:"gpt-4-turbo"}, {key:"my-gpt4", value:"gpt-4-turbo"}]
|
||||
//
|
||||
// If no alias matches → returns a single entry with the original model ID and no alias value.
|
||||
//
|
||||
// Example: modelID="gpt-3.5-turbo", no alias match
|
||||
// → [{key:"gpt-3.5-turbo", value:""}]
|
||||
func (p *ListModelsPipeline) resolveModelID(modelID string) []aliasMatch {
|
||||
var candidates []aliasMatch
|
||||
for aliasKey, providerID := range p.Aliases {
|
||||
if matches(modelID, providerID, p.MatchFns) {
|
||||
candidates = append(candidates, aliasMatch{key: aliasKey, value: providerID})
|
||||
}
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return []aliasMatch{{key: modelID, value: ""}}
|
||||
}
|
||||
// Also include the raw model ID so both the alias key and the original name appear in output.
|
||||
candidates = append(candidates, aliasMatch{key: modelID, value: ""})
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return strings.ToLower(candidates[i].key) < strings.ToLower(candidates[j].key)
|
||||
})
|
||||
return candidates
|
||||
}
|
||||
|
||||
// FilterModel applies the full filter pipeline for a single model from the API response.
|
||||
//
|
||||
// Steps:
|
||||
// 1. Resolve name — check alias VALUES for a match (uses MatchFns).
|
||||
// If matched: resolvedName = alias KEY, aliasValue = provider ID.
|
||||
// If not matched: resolvedName = original modelID, aliasValue = "".
|
||||
// 2. Allowlist check (only when allowlist is restricted, i.e. not wildcard):
|
||||
// Skip if resolvedName is not in AllowedModels.
|
||||
// 3. Blacklist check (always):
|
||||
// Skip if resolvedName is blacklisted. Blacklist takes precedence over everything.
|
||||
// 4. Return one FilterResult per passing candidate.
|
||||
//
|
||||
// An empty slice means the model should be skipped entirely.
|
||||
// When multiple aliases map to the same provider model ID, each alias that passes
|
||||
// the filters produces its own FilterResult entry.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// allowedModels=["my-gpt4"], aliases={"my-gpt4":"gpt-4-turbo"}, blacklist=[]
|
||||
// FilterModel("gpt-4-turbo") → [{ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}]
|
||||
// FilterModel("gpt-3.5") → [] (not in allowlist)
|
||||
//
|
||||
// allowedModels=*, aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"}, blacklist=[]
|
||||
// FilterModel("gpt-4-turbo") → [{ResolvedID:"gpt-4-turbo", AliasValue:""},
|
||||
// {ResolvedID:"gpt4-alias", AliasValue:"gpt-4-turbo"},
|
||||
// {ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}]
|
||||
//
|
||||
// allowedModels=["gpt-3.5"], aliases={}, blacklist=[]
|
||||
// FilterModel("gpt-3.5") → [{ResolvedID:"gpt-3.5", AliasValue:""}]
|
||||
// FilterModel("gpt-4") → []
|
||||
func (p *ListModelsPipeline) FilterModel(modelID string) []FilterResult {
|
||||
// Step 1: resolve name — collect all alias matches (or the raw ID if none match).
|
||||
candidates := p.resolveModelID(modelID)
|
||||
|
||||
var results []FilterResult
|
||||
for _, candidate := range candidates {
|
||||
resolvedName := candidate.key
|
||||
|
||||
// Step 2: allowlist check.
|
||||
// IsRestricted() is true for both an explicit list AND an empty list (deny-all).
|
||||
// Only a wildcard allowlist marker bypasses this check (pass-through).
|
||||
if !p.Unfiltered && p.AllowedModels.IsRestricted() {
|
||||
allowed := false
|
||||
for _, entry := range p.AllowedModels {
|
||||
if matches(resolvedName, entry, p.MatchFns) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: blacklist check — blacklist always wins regardless of allowlist or aliases.
|
||||
if !p.Unfiltered {
|
||||
blacklisted := false
|
||||
for _, entry := range p.BlacklistedModels {
|
||||
if matches(resolvedName, entry, p.MatchFns) {
|
||||
blacklisted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if blacklisted {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, FilterResult{
|
||||
ResolvedID: resolvedName,
|
||||
AliasValue: candidate.value,
|
||||
})
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// BackfillModels adds model entries that were configured by the caller but not
|
||||
// returned by the provider's API response (or not matched during filtering).
|
||||
//
|
||||
// The `included` map tracks model IDs (lowercased) already added during the
|
||||
// filter pass, used to avoid duplicates.
|
||||
//
|
||||
// Two cases depending on whether the allowlist is restricted:
|
||||
//
|
||||
// Case A — allowlist restricted (caller specified explicit model names):
|
||||
//
|
||||
// Add each allowlist entry that is not yet in `included`, skip if blacklisted.
|
||||
// If the entry has an alias mapping (aliases[entry] exists), set Alias to the
|
||||
// provider-specific ID so callers can route to the right model.
|
||||
//
|
||||
// Example: allowedModels=["my-gpt4","gpt-3.5"], aliases={"my-gpt4":"gpt-4-turbo"}
|
||||
// "my-gpt4" not in included → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"}
|
||||
// "gpt-3.5" not in included → add {ID:"openai/gpt-3.5"}
|
||||
//
|
||||
// Case B — allowlist wildcard (*) only:
|
||||
//
|
||||
// We don't know all model names (no explicit list), so we only backfill entries
|
||||
// that were explicitly configured via aliases and not yet matched from the API.
|
||||
// Note: an empty allowlist is deny-all (IsRestricted()==true), not wildcard.
|
||||
//
|
||||
// Example: aliases={"my-gpt4":"gpt-4-turbo"}, "my-gpt4" not in included
|
||||
// → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"}
|
||||
//
|
||||
// Blacklist always wins — nothing blacklisted is added in either case.
|
||||
func (p *ListModelsPipeline) BackfillModels(included map[string]bool) []schemas.Model {
|
||||
var result []schemas.Model
|
||||
|
||||
if !p.Unfiltered && p.AllowedModels.IsRestricted() {
|
||||
// Case A: backfill explicit allowlist entries not yet matched.
|
||||
for _, entry := range p.AllowedModels {
|
||||
if included[strings.ToLower(entry)] {
|
||||
continue
|
||||
}
|
||||
// Blacklist check.
|
||||
blacklisted := false
|
||||
for _, bl := range p.BlacklistedModels {
|
||||
if matches(entry, bl, p.MatchFns) {
|
||||
blacklisted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if blacklisted {
|
||||
continue
|
||||
}
|
||||
m := schemas.Model{
|
||||
ID: string(p.ProviderKey) + "/" + entry,
|
||||
Name: schemas.Ptr(ToDisplayName(entry)),
|
||||
}
|
||||
// If this allowlist entry has an alias, surface the provider-specific ID.
|
||||
for aliasKey, providerID := range p.Aliases {
|
||||
if matches(entry, aliasKey, p.MatchFns) {
|
||||
m.Alias = schemas.Ptr(providerID)
|
||||
break
|
||||
}
|
||||
}
|
||||
result = append(result, m)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Case B: wildcard allowlist — backfill only explicitly configured aliases.
|
||||
if !p.Unfiltered && len(p.Aliases) > 0 {
|
||||
for aliasKey, providerID := range p.Aliases {
|
||||
if included[strings.ToLower(aliasKey)] {
|
||||
continue
|
||||
}
|
||||
// Blacklist check.
|
||||
blacklisted := false
|
||||
for _, bl := range p.BlacklistedModels {
|
||||
if matches(aliasKey, bl, p.MatchFns) {
|
||||
blacklisted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if blacklisted {
|
||||
continue
|
||||
}
|
||||
result = append(result, schemas.Model{
|
||||
ID: string(p.ProviderKey) + "/" + aliasKey,
|
||||
Name: schemas.Ptr(ToDisplayName(aliasKey)),
|
||||
Alias: schemas.Ptr(providerID),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
112
core/providers/utils/pagination.go
Normal file
112
core/providers/utils/pagination.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// SerialListHelper manages serial key pagination for list operations.
|
||||
// It ensures that all pages from one key are exhausted before moving to the next,
|
||||
// guaranteeing only one API call per pagination request regardless of key count.
|
||||
type SerialListHelper struct {
|
||||
Keys []schemas.Key
|
||||
Cursor *schemas.SerialCursor
|
||||
Logger schemas.Logger
|
||||
}
|
||||
|
||||
// NewSerialListHelper creates a new SerialListHelper from the provided keys and encoded cursor.
|
||||
// If the cursor is empty or nil, pagination starts from the first key.
|
||||
// If the cursor is invalid, an error is returned.
|
||||
func NewSerialListHelper(keys []schemas.Key, encodedCursor *string, logger schemas.Logger) (*SerialListHelper, error) {
|
||||
helper := &SerialListHelper{
|
||||
Keys: keys,
|
||||
Logger: logger,
|
||||
}
|
||||
|
||||
if encodedCursor != nil && *encodedCursor != "" {
|
||||
cursor, err := schemas.DecodeSerialCursor(*encodedCursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
helper.Cursor = cursor
|
||||
}
|
||||
|
||||
return helper, nil
|
||||
}
|
||||
|
||||
// GetCurrentKey returns the key to query and its native cursor.
|
||||
// Returns (key, nativeCursor, true) if there's a key to query.
|
||||
// Returns (Key{}, "", false) if all keys are exhausted.
|
||||
func (h *SerialListHelper) GetCurrentKey() (schemas.Key, string, bool) {
|
||||
if len(h.Keys) == 0 {
|
||||
return schemas.Key{}, "", false
|
||||
}
|
||||
|
||||
keyIndex := 0
|
||||
nativeCursor := ""
|
||||
|
||||
if h.Cursor != nil {
|
||||
keyIndex = h.Cursor.KeyIndex
|
||||
nativeCursor = h.Cursor.Cursor
|
||||
}
|
||||
|
||||
// Check if key index is within bounds
|
||||
if keyIndex >= len(h.Keys) {
|
||||
return schemas.Key{}, "", false
|
||||
}
|
||||
|
||||
return h.Keys[keyIndex], nativeCursor, true
|
||||
}
|
||||
|
||||
// BuildNextCursor creates the cursor for the next pagination request.
|
||||
// Parameters:
|
||||
// - hasMore: whether the current key has more pages
|
||||
// - nativeCursor: the native cursor returned by the current key's API
|
||||
//
|
||||
// Returns:
|
||||
// - encodedCursor: the encoded cursor for the next request (empty if all keys exhausted)
|
||||
// - moreAvailable: true if there are more results available (either from current key or remaining keys)
|
||||
func (h *SerialListHelper) BuildNextCursor(hasMore bool, nativeCursor string) (string, bool) {
|
||||
if len(h.Keys) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
currentKeyIndex := 0
|
||||
if h.Cursor != nil {
|
||||
currentKeyIndex = h.Cursor.KeyIndex
|
||||
}
|
||||
|
||||
if hasMore {
|
||||
// Current key has more pages - return cursor for same key
|
||||
nextCursor := schemas.NewSerialCursor(currentKeyIndex, nativeCursor)
|
||||
return schemas.EncodeSerialCursor(nextCursor), true
|
||||
}
|
||||
|
||||
// Current key exhausted - check if there are more keys
|
||||
nextKeyIndex := currentKeyIndex + 1
|
||||
if nextKeyIndex >= len(h.Keys) {
|
||||
// All keys exhausted
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Move to next key with empty cursor (start fresh)
|
||||
nextCursor := schemas.NewSerialCursor(nextKeyIndex, "")
|
||||
return schemas.EncodeSerialCursor(nextCursor), true
|
||||
}
|
||||
|
||||
// GetCurrentKeyIndex returns the current key index being processed.
|
||||
func (h *SerialListHelper) GetCurrentKeyIndex() int {
|
||||
if h.Cursor != nil {
|
||||
return h.Cursor.KeyIndex
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// HasMoreKeys returns true if there are more keys after the current one.
|
||||
func (h *SerialListHelper) HasMoreKeys() bool {
|
||||
currentKeyIndex := 0
|
||||
if h.Cursor != nil {
|
||||
currentKeyIndex = h.Cursor.KeyIndex
|
||||
}
|
||||
return currentKeyIndex < len(h.Keys)-1
|
||||
}
|
||||
|
||||
195
core/providers/utils/sse.go
Normal file
195
core/providers/utils/sse.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const (
|
||||
sseInitialBufSize = 8 * 1024 // 8KB — sufficient for >99.9% of SSE lines
|
||||
sseMaxBufSize = 10 * 1024 * 1024 // 10MB — allow large tokens (tool calls, audio)
|
||||
)
|
||||
|
||||
// SSEDataReader reads SSE data-only events (Format A: OpenAI, Gemini, Cohere, etc.).
|
||||
// ReadDataLine returns the next SSE data payload, stripping the "data:" prefix.
|
||||
// Returns (nil, io.EOF) at end of stream or on "data: [DONE]".
|
||||
type SSEDataReader interface {
|
||||
ReadDataLine() ([]byte, error)
|
||||
}
|
||||
|
||||
// SSEEventReader reads SSE events with type and data (Format B: Anthropic, Replicate, etc.).
|
||||
// ReadEvent returns the complete event once an empty-line delimiter is encountered.
|
||||
// Multiple "data:" lines within one event are concatenated with newlines.
|
||||
// Returns ("", nil, io.EOF) at end of stream.
|
||||
type SSEEventReader interface {
|
||||
ReadEvent() (eventType string, data []byte, err error)
|
||||
}
|
||||
|
||||
// SSEReaderFactory creates SSE readers for streaming response processing.
|
||||
// Enterprise injects this via BifrostContextKeySSEReaderFactory to replace
|
||||
// the default bufio.Scanner-based implementations with streaming readers.
|
||||
type SSEReaderFactory struct {
|
||||
NewDataReader func(reader io.Reader) SSEDataReader
|
||||
NewEventReader func(reader io.Reader) SSEEventReader
|
||||
}
|
||||
|
||||
// GetSSEDataReader returns an SSEDataReader for the given reader.
|
||||
// If enterprise has injected an SSEReaderFactory via context, uses that.
|
||||
// Otherwise returns a default implementation wrapping bufio.NewScanner.
|
||||
func GetSSEDataReader(ctx *schemas.BifrostContext, reader io.Reader) SSEDataReader {
|
||||
if ctx != nil {
|
||||
if factory, ok := ctx.Value(schemas.BifrostContextKeySSEReaderFactory).(*SSEReaderFactory); ok && factory != nil && factory.NewDataReader != nil {
|
||||
return factory.NewDataReader(reader)
|
||||
}
|
||||
}
|
||||
return newDefaultSSEDataReader(reader)
|
||||
}
|
||||
|
||||
// GetSSEEventReader returns an SSEEventReader for the given reader.
|
||||
// If enterprise has injected an SSEReaderFactory via context, uses that.
|
||||
// Otherwise returns a default implementation wrapping bufio.NewScanner.
|
||||
func GetSSEEventReader(ctx *schemas.BifrostContext, reader io.Reader) SSEEventReader {
|
||||
if ctx != nil {
|
||||
if factory, ok := ctx.Value(schemas.BifrostContextKeySSEReaderFactory).(*SSEReaderFactory); ok && factory != nil && factory.NewEventReader != nil {
|
||||
return factory.NewEventReader(reader)
|
||||
}
|
||||
}
|
||||
return newDefaultSSEEventReader(reader)
|
||||
}
|
||||
|
||||
// Reusable byte prefixes for SSE field parsing.
|
||||
var (
|
||||
sseDataPrefix = []byte("data:")
|
||||
sseDoneMarker = []byte("[DONE]")
|
||||
sseEventPrefix = []byte("event:")
|
||||
sseIDPrefix = []byte("id:")
|
||||
sseRetryPrefix = []byte("retry:")
|
||||
)
|
||||
|
||||
// defaultSSEDataReader implements SSEDataReader using bufio.NewScanner.
|
||||
// Handles Format A SSE streams (data-only: OpenAI, Gemini, Cohere, etc.).
|
||||
type defaultSSEDataReader struct {
|
||||
scanner *bufio.Scanner
|
||||
}
|
||||
|
||||
func newDefaultSSEDataReader(reader io.Reader) *defaultSSEDataReader {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 0, sseInitialBufSize), sseMaxBufSize)
|
||||
return &defaultSSEDataReader{scanner: scanner}
|
||||
}
|
||||
|
||||
func (r *defaultSSEDataReader) ReadDataLine() ([]byte, error) {
|
||||
for r.scanner.Scan() {
|
||||
line := r.scanner.Bytes()
|
||||
// Skip empty lines and comments
|
||||
if len(line) == 0 || line[0] == ':' {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse "data:" lines
|
||||
if bytes.HasPrefix(line, sseDataPrefix) {
|
||||
data := line[5:] // len("data:") == 5
|
||||
if len(data) > 0 && data[0] == ' ' {
|
||||
data = data[1:]
|
||||
}
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(data, sseDoneMarker) {
|
||||
return nil, io.EOF
|
||||
}
|
||||
// Copy to decouple from scanner's internal buffer
|
||||
return append([]byte(nil), data...), nil
|
||||
}
|
||||
|
||||
// Skip known SSE fields (event, id, retry)
|
||||
if bytes.HasPrefix(line, sseEventPrefix) ||
|
||||
bytes.HasPrefix(line, sseIDPrefix) ||
|
||||
bytes.HasPrefix(line, sseRetryPrefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Non-SSE line: return as-is (raw JSON error fallback, e.g. OpenAI)
|
||||
return append([]byte(nil), line...), nil
|
||||
}
|
||||
if err := r.scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
// defaultSSEEventReader implements SSEEventReader using bufio.NewScanner.
|
||||
// Handles Format B SSE streams (event+data: Anthropic, Replicate, Mistral, etc.).
|
||||
// Events are delimited by empty lines; multiple "data:" lines are concatenated.
|
||||
type defaultSSEEventReader struct {
|
||||
scanner *bufio.Scanner
|
||||
eventType string
|
||||
eventData []byte
|
||||
}
|
||||
|
||||
func newDefaultSSEEventReader(reader io.Reader) *defaultSSEEventReader {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 0, sseInitialBufSize), sseMaxBufSize)
|
||||
return &defaultSSEEventReader{scanner: scanner}
|
||||
}
|
||||
|
||||
func (r *defaultSSEEventReader) ReadEvent() (string, []byte, error) {
|
||||
for r.scanner.Scan() {
|
||||
line := r.scanner.Bytes()
|
||||
|
||||
// Skip comments
|
||||
if len(line) > 0 && line[0] == ':' {
|
||||
continue
|
||||
}
|
||||
|
||||
// Empty line = event boundary
|
||||
if len(line) == 0 {
|
||||
if r.eventType == "" && len(r.eventData) == 0 {
|
||||
continue
|
||||
}
|
||||
eventType := r.eventType
|
||||
eventData := make([]byte, len(r.eventData))
|
||||
copy(eventData, r.eventData)
|
||||
r.eventType = ""
|
||||
r.eventData = r.eventData[:0]
|
||||
return eventType, eventData, nil
|
||||
}
|
||||
|
||||
// Parse SSE fields
|
||||
if bytes.HasPrefix(line, sseEventPrefix) {
|
||||
field := line[6:] // len("event:") == 6
|
||||
if len(field) > 0 && field[0] == ' ' {
|
||||
field = field[1:]
|
||||
}
|
||||
r.eventType = string(field)
|
||||
} else if bytes.HasPrefix(line, sseDataPrefix) {
|
||||
data := line[5:] // len("data:") == 5
|
||||
if len(data) > 0 && data[0] == ' ' {
|
||||
data = data[1:]
|
||||
}
|
||||
if len(r.eventData) > 0 {
|
||||
r.eventData = append(r.eventData, '\n')
|
||||
}
|
||||
r.eventData = append(r.eventData, data...)
|
||||
}
|
||||
// id:, retry:, and other fields are silently skipped
|
||||
}
|
||||
|
||||
// Scanner done — return any accumulated event before EOF
|
||||
if r.eventType != "" || len(r.eventData) > 0 {
|
||||
eventType := r.eventType
|
||||
eventData := make([]byte, len(r.eventData))
|
||||
copy(eventData, r.eventData)
|
||||
r.eventType = ""
|
||||
r.eventData = r.eventData[:0]
|
||||
return eventType, eventData, nil
|
||||
}
|
||||
|
||||
if err := r.scanner.Err(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return "", nil, io.EOF
|
||||
}
|
||||
75
core/providers/utils/stream.go
Normal file
75
core/providers/utils/stream.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// CheckFirstStreamChunkForError reads the first chunk from a streaming channel to detect
|
||||
// errors returned inside HTTP 200 SSE streams (e.g., providers that send rate limit
|
||||
// errors as SSE events instead of HTTP 429).
|
||||
//
|
||||
// If the first chunk is an error, it drains the source channel in the background
|
||||
// (so the provider goroutine can exit cleanly) and returns the error for synchronous
|
||||
// handling, enabling retries and fallbacks. The returned drainDone channel is closed
|
||||
// once the drain completes — callers must wait on it before releasing any resources
|
||||
// (e.g., plugin pipelines) that the provider goroutine's postHookRunner may still reference.
|
||||
//
|
||||
// If the first chunk is valid data, it returns a wrapped channel that re-emits
|
||||
// the first chunk followed by all remaining chunks from the source. drainDone is
|
||||
// closed when the wrapper goroutine finishes forwarding the source stream.
|
||||
//
|
||||
// If the source channel is closed immediately (empty stream), it returns a
|
||||
// nil channel with nil error. drainDone is already closed.
|
||||
//
|
||||
// The ctx argument cancels the background forwarding goroutine if the consumer
|
||||
// abandons the returned wrapped channel. On ctx.Done the goroutine drains the
|
||||
// source stream so the upstream provider's blocked send can exit cleanly.
|
||||
func CheckFirstStreamChunkForError(
|
||||
ctx context.Context,
|
||||
stream chan *schemas.BifrostStreamChunk,
|
||||
) (chan *schemas.BifrostStreamChunk, <-chan struct{}, *schemas.BifrostError) {
|
||||
firstChunk, ok := <-stream
|
||||
if !ok {
|
||||
// Channel closed immediately (empty stream) — return nil so callers
|
||||
// can distinguish this from a live stream channel.
|
||||
done := make(chan struct{})
|
||||
close(done)
|
||||
return nil, done, nil
|
||||
}
|
||||
|
||||
// Check if first chunk is an error
|
||||
if firstChunk.BifrostError != nil && firstChunk.BifrostError.Error != nil &&
|
||||
(firstChunk.BifrostError.Error.Message != "" || firstChunk.BifrostError.Error.Code != nil || firstChunk.BifrostError.Error.Type != nil) {
|
||||
// Drain source channel to let the provider goroutine exit cleanly
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for range stream {
|
||||
}
|
||||
}()
|
||||
return nil, done, firstChunk.BifrostError
|
||||
}
|
||||
|
||||
// First chunk is valid data — wrap channel to re-inject it
|
||||
done := make(chan struct{})
|
||||
wrapped := make(chan *schemas.BifrostStreamChunk, max(cap(stream), 1))
|
||||
wrapped <- firstChunk
|
||||
go func() {
|
||||
defer close(done)
|
||||
defer close(wrapped)
|
||||
for chunk := range stream {
|
||||
select {
|
||||
case wrapped <- chunk:
|
||||
case <-ctx.Done():
|
||||
// Consumer abandoned the wrapped channel. Drain the source so the
|
||||
// provider's blocked send unblocks and its goroutine can exit.
|
||||
for range stream {
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return wrapped, done, nil
|
||||
}
|
||||
252
core/providers/utils/stream_test.go
Normal file
252
core/providers/utils/stream_test.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestCheckFirstStreamChunk_ErrorInFirstChunk(t *testing.T) {
|
||||
stream := make(chan *schemas.BifrostStreamChunk, 2)
|
||||
stream <- &schemas.BifrostStreamChunk{
|
||||
BifrostError: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Code: schemas.Ptr("limit_burst_rate"),
|
||||
Message: "Request rate increased too quickly",
|
||||
},
|
||||
},
|
||||
}
|
||||
close(stream)
|
||||
|
||||
_, drainDone, err := CheckFirstStreamChunkForError(context.Background(), stream)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
<-drainDone
|
||||
if err.Error.Message != "Request rate increased too quickly" {
|
||||
t.Errorf("unexpected error message: %s", err.Error.Message)
|
||||
}
|
||||
if err.Error.Code == nil || *err.Error.Code != "limit_burst_rate" {
|
||||
t.Errorf("unexpected error code: %v", err.Error.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckFirstStreamChunk_ValidFirstChunk(t *testing.T) {
|
||||
stream := make(chan *schemas.BifrostStreamChunk, 3)
|
||||
chunk1 := &schemas.BifrostStreamChunk{
|
||||
BifrostChatResponse: &schemas.BifrostChatResponse{
|
||||
ID: "chatcmpl-123",
|
||||
},
|
||||
}
|
||||
chunk2 := &schemas.BifrostStreamChunk{
|
||||
BifrostChatResponse: &schemas.BifrostChatResponse{
|
||||
ID: "chatcmpl-123",
|
||||
},
|
||||
}
|
||||
stream <- chunk1
|
||||
stream <- chunk2
|
||||
close(stream)
|
||||
|
||||
wrapped, _, err := CheckFirstStreamChunkForError(context.Background(), stream)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// First chunk should be re-injected
|
||||
got1 := <-wrapped
|
||||
if got1.BifrostChatResponse == nil || got1.BifrostChatResponse.ID != "chatcmpl-123" {
|
||||
t.Error("first chunk not re-injected correctly")
|
||||
}
|
||||
|
||||
// Second chunk should follow
|
||||
got2 := <-wrapped
|
||||
if got2.BifrostChatResponse == nil || got2.BifrostChatResponse.ID != "chatcmpl-123" {
|
||||
t.Error("second chunk not forwarded correctly")
|
||||
}
|
||||
|
||||
// Channel should be closed
|
||||
_, ok := <-wrapped
|
||||
if ok {
|
||||
t.Error("expected wrapped channel to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckFirstStreamChunk_EmptyStream(t *testing.T) {
|
||||
stream := make(chan *schemas.BifrostStreamChunk)
|
||||
close(stream)
|
||||
|
||||
wrapped, drainDone, err := CheckFirstStreamChunkForError(context.Background(), stream)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Empty stream should return nil channel
|
||||
if wrapped != nil {
|
||||
t.Error("expected nil channel for empty stream")
|
||||
}
|
||||
|
||||
// drainDone should be already closed
|
||||
select {
|
||||
case <-drainDone:
|
||||
default:
|
||||
t.Error("expected drainDone to be closed for empty stream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckFirstStreamChunk_ErrorInSecondChunk(t *testing.T) {
|
||||
stream := make(chan *schemas.BifrostStreamChunk, 3)
|
||||
stream <- &schemas.BifrostStreamChunk{
|
||||
BifrostChatResponse: &schemas.BifrostChatResponse{
|
||||
ID: "chatcmpl-123",
|
||||
},
|
||||
}
|
||||
stream <- &schemas.BifrostStreamChunk{
|
||||
BifrostError: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "some error in second chunk",
|
||||
},
|
||||
},
|
||||
}
|
||||
close(stream)
|
||||
|
||||
// Should NOT return error — only first chunk matters for retry
|
||||
wrapped, _, err := CheckFirstStreamChunkForError(context.Background(), stream)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Read all chunks
|
||||
got1 := <-wrapped
|
||||
if got1.BifrostChatResponse == nil {
|
||||
t.Error("first chunk should be valid data")
|
||||
}
|
||||
got2 := <-wrapped
|
||||
if got2.BifrostError == nil {
|
||||
t.Error("second chunk should be the error")
|
||||
}
|
||||
|
||||
_, ok := <-wrapped
|
||||
if ok {
|
||||
t.Error("expected wrapped channel to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckFirstStreamChunk_ErrorDrainsSource(t *testing.T) {
|
||||
stream := make(chan *schemas.BifrostStreamChunk, 5)
|
||||
stream <- &schemas.BifrostStreamChunk{
|
||||
BifrostError: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "rate limit error",
|
||||
},
|
||||
},
|
||||
}
|
||||
// Add more chunks that should be drained
|
||||
stream <- &schemas.BifrostStreamChunk{
|
||||
BifrostChatResponse: &schemas.BifrostChatResponse{ID: "1"},
|
||||
}
|
||||
stream <- &schemas.BifrostStreamChunk{
|
||||
BifrostChatResponse: &schemas.BifrostChatResponse{ID: "2"},
|
||||
}
|
||||
close(stream)
|
||||
|
||||
_, drainDone, err := CheckFirstStreamChunkForError(context.Background(), stream)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
<-drainDone
|
||||
if err.Error.Message != "rate limit error" {
|
||||
t.Errorf("unexpected error message: %s", err.Error.Message)
|
||||
}
|
||||
if drainDone == nil {
|
||||
t.Fatal("expected drainDone channel, got nil")
|
||||
}
|
||||
// Wait for drain to complete — verifies the channel signals properly
|
||||
<-drainDone
|
||||
}
|
||||
|
||||
func TestCheckFirstStreamChunk_ErrorWithEmptyMessage(t *testing.T) {
|
||||
// Error with empty message and no code/type should NOT be treated as an error
|
||||
stream := make(chan *schemas.BifrostStreamChunk, 2)
|
||||
stream <- &schemas.BifrostStreamChunk{
|
||||
BifrostError: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
close(stream)
|
||||
|
||||
wrapped, _, err := CheckFirstStreamChunkForError(context.Background(), stream)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for empty message: %v", err)
|
||||
}
|
||||
// Should be treated as valid chunk
|
||||
<-wrapped
|
||||
}
|
||||
|
||||
func TestCheckFirstStreamChunk_CtxCancelUnblocksWrapper(t *testing.T) {
|
||||
// Source with cap=1 so wrapped also has cap=1. wrapped is left full by
|
||||
// the re-injected first chunk, which makes the forwarder goroutine block
|
||||
// on its next send — the exact leak condition this test guards against.
|
||||
src := make(chan *schemas.BifrostStreamChunk, 1)
|
||||
src <- &schemas.BifrostStreamChunk{
|
||||
BifrostChatResponse: &schemas.BifrostChatResponse{ID: "1"},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
wrapped, drainDone, err := CheckFirstStreamChunkForError(ctx, src)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if wrapped == nil {
|
||||
t.Fatal("expected wrapped channel, got nil")
|
||||
}
|
||||
|
||||
// Push a second chunk; forwarder will read it from src and then block
|
||||
// trying to send into the full wrapped channel (we intentionally never
|
||||
// read from wrapped).
|
||||
src <- &schemas.BifrostStreamChunk{
|
||||
BifrostChatResponse: &schemas.BifrostChatResponse{ID: "2"},
|
||||
}
|
||||
|
||||
// Cancel — forwarder must stop trying to send to wrapped and drain src.
|
||||
cancel()
|
||||
|
||||
// Simulate the upstream producer still emitting, then closing. The
|
||||
// drain loop should consume these and terminate.
|
||||
src <- &schemas.BifrostStreamChunk{
|
||||
BifrostChatResponse: &schemas.BifrostChatResponse{ID: "3"},
|
||||
}
|
||||
close(src)
|
||||
|
||||
select {
|
||||
case <-drainDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("drainDone did not close after ctx cancel; forwarder goroutine leaked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckFirstStreamChunk_CodeOnlyError(t *testing.T) {
|
||||
// Error with code but no message should be treated as an error
|
||||
stream := make(chan *schemas.BifrostStreamChunk, 2)
|
||||
stream <- &schemas.BifrostStreamChunk{
|
||||
BifrostError: &schemas.BifrostError{
|
||||
Error: &schemas.ErrorField{
|
||||
Code: schemas.Ptr("limit_burst_rate"),
|
||||
},
|
||||
},
|
||||
}
|
||||
close(stream)
|
||||
|
||||
_, drainDone, err := CheckFirstStreamChunkForError(context.Background(), stream)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for code-only error, got nil")
|
||||
}
|
||||
<-drainDone
|
||||
if err.Error.Code == nil || *err.Error.Code != "limit_burst_rate" {
|
||||
t.Errorf("unexpected error code: %v", err.Error.Code)
|
||||
}
|
||||
}
|
||||
218
core/providers/utils/streaming_client_test.go
Normal file
218
core/providers/utils/streaming_client_test.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// TestBuildStreamingClient_ZerosReadWriteTimeout verifies the streaming client
|
||||
// has ReadTimeout=0 / WriteTimeout=0 / MaxConnDuration=0 while preserving other
|
||||
// config from the base.
|
||||
func TestBuildStreamingClient_ZerosReadWriteTimeout(t *testing.T) {
|
||||
base := &fasthttp.Client{
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
MaxConnDuration: 5 * time.Minute,
|
||||
MaxConnWaitTimeout: 15 * time.Second,
|
||||
MaxConnsPerHost: 123,
|
||||
}
|
||||
ConfigureDialer(base)
|
||||
|
||||
stream := BuildStreamingClient(base)
|
||||
|
||||
if stream.ReadTimeout != 0 {
|
||||
t.Errorf("ReadTimeout: got %v, want 0", stream.ReadTimeout)
|
||||
}
|
||||
if stream.WriteTimeout != 0 {
|
||||
t.Errorf("WriteTimeout: got %v, want 0", stream.WriteTimeout)
|
||||
}
|
||||
if stream.MaxConnDuration != 0 {
|
||||
t.Errorf("MaxConnDuration: got %v, want 0", stream.MaxConnDuration)
|
||||
}
|
||||
if !stream.StreamResponseBody {
|
||||
t.Error("StreamResponseBody: got false, want true")
|
||||
}
|
||||
if stream.MaxConnWaitTimeout != base.MaxConnWaitTimeout {
|
||||
t.Errorf("MaxConnWaitTimeout should be preserved: got %v, want %v",
|
||||
stream.MaxConnWaitTimeout, base.MaxConnWaitTimeout)
|
||||
}
|
||||
if stream.MaxConnsPerHost != base.MaxConnsPerHost {
|
||||
t.Errorf("MaxConnsPerHost should be preserved: got %v, want %v",
|
||||
stream.MaxConnsPerHost, base.MaxConnsPerHost)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildStreamingClient_BaseUnchanged verifies BuildStreamingClient does not
|
||||
// mutate the base client (since unary callers still need the 30s timeout).
|
||||
func TestBuildStreamingClient_BaseUnchanged(t *testing.T) {
|
||||
base := &fasthttp.Client{
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
MaxConnDuration: 5 * time.Minute,
|
||||
}
|
||||
_ = BuildStreamingClient(base)
|
||||
|
||||
if base.ReadTimeout != 30*time.Second {
|
||||
t.Errorf("base ReadTimeout mutated: got %v, want 30s", base.ReadTimeout)
|
||||
}
|
||||
if base.MaxConnDuration != 5*time.Minute {
|
||||
t.Errorf("base MaxConnDuration mutated: got %v, want 5m", base.MaxConnDuration)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildStreamingClient_LongStreamSurvives verifies that a stream sending
|
||||
// chunks every 500ms for 2.5s (total) is not killed by the base client's 1s
|
||||
// ReadTimeout. Before the fix, fasthttp would abort at ~1s.
|
||||
func TestBuildStreamingClient_LongStreamSurvives(t *testing.T) {
|
||||
const chunkInterval = 500 * time.Millisecond
|
||||
const totalChunks = 5 // 2.5s total, well past base ReadTimeout=1s
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
for i := 0; i < totalChunks; i++ {
|
||||
fmt.Fprintf(w, "data: chunk-%d\n\n", i)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
time.Sleep(chunkInterval)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
base := &fasthttp.Client{
|
||||
ReadTimeout: 1 * time.Second, // would abort the stream without the fix
|
||||
WriteTimeout: 1 * time.Second,
|
||||
}
|
||||
ConfigureDialer(base)
|
||||
stream := BuildStreamingClient(base)
|
||||
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
req.SetRequestURI(srv.URL)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
resp.StreamBody = true
|
||||
|
||||
if err := stream.Do(req, resp); err != nil {
|
||||
t.Fatalf("Do: %v", err)
|
||||
}
|
||||
if resp.StatusCode() != http.StatusOK {
|
||||
t.Fatalf("status: %d", resp.StatusCode())
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.BodyStream())
|
||||
got := 0
|
||||
for scanner.Scan() {
|
||||
if line := scanner.Text(); len(line) >= 5 && line[:5] == "data:" {
|
||||
got++
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
t.Fatalf("scanner: %v", err)
|
||||
}
|
||||
if got != totalChunks {
|
||||
t.Errorf("chunks received: got %d, want %d (stream was likely killed early)", got, totalChunks)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildStreamingHTTPClient_ZerosTimeout verifies the net/http streaming
|
||||
// client has Timeout=0 and shares the base's Transport.
|
||||
func TestBuildStreamingHTTPClient_ZerosTimeout(t *testing.T) {
|
||||
transport := &http.Transport{ResponseHeaderTimeout: 10 * time.Second}
|
||||
base := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
stream := BuildStreamingHTTPClient(base)
|
||||
|
||||
if stream.Timeout != 0 {
|
||||
t.Errorf("Timeout: got %v, want 0", stream.Timeout)
|
||||
}
|
||||
if stream.Transport != base.Transport {
|
||||
t.Error("Transport: streaming client should share base's Transport")
|
||||
}
|
||||
if base.Timeout != 30*time.Second {
|
||||
t.Errorf("base Timeout mutated: got %v, want 30s", base.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildStreamingHTTPClient_Nil verifies nil base returns empty client
|
||||
// (not a panic).
|
||||
func TestBuildStreamingHTTPClient_Nil(t *testing.T) {
|
||||
stream := BuildStreamingHTTPClient(nil)
|
||||
if stream == nil {
|
||||
t.Fatal("BuildStreamingHTTPClient(nil) returned nil")
|
||||
}
|
||||
if stream.Timeout != 0 {
|
||||
t.Errorf("Timeout: got %v, want 0", stream.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildStreamingHTTPClient_LongStreamSurvives verifies that the streaming
|
||||
// client can read a response body that takes longer than the base client's
|
||||
// Timeout — proving Timeout=0 actually lifts the whole-request deadline.
|
||||
func TestBuildStreamingHTTPClient_LongStreamSurvives(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
flusher, _ := w.(http.Flusher)
|
||||
for i := 0; i < 4; i++ {
|
||||
fmt.Fprintf(w, "data: chunk-%d\n\n", i)
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
time.Sleep(400 * time.Millisecond)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
base := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{Timeout: 5 * time.Second}).DialContext,
|
||||
ResponseHeaderTimeout: 5 * time.Second,
|
||||
},
|
||||
Timeout: 500 * time.Millisecond, // would abort the stream without the fix
|
||||
}
|
||||
stream := BuildStreamingHTTPClient(base)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequestWithContext: %v", err)
|
||||
}
|
||||
resp, err := stream.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Do: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
got := 0
|
||||
for scanner.Scan() {
|
||||
if line := scanner.Text(); len(line) >= 5 && line[:5] == "data:" {
|
||||
got++
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
t.Fatalf("scanner: %v", err)
|
||||
}
|
||||
if got != 4 {
|
||||
t.Errorf("chunks received: got %d, want 4 (stream was likely killed by Timeout)", got)
|
||||
}
|
||||
}
|
||||
274
core/providers/utils/streamterminaldetector.go
Normal file
274
core/providers/utils/streamterminaldetector.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
)
|
||||
|
||||
const maxTerminalDetectorBufferBytes = 256 * 1024
|
||||
|
||||
var (
|
||||
sseFrameDelimiterLF = []byte("\n\n")
|
||||
sseFrameDelimiterCRLF = []byte("\r\n\r\n")
|
||||
)
|
||||
|
||||
// StreamTerminalDetector incrementally parses stream frames and detects
|
||||
// semantic completion markers such as finishReason or [DONE].
|
||||
type StreamTerminalDetector struct {
|
||||
pending bytes.Buffer
|
||||
}
|
||||
|
||||
// ObserveChunk ingests a new raw stream chunk and returns true if a terminal
|
||||
// marker was detected in a parsed frame payload.
|
||||
func (d *StreamTerminalDetector) ObserveChunk(chunk []byte) bool {
|
||||
if len(chunk) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fast path: detect terminal markers when a single chunk already contains
|
||||
// a complete payload (SSE data line or plain JSON body).
|
||||
// Skip this when the chunk already contains full SSE frame delimiters,
|
||||
// because multi-event chunks need frame-by-frame parsing.
|
||||
if !containsSSEFrameDelimiter(chunk) && d.detectInFrame(chunk) {
|
||||
return true
|
||||
}
|
||||
|
||||
d.pending.Write(chunk)
|
||||
|
||||
for {
|
||||
data := d.pending.Bytes()
|
||||
delimIdx, delimLen := findFirstSSEFrameDelimiter(data)
|
||||
if delimIdx < 0 {
|
||||
break
|
||||
}
|
||||
|
||||
frame := append([]byte(nil), data[:delimIdx]...)
|
||||
d.pending.Next(delimIdx + delimLen)
|
||||
if d.detectInFrame(frame) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Some passthrough streams emit plain JSON chunks (no SSE "\n\n" framing).
|
||||
// Try parsing the current pending buffer as a whole JSON payload.
|
||||
if d.detectInUndelimitedPending() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Keep memory bounded if the upstream never emits a frame delimiter.
|
||||
if d.pending.Len() > maxTerminalDetectorBufferBytes {
|
||||
drain := d.pending.Bytes()
|
||||
if idx, delimLen := findLastSSEFrameDelimiter(drain); idx >= 0 {
|
||||
d.pending.Next(idx + delimLen)
|
||||
} else {
|
||||
trimTo := maxTerminalDetectorBufferBytes / 2
|
||||
keptPrefix := append([]byte(nil), drain[:trimTo]...)
|
||||
d.pending.Reset()
|
||||
d.pending.Write(keptPrefix)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *StreamTerminalDetector) detectInUndelimitedPending() bool {
|
||||
if d.pending.Len() == 0 {
|
||||
return false
|
||||
}
|
||||
payload := bytes.TrimSpace(d.pending.Bytes())
|
||||
if len(payload) == 0 {
|
||||
return false
|
||||
}
|
||||
return hasFinishReasonMarker(payload)
|
||||
}
|
||||
|
||||
func (d *StreamTerminalDetector) detectInFrame(frame []byte) bool {
|
||||
payload := extractSSEDataPayload(frame)
|
||||
if len(payload) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(string(payload))
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
if text == "[DONE]" {
|
||||
return true
|
||||
}
|
||||
|
||||
return hasFinishReasonMarker([]byte(text))
|
||||
}
|
||||
|
||||
func extractSSEDataPayload(frame []byte) []byte {
|
||||
trimmed := bytes.TrimSpace(frame)
|
||||
if len(trimmed) == 0 {
|
||||
return nil
|
||||
}
|
||||
if !hasSSEDataLinePrefix(trimmed) {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
lines := bytes.Split(trimmed, []byte("\n"))
|
||||
var payload bytes.Buffer
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || bytes.HasPrefix(line, []byte(":")) {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(line, []byte("data:")) {
|
||||
data := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
if payload.Len() > 0 {
|
||||
payload.WriteByte('\n')
|
||||
}
|
||||
payload.Write(data)
|
||||
}
|
||||
}
|
||||
if payload.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
return payload.Bytes()
|
||||
}
|
||||
|
||||
func hasSSEDataLinePrefix(frame []byte) bool {
|
||||
lines := bytes.Split(frame, []byte("\n"))
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || bytes.HasPrefix(line, []byte(":")) {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(line, []byte("data:")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasFinishReasonMarker(payload []byte) bool {
|
||||
var root any
|
||||
if err := sonic.Unmarshal(payload, &root); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
switch v := root.(type) {
|
||||
case map[string]any:
|
||||
return hasTerminalMarkerInTopLevelObject(v)
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
obj, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if hasTerminalMarkerInTopLevelObject(obj) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasTerminalMarkerInTopLevelObject(root map[string]any) bool {
|
||||
if hasValidFinishReasonValue(root) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Gemini/Vertex streamGenerateContent often signals terminal state in
|
||||
// top-level candidates[*].finishReason.
|
||||
if candidatesValue, ok := root["candidates"]; ok {
|
||||
if candidates, ok := candidatesValue.([]any); ok {
|
||||
return allCandidatesFinished(candidates)
|
||||
}
|
||||
}
|
||||
|
||||
// usageMetadata can show up before terminal chunks in long-running streams.
|
||||
// Treat it as terminal only when all candidates are finished.
|
||||
if usageMetadata, ok := root["usageMetadata"]; ok {
|
||||
if usageMap, ok := usageMetadata.(map[string]any); ok && len(usageMap) > 0 {
|
||||
if candidatesValue, ok := root["candidates"]; ok {
|
||||
if candidates, ok := candidatesValue.([]any); ok && allCandidatesFinished(candidates) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if promptFeedback, ok := root["promptFeedback"]; ok {
|
||||
if feedbackMap, ok := promptFeedback.(map[string]any); ok {
|
||||
if reason, ok := feedbackMap["blockReason"].(string); ok && strings.TrimSpace(reason) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func allCandidatesFinished(candidates []any) bool {
|
||||
if len(candidates) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
candidateMap, ok := candidate.(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !hasValidFinishReasonValue(candidateMap) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func hasValidFinishReasonValue(node map[string]any) bool {
|
||||
for _, key := range []string{"finishReason", "finish_reason"} {
|
||||
value, ok := node[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if str, ok := value.(string); ok && strings.TrimSpace(str) != "" && str != "FINISH_REASON_UNSPECIFIED" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func findFirstSSEFrameDelimiter(data []byte) (idx int, delimLen int) {
|
||||
idxLF := bytes.Index(data, sseFrameDelimiterLF)
|
||||
idxCRLF := bytes.Index(data, sseFrameDelimiterCRLF)
|
||||
|
||||
switch {
|
||||
case idxLF < 0 && idxCRLF < 0:
|
||||
return -1, 0
|
||||
case idxLF < 0:
|
||||
return idxCRLF, len(sseFrameDelimiterCRLF)
|
||||
case idxCRLF < 0:
|
||||
return idxLF, len(sseFrameDelimiterLF)
|
||||
case idxCRLF < idxLF:
|
||||
return idxCRLF, len(sseFrameDelimiterCRLF)
|
||||
default:
|
||||
return idxLF, len(sseFrameDelimiterLF)
|
||||
}
|
||||
}
|
||||
|
||||
func findLastSSEFrameDelimiter(data []byte) (idx int, delimLen int) {
|
||||
idxLF := bytes.LastIndex(data, sseFrameDelimiterLF)
|
||||
idxCRLF := bytes.LastIndex(data, sseFrameDelimiterCRLF)
|
||||
|
||||
switch {
|
||||
case idxLF < 0 && idxCRLF < 0:
|
||||
return -1, 0
|
||||
case idxLF < 0:
|
||||
return idxCRLF, len(sseFrameDelimiterCRLF)
|
||||
case idxCRLF < 0:
|
||||
return idxLF, len(sseFrameDelimiterLF)
|
||||
case idxCRLF > idxLF:
|
||||
return idxCRLF, len(sseFrameDelimiterCRLF)
|
||||
default:
|
||||
return idxLF, len(sseFrameDelimiterLF)
|
||||
}
|
||||
}
|
||||
|
||||
func containsSSEFrameDelimiter(data []byte) bool {
|
||||
return bytes.Contains(data, sseFrameDelimiterLF) || bytes.Contains(data, sseFrameDelimiterCRLF)
|
||||
}
|
||||
155
core/providers/utils/streamterminaldetector_test.go
Normal file
155
core/providers/utils/streamterminaldetector_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkSSEFinishReasonAcrossChunks(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
|
||||
chunks := [][]byte{
|
||||
[]byte("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"hi\"}]},"),
|
||||
[]byte("\"finishReason\":\"STOP\"}]}\n\n"),
|
||||
}
|
||||
|
||||
if detector.ObserveChunk(chunks[0]) {
|
||||
t.Fatalf("unexpected terminal detection on first chunk")
|
||||
}
|
||||
if !detector.ObserveChunk(chunks[1]) {
|
||||
t.Fatalf("expected terminal detection for candidates finishReason")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkSSETopLevelFinishReasonAcrossChunks(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
chunks := [][]byte{
|
||||
[]byte("data: {\"id\":\"abc\","),
|
||||
[]byte("\"finishReason\":\"STOP\"}\n\n"),
|
||||
}
|
||||
if detector.ObserveChunk(chunks[0]) {
|
||||
t.Fatalf("unexpected terminal detection on first chunk")
|
||||
}
|
||||
if !detector.ObserveChunk(chunks[1]) {
|
||||
t.Fatalf("expected terminal detection for top-level finishReason")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkDoneMarker(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
if !detector.ObserveChunk([]byte("data: [DONE]\n\n")) {
|
||||
t.Fatalf("expected [DONE] marker to be terminal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkDoneMarkerCRLF(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
if !detector.ObserveChunk([]byte("data: [DONE]\r\n\r\n")) {
|
||||
t.Fatalf("expected [DONE] marker with CRLF delimiter to be terminal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkIgnoresUnspecifiedFinishReason(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
if detector.ObserveChunk([]byte("data: {\"finishReason\":\"FINISH_REASON_UNSPECIFIED\"}\n\n")) {
|
||||
t.Fatalf("unexpected terminal detection for FINISH_REASON_UNSPECIFIED")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkPlainJSONAcrossChunks(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
if detector.ObserveChunk([]byte("{\"content\":\"hello\",")) {
|
||||
t.Fatalf("unexpected terminal detection for incomplete json")
|
||||
}
|
||||
if !detector.ObserveChunk([]byte("\"finishReason\":\"STOP\"}")) {
|
||||
t.Fatalf("expected terminal detection for plain json stream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkJSONWithDataURITokenAndDelimiter(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
chunk := []byte("{\"finishReason\":\"STOP\",\"content\":{\"parts\":[{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"data:image/png;base64,AAAA\"}}]}}\n\n")
|
||||
if !detector.ObserveChunk(chunk) {
|
||||
t.Fatalf("expected terminal detection for delimited JSON containing data URI token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkMultiEventSSEInSingleChunk(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
chunk := []byte("data: {}\n\ndata: {\"finishReason\":\"STOP\"}\n\n")
|
||||
if !detector.ObserveChunk(chunk) {
|
||||
t.Fatalf("expected terminal detection for multi-event SSE chunk")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkMetadataOnlyFrameIsNotTerminal(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
if detector.ObserveChunk([]byte("data: {\"usageMetadata\":{\"totalTokenCount\":12}}\n\n")) {
|
||||
t.Fatalf("unexpected terminal detection for metadata-only frame")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkMetadataWithFinishedCandidateIsTerminal(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
if !detector.ObserveChunk([]byte("data: {\"usageMetadata\":{\"totalTokenCount\":12},\"candidates\":[{\"finishReason\":\"STOP\"}]}\n\n")) {
|
||||
t.Fatalf("expected terminal detection for metadata with finished candidate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkMetadataWithUnspecifiedCandidateIsNotTerminal(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
if detector.ObserveChunk([]byte("data: {\"usageMetadata\":{\"totalTokenCount\":12},\"candidates\":[{\"finishReason\":\"FINISH_REASON_UNSPECIFIED\"}]}\n\n")) {
|
||||
t.Fatalf("unexpected terminal detection for metadata with unfinished candidate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkMetadataWithMixedCandidatesIsNotTerminal(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
if detector.ObserveChunk([]byte("data: {\"usageMetadata\":{\"totalTokenCount\":12},\"candidates\":[{\"finishReason\":\"STOP\"},{}]}\n\n")) {
|
||||
t.Fatalf("unexpected terminal detection for metadata with mixed finished/unfinished candidates")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkTopLevelArrayWithCandidatesFinishReason(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
chunk := []byte("data: [{\"candidates\":[{\"finishReason\":\"STOP\"}]}]\n\n")
|
||||
if !detector.ObserveChunk(chunk) {
|
||||
t.Fatalf("expected terminal detection for top-level array payload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkCandidatesRequireAllFinished(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
chunk := []byte("data: {\"candidates\":[{\"finishReason\":\"STOP\"},{\"finishReason\":\"FINISH_REASON_UNSPECIFIED\"}]}\n\n")
|
||||
if detector.ObserveChunk(chunk) {
|
||||
t.Fatalf("unexpected terminal detection when not all candidates are finished")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkCandidatesAllFinishedIsTerminal(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
chunk := []byte("data: {\"candidates\":[{\"finishReason\":\"STOP\"},{\"finishReason\":\"MAX_TOKENS\"}]}\n\n")
|
||||
if !detector.ObserveChunk(chunk) {
|
||||
t.Fatalf("expected terminal detection when all candidates are finished")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamTerminalDetectorObserveChunkUndelimitedOverflowKeepsPrefix(t *testing.T) {
|
||||
detector := &StreamTerminalDetector{}
|
||||
originalPrefix := bytes.Repeat([]byte("a"), maxTerminalDetectorBufferBytes/2)
|
||||
chunk := append([]byte(nil), originalPrefix...)
|
||||
chunk = append(chunk, bytes.Repeat([]byte("b"), maxTerminalDetectorBufferBytes)...)
|
||||
|
||||
if detector.ObserveChunk(chunk) {
|
||||
t.Fatalf("unexpected terminal detection for non-json buffer")
|
||||
}
|
||||
|
||||
trimTo := maxTerminalDetectorBufferBytes / 2
|
||||
got := detector.pending.Bytes()
|
||||
if len(got) != trimTo {
|
||||
t.Fatalf("expected pending length %d, got %d", trimTo, len(got))
|
||||
}
|
||||
if !bytes.Equal(got, originalPrefix) {
|
||||
t.Fatalf("expected pending buffer to keep original prefix")
|
||||
}
|
||||
}
|
||||
171
core/providers/utils/tls_test.go
Normal file
171
core/providers/utils/tls_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// testLogger is a minimal logger for tests that implements schemas.Logger.
|
||||
type testLogger struct{}
|
||||
|
||||
func (testLogger) Debug(string, ...any) {}
|
||||
func (testLogger) Info(string, ...any) {}
|
||||
func (testLogger) Warn(string, ...any) {}
|
||||
func (testLogger) Error(string, ...any) {}
|
||||
func (testLogger) Fatal(string, ...any) {}
|
||||
func (testLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (testLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (testLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
// validTestCertPEM returns a minimal valid PEM-encoded CA certificate for testing.
|
||||
func validTestCertPEM(t *testing.T) string {
|
||||
t.Helper()
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generate key: %v", err)
|
||||
}
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "test-ca"},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatalf("create certificate: %v", err)
|
||||
}
|
||||
block := &pem.Block{Type: "CERTIFICATE", Bytes: certDER}
|
||||
return string(pem.EncodeToMemory(block))
|
||||
}
|
||||
|
||||
func TestConfigureTLS_ReturnsUnchangedWhenNeitherSet(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
logger := testLogger{}
|
||||
|
||||
result := ConfigureTLS(client, schemas.NetworkConfig{}, logger)
|
||||
|
||||
if result != client {
|
||||
t.Error("ConfigureTLS should return the same client when neither InsecureSkipVerify nor CACertPEM is set")
|
||||
}
|
||||
if client.TLSConfig != nil {
|
||||
t.Error("TLSConfig should remain nil when no TLS options are set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureTLS_SetsInsecureSkipVerify(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
logger := testLogger{}
|
||||
|
||||
result := ConfigureTLS(client, schemas.NetworkConfig{InsecureSkipVerify: true}, logger)
|
||||
|
||||
if result != client {
|
||||
t.Error("ConfigureTLS should return the same client")
|
||||
}
|
||||
if client.TLSConfig == nil {
|
||||
t.Fatal("TLSConfig should be set when InsecureSkipVerify is true")
|
||||
}
|
||||
if !client.TLSConfig.InsecureSkipVerify {
|
||||
t.Error("InsecureSkipVerify should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureTLS_AppliesCACertPEM(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
logger := testLogger{}
|
||||
caPEM := validTestCertPEM(t)
|
||||
|
||||
result := ConfigureTLS(client, schemas.NetworkConfig{CACertPEM: caPEM}, logger)
|
||||
|
||||
if result != client {
|
||||
t.Error("ConfigureTLS should return the same client")
|
||||
}
|
||||
if client.TLSConfig == nil {
|
||||
t.Fatal("TLSConfig should be set when CACertPEM is provided")
|
||||
}
|
||||
if client.TLSConfig.RootCAs == nil {
|
||||
t.Error("RootCAs should be set when CACertPEM is provided")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureTLS_HandlesInvalidCACertPEM(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
logger := testLogger{}
|
||||
|
||||
result := ConfigureTLS(client, schemas.NetworkConfig{CACertPEM: "not-valid-pem"}, logger)
|
||||
|
||||
if result != client {
|
||||
t.Error("ConfigureTLS should return the same client even when CACertPEM is invalid")
|
||||
}
|
||||
// Invalid PEM logs warning and skips RootCAs; TLSConfig may still be set with MinVersion
|
||||
if client.TLSConfig != nil && client.TLSConfig.RootCAs != nil {
|
||||
t.Error("RootCAs should not be set when CACertPEM is invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureTLS_MergesWithExistingTLSConfig(t *testing.T) {
|
||||
// Simulate client that already has TLSConfig from ConfigureProxy
|
||||
existingRootCAs, _ := x509.SystemCertPool()
|
||||
if existingRootCAs == nil {
|
||||
existingRootCAs = x509.NewCertPool()
|
||||
}
|
||||
client := &fasthttp.Client{
|
||||
TLSConfig: &tls.Config{
|
||||
RootCAs: existingRootCAs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
}
|
||||
logger := testLogger{}
|
||||
caPEM := validTestCertPEM(t)
|
||||
|
||||
result := ConfigureTLS(client, schemas.NetworkConfig{CACertPEM: caPEM}, logger)
|
||||
|
||||
if result != client {
|
||||
t.Error("ConfigureTLS should return the same client")
|
||||
}
|
||||
if client.TLSConfig == nil {
|
||||
t.Fatal("TLSConfig should remain set")
|
||||
}
|
||||
if client.TLSConfig.RootCAs == nil {
|
||||
t.Error("RootCAs should be set (merged with existing)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureTLS_InsecureSkipVerifyAndCACertPEM(t *testing.T) {
|
||||
client := &fasthttp.Client{}
|
||||
logger := testLogger{}
|
||||
caPEM := validTestCertPEM(t)
|
||||
|
||||
result := ConfigureTLS(client, schemas.NetworkConfig{
|
||||
InsecureSkipVerify: true,
|
||||
CACertPEM: caPEM,
|
||||
}, logger)
|
||||
|
||||
if result != client {
|
||||
t.Error("ConfigureTLS should return the same client")
|
||||
}
|
||||
if client.TLSConfig == nil {
|
||||
t.Fatal("TLSConfig should be set")
|
||||
}
|
||||
if !client.TLSConfig.InsecureSkipVerify {
|
||||
t.Error("InsecureSkipVerify should be true when both options are set")
|
||||
}
|
||||
if client.TLSConfig.RootCAs == nil {
|
||||
t.Error("RootCAs should be set when CACertPEM is provided")
|
||||
}
|
||||
}
|
||||
2779
core/providers/utils/utils.go
Normal file
2779
core/providers/utils/utils.go
Normal file
File diff suppressed because it is too large
Load Diff
122
core/providers/utils/utils_json_test.go
Normal file
122
core/providers/utils/utils_json_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSetJSONField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
path string
|
||||
value interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "set_string_on_empty_object",
|
||||
data: []byte(`{}`),
|
||||
path: "name",
|
||||
value: "test",
|
||||
expected: `{"name":"test"}`,
|
||||
},
|
||||
{
|
||||
name: "set_nested_path",
|
||||
data: []byte(`{}`),
|
||||
path: "file.displayName",
|
||||
value: "photo.jpg",
|
||||
expected: `{"file":{"displayName":"photo.jpg"}}`,
|
||||
},
|
||||
{
|
||||
name: "set_boolean",
|
||||
data: []byte(`{"model":"x"}`),
|
||||
path: "stream",
|
||||
value: true,
|
||||
expected: `{"model":"x","stream":true}`,
|
||||
},
|
||||
{
|
||||
name: "set_string_array",
|
||||
data: []byte(`{}`),
|
||||
path: "betas",
|
||||
value: []string{"a", "b"},
|
||||
expected: `{"betas":["a","b"]}`,
|
||||
},
|
||||
{
|
||||
name: "preserves_existing_fields",
|
||||
data: []byte(`{"a":1,"b":2}`),
|
||||
path: "c",
|
||||
value: 3,
|
||||
expected: `{"a":1,"b":2,"c":3}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := SetJSONField(tt.data, tt.path, tt.value)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, string(result), "exact byte-level ordering must match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteJSONField(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
path string
|
||||
expected string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "delete_existing_field",
|
||||
data: []byte(`{"a":1,"b":2}`),
|
||||
path: "a",
|
||||
expected: `{"b":2}`,
|
||||
},
|
||||
{
|
||||
name: "delete_nonexistent_field",
|
||||
data: []byte(`{"a":1}`),
|
||||
path: "b",
|
||||
expected: `{"a":1}`,
|
||||
},
|
||||
{
|
||||
name: "sequential_deletes",
|
||||
data: []byte(`{"a":1,"b":2,"c":3}`),
|
||||
path: "", // handled in validate
|
||||
expected: `{}`,
|
||||
},
|
||||
{
|
||||
name: "preserves_remaining_order",
|
||||
data: []byte(`{"x":1,"y":2,"z":3}`),
|
||||
path: "y",
|
||||
expected: `{"x":1,"z":3}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.name == "sequential_deletes" {
|
||||
data := []byte(`{"a":1,"b":2,"c":3}`)
|
||||
var err error
|
||||
data, err = DeleteJSONField(data, "a")
|
||||
require.NoError(t, err)
|
||||
data, err = DeleteJSONField(data, "b")
|
||||
require.NoError(t, err)
|
||||
data, err = DeleteJSONField(data, "c")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, string(data), "exact byte-level ordering must match")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := DeleteJSONField(tt.data, tt.path)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, string(result), "exact byte-level ordering must match")
|
||||
})
|
||||
}
|
||||
}
|
||||
1441
core/providers/utils/utils_test.go
Normal file
1441
core/providers/utils/utils_test.go
Normal file
File diff suppressed because it is too large
Load Diff
34
core/providers/utils/videos.go
Normal file
34
core/providers/utils/videos.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// StripVideoIDProviderSuffix removes ":<provider>" from a video ID if present.
|
||||
func StripVideoIDProviderSuffix(videoID string, provider schemas.ModelProvider) string {
|
||||
suffix := ":" + string(provider)
|
||||
stripped := strings.TrimSuffix(videoID, suffix)
|
||||
// URL decode the ID to restore original characters (e.g., %2F -> /)
|
||||
if decoded, err := url.PathUnescape(stripped); err == nil {
|
||||
return decoded
|
||||
}
|
||||
return stripped
|
||||
}
|
||||
|
||||
// AddVideoIDProviderSuffix ensures a video ID is scoped as "<id>:<provider>".
|
||||
func AddVideoIDProviderSuffix(videoID string, provider schemas.ModelProvider) string {
|
||||
if videoID == "" {
|
||||
return videoID
|
||||
}
|
||||
suffix := ":" + string(provider)
|
||||
if strings.HasSuffix(videoID, suffix) {
|
||||
return videoID
|
||||
}
|
||||
// URL-encode the video ID to make it safe for URL paths
|
||||
// This converts / to %2F and other special characters
|
||||
escapedVideoID := url.PathEscape(videoID)
|
||||
return escapedVideoID + suffix
|
||||
}
|
||||
Reference in New Issue
Block a user