first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,136 @@
// Package utils provides common utility functions used across different provider implementations.
// This file contains audio-related utility functions for format conversion.
package utils
import (
"bytes"
"encoding/binary"
)
// PCMConfig holds the configuration for PCM audio data
type PCMConfig struct {
SampleRate int // Sample rate in Hz (e.g., 24000)
NumChannels int // Number of audio channels (1 = mono, 2 = stereo)
BitsPerSample int // Bits per sample (e.g., 16)
}
// DefaultGeminiPCMConfig returns the default PCM configuration for Gemini TTS
// Gemini TTS returns audio in PCM format with the following specs:
// - Format: signed 16-bit little-endian (s16le)
// - Sample rate: 24000 Hz
// - Channels: 1 (mono)
func DefaultGeminiPCMConfig() PCMConfig {
return PCMConfig{
SampleRate: 24000,
NumChannels: 1,
BitsPerSample: 16,
}
}
// ConvertPCMToWAV converts raw PCM audio data to WAV format
// The PCM data is expected to be in signed little-endian format (s16le for 16-bit)
func ConvertPCMToWAV(pcmData []byte, config PCMConfig) ([]byte, error) {
byteRate := config.SampleRate * config.NumChannels * config.BitsPerSample / 8
blockAlign := config.NumChannels * config.BitsPerSample / 8
dataSize := uint32(len(pcmData))
fileSize := 36 + dataSize // 36 bytes for header + data
var buf bytes.Buffer
// RIFF header
buf.WriteString("RIFF")
binary.Write(&buf, binary.LittleEndian, fileSize)
buf.WriteString("WAVE")
// fmt subchunk
buf.WriteString("fmt ")
binary.Write(&buf, binary.LittleEndian, uint32(16)) // Subchunk1Size (16 for PCM)
binary.Write(&buf, binary.LittleEndian, uint16(1)) // AudioFormat (1 = PCM)
binary.Write(&buf, binary.LittleEndian, uint16(config.NumChannels)) // NumChannels
binary.Write(&buf, binary.LittleEndian, uint32(config.SampleRate)) // SampleRate
binary.Write(&buf, binary.LittleEndian, uint32(byteRate)) // ByteRate
binary.Write(&buf, binary.LittleEndian, uint16(blockAlign)) // BlockAlign
binary.Write(&buf, binary.LittleEndian, uint16(config.BitsPerSample)) // BitsPerSample
// data subchunk
buf.WriteString("data")
binary.Write(&buf, binary.LittleEndian, dataSize)
buf.Write(pcmData)
return buf.Bytes(), nil
}
var (
riff = []byte("RIFF")
wave = []byte("WAVE")
id3 = []byte("ID3")
form = []byte("FORM")
aiff = []byte("AIFF")
aifc = []byte("AIFC")
flac = []byte("fLaC")
oggs = []byte("OggS")
adif = []byte("ADIF")
)
// DetectAudioMimeType attempts to detect the MIME type from audio file headers.
// Supports detection of: WAV, MP3, AIFF, AAC, OGG Vorbis, and FLAC formats.
func DetectAudioMimeType(audioData []byte) string {
if len(audioData) < 4 {
return "audio/mp3"
}
// WAV (RIFF/WAVE)
if len(audioData) >= 12 &&
bytes.Equal(audioData[:4], riff) &&
bytes.Equal(audioData[8:12], wave) {
return "audio/wav"
}
// MP3: ID3v2 tag (keep this check for MP3)
if len(audioData) >= 3 && bytes.Equal(audioData[:3], id3) {
return "audio/mp3"
}
// AAC: ADIF or ADTS (0xFFF sync) - check before MP3 frame sync to avoid misclassification
if bytes.HasPrefix(audioData, adif) {
return "audio/aac"
}
if len(audioData) >= 2 && audioData[0] == 0xFF && (audioData[1]&0xF6) == 0xF0 {
return "audio/aac"
}
// AIFF / AIFC (map both to audio/aiff)
if len(audioData) >= 12 && bytes.Equal(audioData[:4], form) &&
(bytes.Equal(audioData[8:12], aiff) || bytes.Equal(audioData[8:12], aifc)) {
return "audio/aiff"
}
// FLAC
if bytes.HasPrefix(audioData, flac) {
return "audio/flac"
}
// OGG container
if bytes.HasPrefix(audioData, oggs) {
return "audio/ogg"
}
// MP3: MPEG frame sync (cover common variants) - check after AAC to avoid misclassification
if len(audioData) >= 2 && audioData[0] == 0xFF &&
(audioData[1] == 0xFB || audioData[1] == 0xF3 || audioData[1] == 0xF2 || audioData[1] == 0xFA) {
return "audio/mp3"
}
// Fallback within supported set
return "audio/mp3"
}
// AudioFilenameFromBytes returns a filename with the correct extension for the given audio data.
// Falls back to "audio.mp3" if the format cannot be detected.
func AudioFilenameFromBytes(audioData []byte) string {
mimeToExt := map[string]string{
"audio/wav": "audio.wav",
"audio/aac": "audio.aac",
"audio/aiff": "audio.aiff",
"audio/flac": "audio.flac",
"audio/ogg": "audio.ogg",
}
mime := DetectAudioMimeType(audioData)
if ext, ok := mimeToExt[mime]; ok {
return ext
}
return "audio.mp3"
}

View File

@@ -0,0 +1,192 @@
package utils
import (
"compress/gzip"
"compress/zlib"
"io"
"sync"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
)
// ---------------------------------------------------------------------------
// Pooled decompression readers
//
// Each encoding gets a sync.Pool with Acquire/Release helpers that follow the
// same contract:
// - Acquire(r io.Reader) returns a ready-to-read decompressor, reusing a
// pooled instance when possible, falling back to a fresh allocation.
// - Release returns the decompressor to the pool for future reuse.
// Callers MUST call Release exactly once after the reader is fully consumed.
//
// All pool operations are panic-safe: type assertions use the comma-ok form,
// Reset calls are wrapped in recover, and nil checks guard every dereference.
// A corrupt or wrong-typed pooled instance is silently discarded (GC reclaims
// it) and a fresh allocation takes its place.
//
// Gzip, deflate, and brotli readers are stateless between streams — Close (if
// applicable) then Reset is safe. Zstd decoders run background goroutines so
// Close is terminal; pooled decoders are reset without closing.
// ---------------------------------------------------------------------------
// safeReset calls resetFn and recovers from any panic. Returns true on success.
// A corrupt pooled reader may panic inside Reset; this prevents that from
// bringing down the server.
func safeReset(resetFn func() error) (ok bool) {
defer func() {
if r := recover(); r != nil {
ok = false
}
}()
return resetFn() == nil
}
// ---- gzip ----
var gzipReaderPool = sync.Pool{
New: func() any {
return &gzip.Reader{}
},
}
// AcquireGzipReader gets a gzip.Reader from the pool and resets it to read from r,
// or creates a new one if the pool is empty or reset fails.
func AcquireGzipReader(r io.Reader) (*gzip.Reader, error) {
if v := gzipReaderPool.Get(); v != nil {
if gz, ok := v.(*gzip.Reader); ok {
if safeReset(func() error { return gz.Reset(r) }) {
return gz, nil
}
}
// Wrong type or reset failed/panicked — discard, let GC reclaim.
}
return gzip.NewReader(r)
}
// ReleaseGzipReader closes and returns a gzip.Reader to the pool.
func ReleaseGzipReader(gz *gzip.Reader) {
if gz == nil {
return
}
_ = gz.Close()
gzipReaderPool.Put(gz)
}
// ---- deflate ----
//
// HTTP Content-Encoding "deflate" is zlib-wrapped DEFLATE (RFC 1950), NOT raw
// DEFLATE (RFC 1951). This matches fasthttp's implementation and the HTTP spec
// (RFC 9110 §8.4.1.2). We use compress/zlib, not compress/flate.
// deflateReader is the interface that zlib readers support for pooling.
// The concrete type from zlib.NewReader is unexported, but implements Resetter.
type deflateReader interface {
io.ReadCloser
Reset(r io.Reader, dict []byte) error
}
// No New func: zlib.NewReader validates the header eagerly, so it cannot be
// created with a nil reader. Pooled readers are populated via Release.
var deflateReaderPool = sync.Pool{}
// AcquireFlateReader gets a zlib (HTTP "deflate") reader from the pool and
// resets it to read from r, or creates a new one if the pool is empty or
// reset fails.
func AcquireFlateReader(r io.Reader) (io.ReadCloser, error) {
if v := deflateReaderPool.Get(); v != nil {
if dr, ok := v.(deflateReader); ok {
if safeReset(func() error { return dr.Reset(r, nil) }) {
return dr, nil
}
}
}
return zlib.NewReader(r)
}
// ReleaseFlateReader closes and returns a deflate reader to the pool.
func ReleaseFlateReader(fr io.ReadCloser) {
if fr == nil {
return
}
_ = fr.Close()
deflateReaderPool.Put(fr)
}
// ---- brotli ----
var brotliReaderPool = sync.Pool{
New: func() any {
return brotli.NewReader(nil)
},
}
// AcquireBrotliReader gets a brotli.Reader from the pool and resets it to read
// from r, or creates a new one if the pool is empty or reset panics.
func AcquireBrotliReader(r io.Reader) *brotli.Reader {
if v := brotliReaderPool.Get(); v != nil {
if br, ok := v.(*brotli.Reader); ok {
// brotli.Reset is void (no error), but wrap in safeReset for
// consistency: a corrupt pooled reader could panic on Reset.
if safeReset(func() error { br.Reset(r); return nil }) {
return br
}
}
// Wrong type or reset panicked — discard, let GC reclaim.
}
return brotli.NewReader(r)
}
// ReleaseBrotliReader returns a brotli.Reader to the pool.
// Brotli readers have no Close method; Reset(nil) is sufficient to drop the
// reference to the underlying reader.
func ReleaseBrotliReader(br *brotli.Reader) {
if br == nil {
return
}
br.Reset(nil)
brotliReaderPool.Put(br)
}
// ---- zstd ----
var zstdDecoderPool = sync.Pool{
New: func() any {
dec, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(1))
if err != nil {
// NewReader(nil) failing is unexpected; return nil so Acquire
// falls through to a fresh allocation with the real reader.
return nil
}
return dec
},
}
// AcquireZstdDecoder gets a zstd.Decoder from the pool and resets it to read
// from r, or creates a new one if the pool is empty or reset fails.
// Decoders are created with concurrency=1 to minimise goroutine overhead.
func AcquireZstdDecoder(r io.Reader) (*zstd.Decoder, error) {
if v := zstdDecoderPool.Get(); v != nil {
if dec, ok := v.(*zstd.Decoder); ok && dec != nil {
if safeReset(func() error { return dec.Reset(r) }) {
return dec, nil
}
// Reset failed/panicked — release references before discarding.
// Don't call Close (terminal); Reset(nil) is safe per pool contract.
_ = dec.Reset(nil)
}
}
return zstd.NewReader(r, zstd.WithDecoderConcurrency(1))
}
// ReleaseZstdDecoder returns a zstd.Decoder to the pool.
// Unlike other decoders, zstd.Close() is terminal (stops background goroutines
// permanently). We only call Reset(nil) to release the source reference, then
// re-pool. Close is never called on pooled decoders.
func ReleaseZstdDecoder(dec *zstd.Decoder) {
if dec == nil {
return
}
_ = dec.Reset(nil)
zstdDecoderPool.Put(dec)
}

View File

@@ -0,0 +1,516 @@
package utils
import (
"bytes"
"compress/gzip"
"compress/zlib"
"fmt"
"io"
"strings"
"sync"
"testing"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
)
const poolTestIterations = 10
var testPayload = []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello world"}]}`)
// ---------------------------------------------------------------------------
// helpers — one compressor per encoding
// ---------------------------------------------------------------------------
func compressGzip(data []byte) []byte {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
if _, err := gz.Write(data); err != nil {
panic(fmt.Errorf("gzip write: %w", err))
}
if err := gz.Close(); err != nil {
panic(fmt.Errorf("gzip close: %w", err))
}
return buf.Bytes()
}
// compressFlate produces zlib-wrapped DEFLATE (RFC 1950) — the correct format
// for HTTP Content-Encoding "deflate" per RFC 9110 §8.4.1.2.
func compressFlate(data []byte) []byte {
var buf bytes.Buffer
w, err := zlib.NewWriterLevel(&buf, zlib.DefaultCompression)
if err != nil {
panic(fmt.Errorf("zlib new writer: %w", err))
}
if _, err := w.Write(data); err != nil {
panic(fmt.Errorf("zlib write: %w", err))
}
if err := w.Close(); err != nil {
panic(fmt.Errorf("zlib close: %w", err))
}
return buf.Bytes()
}
func compressBrotli(data []byte) []byte {
var buf bytes.Buffer
w := brotli.NewWriter(&buf)
if _, err := w.Write(data); err != nil {
panic(fmt.Errorf("brotli write: %w", err))
}
if err := w.Close(); err != nil {
panic(fmt.Errorf("brotli close: %w", err))
}
return buf.Bytes()
}
func compressZstd(data []byte) []byte {
var buf bytes.Buffer
enc, err := zstd.NewWriter(&buf)
if err != nil {
panic(fmt.Errorf("zstd new writer: %w", err))
}
if _, err := enc.Write(data); err != nil {
panic(fmt.Errorf("zstd write: %w", err))
}
if err := enc.Close(); err != nil {
panic(fmt.Errorf("zstd close: %w", err))
}
return buf.Bytes()
}
// ---------------------------------------------------------------------------
// Pool cycle tests — each runs Acquire → ReadAll → Release N times to verify
// that pooled instances are reused correctly across iterations.
// ---------------------------------------------------------------------------
func TestAcquireReleaseGzipReader(t *testing.T) {
compressed := compressGzip(testPayload)
for i := 0; i < poolTestIterations; i++ {
gz, err := AcquireGzipReader(bytes.NewReader(compressed))
if err != nil {
t.Fatalf("iteration %d: AcquireGzipReader error: %v", i, err)
}
got, err := io.ReadAll(gz)
if err != nil {
t.Fatalf("iteration %d: ReadAll error: %v", i, err)
}
if !bytes.Equal(got, testPayload) {
t.Errorf("iteration %d: got %q, want %q", i, got, testPayload)
}
ReleaseGzipReader(gz)
}
}
func TestAcquireReleaseFlateReader(t *testing.T) {
compressed := compressFlate(testPayload)
for i := 0; i < poolTestIterations; i++ {
fr, err := AcquireFlateReader(bytes.NewReader(compressed))
if err != nil {
t.Fatalf("iteration %d: AcquireFlateReader error: %v", i, err)
}
got, err := io.ReadAll(fr)
if err != nil {
t.Fatalf("iteration %d: ReadAll error: %v", i, err)
}
if !bytes.Equal(got, testPayload) {
t.Errorf("iteration %d: got %q, want %q", i, got, testPayload)
}
ReleaseFlateReader(fr)
}
}
func TestAcquireReleaseBrotliReader(t *testing.T) {
compressed := compressBrotli(testPayload)
for i := 0; i < poolTestIterations; i++ {
br := AcquireBrotliReader(bytes.NewReader(compressed))
got, err := io.ReadAll(br)
if err != nil {
t.Fatalf("iteration %d: ReadAll error: %v", i, err)
}
if !bytes.Equal(got, testPayload) {
t.Errorf("iteration %d: got %q, want %q", i, got, testPayload)
}
ReleaseBrotliReader(br)
}
}
func TestAcquireReleaseZstdDecoder(t *testing.T) {
compressed := compressZstd(testPayload)
for i := 0; i < poolTestIterations; i++ {
dec, err := AcquireZstdDecoder(bytes.NewReader(compressed))
if err != nil {
t.Fatalf("iteration %d: AcquireZstdDecoder error: %v", i, err)
}
got, err := io.ReadAll(dec)
if err != nil {
t.Fatalf("iteration %d: ReadAll error: %v", i, err)
}
if !bytes.Equal(got, testPayload) {
t.Errorf("iteration %d: got %q, want %q", i, got, testPayload)
}
ReleaseZstdDecoder(dec)
}
}
// ---------------------------------------------------------------------------
// Panic-safety tests — verify that corrupt or wrong-typed pooled instances
// are handled gracefully (no panics, fallback to fresh allocation).
//
// Each poison test drains the target pool first so stale values from earlier
// tests cannot interfere. sync.Pool.Get offers no ordering guarantee, so
// draining ensures the poisoned value is the only item available.
// ---------------------------------------------------------------------------
// drainPool removes all previously pooled values so the next Put/Get pair
// exercises the intended fallback path. Temporarily sets New to nil to ensure
// Get returns nil when the pool is empty.
func drainPool(p *sync.Pool) {
origNew := p.New
p.New = nil
for p.Get() != nil {
}
p.New = origNew
}
// TestPool_WrongType_NoPanic poisons each pool with a wrong-typed value,
// then verifies Acquire still succeeds by falling back to a fresh allocation.
func TestPool_WrongType_NoPanic(t *testing.T) {
wrongValue := "not a reader"
compressed := compressGzip(testPayload)
t.Run("gzip", func(t *testing.T) {
drainPool(&gzipReaderPool)
gzipReaderPool.Put(wrongValue)
gz, err := AcquireGzipReader(bytes.NewReader(compressed))
if err != nil {
t.Fatalf("AcquireGzipReader should fall back, got error: %v", err)
}
got, _ := io.ReadAll(gz)
if !bytes.Equal(got, testPayload) {
t.Errorf("got %q, want %q", got, testPayload)
}
ReleaseGzipReader(gz)
})
t.Run("deflate", func(t *testing.T) {
drainPool(&deflateReaderPool)
deflateReaderPool.Put(wrongValue)
fr, err := AcquireFlateReader(bytes.NewReader(compressFlate(testPayload)))
if err != nil {
t.Fatalf("AcquireFlateReader should fall back, got error: %v", err)
}
got, _ := io.ReadAll(fr)
if !bytes.Equal(got, testPayload) {
t.Errorf("got %q, want %q", got, testPayload)
}
ReleaseFlateReader(fr)
})
t.Run("brotli", func(t *testing.T) {
drainPool(&brotliReaderPool)
brotliReaderPool.Put(wrongValue)
br := AcquireBrotliReader(bytes.NewReader(compressBrotli(testPayload)))
got, err := io.ReadAll(br)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if !bytes.Equal(got, testPayload) {
t.Errorf("got %q, want %q", got, testPayload)
}
ReleaseBrotliReader(br)
})
t.Run("zstd", func(t *testing.T) {
drainPool(&zstdDecoderPool)
zstdDecoderPool.Put(wrongValue)
dec, err := AcquireZstdDecoder(bytes.NewReader(compressZstd(testPayload)))
if err != nil {
t.Fatalf("AcquireZstdDecoder should fall back, got error: %v", err)
}
got, _ := io.ReadAll(dec)
if !bytes.Equal(got, testPayload) {
t.Errorf("got %q, want %q", got, testPayload)
}
ReleaseZstdDecoder(dec)
})
}
// TestPool_NilInPool_NoPanic puts an explicit nil into each pool, then
// verifies Acquire still succeeds by falling back to a fresh allocation.
func TestPool_NilInPool_NoPanic(t *testing.T) {
t.Run("gzip", func(t *testing.T) {
drainPool(&gzipReaderPool)
gzipReaderPool.Put((*gzip.Reader)(nil))
gz, err := AcquireGzipReader(bytes.NewReader(compressGzip(testPayload)))
if err != nil {
t.Fatalf("should fall back, got error: %v", err)
}
got, _ := io.ReadAll(gz)
if !bytes.Equal(got, testPayload) {
t.Errorf("got %q, want %q", got, testPayload)
}
ReleaseGzipReader(gz)
})
t.Run("zstd", func(t *testing.T) {
drainPool(&zstdDecoderPool)
zstdDecoderPool.Put((*zstd.Decoder)(nil))
dec, err := AcquireZstdDecoder(bytes.NewReader(compressZstd(testPayload)))
if err != nil {
t.Fatalf("should fall back, got error: %v", err)
}
got, _ := io.ReadAll(dec)
if !bytes.Equal(got, testPayload) {
t.Errorf("got %q, want %q", got, testPayload)
}
ReleaseZstdDecoder(dec)
})
}
// TestPool_InvalidData_NoPanic verifies that Acquire handles corrupt/invalid
// compressed data without panicking. The error should surface as a return
// value or a read error, never a panic.
func TestPool_InvalidData_NoPanic(t *testing.T) {
garbage := []byte("this is not compressed data at all")
t.Run("gzip", func(t *testing.T) {
// gzip.NewReader validates the header immediately, so Acquire returns error.
_, err := AcquireGzipReader(bytes.NewReader(garbage))
if err == nil {
t.Fatal("expected error for invalid gzip data")
}
})
t.Run("deflate", func(t *testing.T) {
// zlib.NewReader validates the header eagerly, so Acquire returns error.
_, err := AcquireFlateReader(bytes.NewReader(garbage))
if err == nil {
t.Fatal("expected error for invalid deflate (zlib) data")
}
})
t.Run("brotli", func(t *testing.T) {
// brotli doesn't validate eagerly; error surfaces on Read.
br := AcquireBrotliReader(bytes.NewReader(garbage))
_, readErr := io.ReadAll(br)
if readErr == nil {
t.Fatal("expected read error for invalid brotli data")
}
ReleaseBrotliReader(br)
})
t.Run("zstd", func(t *testing.T) {
// zstd.Decoder doesn't validate eagerly; error surfaces on Read.
dec, err := AcquireZstdDecoder(bytes.NewReader(garbage))
if err != nil {
t.Fatalf("AcquireZstdDecoder should not error eagerly: %v", err)
}
_, readErr := io.ReadAll(dec)
if readErr == nil {
t.Fatal("expected read error for invalid zstd data")
}
ReleaseZstdDecoder(dec)
})
}
// TestPool_CorruptPooledInstance_NoPanic simulates a corrupt pooled reader
// whose Reset panics. Verifies safeReset catches the panic and Acquire
// falls through to a fresh allocation.
func TestPool_CorruptPooledInstance_NoPanic(t *testing.T) {
t.Run("safeReset_catches_panic", func(t *testing.T) {
ok := safeReset(func() error {
panic("simulated corrupt reader")
})
if ok {
t.Fatal("safeReset should return false when resetFn panics")
}
})
t.Run("gzip_after_corrupt", func(t *testing.T) {
// Poison pool with a gzip.Reader that has corrupt internal state:
// a zero-value reader that has been closed without ever being used.
drainPool(&gzipReaderPool)
corrupt := &gzip.Reader{}
gzipReaderPool.Put(corrupt)
// Acquire should recover, discard the corrupt reader, and create fresh.
compressed := compressGzip(testPayload)
gz, err := AcquireGzipReader(bytes.NewReader(compressed))
if err != nil {
t.Fatalf("should recover from corrupt pooled reader, got: %v", err)
}
got, _ := io.ReadAll(gz)
if !bytes.Equal(got, testPayload) {
t.Errorf("got %q, want %q", got, testPayload)
}
ReleaseGzipReader(gz)
})
}
// TestRelease_Nil_NoPanic verifies Release functions are safe to call with nil.
func TestRelease_Nil_NoPanic(t *testing.T) {
ReleaseGzipReader(nil)
ReleaseFlateReader(nil)
ReleaseBrotliReader(nil)
ReleaseZstdDecoder(nil)
}
// TestPool_RecoveryAndReuse verifies that after a corrupt instance is discarded,
// the pool recovers and subsequent cycles work normally.
func TestPool_RecoveryAndReuse(t *testing.T) {
// Drain then poison all pools with wrong types.
drainPool(&gzipReaderPool)
drainPool(&deflateReaderPool)
drainPool(&brotliReaderPool)
drainPool(&zstdDecoderPool)
gzipReaderPool.Put("wrong")
deflateReaderPool.Put(42)
brotliReaderPool.Put(struct{}{})
zstdDecoderPool.Put(false)
// Run a normal Acquire → ReadAll → Release cycle for each.
// This verifies the pool recovers: the wrong-typed value is discarded,
// a fresh instance is created and released back, making the pool healthy.
t.Run("gzip", func(t *testing.T) {
compressed := compressGzip(testPayload)
for i := 0; i < 3; i++ {
gz, err := AcquireGzipReader(bytes.NewReader(compressed))
if err != nil {
t.Fatalf("iteration %d: %v", i, err)
}
got, _ := io.ReadAll(gz)
if !bytes.Equal(got, testPayload) {
t.Errorf("iteration %d: mismatch", i)
}
ReleaseGzipReader(gz)
}
})
t.Run("deflate", func(t *testing.T) {
compressed := compressFlate(testPayload)
for i := 0; i < 3; i++ {
fr, err := AcquireFlateReader(bytes.NewReader(compressed))
if err != nil {
t.Fatalf("iteration %d: %v", i, err)
}
got, _ := io.ReadAll(fr)
if !bytes.Equal(got, testPayload) {
t.Errorf("iteration %d: mismatch", i)
}
ReleaseFlateReader(fr)
}
})
t.Run("brotli", func(t *testing.T) {
compressed := compressBrotli(testPayload)
for i := 0; i < 3; i++ {
br := AcquireBrotliReader(bytes.NewReader(compressed))
got, err := io.ReadAll(br)
if err != nil {
t.Fatalf("iteration %d: %v", i, err)
}
if !bytes.Equal(got, testPayload) {
t.Errorf("iteration %d: mismatch", i)
}
ReleaseBrotliReader(br)
}
})
t.Run("zstd", func(t *testing.T) {
compressed := compressZstd(testPayload)
for i := 0; i < 3; i++ {
dec, err := AcquireZstdDecoder(bytes.NewReader(compressed))
if err != nil {
t.Fatalf("iteration %d: %v", i, err)
}
got, _ := io.ReadAll(dec)
if !bytes.Equal(got, testPayload) {
t.Errorf("iteration %d: mismatch", i)
}
ReleaseZstdDecoder(dec)
}
})
}
// TestPool_EmptyReader_NoPanic verifies Acquire handles an empty reader
// (zero bytes) without panicking. Gzip/zstd should return an error (no header),
// deflate/brotli should return empty or error on Read.
func TestPool_EmptyReader_NoPanic(t *testing.T) {
empty := bytes.NewReader(nil)
t.Run("gzip", func(t *testing.T) {
_, err := AcquireGzipReader(empty)
if err == nil {
t.Fatal("expected error for empty gzip input")
}
})
t.Run("deflate", func(t *testing.T) {
// zlib.NewReader validates the header eagerly; empty input has no header.
_, err := AcquireFlateReader(bytes.NewReader(nil))
if err == nil {
t.Fatal("expected error for empty deflate (zlib) input")
}
})
t.Run("brotli", func(t *testing.T) {
br := AcquireBrotliReader(bytes.NewReader(nil))
data, _ := io.ReadAll(br)
if len(data) != 0 {
t.Errorf("expected empty output, got %d bytes", len(data))
}
ReleaseBrotliReader(br)
})
t.Run("zstd", func(t *testing.T) {
dec, err := AcquireZstdDecoder(bytes.NewReader(nil))
if err != nil {
// Some versions error eagerly, that's fine.
return
}
data, _ := io.ReadAll(dec)
// Empty input with no zstd frame yields empty output or read error.
_ = data
ReleaseZstdDecoder(dec)
})
}
// TestSafeReset verifies safeReset correctly handles panics and errors.
func TestSafeReset(t *testing.T) {
t.Run("success", func(t *testing.T) {
ok := safeReset(func() error { return nil })
if !ok {
t.Fatal("expected true for successful reset")
}
})
t.Run("error", func(t *testing.T) {
ok := safeReset(func() error { return io.ErrUnexpectedEOF })
if ok {
t.Fatal("expected false for failed reset")
}
})
t.Run("panic_string", func(t *testing.T) {
ok := safeReset(func() error { panic("boom") })
if ok {
t.Fatal("expected false for panicking reset")
}
})
t.Run("panic_nonnil", func(t *testing.T) {
ok := safeReset(func() error { panic("") })
if ok {
t.Fatal("expected false for empty-string panic")
}
})
t.Run("panic_error", func(t *testing.T) {
ok := safeReset(func() error {
panic(fmt.Errorf("internal corruption: %s", strings.Repeat("x", 100)))
})
if ok {
t.Fatal("expected false for error panic")
}
})
}

View File

@@ -0,0 +1,359 @@
package utils
import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/maximhq/bifrost/core/network"
"github.com/valyala/fasthttp"
)
// TestConfigureDialer_SetsRetryIfErr verifies that ConfigureDialer installs
// the StaleConnectionRetryIfErr callback on the client.
func TestConfigureDialer_SetsRetryIfErr(t *testing.T) {
client := &fasthttp.Client{}
if client.RetryIfErr != nil {
t.Fatal("precondition: RetryIfErr should be nil on a new client")
}
ConfigureDialer(client)
if client.RetryIfErr == nil {
t.Fatal("ConfigureDialer should set RetryIfErr")
}
// Verify it behaves like StaleConnectionRetryIfErr
reset, retry := client.RetryIfErr(nil, 1, fmt.Errorf("cannot find whitespace in the first line of response"))
if !reset || !retry {
t.Error("RetryIfErr should retry on whitespace error")
}
reset, retry = client.RetryIfErr(nil, 1, fmt.Errorf("dial tcp: no such host"))
if reset || retry {
t.Error("RetryIfErr should not retry on unrelated errors")
}
}
// TestConfigureDialer_SetsDial verifies that ConfigureDialer installs a custom
// Dial function on the client when no existing Dial is present.
func TestConfigureDialer_SetsDial(t *testing.T) {
client := &fasthttp.Client{}
if client.Dial != nil {
t.Fatal("precondition: Dial should be nil on a new client")
}
ConfigureDialer(client)
if client.Dial == nil {
t.Fatal("ConfigureDialer should set a Dial function")
}
}
// TestConfigureDialer_ComposesWithExistingDial verifies that when a custom Dial
// function is already set (e.g., from ConfigureProxy), ConfigureDialer wraps it
// and still enables TCP keepalive on the resulting connection.
func TestConfigureDialer_ComposesWithExistingDial(t *testing.T) {
var proxyDialCalled atomic.Bool
client := &fasthttp.Client{}
// Simulate a proxy dial function (set by ConfigureProxy)
client.Dial = func(addr string) (net.Conn, error) {
proxyDialCalled.Store(true)
return net.Dial("tcp", addr)
}
ConfigureDialer(client)
// Start a test server to connect to
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "ok")
}))
defer server.Close()
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(server.URL)
req.Header.SetMethod(http.MethodGet)
if err := client.Do(req, resp); err != nil {
t.Fatalf("request failed: %v", err)
}
if resp.StatusCode() != 200 {
t.Fatalf("expected 200, got %d", resp.StatusCode())
}
if !proxyDialCalled.Load() {
t.Error("ConfigureDialer should have called the existing proxy dial function")
}
}
// TestConfigureDialer_TCPKeepAliveEnabled verifies that connections created
// through ConfigureDialer have TCP keepalive enabled.
func TestConfigureDialer_TCPKeepAliveEnabled(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "ok")
}))
defer server.Close()
// Test without existing dial (direct connection path)
t.Run("without_existing_dial", func(t *testing.T) {
client := &fasthttp.Client{}
ConfigureDialer(client)
// The Dial function should create connections with keepalive
// We can verify by making a connection and checking the TCP options
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(server.URL)
req.Header.SetMethod(http.MethodGet)
if err := client.Do(req, resp); err != nil {
t.Fatalf("request failed: %v", err)
}
if resp.StatusCode() != 200 {
t.Fatalf("expected 200, got %d", resp.StatusCode())
}
})
// Test with existing dial (proxy composition path)
t.Run("with_existing_dial", func(t *testing.T) {
var connFromProxy net.Conn
client := &fasthttp.Client{}
client.Dial = func(addr string) (net.Conn, error) {
conn, err := net.Dial("tcp", addr)
connFromProxy = conn
return conn, err
}
ConfigureDialer(client)
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(server.URL)
req.Header.SetMethod(http.MethodGet)
if err := client.Do(req, resp); err != nil {
t.Fatalf("request failed: %v", err)
}
// Verify the proxy-returned connection is a TCP connection
// (ConfigureDialer enables keepalive via SetKeepAliveConfig on it)
if connFromProxy == nil {
t.Fatal("proxy dial should have been called")
}
if _, ok := connFromProxy.(*net.TCPConn); !ok {
t.Errorf("expected *net.TCPConn, got %T", connFromProxy)
}
})
}
// TestConfigureDialer_ReturnValue verifies that ConfigureDialer returns the
// same client pointer it received (for chaining).
func TestConfigureDialer_ReturnValue(t *testing.T) {
client := &fasthttp.Client{}
result := ConfigureDialer(client)
if result != client {
t.Error("ConfigureDialer should return the same client pointer")
}
}
// TestConfigureDialer_Idempotent verifies that calling ConfigureDialer multiple
// times doesn't break the client.
func TestConfigureDialer_Idempotent(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, "ok")
}))
defer server.Close()
client := &fasthttp.Client{}
ConfigureDialer(client)
ConfigureDialer(client) // called again
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(server.URL)
req.Header.SetMethod(http.MethodPost)
req.SetBodyString(`{"test": true}`)
if err := client.Do(req, resp); err != nil {
t.Fatalf("request failed after double ConfigureDialer: %v", err)
}
if resp.StatusCode() != 200 {
t.Fatalf("expected 200, got %d", resp.StatusCode())
}
}
// TestConfigureDialer_WithRetryOnStaleConnection is an integration test that
// verifies ConfigureDialer enables successful POST retry after TTL mismatch.
// This combines both the retry and keepalive behaviors.
func TestConfigureDialer_WithRetryOnStaleConnection(t *testing.T) {
if testing.Short() {
t.Skip("skipping TTL mismatch test in short mode (requires 11s wait)")
}
const (
serverIdleTimeout = 10 * time.Second
clientIdleTimeout = 15 * time.Second
waitBetween = 11 * time.Second
)
var requestCount atomic.Int32
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"ok": true, "request": %d}`, requestCount.Load())
}))
server.Config.IdleTimeout = serverIdleTimeout
server.Start()
defer server.Close()
client := &fasthttp.Client{
MaxIdleConnDuration: clientIdleTimeout,
MaxConnsPerHost: 10,
}
// Use ConfigureDialer (the function under test) instead of manually setting RetryIfErr
ConfigureDialer(client)
// First request: establish connection in pool
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(server.URL)
req.Header.SetMethod(http.MethodPost)
req.SetBodyString(`{"prompt": "hello"}`)
if err := client.Do(req, resp); err != nil {
t.Fatalf("First POST failed: %v", err)
}
if resp.StatusCode() != 200 {
t.Fatalf("First POST: expected 200, got %d", resp.StatusCode())
}
_ = resp.Body()
// Wait for server TTL to expire
t.Logf("Waiting %v for server idle timeout to expire...", waitBetween)
time.Sleep(waitBetween)
// Second request: stale connection should be retried by ConfigureDialer's retry policy
req2 := fasthttp.AcquireRequest()
resp2 := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req2)
defer fasthttp.ReleaseResponse(resp2)
req2.SetRequestURI(server.URL)
req2.Header.SetMethod(http.MethodPost)
req2.SetBodyString(`{"prompt": "world"}`)
if err := client.Do(req2, resp2); err != nil {
t.Fatalf("Second POST failed (ConfigureDialer retry should have saved it): %v", err)
}
if resp2.StatusCode() != 200 {
t.Fatalf("Second POST: expected 200, got %d", resp2.StatusCode())
}
t.Logf("Second POST succeeded after TTL mismatch via ConfigureDialer")
}
// TestConfigureRetry_Deprecated verifies the deprecated ConfigureRetry still works.
func TestConfigureRetry_Deprecated(t *testing.T) {
client := &fasthttp.Client{}
result := ConfigureRetry(client)
if result != client {
t.Error("ConfigureRetry should return the same client pointer")
}
if client.RetryIfErr == nil {
t.Fatal("ConfigureRetry should set RetryIfErr")
}
// Verify it uses the same StaleConnectionRetryIfErr
reset, retry := client.RetryIfErr(nil, 1, fmt.Errorf("cannot find whitespace"))
if !reset || !retry {
t.Error("ConfigureRetry should install StaleConnectionRetryIfErr")
}
}
// TestConfigureDialer_DialError verifies that dial errors from the existing
// dial function are properly propagated (not swallowed).
func TestConfigureDialer_DialError(t *testing.T) {
expectedErr := fmt.Errorf("proxy connection refused")
client := &fasthttp.Client{}
client.Dial = func(addr string) (net.Conn, error) {
return nil, expectedErr
}
ConfigureDialer(client)
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI("http://localhost:1/test")
req.Header.SetMethod(http.MethodPost)
err := client.Do(req, resp)
if err == nil {
t.Fatal("expected error from failed proxy dial")
}
t.Logf("Got expected error: %v", err)
}
// TestStaleConnectionRetryIfErr_WrappedErrors verifies behavior with wrapped errors.
func TestStaleConnectionRetryIfErr_WrappedErrors(t *testing.T) {
tests := []struct {
name string
err error
wantRetry bool
}{
{
name: "wrapped whitespace error",
err: fmt.Errorf("fasthttp: %w", fmt.Errorf("cannot find whitespace in header")),
wantRetry: true,
},
{
name: "wrapped connection reset",
err: fmt.Errorf("during POST: connection reset by peer"),
wantRetry: true,
},
{
name: "wrapped broken pipe",
err: fmt.Errorf("during POST: %w", fmt.Errorf("write tcp 10.0.0.1:53374->10.0.0.2:30000: write: broken pipe")),
wantRetry: true,
},
{
name: "ErrConnectionClosed from fasthttp",
err: fasthttp.ErrConnectionClosed,
wantRetry: false, // Not matched - this error appears AFTER the retry loop
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, retry := network.StaleConnectionRetryIfErr(nil, 1, tt.err)
if retry != tt.wantRetry {
t.Errorf("retry = %v, want %v", retry, tt.wantRetry)
}
})
}
}

View File

@@ -0,0 +1,14 @@
package utils
import (
"encoding/base64"
"fmt"
"net/http"
)
// FileBytesToBase64DataURL converts raw file bytes to base64 data URL format
func FileBytesToBase64DataURL(fileBytes []byte) string {
mimeType := http.DetectContentType(fileBytes)
b64Data := base64.StdEncoding.EncodeToString(fileBytes)
return fmt.Sprintf("data:%s;base64,%s", mimeType, b64Data)
}

View File

@@ -0,0 +1,302 @@
package utils
import (
"strings"
"testing"
"github.com/valyala/fasthttp"
)
func TestIsHTMLResponse(t *testing.T) {
tests := []struct {
name string
contentType string
body []byte
expectedIsHTML bool
description string
}{
{
name: "HTML with Content-Type header",
contentType: "text/html; charset=utf-8",
body: []byte("<html><body>Error</body></html>"),
expectedIsHTML: true,
description: "Should detect HTML from Content-Type header",
},
{
name: "HTML without Content-Type",
contentType: "application/octet-stream",
body: []byte("<!DOCTYPE html><html><head><title>Error 500</title></head></html>"),
expectedIsHTML: true,
description: "Should detect HTML from DOCTYPE",
},
{
name: "HTML with h1 tag",
contentType: "application/octet-stream",
body: []byte("<h1>Service Unavailable</h1>"),
expectedIsHTML: true,
description: "Should detect HTML from h1 tag",
},
{
name: "JSON response",
contentType: "application/json",
body: []byte(`{"error": "invalid request"}`),
expectedIsHTML: false,
description: "Should not detect JSON as HTML",
},
{
name: "Plain text response",
contentType: "text/plain",
body: []byte("Invalid request"),
expectedIsHTML: false,
description: "Should not detect plain text as HTML",
},
{
name: "Empty body",
contentType: "text/html",
body: []byte(""),
expectedIsHTML: true,
description: "Should detect HTML from Content-Type even with empty body",
},
{
name: "Very short body",
contentType: "application/json",
body: []byte("abc"),
expectedIsHTML: false,
description: "Should not detect very short body as HTML",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := &fasthttp.Response{}
resp.Header.Set("Content-Type", tt.contentType)
result := IsHTMLResponse(resp, tt.body)
if result != tt.expectedIsHTML {
t.Errorf("isHTMLResponse() = %v, want %v. %s", result, tt.expectedIsHTML, tt.description)
}
})
}
}
func TestExtractHTMLErrorMessage(t *testing.T) {
tests := []struct {
name string
htmlBody []byte
expectMsg string
description string
}{
{
name: "Extract from title tag",
htmlBody: []byte(`
<!DOCTYPE html>
<html>
<head><title>404 Not Found</title></head>
<body><p>The page was not found</p></body>
</html>
`),
expectMsg: "404 Not Found",
description: "Should extract title from title tag",
},
{
name: "Extract from h1 tag",
htmlBody: []byte(`
<!DOCTYPE html>
<html>
<body>
<h1>Service Unavailable</h1>
<p>The service is currently unavailable</p>
</body>
</html>
`),
expectMsg: "Service Unavailable",
description: "Should extract from h1 tag when title is missing",
},
{
name: "Extract from h2 tag",
htmlBody: []byte(`
<!DOCTYPE html>
<html>
<body>
<h2 class="error-header">Authentication Failed</h2>
<p>Please check your credentials</p>
</body>
</html>
`),
expectMsg: "Authentication Failed",
description: "Should extract from h2 tag with attributes",
},
{
name: "Extract visible text when no headers",
htmlBody: []byte(`
<html>
<body>
<div>There was an error processing your request. Please try again later.</div>
</body>
</html>
`),
expectMsg: "There was an error processing your request. Please try again later.",
description: "Should extract visible text from div when no headers found",
},
{
name: "Ignore script and style tags",
htmlBody: []byte(`
<html>
<head><title>Error</title></head>
<body>
<script>var x = 'ignore me';</script>
<style>.error { color: red; }</style>
<h1>Actual Error Message</h1>
</body>
</html>
`),
expectMsg: "Actual Error Message",
description: "Should ignore script and style content",
},
{
name: "Extract from first valid h1",
htmlBody: []byte(`
<html>
<body>
<h1></h1>
<h1>Second header with actual content</h1>
</body>
</html>
`),
expectMsg: "Second header with actual content",
description: "Should extract from first non-empty header",
},
{
name: "Handle meta description",
htmlBody: []byte(`
<html>
<head>
<meta name="description" content="Rate limit exceeded. Please wait 60 seconds.">
</head>
<body></body>
</html>
`),
expectMsg: "Rate limit exceeded. Please wait 60 seconds.",
description: "Should extract from meta description",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ExtractHTMLErrorMessage(tt.htmlBody)
if result != tt.expectMsg {
t.Errorf("extractHTMLErrorMessage() = %q, want %q. %s", result, tt.expectMsg, tt.description)
}
})
}
}
func TestHandleProviderAPIErrorWithHTML(t *testing.T) {
tests := []struct {
name string
statusCode int
contentType string
body []byte
description string
expectedInMessage string
}{
{
name: "HTML 500 error - lazy detection",
statusCode: 500,
contentType: "text/html; charset=utf-8",
body: []byte(`
<!DOCTYPE html>
<html>
<head><title>Internal Server Error</title></head>
<body><h1>Something went wrong</h1></body>
</html>
`),
description: "Should detect and handle HTML only after JSON parse fails",
expectedInMessage: "HTML response received from provider",
},
{
name: "HTML 403 error - lazy detection",
statusCode: 403,
contentType: "text/html",
body: []byte(`
<html>
<body>
<h1>Forbidden</h1>
<p>Access denied</p>
</body>
</html>
`),
description: "Should detect HTML on parse failure",
expectedInMessage: "HTML response received from provider",
},
{
name: "Invalid JSON with HTML fallback",
statusCode: 400,
contentType: "application/json",
body: []byte(`not valid json`),
description: "Should fall back to raw string when not HTML",
expectedInMessage: "provider API error",
},
{
name: "Valid JSON error response",
statusCode: 400,
contentType: "application/json",
body: []byte(`{"error": {"message": "Invalid request"}, "code": "invalid_request"}`),
description: "Should handle valid JSON without HTML detection",
expectedInMessage: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := &fasthttp.Response{}
resp.SetStatusCode(tt.statusCode)
resp.Header.Set("Content-Type", tt.contentType)
resp.SetBody(tt.body)
var errorResp map[string]interface{}
bifrostErr := HandleProviderAPIError(resp, &errorResp)
if bifrostErr == nil {
t.Errorf("HandleProviderAPIError() returned nil error")
return
}
if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != tt.statusCode {
t.Errorf("HandleProviderAPIError() status code = %v, want %v", bifrostErr.StatusCode, tt.statusCode)
}
if bifrostErr.Error == nil {
t.Errorf("HandleProviderAPIError() error field is nil")
return
}
// Check if expected message is in the response
if tt.expectedInMessage != "" && !strings.Contains(bifrostErr.Error.Message, tt.expectedInMessage) {
t.Errorf("Expected message to contain %q, got %q", tt.expectedInMessage, bifrostErr.Error.Message)
}
t.Logf("Handled %s: status=%d, message=%q", tt.name, *bifrostErr.StatusCode, bifrostErr.Error.Message)
})
}
}
func BenchmarkIsHTMLResponse(b *testing.B) {
resp := &fasthttp.Response{}
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
body := []byte(`<!DOCTYPE html><html><head><title>Error</title></head><body><h1>Test Error</h1></body></html>`)
b.ResetTimer()
for i := 0; i < b.N; i++ {
IsHTMLResponse(resp, body)
}
}
func BenchmarkExtractHTMLErrorMessage(b *testing.B) {
body := []byte(`<!DOCTYPE html><html><head><title>Internal Server Error</title></head><body><h1>Something went wrong</h1><p>This is a detailed error message</p></body></html>`)
b.ResetTimer()
for i := 0; i < b.N; i++ {
ExtractHTMLErrorMessage(body)
}
}

View File

@@ -0,0 +1,284 @@
package utils
import (
"errors"
"io"
"sync"
"sync/atomic"
"testing"
"time"
)
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
// readCloserSpy implements io.ReadCloser and records how many times Close() was called.
type readCloserSpy struct {
mu sync.Mutex
closed int
}
func (c *readCloserSpy) Read([]byte) (int, error) { return 0, io.EOF }
func (c *readCloserSpy) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed++
return nil
}
func (c *readCloserSpy) closeCount() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.closed
}
// zeroThenBlockReader returns (0, nil) on the first read, then blocks forever.
type zeroThenBlockReader struct {
first atomic.Bool
pipeRd *io.PipeReader
}
func (r *zeroThenBlockReader) Read(p []byte) (int, error) {
if r.first.CompareAndSwap(false, true) {
return 0, nil // zero-byte read
}
// block until pipe is closed
return r.pipeRd.Read(p)
}
// ---------------------------------------------------------------------------
// tests
// ---------------------------------------------------------------------------
func TestIdleTimeoutReader_NormalRead(t *testing.T) {
t.Parallel()
pr, pw := io.Pipe()
defer pr.Close()
// Use pr as bodyStream — closing pr unblocks reads.
wrapped, cleanup := NewIdleTimeoutReader(pr, pr, 500*time.Millisecond)
defer cleanup()
// Writer sends 5 chunks quickly.
go func() {
for i := 0; i < 5; i++ {
time.Sleep(10 * time.Millisecond)
pw.Write([]byte("chunk"))
}
pw.Close()
}()
buf := make([]byte, 64)
var totalBytes int
for {
n, err := wrapped.Read(buf)
totalBytes += n
if err != nil {
if err != io.EOF {
t.Fatalf("unexpected error: %v", err)
}
break
}
}
if totalBytes != 5*len("chunk") {
t.Fatalf("expected %d bytes, got %d", 5*len("chunk"), totalBytes)
}
}
func TestIdleTimeoutReader_TimeoutClosesStream(t *testing.T) {
t.Parallel()
pr, pw := io.Pipe()
defer pw.Close()
// 100ms timeout, write nothing — should timeout and close the pipe reader.
wrapped, cleanup := NewIdleTimeoutReader(pr, pr, 100*time.Millisecond)
defer cleanup()
start := time.Now()
buf := make([]byte, 64)
_, err := wrapped.Read(buf)
elapsed := time.Since(start)
if err == nil {
t.Fatal("expected an error from timed-out read, got nil")
}
// Should complete within ~200ms (100ms timeout + margin), not hang.
if elapsed > 500*time.Millisecond {
t.Fatalf("read took %v, expected ~100ms timeout", elapsed)
}
}
func TestIdleTimeoutReader_TimeoutAfterPartialData(t *testing.T) {
t.Parallel()
pr, pw := io.Pipe()
// 200ms idle timeout.
wrapped, cleanup := NewIdleTimeoutReader(pr, pr, 200*time.Millisecond)
defer cleanup()
// Writer sends 3 chunks then stops.
go func() {
for i := 0; i < 3; i++ {
time.Sleep(20 * time.Millisecond)
pw.Write([]byte("data"))
}
// stop writing — idle timeout should fire after 200ms and close pr
}()
buf := make([]byte, 64)
chunksRead := 0
for {
n, err := wrapped.Read(buf)
if n > 0 {
chunksRead++
}
if err != nil {
break
}
}
if chunksRead != 3 {
t.Fatalf("expected 3 chunks before timeout, got %d", chunksRead)
}
pw.Close()
}
func TestIdleTimeoutReader_ResetOnData(t *testing.T) {
t.Parallel()
pr, pw := io.Pipe()
// 200ms timeout, but data arrives every 150ms — should never timeout.
wrapped, cleanup := NewIdleTimeoutReader(pr, pr, 200*time.Millisecond)
defer cleanup()
go func() {
for i := 0; i < 5; i++ {
time.Sleep(150 * time.Millisecond)
pw.Write([]byte("ok"))
}
pw.Close()
}()
buf := make([]byte, 64)
chunksRead := 0
for {
n, err := wrapped.Read(buf)
if n > 0 {
chunksRead++
}
if err != nil {
if err != io.EOF {
t.Fatalf("expected EOF after all chunks, got: %v", err)
}
break
}
}
if chunksRead != 5 {
t.Fatalf("expected 5 chunks (timer should reset), got %d", chunksRead)
}
}
func TestIdleTimeoutReader_CleanupStopsTimer(t *testing.T) {
t.Parallel()
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
spy := &readCloserSpy{}
_, cleanup := NewIdleTimeoutReader(pr, spy, 100*time.Millisecond)
// Call cleanup immediately — timer should be stopped.
cleanup()
// Wait well past the timeout duration.
time.Sleep(250 * time.Millisecond)
if spy.closeCount() != 0 {
t.Fatalf("expected closer to NOT be called after cleanup, but was called %d times", spy.closeCount())
}
}
func TestIdleTimeoutReader_DoubleCloseIsSafe(t *testing.T) {
t.Parallel()
spy := &readCloserSpy{}
br := &zeroThenBlockReader{first: atomic.Bool{}, pipeRd: nil}
// Use spy as bodyStream — it implements both io.Reader and io.Closer.
_, cleanup := NewIdleTimeoutReader(br, spy, 50*time.Millisecond)
defer cleanup()
// Let the timer fire (closes spy via sync.Once).
time.Sleep(100 * time.Millisecond)
// Manually close again — should not panic.
spy.Close()
// sync.Once ensures the idle timer's close ran exactly once.
// The manual close above adds another, so total should be 2
// (the once.Do protects the timer path, not external callers).
// The key guarantee: no panic.
if spy.closeCount() < 1 {
t.Fatal("expected at least one close call")
}
}
func TestIdleTimeoutReader_ZeroBytesDoNotResetTimer(t *testing.T) {
t.Parallel()
pr, pw := io.Pipe()
defer pw.Close()
// Use pr as bodyStream — when idle timeout fires, it closes pr,
// which causes reads on pr to return io.ErrClosedPipe.
zr := &zeroThenBlockReader{pipeRd: pr}
wrapped, cleanup := NewIdleTimeoutReader(zr, pr, 100*time.Millisecond)
defer cleanup()
done := make(chan error, 1)
go func() {
buf := make([]byte, 64)
// First read returns (0, nil), second read blocks until pipe is closed.
for {
_, err := wrapped.Read(buf)
if err != nil {
done <- err
return
}
}
}()
select {
case <-done:
// Timer fired and closed the pipe — Read() returned an error. Good.
case <-time.After(500 * time.Millisecond):
t.Fatal("expected idle timeout to fire, but read is still blocking")
}
}
func TestIdleTimeoutReader_ErrorFromClosedPipe(t *testing.T) {
t.Parallel()
pr, pw := io.Pipe()
defer pw.Close()
// Use pr as bodyStream — when idle timeout fires, it closes pr,
// which makes Read return io.ErrClosedPipe.
wrapped, cleanup := NewIdleTimeoutReader(pr, pr, 50*time.Millisecond)
defer cleanup()
buf := make([]byte, 64)
_, err := wrapped.Read(buf)
if err == nil {
t.Fatal("expected error from closed pipe")
}
// The error should indicate the pipe was closed.
if !errors.Is(err, io.ErrClosedPipe) && !errors.Is(err, io.EOF) {
// Some implementations return io.ErrClosedPipe, others EOF.
t.Logf("got error: %v (acceptable)", err)
}
}

View File

@@ -0,0 +1,50 @@
package utils
import (
"strconv"
"strings"
)
// ConvertSizeToAspectRatioAndResolution converts a standard size string (e.g., "1024x1024")
// to an aspect ratio and image size tier.
// aspectRatio is one of "1:1", "3:4", "4:3", "9:16", "16:9" (empty if unrecognised).
// imageSize is one of "1K", "2K", "4K" (empty if out of range).
func ConvertSizeToAspectRatioAndResolution(size string) (aspectRatio, imageSize string) {
parts := strings.Split(size, "x")
if len(parts) != 2 {
return "", ""
}
width, err1 := strconv.Atoi(parts[0])
height, err2 := strconv.Atoi(parts[1])
if err1 != nil || err2 != nil {
return "", ""
}
if width <= 0 || height <= 0 {
return "", ""
}
if width <= 1024 && height <= 1024 {
imageSize = "1K"
} else if width <= 2048 && height <= 2048 {
imageSize = "2K"
} else if width <= 4096 && height <= 4096 {
imageSize = "4K"
}
ratio := float64(width) / float64(height)
if ratio >= 0.99 && ratio <= 1.01 {
aspectRatio = "1:1"
} else if ratio >= 0.74 && ratio <= 0.76 {
aspectRatio = "3:4"
} else if ratio >= 1.32 && ratio <= 1.34 {
aspectRatio = "4:3"
} else if ratio >= 0.56 && ratio <= 0.57 {
aspectRatio = "9:16"
} else if ratio >= 1.77 && ratio <= 1.78 {
aspectRatio = "16:9"
}
return aspectRatio, imageSize
}

View File

@@ -0,0 +1,339 @@
package utils
import (
"bytes"
"io"
"math"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// LargeResponseReader wraps an io.Reader and releases the fasthttp response on Close.
// Used by providers to keep the response alive while the transport streams it to the client.
type LargeResponseReader struct {
io.Reader
Resp *fasthttp.Response
cleanup func()
consumed bool // true after Read returns io.EOF — body fully consumed through Reader chain
}
// Read delegates to the wrapped Reader and tracks EOF so Close() can skip
// a redundant (and potentially blocking) drain of the body stream.
func (r *LargeResponseReader) Read(p []byte) (int, error) {
n, err := r.Reader.Read(p)
if err == io.EOF {
r.consumed = true
}
return n, err
}
// Close drains any unconsumed body stream and releases the underlying fasthttp
// response back to the pool. Draining prevents "whitespace in header" errors on
// connection reuse when the client disconnects before the full response is consumed
// (see: fasthttp#1743).
//
// When the body was already fully consumed through the Reader chain (consumed == true),
// the drain is skipped. For identity-encoded responses (no Content-Length), the body
// stream is a fasthttp closeReader that blocks until the TCP connection closes — which
// can take minutes if the upstream server keeps the connection alive.
func (r *LargeResponseReader) Close() error {
if r == nil || r.Resp == nil {
return nil
}
if !r.consumed {
if bodyStream := r.Resp.BodyStream(); bodyStream != nil {
_, _ = io.Copy(io.Discard, bodyStream)
if closer, ok := bodyStream.(io.Closer); ok {
_ = closer.Close()
}
}
}
if r.cleanup != nil {
r.cleanup()
r.cleanup = nil
}
fasthttp.ReleaseResponse(r.Resp)
r.Resp = nil
return nil
}
// BuildLargeResponseClient creates a streaming-enabled fasthttp client for large response detection.
// The client caps buffering at the threshold and enables response body streaming.
//
// ReadTimeout/WriteTimeout/MaxConnDuration are zeroed: large-response bodies may take arbitrarily
// long to download, and fasthttp's ReadTimeout bounds *full* body read — not idle. Idle detection
// on stalled streams is handled separately (see NewIdleTimeoutReader / SetupStreamingPassthrough).
func BuildLargeResponseClient(base *fasthttp.Client, responseThreshold int64) *fasthttp.Client {
client := CloneFastHTTPClientConfig(base)
if responseThreshold > 0 && responseThreshold <= int64(math.MaxInt) {
client.MaxResponseBodySize = int(responseThreshold)
}
client.StreamResponseBody = true
client.ReadTimeout = 0
client.WriteTimeout = 0
client.MaxConnDuration = 0
return client
}
// PrepareResponseStreaming configures response body streaming when a large response
// threshold is set in context. Returns the client to use for MakeRequestWithContext.
// When threshold > 0: sets resp.StreamBody = true and returns a streaming-enabled client.
// When threshold <= 0: returns the original client unchanged (no-op for feature-off path).
func PrepareResponseStreaming(ctx *schemas.BifrostContext, client *fasthttp.Client, resp *fasthttp.Response) *fasthttp.Client {
responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64)
if responseThreshold <= 0 {
return client
}
resp.StreamBody = true
return BuildLargeResponseClient(client, responseThreshold)
}
// MaterializeStreamErrorBody reads a streamed error body into resp so that resp.Body()
// returns the error payload for parsing. No-op when response streaming is not active.
func MaterializeStreamErrorBody(ctx *schemas.BifrostContext, resp *fasthttp.Response) {
responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64)
if responseThreshold <= 0 {
return
}
if bodyStream := resp.BodyStream(); bodyStream != nil {
gz, reader, wasGzip := decompressBodyStreamIfGzip(resp, bodyStream)
if wasGzip {
defer ReleaseGzipReader(gz)
}
bodyBytes, readErr := io.ReadAll(io.LimitReader(reader, 512*1024)) // 512KB cap for error bodies
if readErr != nil {
return
}
resp.SetBody(bodyBytes)
}
}
// FinalizeResponseWithLargeDetection processes the response body with optional large response
// detection. Takes ownership semantics: when isLargeResponse is true, the caller must NOT
// release resp (it's wrapped in a reader stored in context). When false, resp is unchanged
// and the caller should release as normal.
//
// Returns:
// - (body, false, nil) — normal path; body ready for parsing; resp NOT released.
// - (nil, true, nil) — large response detected; context keys set for streaming;
// caller must set respOwned = false.
// - (nil, false, err) — error; resp NOT released.
func FinalizeResponseWithLargeDetection(
ctx *schemas.BifrostContext,
resp *fasthttp.Response,
logger schemas.Logger,
) ([]byte, bool, *schemas.BifrostError) {
responseThreshold, _ := ctx.Value(schemas.BifrostContextKeyLargeResponseThreshold).(int64)
// No threshold — normal buffered read (feature-off path)
if responseThreshold <= 0 {
body, err := CheckAndDecodeBody(resp)
if err != nil {
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
// Copy body before caller releases resp
return append([]byte(nil), body...), false, nil
}
contentLength := resp.Header.ContentLength()
// Known small response — read from stream, return body for normal parsing
if contentLength > 0 && int64(contentLength) <= responseThreshold {
if bodyStream := resp.BodyStream(); bodyStream != nil {
gz, reader, wasGzip := decompressBodyStreamIfGzip(resp, bodyStream)
if wasGzip {
defer ReleaseGzipReader(gz)
}
bodyBytes, readErr := io.ReadAll(reader)
if readErr != nil {
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr)
}
return bodyBytes, false, nil
}
// No stream — buffered fallback
body, err := CheckAndDecodeBody(resp)
if err != nil {
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
return append([]byte(nil), body...), false, nil
}
// Unknown Content-Length (chunked transfer encoding) — buffer up to responseThreshold
// to determine if response is truly large. Responses within threshold are returned
// buffered for normal parsing/logging; only responses exceeding threshold are streamed.
if contentLength <= 0 {
if bodyStream := resp.BodyStream(); bodyStream != nil {
gz, reader, wasGzip := decompressBodyStreamIfGzip(resp, bodyStream)
releaseGzip := func() {}
if wasGzip {
releaseGzip = func() {
ReleaseGzipReader(gz)
}
}
bodyBytes, readErr := io.ReadAll(io.LimitReader(reader, responseThreshold+1))
if readErr != nil {
releaseGzip()
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr)
}
if int64(len(bodyBytes)) <= responseThreshold {
releaseGzip()
return bodyBytes, false, nil
}
// Exceeds threshold without Content-Length — set up large response streaming.
combinedReader := io.MultiReader(bytes.NewReader(bodyBytes), reader)
closableReader := &LargeResponseReader{
Reader: combinedReader,
Resp: resp,
cleanup: releaseGzip,
}
ctx.SetValue(schemas.BifrostContextKeyLargeResponseMode, true)
ctx.SetValue(schemas.BifrostContextKeyLargeResponseReader, closableReader)
ctx.SetValue(schemas.BifrostContextKeyLargeResponseContentLength, contentLength)
if ct := string(resp.Header.ContentType()); ct != "" {
ctx.SetValue(schemas.BifrostContextKeyLargeResponseContentType, ct)
}
previewLen := min(len(bodyBytes), 1048576)
ctx.SetValue(schemas.BifrostContextKeyLargePayloadResponsePreview, string(bodyBytes[:previewLen]))
return nil, true, nil
}
// No stream — buffered fallback
body, err := CheckAndDecodeBody(resp)
if err != nil {
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
return append([]byte(nil), body...), false, nil
}
// Known large response (Content-Length > threshold) — prefetch first 64KB for
// metadata extraction, then stream the rest without full materialization.
bodyStream := resp.BodyStream()
if bodyStream == nil {
// No stream available — fall back to buffered read
if logger != nil {
logger.Warn("large-response fallback to buffered path: content_length=%d threshold=%d body_stream_nil=true", contentLength, responseThreshold)
}
body, err := CheckAndDecodeBody(resp)
if err != nil {
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
return append([]byte(nil), body...), false, nil
}
// Decompress on-the-fly if provider returned gzip-encoded response.
// Clears Content-Encoding so the transport doesn't re-add it to the client response.
gz, decompressedStream, wasGzip := decompressBodyStreamIfGzip(resp, bodyStream)
if wasGzip {
contentLength = -1 // decompressed size unknown; transport will use chunked encoding
}
prefetchSize := 64 * 1024 // default
if ps, ok := ctx.Value(schemas.BifrostContextKeyLargePayloadPrefetchSize).(int); ok && ps > 0 {
prefetchSize = ps
}
prefetchBuf := make([]byte, prefetchSize)
n, readErr := io.ReadFull(decompressedStream, prefetchBuf)
if readErr != nil && readErr != io.EOF && readErr != io.ErrUnexpectedEOF {
if wasGzip {
ReleaseGzipReader(gz)
}
return nil, false, NewBifrostOperationError(schemas.ErrProviderResponseDecode, readErr)
}
prefetchBuf = prefetchBuf[:n]
combinedReader := io.MultiReader(bytes.NewReader(prefetchBuf), decompressedStream)
closableReader := &LargeResponseReader{
Reader: combinedReader,
Resp: resp,
cleanup: func() {
if wasGzip {
ReleaseGzipReader(gz)
}
},
}
ctx.SetValue(schemas.BifrostContextKeyLargeResponseMode, true)
ctx.SetValue(schemas.BifrostContextKeyLargeResponseReader, closableReader)
ctx.SetValue(schemas.BifrostContextKeyLargeResponseContentLength, contentLength)
if ct := string(resp.Header.ContentType()); ct != "" {
ctx.SetValue(schemas.BifrostContextKeyLargeResponseContentType, ct)
}
previewLen := min(n, 1048576)
ctx.SetValue(schemas.BifrostContextKeyLargePayloadResponsePreview, string(prefetchBuf[:previewLen]))
return nil, true, nil
}
// ParseOpenAIUsageFromBytes parses OpenAI-format usage from raw JSON bytes into BifrostLLMUsage.
// Handles both Chat Completions (prompt_tokens/completion_tokens) and Responses API
// (input_tokens/output_tokens) field names. Expects the "usage" object bytes directly,
// not the full response body.
func ParseOpenAIUsageFromBytes(data []byte) *schemas.BifrostLLMUsage {
var usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
// Responses API uses different field names
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
if err := sonic.Unmarshal(data, &usage); err != nil {
return nil
}
result := &schemas.BifrostLLMUsage{}
if usage.PromptTokens > 0 {
result.PromptTokens = usage.PromptTokens
} else if usage.InputTokens > 0 {
result.PromptTokens = usage.InputTokens
}
if usage.CompletionTokens > 0 {
result.CompletionTokens = usage.CompletionTokens
} else if usage.OutputTokens > 0 {
result.CompletionTokens = usage.OutputTokens
}
if usage.TotalTokens > 0 {
result.TotalTokens = usage.TotalTokens
} else {
result.TotalTokens = result.PromptTokens + result.CompletionTokens
}
if result.TotalTokens == 0 {
return nil
}
return result
}
// SetupStreamingPassthrough configures large response passthrough for streaming
// responses when large payload mode is active. Wraps the response body stream
// in a LargeResponseReader and sets context keys for the transport layer.
// Returns true if passthrough was set up. When true, the caller should return
// a closed channel and must NOT release resp — it's owned by the reader in context.
func SetupStreamingPassthrough(ctx *schemas.BifrostContext, resp *fasthttp.Response) bool {
isLargePayload, _ := ctx.Value(schemas.BifrostContextKeyLargePayloadMode).(bool)
if !isLargePayload {
return false
}
reader, releaseGzip := DecompressStreamBody(resp)
// Wrap reader with idle timeout to detect stalled streams.
reader, stopIdleTimeout := NewIdleTimeoutReader(reader, resp.BodyStream(), GetStreamIdleTimeout(ctx))
closableReader := &LargeResponseReader{
Reader: reader,
Resp: resp,
cleanup: func() {
stopIdleTimeout()
releaseGzip()
},
}
ctx.SetValue(schemas.BifrostContextKeyLargeResponseMode, true)
ctx.SetValue(schemas.BifrostContextKeyLargeResponseReader, closableReader)
if ct := string(resp.Header.ContentType()); ct != "" {
ctx.SetValue(schemas.BifrostContextKeyLargeResponseContentType, ct)
}
return true
}

View File

@@ -0,0 +1,406 @@
package utils
import (
"context"
"net"
"sync/atomic"
"testing"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
)
// newTestServer creates an in-memory fasthttp server that responds after the given delay.
// Returns a client configured to talk to it and a cleanup function.
func newTestServer(t *testing.T, delay time.Duration, statusCode int) (*fasthttp.Client, func()) {
t.Helper()
ln := fasthttputil.NewInmemoryListener()
server := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
if delay > 0 {
time.Sleep(delay)
}
ctx.SetStatusCode(statusCode)
ctx.SetBody([]byte(`{"ok":true}`))
},
}
go server.Serve(ln) //nolint:errcheck
client := &fasthttp.Client{
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
}
cleanup := func() {
ln.Close()
}
return client, cleanup
}
func TestMakeRequestWithContext_SuccessReturnsNoopWait(t *testing.T) {
client, cleanup := newTestServer(t, 0, 200)
defer cleanup()
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI("http://test/")
latency, bifrostErr, wait := MakeRequestWithContext(context.Background(), client, req, resp)
defer wait()
if bifrostErr != nil {
t.Fatalf("expected no error, got: %v", bifrostErr.Error.Message)
}
if latency <= 0 {
t.Fatal("expected positive latency")
}
if resp.StatusCode() != 200 {
t.Fatalf("expected status 200, got %d", resp.StatusCode())
}
}
func TestMakeRequestWithContext_DeadlineExceededReturnsTimeoutError(t *testing.T) {
// Server takes 500ms to respond
client, cleanup := newTestServer(t, 500*time.Millisecond, 200)
defer cleanup()
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
req.SetRequestURI("http://test/")
// Deadline exceeded almost immediately
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, bifrostErr, wait := MakeRequestWithContext(ctx, client, req, resp)
// Should get a timeout error with 504 status
if bifrostErr == nil {
t.Fatal("expected timeout error")
}
if bifrostErr.Error.Type == nil || *bifrostErr.Error.Type != schemas.RequestTimedOut {
t.Fatalf("expected RequestTimedOut error type, got: %v", bifrostErr.Error.Type)
}
if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != 504 {
t.Fatalf("expected status 504, got: %v", bifrostErr.StatusCode)
}
// wait() should block until the goroutine finishes, then we can safely release
start := time.Now()
wait()
elapsed := time.Since(start)
// The wait should have taken roughly the remaining server delay (~490ms)
if elapsed < 200*time.Millisecond {
t.Fatalf("wait() returned too quickly (%v), expected it to block until goroutine finishes", elapsed)
}
// Now safe to release
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
}
func TestMakeRequestWithContext_ContextCancelReturnsCancelledError(t *testing.T) {
// Server takes 500ms to respond
client, cleanup := newTestServer(t, 500*time.Millisecond, 200)
defer cleanup()
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
req.SetRequestURI("http://test/")
// Cancel context explicitly (not deadline)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
}()
_, bifrostErr, wait := MakeRequestWithContext(ctx, client, req, resp)
// Should get a cancellation error with 499 status
if bifrostErr == nil {
t.Fatal("expected cancellation error")
}
if bifrostErr.Error.Type == nil || *bifrostErr.Error.Type != schemas.RequestCancelled {
t.Fatalf("expected RequestCancelled error type, got: %v", bifrostErr.Error.Type)
}
if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != 499 {
t.Fatalf("expected status 499, got: %v", bifrostErr.StatusCode)
}
// wait() should block until the goroutine finishes
start := time.Now()
wait()
elapsed := time.Since(start)
if elapsed < 200*time.Millisecond {
t.Fatalf("wait() returned too quickly (%v), expected it to block until goroutine finishes", elapsed)
}
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
}
func TestMakeRequestWithContext_WaitPreventsDataRace(t *testing.T) {
// This test verifies the fix for the data race. Under -race, accessing resp
// while client.Do is still writing to it would be flagged. The wait function
// ensures we don't release until the goroutine is done.
//
// Run with: go test -race -run TestMakeRequestWithContext_WaitPreventsDataRace
// Server responds after 200ms
client, cleanup := newTestServer(t, 200*time.Millisecond, 200)
defer cleanup()
for range 10 {
func() {
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
req.SetRequestURI("http://test/")
// Cancel context after 5ms — well before server responds
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer cancel()
_, _, wait := MakeRequestWithContext(ctx, client, req, resp)
// Simulate the real caller pattern: defer wait() before defer Release.
// Go defers are LIFO, so wait() runs first, then Release.
// This is the pattern that prevents the data race.
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
defer wait()
}()
}
}
func TestMakeRequestWithContext_WaitIsIdempotent(t *testing.T) {
client, cleanup := newTestServer(t, 50*time.Millisecond, 200)
defer cleanup()
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI("http://test/")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer cancel()
_, _, wait := MakeRequestWithContext(ctx, client, req, resp)
// First call should block
wait()
// Second call should not deadlock (channel already drained)
// Note: this will deadlock if the implementation is wrong, so the test
// would time out rather than fail gracefully.
done := make(chan struct{})
go func() {
wait()
close(done)
}()
select {
case <-done:
// Second wait() completed — but note this actually WILL deadlock with
// the current implementation since <-errChan can only be read once.
// This documents the behavior: wait() should only be called once.
case <-time.After(100 * time.Millisecond):
// Expected: second wait() blocks forever because errChan is already drained.
// This is fine — callers should only call wait() once (via a single defer).
}
}
func TestMakeRequestWithContext_SuccessWaitDoesNotBlock(t *testing.T) {
client, cleanup := newTestServer(t, 0, 200)
defer cleanup()
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI("http://test/")
_, _, wait := MakeRequestWithContext(context.Background(), client, req, resp)
// On the success path, wait should be a noop that returns immediately
start := time.Now()
wait()
if time.Since(start) > 10*time.Millisecond {
t.Fatal("wait() on success path should be a noop and return immediately")
}
}
func TestMakeRequestWithContext_ConcurrentRequestsWithCancellation(t *testing.T) {
// Simulate the production scenario: multiple concurrent requests where some
// contexts cancel while the HTTP call is in-flight. Under -race, this would
// detect the original bug where deferred Release races with client.Do.
client, cleanup := newTestServer(t, 100*time.Millisecond, 200)
defer cleanup()
const numRequests = 20
var completed atomic.Int32
done := make(chan struct{})
for range numRequests {
go func() {
defer func() {
if completed.Add(1) == numRequests {
close(done)
}
}()
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
req.SetRequestURI("http://test/")
// Half the requests cancel early, half complete normally
var ctx context.Context
var cancel context.CancelFunc
if completed.Load()%2 == 0 {
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Millisecond)
} else {
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
}
_, _, wait := MakeRequestWithContext(ctx, client, req, resp)
// Correct pattern: wait before release
wait()
cancel()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
}()
}
select {
case <-done:
// All requests completed
case <-time.After(10 * time.Second):
t.Fatalf("timed out waiting for requests, only %d/%d completed", completed.Load(), numRequests)
}
}
func TestNewBifrostTimeoutError(t *testing.T) {
err := NewBifrostTimeoutError("test timeout", context.DeadlineExceeded)
if !err.IsBifrostError {
t.Fatal("expected IsBifrostError to be true")
}
if err.StatusCode == nil || *err.StatusCode != 504 {
t.Fatalf("expected StatusCode 504, got %v", err.StatusCode)
}
if err.Error.Type == nil || *err.Error.Type != schemas.RequestTimedOut {
t.Fatalf("expected RequestTimedOut type, got %v", err.Error.Type)
}
if err.Error.Message != "test timeout" {
t.Fatalf("expected 'test timeout', got %s", err.Error.Message)
}
// Note: ExtraFields.Provider is populated by bifrost.go's dispatcher via
// PopulateExtraFields, not by NewBifrostTimeoutError — the constructor has
// no provider context.
}
func TestMakeRequestWithContext_ClientError(t *testing.T) {
// Test that client errors still return noop wait function
client := &fasthttp.Client{
Dial: func(addr string) (net.Conn, error) {
return nil, &net.OpError{Op: "dial", Net: "tcp", Err: &net.DNSError{Err: "no such host", Name: "nonexistent.invalid"}}
},
ReadTimeout: 1 * time.Second,
WriteTimeout: 1 * time.Second,
}
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI("http://nonexistent.invalid/")
_, bifrostErr, wait := MakeRequestWithContext(context.Background(), client, req, resp)
defer wait()
if bifrostErr == nil {
t.Fatal("expected error for nonexistent host")
}
// wait should be noop since the goroutine completed (with error)
start := time.Now()
wait()
if time.Since(start) > 10*time.Millisecond {
t.Fatal("wait() should be noop on error path")
}
}
func TestMakeRequestWithContext_DeferOrderingPattern(t *testing.T) {
// Verify the exact defer pattern used by callers works correctly under -race.
// This mirrors the real provider code pattern.
client, cleanup := newTestServer(t, 150*time.Millisecond, 200)
defer cleanup()
// Track the order of operations
var order []string
var orderDone = make(chan struct{})
go func() {
defer close(orderDone)
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
req.SetRequestURI("http://test/")
// Mimic the real provider pattern with defer ordering:
// These defers run in reverse order (LIFO)
defer func() {
fasthttp.ReleaseRequest(req)
order = append(order, "release-req")
}()
defer func() {
fasthttp.ReleaseResponse(resp)
order = append(order, "release-resp")
}()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, _, wait := MakeRequestWithContext(ctx, client, req, resp)
// This defer runs FIRST (last declared = first to run)
defer func() {
wait()
order = append(order, "wait-done")
}()
}()
select {
case <-orderDone:
case <-time.After(5 * time.Second):
t.Fatal("timed out")
}
// Verify order: wait must complete before any release
if len(order) != 3 {
t.Fatalf("expected 3 operations, got %d: %v", len(order), order)
}
if order[0] != "wait-done" {
t.Fatalf("expected wait-done first, got: %v", order)
}
if order[1] != "release-resp" {
t.Fatalf("expected release-resp second, got: %v", order)
}
if order[2] != "release-req" {
t.Fatalf("expected release-req third, got: %v", order)
}
}

View File

@@ -0,0 +1,265 @@
package utils
import (
"container/list"
"strings"
"sync"
"github.com/maximhq/bifrost/core/schemas"
)
const DefaultModelParamsCacheSize = 2048
// ModelParams holds cached parameters for a model.
// Add new fields here as more model-level parameters need caching.
type ModelParams struct {
MaxOutputTokens *int
}
type modelParamsCacheEntry struct {
model string
params ModelParams
}
// inflightCall represents an in-progress cache miss handler invocation.
// Multiple goroutines waiting for the same model share one call.
type inflightCall struct {
done chan struct{}
result *ModelParams
}
type modelParamsCache struct {
mu sync.RWMutex
capacity int
items map[string]*list.Element
order *list.List // front = most recently inserted/updated
cacheMissHandler func(model string) *ModelParams
inflightMu sync.Mutex
inflight map[string]*inflightCall
}
var (
globalModelParamsCache *modelParamsCache
cacheOnce sync.Once
)
// knownAnthropicMaxOutputTokens provides static fallback defaults for Claude models
// when both cache and DB miss handler return nothing. Only Anthropic requires max_tokens.
var knownAnthropicMaxOutputTokens = map[string]int{
"claude-opus-4-6": 128000,
"claude-sonnet-4-6": 64000,
"claude-haiku-4-5": 64000,
"claude-sonnet-4-5": 64000,
"claude-opus-4-5": 64000,
"claude-opus-4-1": 32000,
"claude-sonnet-4": 64000,
"claude-opus-4": 32000,
"claude-sonnet-4-0": 64000,
"claude-opus-4-0": 32000,
"claude-3-5-sonnet": 8192,
"claude-3-5-haiku": 8192,
"claude-3-7-sonnet": 8192,
"claude-3-opus": 4096,
"claude-3-sonnet": 4096,
"claude-3-haiku": 4096,
}
func newModelParamsCache(capacity int) *modelParamsCache {
return &modelParamsCache{
capacity: capacity,
items: make(map[string]*list.Element, capacity),
order: list.New(),
inflight: make(map[string]*inflightCall),
}
}
func getModelParamsCache() *modelParamsCache {
cacheOnce.Do(func() {
globalModelParamsCache = newModelParamsCache(DefaultModelParamsCacheSize)
})
return globalModelParamsCache
}
func (c *modelParamsCache) Get(model string) (ModelParams, bool) {
c.mu.Lock()
elem, ok := c.items[model]
if ok {
c.order.MoveToFront(elem)
params := elem.Value.(*modelParamsCacheEntry).params
c.mu.Unlock()
return params, true
}
handler := c.cacheMissHandler
c.mu.Unlock()
if handler == nil {
return ModelParams{}, false
}
// Deduplicate concurrent miss handler calls for the same model.
c.inflightMu.Lock()
if call, ok := c.inflight[model]; ok {
c.inflightMu.Unlock()
<-call.done
if call.result == nil {
return ModelParams{}, false
}
return *call.result, true
}
call := &inflightCall{done: make(chan struct{})}
c.inflight[model] = call
c.inflightMu.Unlock()
result := handler(model)
call.result = result
close(call.done)
c.inflightMu.Lock()
delete(c.inflight, model)
c.inflightMu.Unlock()
if result == nil {
return ModelParams{}, false
}
c.Set(model, *result)
return *result, true
}
func (c *modelParamsCache) Set(model string, params ModelParams) {
c.mu.Lock()
defer c.mu.Unlock()
if elem, ok := c.items[model]; ok {
elem.Value.(*modelParamsCacheEntry).params = params
c.order.MoveToFront(elem)
return
}
if c.order.Len() >= c.capacity {
c.evict()
}
entry := &modelParamsCacheEntry{model: model, params: params}
elem := c.order.PushFront(entry)
c.items[model] = elem
}
func (c *modelParamsCache) BulkSet(entries map[string]ModelParams) {
c.mu.Lock()
defer c.mu.Unlock()
for model, params := range entries {
if elem, ok := c.items[model]; ok {
elem.Value.(*modelParamsCacheEntry).params = params
c.order.MoveToFront(elem)
continue
}
if c.order.Len() >= c.capacity {
c.evict()
}
entry := &modelParamsCacheEntry{model: model, params: params}
elem := c.order.PushFront(entry)
c.items[model] = elem
}
}
func (c *modelParamsCache) evict() {
tail := c.order.Back()
if tail == nil {
return
}
c.order.Remove(tail)
delete(c.items, tail.Value.(*modelParamsCacheEntry).model)
}
// GetModelParams returns the cached parameters for a model.
// On cache miss, calls the registered miss handler (if any) to load from DB.
func GetModelParams(model string) (ModelParams, bool) {
return getModelParamsCache().Get(model)
}
// SetModelParams sets the parameters for a model in the cache.
func SetModelParams(model string, params ModelParams) {
getModelParamsCache().Set(model, params)
}
// BulkSetModelParams sets parameters for multiple models at once.
func BulkSetModelParams(entries map[string]ModelParams) {
getModelParamsCache().BulkSet(entries)
}
// SetCacheMissHandler registers a callback invoked on cache miss.
// The handler should query the DB for the model's parameters and return them,
// or nil if not found. The result is automatically cached.
func SetCacheMissHandler(fn func(model string) *ModelParams) {
c := getModelParamsCache()
c.mu.Lock()
defer c.mu.Unlock()
c.cacheMissHandler = fn
}
// GetMaxOutputTokens returns the cached max_output_tokens for a model.
// Returns 0, false on cache miss or if max_output_tokens is not set.
func GetMaxOutputTokens(model string) (int, bool) {
params, ok := GetModelParams(model)
if !ok || params.MaxOutputTokens == nil {
return 0, false
}
return *params.MaxOutputTokens, true
}
// GetMaxOutputTokensOrDefault returns the cached max_output_tokens for a model,
// or the provided default value on cache miss. For Claude models, falls back to
// known static defaults before using the caller's default.
func GetMaxOutputTokensOrDefault(model string, defaultValue int) int {
if m, ok := GetMaxOutputTokens(model); ok {
return m
}
if strings.Contains(model, "claude") {
base := normalizeClaudeModelName(model)
if base != model {
if m, ok := GetMaxOutputTokens(base); ok {
return m
}
}
if m, ok := knownAnthropicMaxOutputTokens[base]; ok {
return m
}
}
return defaultValue
}
// normalizeClaudeModelName extracts the base Claude model name from
// provider-specific model ID formats.
//
// Examples:
//
// "claude-sonnet-4-20250514" → "claude-sonnet-4"
// "anthropic.claude-sonnet-4-20250514-v1:0" → "claude-sonnet-4"
// "us.anthropic.claude-sonnet-4-20250514-v1:0" → "claude-sonnet-4"
// "claude-3-5-sonnet-20241022" → "claude-3-5-sonnet"
func normalizeClaudeModelName(model string) string {
// Strip region + provider prefixes (us.anthropic., anthropic., etc.)
if idx := strings.LastIndex(model, "."); idx >= 0 {
model = model[idx+1:]
}
// Strip Bedrock version suffix (":0", ":1", etc.) and the preceding "-v1"/"-v2"
if idx := strings.Index(model, ":"); idx >= 0 {
model = model[:idx]
if len(model) >= 3 {
suffix := model[len(model)-3:]
if suffix == "-v1" || suffix == "-v2" {
model = model[:len(model)-3]
}
}
}
// Strip "-v1", "-v2" even without colon (e.g., "anthropic.claude-opus-4-6-v1")
if strings.HasSuffix(model, "-v1") || strings.HasSuffix(model, "-v2") {
model = model[:len(model)-3]
}
// Strip date version suffix using schemas.BaseModelName
return schemas.BaseModelName(model)
}

View File

@@ -0,0 +1,337 @@
package utils
import (
"fmt"
"sync"
"testing"
)
func intPtr(v int) *int { return &v }
func TestModelParamsCacheGetSet(t *testing.T) {
cache := newModelParamsCache(10)
cache.Set("claude-sonnet-4-20250514", ModelParams{MaxOutputTokens: intPtr(8192)})
val, ok := cache.Get("claude-sonnet-4-20250514")
if !ok || val.MaxOutputTokens == nil || *val.MaxOutputTokens != 8192 {
t.Errorf("expected 8192, got %+v (ok=%v)", val, ok)
}
}
func TestModelParamsCacheMiss(t *testing.T) {
cache := newModelParamsCache(10)
val, ok := cache.Get("nonexistent-model")
if ok || val.MaxOutputTokens != nil {
t.Errorf("expected miss, got %+v (ok=%v)", val, ok)
}
}
func TestModelParamsCacheUpdate(t *testing.T) {
cache := newModelParamsCache(10)
cache.Set("claude-sonnet-4", ModelParams{MaxOutputTokens: intPtr(8192)})
cache.Set("claude-sonnet-4", ModelParams{MaxOutputTokens: intPtr(16384)})
val, ok := cache.Get("claude-sonnet-4")
if !ok || val.MaxOutputTokens == nil || *val.MaxOutputTokens != 16384 {
t.Errorf("expected 16384 after update, got %+v (ok=%v)", val, ok)
}
}
func TestModelParamsCacheEviction(t *testing.T) {
cache := newModelParamsCache(3)
cache.Set("model-a", ModelParams{MaxOutputTokens: intPtr(1000)})
cache.Set("model-b", ModelParams{MaxOutputTokens: intPtr(2000)})
cache.Set("model-c", ModelParams{MaxOutputTokens: intPtr(3000)})
// This should evict model-a (oldest insertion)
cache.Set("model-d", ModelParams{MaxOutputTokens: intPtr(4000)})
if _, ok := cache.Get("model-a"); ok {
t.Error("model-a should have been evicted")
}
if val, ok := cache.Get("model-b"); !ok || *val.MaxOutputTokens != 2000 {
t.Errorf("model-b should still exist, got %+v (ok=%v)", val, ok)
}
if val, ok := cache.Get("model-d"); !ok || *val.MaxOutputTokens != 4000 {
t.Errorf("model-d should exist, got %+v (ok=%v)", val, ok)
}
}
func TestModelParamsCacheBulkSet(t *testing.T) {
cache := newModelParamsCache(100)
entries := map[string]ModelParams{
"claude-sonnet-4": {MaxOutputTokens: intPtr(8192)},
"claude-opus-4": {MaxOutputTokens: intPtr(4096)},
"gpt-4o": {MaxOutputTokens: intPtr(16384)},
"gemini-2.0-flash": {MaxOutputTokens: intPtr(8192)},
}
cache.BulkSet(entries)
for model, expected := range entries {
val, ok := cache.Get(model)
if !ok || *val.MaxOutputTokens != *expected.MaxOutputTokens {
t.Errorf("BulkSet: model %s expected %d, got %+v (ok=%v)", model, *expected.MaxOutputTokens, val, ok)
}
}
}
func TestModelParamsCacheBulkSetOverflow(t *testing.T) {
cache := newModelParamsCache(3)
entries := map[string]ModelParams{
"model-1": {MaxOutputTokens: intPtr(1000)},
"model-2": {MaxOutputTokens: intPtr(2000)},
"model-3": {MaxOutputTokens: intPtr(3000)},
"model-4": {MaxOutputTokens: intPtr(4000)},
"model-5": {MaxOutputTokens: intPtr(5000)},
}
cache.BulkSet(entries)
if cache.order.Len() != 3 {
t.Errorf("expected 3 entries after overflow BulkSet, got %d", cache.order.Len())
}
}
func TestModelParamsCacheBulkSetUpdate(t *testing.T) {
cache := newModelParamsCache(10)
cache.Set("claude-sonnet-4", ModelParams{MaxOutputTokens: intPtr(4096)})
cache.BulkSet(map[string]ModelParams{
"claude-sonnet-4": {MaxOutputTokens: intPtr(8192)},
})
val, ok := cache.Get("claude-sonnet-4")
if !ok || *val.MaxOutputTokens != 8192 {
t.Errorf("BulkSet should update existing entry, got %+v (ok=%v)", val, ok)
}
}
func TestModelParamsCacheConcurrency(t *testing.T) {
cache := newModelParamsCache(100)
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
model := fmt.Sprintf("model-%d", i)
cache.Set(model, ModelParams{MaxOutputTokens: intPtr(i * 1000)})
cache.Get(model)
}(i)
}
wg.Wait()
if cache.order.Len() > 100 {
t.Errorf("cache exceeded capacity: %d", cache.order.Len())
}
}
func TestGetMaxOutputTokens(t *testing.T) {
cache := getModelParamsCache()
cache.Set("test-max-output", ModelParams{MaxOutputTokens: intPtr(16384)})
val, ok := GetMaxOutputTokens("test-max-output")
if !ok || val != 16384 {
t.Errorf("expected 16384, got %d (ok=%v)", val, ok)
}
val, ok = GetMaxOutputTokens("missing-model-get")
if ok || val != 0 {
t.Errorf("expected miss, got %d (ok=%v)", val, ok)
}
}
func TestGetMaxOutputTokensNilField(t *testing.T) {
cache := getModelParamsCache()
cache.Set("test-nil-field", ModelParams{})
val, ok := GetMaxOutputTokens("test-nil-field")
if ok || val != 0 {
t.Errorf("expected miss for nil MaxOutputTokens, got %d (ok=%v)", val, ok)
}
}
func TestGetMaxOutputTokensOrDefault(t *testing.T) {
cache := getModelParamsCache()
cache.Set("test-or-default", ModelParams{MaxOutputTokens: intPtr(16384)})
val := GetMaxOutputTokensOrDefault("test-or-default", 4096)
if val != 16384 {
t.Errorf("expected cached value 16384, got %d", val)
}
val = GetMaxOutputTokensOrDefault("missing-model-default", 4096)
if val != 4096 {
t.Errorf("expected default 4096 for missing non-claude model, got %d", val)
}
}
func TestCacheMissHandler(t *testing.T) {
cache := newModelParamsCache(10)
called := false
cache.cacheMissHandler = func(model string) *ModelParams {
called = true
if model == "db-model" {
return &ModelParams{MaxOutputTokens: intPtr(32000)}
}
return nil
}
// Miss handler returns a value → should be cached
val, ok := cache.Get("db-model")
if !ok || val.MaxOutputTokens == nil || *val.MaxOutputTokens != 32000 {
t.Errorf("expected 32000 from miss handler, got %+v (ok=%v)", val, ok)
}
if !called {
t.Error("miss handler was not called")
}
// Verify it was cached (handler should not be called again)
called = false
val, ok = cache.Get("db-model")
if !ok || *val.MaxOutputTokens != 32000 {
t.Errorf("expected cached 32000, got %+v (ok=%v)", val, ok)
}
if called {
t.Error("miss handler should not be called for cached entry")
}
// Miss handler returns nil → should return false
val, ok = cache.Get("unknown-model")
if ok {
t.Errorf("expected miss for unknown model, got %+v", val)
}
}
func TestCacheMissHandlerNil(t *testing.T) {
cache := newModelParamsCache(10)
// No handler registered
val, ok := cache.Get("any-model")
if ok {
t.Errorf("expected miss with nil handler, got %+v", val)
}
}
func TestNormalizeClaudeModelName(t *testing.T) {
tests := []struct {
input string
expected string
desc string
}{
// Anthropic direct (bare model names)
{"claude-sonnet-4-5", "claude-sonnet-4-5", "Anthropic: no version suffix"},
{"claude-sonnet-4-20250514", "claude-sonnet-4", "Anthropic: date suffix"},
{"claude-opus-4-5", "claude-opus-4-5", "Anthropic: no version suffix"},
{"claude-opus-4-6-20250514", "claude-opus-4-6", "Anthropic: date suffix"},
{"claude-sonnet-4-6", "claude-sonnet-4-6", "Anthropic: no version suffix"},
{"claude-3-5-sonnet-20241022", "claude-3-5-sonnet", "Anthropic: legacy date suffix"},
{"claude-3-7-sonnet-20250219", "claude-3-7-sonnet", "Anthropic: legacy date suffix"},
// Bedrock (anthropic. prefix + -v1:0 suffix)
{"anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet", "Bedrock: prefix + v1:0"},
{"anthropic.claude-opus-4-6-v1", "claude-opus-4-6", "Bedrock: prefix + v1 no colon"},
{"anthropic.claude-3-7-sonnet-v1", "claude-3-7-sonnet", "Bedrock: prefix + v1 no colon"},
{"anthropic.claude-sonnet-4-20250514-v1:0", "claude-sonnet-4", "Bedrock: prefix + date + v1:0"},
{"anthropic.claude-3-5-sonnet-20241022-v1:0", "claude-3-5-sonnet", "Bedrock: prefix + legacy date + v1:0"},
// Bedrock with region prefix
{"us.anthropic.claude-sonnet-4-6", "claude-sonnet-4-6", "Bedrock regional: us prefix"},
{"us.anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet", "Bedrock regional: us + v1:0"},
{"global.anthropic.claude-opus-4-6-20260301-v1:0", "claude-opus-4-6", "Bedrock regional: global + date + v1:0"},
{"eu.anthropic.claude-sonnet-4-5-20250929-v1:0", "claude-sonnet-4-5", "Bedrock regional: eu + date + v1:0"},
// Vertex (same as Anthropic direct — deployment is bare model name)
{"claude-sonnet-4-5", "claude-sonnet-4-5", "Vertex: bare model"},
{"claude-sonnet-4-20250514", "claude-sonnet-4", "Vertex: date suffix"},
// Azure (deployment names — typically bare model names)
{"claude-opus-4-5", "claude-opus-4-5", "Azure: deployment name"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
got := normalizeClaudeModelName(tt.input)
if got != tt.expected {
t.Errorf("normalizeClaudeModelName(%q) = %q, want %q", tt.input, got, tt.expected)
}
})
}
}
func TestGetMaxOutputTokensOrDefaultStaticFallback(t *testing.T) {
// Use a fresh cache with no entries to test static fallback only
// We test via the normalizeClaudeModelName + map lookup directly
// since the global cache may have entries from other tests
tests := []struct {
model string
expected int
desc string
}{
// Anthropic direct
{"claude-sonnet-4-20250514", 64000, "Anthropic: claude-sonnet-4"},
{"claude-opus-4-6-20250514", 128000, "Anthropic: claude-opus-4-6"},
{"claude-3-5-sonnet-20241022", 8192, "Anthropic: claude-3-5-sonnet"},
// Bedrock
{"anthropic.claude-sonnet-4-20250514-v1:0", 64000, "Bedrock: claude-sonnet-4"},
{"anthropic.claude-opus-4-6-v1", 128000, "Bedrock: claude-opus-4-6"},
{"anthropic.claude-3-5-sonnet-20241022-v1:0", 8192, "Bedrock: claude-3-5-sonnet"},
// Bedrock with region prefix
{"us.anthropic.claude-opus-4-6-v1:0", 128000, "Bedrock regional: claude-opus-4-6"},
{"global.anthropic.claude-sonnet-4-5-20250929-v1:0", 64000, "Bedrock regional: claude-sonnet-4-5"},
{"eu.anthropic.claude-3-haiku-20240307-v1:0", 4096, "Bedrock regional: claude-3-haiku"},
// Vertex
{"claude-opus-4-5", 64000, "Vertex: claude-opus-4-5"},
{"claude-haiku-4-5", 64000, "Vertex: claude-haiku-4-5"},
// Azure
{"claude-3-5-sonnet-20241022", 8192, "Azure: claude-3-5-sonnet"},
{"claude-sonnet-4-6", 64000, "Azure: claude-sonnet-4-6"},
// Non-Claude models should return the default
{"gpt-4o", 4096, "Non-Claude: gpt-4o"},
{"gemini-2.0-flash", 4096, "Non-Claude: gemini-2.0-flash"},
{"command-r-plus", 4096, "Non-Claude: command-r-plus"},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
// Test the static fallback logic directly
got := staticAnthropicFallback(tt.model, 4096)
if got != tt.expected {
t.Errorf("staticAnthropicFallback(%q, 4096) = %d, want %d", tt.model, got, tt.expected)
}
})
}
}
// staticAnthropicFallback is a test helper that mimics the fallback logic
// in GetMaxOutputTokensOrDefault without going through the global cache.
func staticAnthropicFallback(model string, defaultValue int) int {
if !contains(model, "claude") {
return defaultValue
}
base := normalizeClaudeModelName(model)
if m, ok := knownAnthropicMaxOutputTokens[base]; ok {
return m
}
return defaultValue
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 || indexSubstring(s, substr) >= 0)
}
func indexSubstring(s, substr string) int {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return i
}
}
return -1
}

View File

@@ -0,0 +1,356 @@
// Package utils — list_models.go
// Centralised pipeline for filtering and backfilling models in ListModels responses.
//
// Every provider's ToBifrostListModelsResponse follows the same logical steps:
// 1. Resolve each API model's name (alias lookup → alias key; else raw model ID)
// 2. Filter (allowlist + blacklist check on the resolved name)
// 3. Backfill entries that were not returned by the API but should appear in output
//
// Providers plug in custom MatchFns to extend the default matching behaviour.
// Example: Bedrock adds region-prefix-aware matching on top of DefaultMatchFns.
package utils
import (
"sort"
"strings"
"github.com/maximhq/bifrost/core/schemas"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
// ToDisplayName converts a raw model ID or alias key into a human-readable display name.
// Splits on "-" or "_", title-cases each word, and joins with spaces.
//
// "gemini-pro" → "Gemini Pro"
// "claude_3_opus" → "Claude 3 Opus"
// "gpt-4-turbo" → "Gpt 4 Turbo"
func ToDisplayName(id string) string {
caser := cases.Title(language.English)
parts := strings.FieldsFunc(id, func(r rune) bool {
return r == '-' || r == '_'
})
if len(parts) == 0 {
return ""
}
for i, part := range parts {
if part != "" {
parts[i] = caser.String(strings.ToLower(part))
}
}
return strings.Join(parts, " ")
}
// MatchFn reports whether two model ID strings should be treated as equivalent.
// Functions are applied in order during every comparison — the first one that
// returns true short-circuits the rest.
//
// Example built-in fns (see DefaultMatchFns):
//
// exactMatch("gpt-4", "gpt-4") → true
// sameBaseModel("claude-3-5-sonnet-20241022", "claude-3-5") → true
type MatchFn func(a, b string) bool
// DefaultMatchFns returns the standard matching functions used by most providers.
// Currently only performs case-insensitive exact matching.
//
// SameBaseModel (strips version suffixes, e.g. "claude-3-5-sonnet-20241022" ≈ "claude-3-5-sonnet")
// is intentionally excluded — users should use aliases for explicit version-to-base-name mapping.
// It can be appended here if fuzzy base-model matching is ever needed globally.
func DefaultMatchFns() []MatchFn {
return []MatchFn{
func(a, b string) bool { return strings.EqualFold(a, b) },
}
}
// matches reports whether a and b are considered equal by any of the provided fns.
// Returns true on the first fn that returns true.
func matches(a, b string, fns []MatchFn) bool {
for _, fn := range fns {
if fn(a, b) {
return true
}
}
return false
}
// FilterResult is the outcome of running Pipeline.FilterModel for a single model
// from the provider's API response. Each returned result represents one alias
// entry (or the raw model ID when no alias matched) that passed all filters.
type FilterResult struct {
// ResolvedID is the user-facing model name to use as the ID suffix.
// If the model matched an alias VALUE, this is the alias KEY.
// Otherwise this is the original model ID from the API response.
//
// Example: API returns "gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo"}
// → ResolvedID = "my-gpt4"
// Example: API returns "gpt-3.5-turbo", no alias match
// → ResolvedID = "gpt-3.5-turbo"
ResolvedID string
// AliasValue is the provider-specific model ID when the model was matched
// via an alias. Set as the model.Alias field so callers know the underlying ID.
// Empty when the model was matched directly (no alias involved).
//
// Example: API returns "gpt-4-turbo", alias key "my-gpt4" matched
// → AliasValue = "gpt-4-turbo"
AliasValue string
}
// Pipeline holds all the context needed to filter and backfill models in a
// single ListModels response. Construct one per ToBifrostListModelsResponse call
// and use its methods instead of passing params + matchFns to every function.
//
// pipeline := &providerUtils.ListModelsPipeline{
// AllowedModels: key.Models,
// BlacklistedModels: key.BlacklistedModels,
// Aliases: key.Aliases,
// Unfiltered: request.Unfiltered,
// ProviderKey: schemas.OpenAI,
// MatchFns: providerUtils.DefaultMatchFns(),
// }
// if pipeline.ShouldEarlyExit() { return empty }
// result := pipeline.FilterModel(model.ID)
// pipeline.BackfillModels(included)
type ListModelsPipeline struct {
AllowedModels schemas.WhiteList
BlacklistedModels schemas.BlackList
// Aliases maps user-facing alias keys to provider-specific model IDs.
// e.g. {"my-gpt4": "gpt-4-turbo-2024-04-09"}
Aliases map[string]string
Unfiltered bool
ProviderKey schemas.ModelProvider
// MatchFns is the ordered list of equivalence functions used for every
// model ID comparison. Use DefaultMatchFns() for standard behaviour;
// providers may append additional fns (e.g. Bedrock's region-prefix remover).
MatchFns []MatchFn
}
// ShouldEarlyExit reports whether ToBifrostListModelsResponse should immediately
// return an empty response without processing any models.
//
// Returns true when:
// - not unfiltered AND allowlist is empty AND no aliases configured
// (there is nothing to match against — all models would be filtered out anyway)
// - not unfiltered AND blacklist blocks everything
//
// Note: allowlist empty + aliases present → do NOT early exit.
// The aliases drive backfill in the wildcard-allowlist case (Case B of BackfillModels).
func (p *ListModelsPipeline) ShouldEarlyExit() bool {
if p.Unfiltered {
return false
}
if p.BlacklistedModels.IsBlockAll() {
return true
}
if p.AllowedModels.IsEmpty() && len(p.Aliases) == 0 {
return true
}
return false
}
// aliasMatch holds a single alias key/value pair returned by resolveModelID.
type aliasMatch struct {
key string
value string
}
// resolveModelID returns all alias entries whose VALUE matches modelID using the pipeline's MatchFns,
// plus the raw model ID itself as an additional entry so both the alias key and the original model
// name appear in the list-models output.
// Results are sorted by alias key (case-insensitive) for deterministic ordering.
//
// If one or more aliases match → returns one aliasMatch per matching alias key, plus the raw ID.
//
// Example: modelID="gpt-4-turbo", aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"}
// → [{key:"gpt-4-turbo", value:""}, {key:"gpt4-alias", value:"gpt-4-turbo"}, {key:"my-gpt4", value:"gpt-4-turbo"}]
//
// If no alias matches → returns a single entry with the original model ID and no alias value.
//
// Example: modelID="gpt-3.5-turbo", no alias match
// → [{key:"gpt-3.5-turbo", value:""}]
func (p *ListModelsPipeline) resolveModelID(modelID string) []aliasMatch {
var candidates []aliasMatch
for aliasKey, providerID := range p.Aliases {
if matches(modelID, providerID, p.MatchFns) {
candidates = append(candidates, aliasMatch{key: aliasKey, value: providerID})
}
}
if len(candidates) == 0 {
return []aliasMatch{{key: modelID, value: ""}}
}
// Also include the raw model ID so both the alias key and the original name appear in output.
candidates = append(candidates, aliasMatch{key: modelID, value: ""})
sort.Slice(candidates, func(i, j int) bool {
return strings.ToLower(candidates[i].key) < strings.ToLower(candidates[j].key)
})
return candidates
}
// FilterModel applies the full filter pipeline for a single model from the API response.
//
// Steps:
// 1. Resolve name — check alias VALUES for a match (uses MatchFns).
// If matched: resolvedName = alias KEY, aliasValue = provider ID.
// If not matched: resolvedName = original modelID, aliasValue = "".
// 2. Allowlist check (only when allowlist is restricted, i.e. not wildcard):
// Skip if resolvedName is not in AllowedModels.
// 3. Blacklist check (always):
// Skip if resolvedName is blacklisted. Blacklist takes precedence over everything.
// 4. Return one FilterResult per passing candidate.
//
// An empty slice means the model should be skipped entirely.
// When multiple aliases map to the same provider model ID, each alias that passes
// the filters produces its own FilterResult entry.
//
// Examples:
//
// allowedModels=["my-gpt4"], aliases={"my-gpt4":"gpt-4-turbo"}, blacklist=[]
// FilterModel("gpt-4-turbo") → [{ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}]
// FilterModel("gpt-3.5") → [] (not in allowlist)
//
// allowedModels=*, aliases={"my-gpt4":"gpt-4-turbo","gpt4-alias":"gpt-4-turbo"}, blacklist=[]
// FilterModel("gpt-4-turbo") → [{ResolvedID:"gpt-4-turbo", AliasValue:""},
// {ResolvedID:"gpt4-alias", AliasValue:"gpt-4-turbo"},
// {ResolvedID:"my-gpt4", AliasValue:"gpt-4-turbo"}]
//
// allowedModels=["gpt-3.5"], aliases={}, blacklist=[]
// FilterModel("gpt-3.5") → [{ResolvedID:"gpt-3.5", AliasValue:""}]
// FilterModel("gpt-4") → []
func (p *ListModelsPipeline) FilterModel(modelID string) []FilterResult {
// Step 1: resolve name — collect all alias matches (or the raw ID if none match).
candidates := p.resolveModelID(modelID)
var results []FilterResult
for _, candidate := range candidates {
resolvedName := candidate.key
// Step 2: allowlist check.
// IsRestricted() is true for both an explicit list AND an empty list (deny-all).
// Only a wildcard allowlist marker bypasses this check (pass-through).
if !p.Unfiltered && p.AllowedModels.IsRestricted() {
allowed := false
for _, entry := range p.AllowedModels {
if matches(resolvedName, entry, p.MatchFns) {
allowed = true
break
}
}
if !allowed {
continue
}
}
// Step 3: blacklist check — blacklist always wins regardless of allowlist or aliases.
if !p.Unfiltered {
blacklisted := false
for _, entry := range p.BlacklistedModels {
if matches(resolvedName, entry, p.MatchFns) {
blacklisted = true
break
}
}
if blacklisted {
continue
}
}
results = append(results, FilterResult{
ResolvedID: resolvedName,
AliasValue: candidate.value,
})
}
return results
}
// BackfillModels adds model entries that were configured by the caller but not
// returned by the provider's API response (or not matched during filtering).
//
// The `included` map tracks model IDs (lowercased) already added during the
// filter pass, used to avoid duplicates.
//
// Two cases depending on whether the allowlist is restricted:
//
// Case A — allowlist restricted (caller specified explicit model names):
//
// Add each allowlist entry that is not yet in `included`, skip if blacklisted.
// If the entry has an alias mapping (aliases[entry] exists), set Alias to the
// provider-specific ID so callers can route to the right model.
//
// Example: allowedModels=["my-gpt4","gpt-3.5"], aliases={"my-gpt4":"gpt-4-turbo"}
// "my-gpt4" not in included → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"}
// "gpt-3.5" not in included → add {ID:"openai/gpt-3.5"}
//
// Case B — allowlist wildcard (*) only:
//
// We don't know all model names (no explicit list), so we only backfill entries
// that were explicitly configured via aliases and not yet matched from the API.
// Note: an empty allowlist is deny-all (IsRestricted()==true), not wildcard.
//
// Example: aliases={"my-gpt4":"gpt-4-turbo"}, "my-gpt4" not in included
// → add {ID:"openai/my-gpt4", Alias:"gpt-4-turbo"}
//
// Blacklist always wins — nothing blacklisted is added in either case.
func (p *ListModelsPipeline) BackfillModels(included map[string]bool) []schemas.Model {
var result []schemas.Model
if !p.Unfiltered && p.AllowedModels.IsRestricted() {
// Case A: backfill explicit allowlist entries not yet matched.
for _, entry := range p.AllowedModels {
if included[strings.ToLower(entry)] {
continue
}
// Blacklist check.
blacklisted := false
for _, bl := range p.BlacklistedModels {
if matches(entry, bl, p.MatchFns) {
blacklisted = true
break
}
}
if blacklisted {
continue
}
m := schemas.Model{
ID: string(p.ProviderKey) + "/" + entry,
Name: schemas.Ptr(ToDisplayName(entry)),
}
// If this allowlist entry has an alias, surface the provider-specific ID.
for aliasKey, providerID := range p.Aliases {
if matches(entry, aliasKey, p.MatchFns) {
m.Alias = schemas.Ptr(providerID)
break
}
}
result = append(result, m)
}
return result
}
// Case B: wildcard allowlist — backfill only explicitly configured aliases.
if !p.Unfiltered && len(p.Aliases) > 0 {
for aliasKey, providerID := range p.Aliases {
if included[strings.ToLower(aliasKey)] {
continue
}
// Blacklist check.
blacklisted := false
for _, bl := range p.BlacklistedModels {
if matches(aliasKey, bl, p.MatchFns) {
blacklisted = true
break
}
}
if blacklisted {
continue
}
result = append(result, schemas.Model{
ID: string(p.ProviderKey) + "/" + aliasKey,
Name: schemas.Ptr(ToDisplayName(aliasKey)),
Alias: schemas.Ptr(providerID),
})
}
}
return result
}

View File

@@ -0,0 +1,112 @@
package utils
import (
schemas "github.com/maximhq/bifrost/core/schemas"
)
// SerialListHelper manages serial key pagination for list operations.
// It ensures that all pages from one key are exhausted before moving to the next,
// guaranteeing only one API call per pagination request regardless of key count.
type SerialListHelper struct {
Keys []schemas.Key
Cursor *schemas.SerialCursor
Logger schemas.Logger
}
// NewSerialListHelper creates a new SerialListHelper from the provided keys and encoded cursor.
// If the cursor is empty or nil, pagination starts from the first key.
// If the cursor is invalid, an error is returned.
func NewSerialListHelper(keys []schemas.Key, encodedCursor *string, logger schemas.Logger) (*SerialListHelper, error) {
helper := &SerialListHelper{
Keys: keys,
Logger: logger,
}
if encodedCursor != nil && *encodedCursor != "" {
cursor, err := schemas.DecodeSerialCursor(*encodedCursor)
if err != nil {
return nil, err
}
helper.Cursor = cursor
}
return helper, nil
}
// GetCurrentKey returns the key to query and its native cursor.
// Returns (key, nativeCursor, true) if there's a key to query.
// Returns (Key{}, "", false) if all keys are exhausted.
func (h *SerialListHelper) GetCurrentKey() (schemas.Key, string, bool) {
if len(h.Keys) == 0 {
return schemas.Key{}, "", false
}
keyIndex := 0
nativeCursor := ""
if h.Cursor != nil {
keyIndex = h.Cursor.KeyIndex
nativeCursor = h.Cursor.Cursor
}
// Check if key index is within bounds
if keyIndex >= len(h.Keys) {
return schemas.Key{}, "", false
}
return h.Keys[keyIndex], nativeCursor, true
}
// BuildNextCursor creates the cursor for the next pagination request.
// Parameters:
// - hasMore: whether the current key has more pages
// - nativeCursor: the native cursor returned by the current key's API
//
// Returns:
// - encodedCursor: the encoded cursor for the next request (empty if all keys exhausted)
// - moreAvailable: true if there are more results available (either from current key or remaining keys)
func (h *SerialListHelper) BuildNextCursor(hasMore bool, nativeCursor string) (string, bool) {
if len(h.Keys) == 0 {
return "", false
}
currentKeyIndex := 0
if h.Cursor != nil {
currentKeyIndex = h.Cursor.KeyIndex
}
if hasMore {
// Current key has more pages - return cursor for same key
nextCursor := schemas.NewSerialCursor(currentKeyIndex, nativeCursor)
return schemas.EncodeSerialCursor(nextCursor), true
}
// Current key exhausted - check if there are more keys
nextKeyIndex := currentKeyIndex + 1
if nextKeyIndex >= len(h.Keys) {
// All keys exhausted
return "", false
}
// Move to next key with empty cursor (start fresh)
nextCursor := schemas.NewSerialCursor(nextKeyIndex, "")
return schemas.EncodeSerialCursor(nextCursor), true
}
// GetCurrentKeyIndex returns the current key index being processed.
func (h *SerialListHelper) GetCurrentKeyIndex() int {
if h.Cursor != nil {
return h.Cursor.KeyIndex
}
return 0
}
// HasMoreKeys returns true if there are more keys after the current one.
func (h *SerialListHelper) HasMoreKeys() bool {
currentKeyIndex := 0
if h.Cursor != nil {
currentKeyIndex = h.Cursor.KeyIndex
}
return currentKeyIndex < len(h.Keys)-1
}

195
core/providers/utils/sse.go Normal file
View File

@@ -0,0 +1,195 @@
package utils
import (
"bufio"
"bytes"
"io"
"github.com/maximhq/bifrost/core/schemas"
)
const (
sseInitialBufSize = 8 * 1024 // 8KB — sufficient for >99.9% of SSE lines
sseMaxBufSize = 10 * 1024 * 1024 // 10MB — allow large tokens (tool calls, audio)
)
// SSEDataReader reads SSE data-only events (Format A: OpenAI, Gemini, Cohere, etc.).
// ReadDataLine returns the next SSE data payload, stripping the "data:" prefix.
// Returns (nil, io.EOF) at end of stream or on "data: [DONE]".
type SSEDataReader interface {
ReadDataLine() ([]byte, error)
}
// SSEEventReader reads SSE events with type and data (Format B: Anthropic, Replicate, etc.).
// ReadEvent returns the complete event once an empty-line delimiter is encountered.
// Multiple "data:" lines within one event are concatenated with newlines.
// Returns ("", nil, io.EOF) at end of stream.
type SSEEventReader interface {
ReadEvent() (eventType string, data []byte, err error)
}
// SSEReaderFactory creates SSE readers for streaming response processing.
// Enterprise injects this via BifrostContextKeySSEReaderFactory to replace
// the default bufio.Scanner-based implementations with streaming readers.
type SSEReaderFactory struct {
NewDataReader func(reader io.Reader) SSEDataReader
NewEventReader func(reader io.Reader) SSEEventReader
}
// GetSSEDataReader returns an SSEDataReader for the given reader.
// If enterprise has injected an SSEReaderFactory via context, uses that.
// Otherwise returns a default implementation wrapping bufio.NewScanner.
func GetSSEDataReader(ctx *schemas.BifrostContext, reader io.Reader) SSEDataReader {
if ctx != nil {
if factory, ok := ctx.Value(schemas.BifrostContextKeySSEReaderFactory).(*SSEReaderFactory); ok && factory != nil && factory.NewDataReader != nil {
return factory.NewDataReader(reader)
}
}
return newDefaultSSEDataReader(reader)
}
// GetSSEEventReader returns an SSEEventReader for the given reader.
// If enterprise has injected an SSEReaderFactory via context, uses that.
// Otherwise returns a default implementation wrapping bufio.NewScanner.
func GetSSEEventReader(ctx *schemas.BifrostContext, reader io.Reader) SSEEventReader {
if ctx != nil {
if factory, ok := ctx.Value(schemas.BifrostContextKeySSEReaderFactory).(*SSEReaderFactory); ok && factory != nil && factory.NewEventReader != nil {
return factory.NewEventReader(reader)
}
}
return newDefaultSSEEventReader(reader)
}
// Reusable byte prefixes for SSE field parsing.
var (
sseDataPrefix = []byte("data:")
sseDoneMarker = []byte("[DONE]")
sseEventPrefix = []byte("event:")
sseIDPrefix = []byte("id:")
sseRetryPrefix = []byte("retry:")
)
// defaultSSEDataReader implements SSEDataReader using bufio.NewScanner.
// Handles Format A SSE streams (data-only: OpenAI, Gemini, Cohere, etc.).
type defaultSSEDataReader struct {
scanner *bufio.Scanner
}
func newDefaultSSEDataReader(reader io.Reader) *defaultSSEDataReader {
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 0, sseInitialBufSize), sseMaxBufSize)
return &defaultSSEDataReader{scanner: scanner}
}
func (r *defaultSSEDataReader) ReadDataLine() ([]byte, error) {
for r.scanner.Scan() {
line := r.scanner.Bytes()
// Skip empty lines and comments
if len(line) == 0 || line[0] == ':' {
continue
}
// Parse "data:" lines
if bytes.HasPrefix(line, sseDataPrefix) {
data := line[5:] // len("data:") == 5
if len(data) > 0 && data[0] == ' ' {
data = data[1:]
}
if len(data) == 0 {
continue
}
if bytes.Equal(data, sseDoneMarker) {
return nil, io.EOF
}
// Copy to decouple from scanner's internal buffer
return append([]byte(nil), data...), nil
}
// Skip known SSE fields (event, id, retry)
if bytes.HasPrefix(line, sseEventPrefix) ||
bytes.HasPrefix(line, sseIDPrefix) ||
bytes.HasPrefix(line, sseRetryPrefix) {
continue
}
// Non-SSE line: return as-is (raw JSON error fallback, e.g. OpenAI)
return append([]byte(nil), line...), nil
}
if err := r.scanner.Err(); err != nil {
return nil, err
}
return nil, io.EOF
}
// defaultSSEEventReader implements SSEEventReader using bufio.NewScanner.
// Handles Format B SSE streams (event+data: Anthropic, Replicate, Mistral, etc.).
// Events are delimited by empty lines; multiple "data:" lines are concatenated.
type defaultSSEEventReader struct {
scanner *bufio.Scanner
eventType string
eventData []byte
}
func newDefaultSSEEventReader(reader io.Reader) *defaultSSEEventReader {
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 0, sseInitialBufSize), sseMaxBufSize)
return &defaultSSEEventReader{scanner: scanner}
}
func (r *defaultSSEEventReader) ReadEvent() (string, []byte, error) {
for r.scanner.Scan() {
line := r.scanner.Bytes()
// Skip comments
if len(line) > 0 && line[0] == ':' {
continue
}
// Empty line = event boundary
if len(line) == 0 {
if r.eventType == "" && len(r.eventData) == 0 {
continue
}
eventType := r.eventType
eventData := make([]byte, len(r.eventData))
copy(eventData, r.eventData)
r.eventType = ""
r.eventData = r.eventData[:0]
return eventType, eventData, nil
}
// Parse SSE fields
if bytes.HasPrefix(line, sseEventPrefix) {
field := line[6:] // len("event:") == 6
if len(field) > 0 && field[0] == ' ' {
field = field[1:]
}
r.eventType = string(field)
} else if bytes.HasPrefix(line, sseDataPrefix) {
data := line[5:] // len("data:") == 5
if len(data) > 0 && data[0] == ' ' {
data = data[1:]
}
if len(r.eventData) > 0 {
r.eventData = append(r.eventData, '\n')
}
r.eventData = append(r.eventData, data...)
}
// id:, retry:, and other fields are silently skipped
}
// Scanner done — return any accumulated event before EOF
if r.eventType != "" || len(r.eventData) > 0 {
eventType := r.eventType
eventData := make([]byte, len(r.eventData))
copy(eventData, r.eventData)
r.eventType = ""
r.eventData = r.eventData[:0]
return eventType, eventData, nil
}
if err := r.scanner.Err(); err != nil {
return "", nil, err
}
return "", nil, io.EOF
}

View File

@@ -0,0 +1,75 @@
package utils
import (
"context"
schemas "github.com/maximhq/bifrost/core/schemas"
)
// CheckFirstStreamChunkForError reads the first chunk from a streaming channel to detect
// errors returned inside HTTP 200 SSE streams (e.g., providers that send rate limit
// errors as SSE events instead of HTTP 429).
//
// If the first chunk is an error, it drains the source channel in the background
// (so the provider goroutine can exit cleanly) and returns the error for synchronous
// handling, enabling retries and fallbacks. The returned drainDone channel is closed
// once the drain completes — callers must wait on it before releasing any resources
// (e.g., plugin pipelines) that the provider goroutine's postHookRunner may still reference.
//
// If the first chunk is valid data, it returns a wrapped channel that re-emits
// the first chunk followed by all remaining chunks from the source. drainDone is
// closed when the wrapper goroutine finishes forwarding the source stream.
//
// If the source channel is closed immediately (empty stream), it returns a
// nil channel with nil error. drainDone is already closed.
//
// The ctx argument cancels the background forwarding goroutine if the consumer
// abandons the returned wrapped channel. On ctx.Done the goroutine drains the
// source stream so the upstream provider's blocked send can exit cleanly.
func CheckFirstStreamChunkForError(
ctx context.Context,
stream chan *schemas.BifrostStreamChunk,
) (chan *schemas.BifrostStreamChunk, <-chan struct{}, *schemas.BifrostError) {
firstChunk, ok := <-stream
if !ok {
// Channel closed immediately (empty stream) — return nil so callers
// can distinguish this from a live stream channel.
done := make(chan struct{})
close(done)
return nil, done, nil
}
// Check if first chunk is an error
if firstChunk.BifrostError != nil && firstChunk.BifrostError.Error != nil &&
(firstChunk.BifrostError.Error.Message != "" || firstChunk.BifrostError.Error.Code != nil || firstChunk.BifrostError.Error.Type != nil) {
// Drain source channel to let the provider goroutine exit cleanly
done := make(chan struct{})
go func() {
defer close(done)
for range stream {
}
}()
return nil, done, firstChunk.BifrostError
}
// First chunk is valid data — wrap channel to re-inject it
done := make(chan struct{})
wrapped := make(chan *schemas.BifrostStreamChunk, max(cap(stream), 1))
wrapped <- firstChunk
go func() {
defer close(done)
defer close(wrapped)
for chunk := range stream {
select {
case wrapped <- chunk:
case <-ctx.Done():
// Consumer abandoned the wrapped channel. Drain the source so the
// provider's blocked send unblocks and its goroutine can exit.
for range stream {
}
return
}
}
}()
return wrapped, done, nil
}

View File

@@ -0,0 +1,252 @@
package utils
import (
"context"
"testing"
"time"
schemas "github.com/maximhq/bifrost/core/schemas"
)
func TestCheckFirstStreamChunk_ErrorInFirstChunk(t *testing.T) {
stream := make(chan *schemas.BifrostStreamChunk, 2)
stream <- &schemas.BifrostStreamChunk{
BifrostError: &schemas.BifrostError{
Error: &schemas.ErrorField{
Code: schemas.Ptr("limit_burst_rate"),
Message: "Request rate increased too quickly",
},
},
}
close(stream)
_, drainDone, err := CheckFirstStreamChunkForError(context.Background(), stream)
if err == nil {
t.Fatal("expected error, got nil")
}
<-drainDone
if err.Error.Message != "Request rate increased too quickly" {
t.Errorf("unexpected error message: %s", err.Error.Message)
}
if err.Error.Code == nil || *err.Error.Code != "limit_burst_rate" {
t.Errorf("unexpected error code: %v", err.Error.Code)
}
}
func TestCheckFirstStreamChunk_ValidFirstChunk(t *testing.T) {
stream := make(chan *schemas.BifrostStreamChunk, 3)
chunk1 := &schemas.BifrostStreamChunk{
BifrostChatResponse: &schemas.BifrostChatResponse{
ID: "chatcmpl-123",
},
}
chunk2 := &schemas.BifrostStreamChunk{
BifrostChatResponse: &schemas.BifrostChatResponse{
ID: "chatcmpl-123",
},
}
stream <- chunk1
stream <- chunk2
close(stream)
wrapped, _, err := CheckFirstStreamChunkForError(context.Background(), stream)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// First chunk should be re-injected
got1 := <-wrapped
if got1.BifrostChatResponse == nil || got1.BifrostChatResponse.ID != "chatcmpl-123" {
t.Error("first chunk not re-injected correctly")
}
// Second chunk should follow
got2 := <-wrapped
if got2.BifrostChatResponse == nil || got2.BifrostChatResponse.ID != "chatcmpl-123" {
t.Error("second chunk not forwarded correctly")
}
// Channel should be closed
_, ok := <-wrapped
if ok {
t.Error("expected wrapped channel to be closed")
}
}
func TestCheckFirstStreamChunk_EmptyStream(t *testing.T) {
stream := make(chan *schemas.BifrostStreamChunk)
close(stream)
wrapped, drainDone, err := CheckFirstStreamChunkForError(context.Background(), stream)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Empty stream should return nil channel
if wrapped != nil {
t.Error("expected nil channel for empty stream")
}
// drainDone should be already closed
select {
case <-drainDone:
default:
t.Error("expected drainDone to be closed for empty stream")
}
}
func TestCheckFirstStreamChunk_ErrorInSecondChunk(t *testing.T) {
stream := make(chan *schemas.BifrostStreamChunk, 3)
stream <- &schemas.BifrostStreamChunk{
BifrostChatResponse: &schemas.BifrostChatResponse{
ID: "chatcmpl-123",
},
}
stream <- &schemas.BifrostStreamChunk{
BifrostError: &schemas.BifrostError{
Error: &schemas.ErrorField{
Message: "some error in second chunk",
},
},
}
close(stream)
// Should NOT return error — only first chunk matters for retry
wrapped, _, err := CheckFirstStreamChunkForError(context.Background(), stream)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Read all chunks
got1 := <-wrapped
if got1.BifrostChatResponse == nil {
t.Error("first chunk should be valid data")
}
got2 := <-wrapped
if got2.BifrostError == nil {
t.Error("second chunk should be the error")
}
_, ok := <-wrapped
if ok {
t.Error("expected wrapped channel to be closed")
}
}
func TestCheckFirstStreamChunk_ErrorDrainsSource(t *testing.T) {
stream := make(chan *schemas.BifrostStreamChunk, 5)
stream <- &schemas.BifrostStreamChunk{
BifrostError: &schemas.BifrostError{
Error: &schemas.ErrorField{
Message: "rate limit error",
},
},
}
// Add more chunks that should be drained
stream <- &schemas.BifrostStreamChunk{
BifrostChatResponse: &schemas.BifrostChatResponse{ID: "1"},
}
stream <- &schemas.BifrostStreamChunk{
BifrostChatResponse: &schemas.BifrostChatResponse{ID: "2"},
}
close(stream)
_, drainDone, err := CheckFirstStreamChunkForError(context.Background(), stream)
if err == nil {
t.Fatal("expected error, got nil")
}
<-drainDone
if err.Error.Message != "rate limit error" {
t.Errorf("unexpected error message: %s", err.Error.Message)
}
if drainDone == nil {
t.Fatal("expected drainDone channel, got nil")
}
// Wait for drain to complete — verifies the channel signals properly
<-drainDone
}
func TestCheckFirstStreamChunk_ErrorWithEmptyMessage(t *testing.T) {
// Error with empty message and no code/type should NOT be treated as an error
stream := make(chan *schemas.BifrostStreamChunk, 2)
stream <- &schemas.BifrostStreamChunk{
BifrostError: &schemas.BifrostError{
Error: &schemas.ErrorField{
Message: "",
},
},
}
close(stream)
wrapped, _, err := CheckFirstStreamChunkForError(context.Background(), stream)
if err != nil {
t.Fatalf("unexpected error for empty message: %v", err)
}
// Should be treated as valid chunk
<-wrapped
}
func TestCheckFirstStreamChunk_CtxCancelUnblocksWrapper(t *testing.T) {
// Source with cap=1 so wrapped also has cap=1. wrapped is left full by
// the re-injected first chunk, which makes the forwarder goroutine block
// on its next send — the exact leak condition this test guards against.
src := make(chan *schemas.BifrostStreamChunk, 1)
src <- &schemas.BifrostStreamChunk{
BifrostChatResponse: &schemas.BifrostChatResponse{ID: "1"},
}
ctx, cancel := context.WithCancel(context.Background())
wrapped, drainDone, err := CheckFirstStreamChunkForError(ctx, src)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if wrapped == nil {
t.Fatal("expected wrapped channel, got nil")
}
// Push a second chunk; forwarder will read it from src and then block
// trying to send into the full wrapped channel (we intentionally never
// read from wrapped).
src <- &schemas.BifrostStreamChunk{
BifrostChatResponse: &schemas.BifrostChatResponse{ID: "2"},
}
// Cancel — forwarder must stop trying to send to wrapped and drain src.
cancel()
// Simulate the upstream producer still emitting, then closing. The
// drain loop should consume these and terminate.
src <- &schemas.BifrostStreamChunk{
BifrostChatResponse: &schemas.BifrostChatResponse{ID: "3"},
}
close(src)
select {
case <-drainDone:
case <-time.After(time.Second):
t.Fatal("drainDone did not close after ctx cancel; forwarder goroutine leaked")
}
}
func TestCheckFirstStreamChunk_CodeOnlyError(t *testing.T) {
// Error with code but no message should be treated as an error
stream := make(chan *schemas.BifrostStreamChunk, 2)
stream <- &schemas.BifrostStreamChunk{
BifrostError: &schemas.BifrostError{
Error: &schemas.ErrorField{
Code: schemas.Ptr("limit_burst_rate"),
},
},
}
close(stream)
_, drainDone, err := CheckFirstStreamChunkForError(context.Background(), stream)
if err == nil {
t.Fatal("expected error for code-only error, got nil")
}
<-drainDone
if err.Error.Code == nil || *err.Error.Code != "limit_burst_rate" {
t.Errorf("unexpected error code: %v", err.Error.Code)
}
}

View File

@@ -0,0 +1,218 @@
package utils
import (
"bufio"
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/valyala/fasthttp"
)
// TestBuildStreamingClient_ZerosReadWriteTimeout verifies the streaming client
// has ReadTimeout=0 / WriteTimeout=0 / MaxConnDuration=0 while preserving other
// config from the base.
func TestBuildStreamingClient_ZerosReadWriteTimeout(t *testing.T) {
base := &fasthttp.Client{
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
MaxConnDuration: 5 * time.Minute,
MaxConnWaitTimeout: 15 * time.Second,
MaxConnsPerHost: 123,
}
ConfigureDialer(base)
stream := BuildStreamingClient(base)
if stream.ReadTimeout != 0 {
t.Errorf("ReadTimeout: got %v, want 0", stream.ReadTimeout)
}
if stream.WriteTimeout != 0 {
t.Errorf("WriteTimeout: got %v, want 0", stream.WriteTimeout)
}
if stream.MaxConnDuration != 0 {
t.Errorf("MaxConnDuration: got %v, want 0", stream.MaxConnDuration)
}
if !stream.StreamResponseBody {
t.Error("StreamResponseBody: got false, want true")
}
if stream.MaxConnWaitTimeout != base.MaxConnWaitTimeout {
t.Errorf("MaxConnWaitTimeout should be preserved: got %v, want %v",
stream.MaxConnWaitTimeout, base.MaxConnWaitTimeout)
}
if stream.MaxConnsPerHost != base.MaxConnsPerHost {
t.Errorf("MaxConnsPerHost should be preserved: got %v, want %v",
stream.MaxConnsPerHost, base.MaxConnsPerHost)
}
}
// TestBuildStreamingClient_BaseUnchanged verifies BuildStreamingClient does not
// mutate the base client (since unary callers still need the 30s timeout).
func TestBuildStreamingClient_BaseUnchanged(t *testing.T) {
base := &fasthttp.Client{
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
MaxConnDuration: 5 * time.Minute,
}
_ = BuildStreamingClient(base)
if base.ReadTimeout != 30*time.Second {
t.Errorf("base ReadTimeout mutated: got %v, want 30s", base.ReadTimeout)
}
if base.MaxConnDuration != 5*time.Minute {
t.Errorf("base MaxConnDuration mutated: got %v, want 5m", base.MaxConnDuration)
}
}
// TestBuildStreamingClient_LongStreamSurvives verifies that a stream sending
// chunks every 500ms for 2.5s (total) is not killed by the base client's 1s
// ReadTimeout. Before the fix, fasthttp would abort at ~1s.
func TestBuildStreamingClient_LongStreamSurvives(t *testing.T) {
const chunkInterval = 500 * time.Millisecond
const totalChunks = 5 // 2.5s total, well past base ReadTimeout=1s
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.WriteHeader(http.StatusOK)
flusher, _ := w.(http.Flusher)
for i := 0; i < totalChunks; i++ {
fmt.Fprintf(w, "data: chunk-%d\n\n", i)
if flusher != nil {
flusher.Flush()
}
time.Sleep(chunkInterval)
}
}))
defer srv.Close()
base := &fasthttp.Client{
ReadTimeout: 1 * time.Second, // would abort the stream without the fix
WriteTimeout: 1 * time.Second,
}
ConfigureDialer(base)
stream := BuildStreamingClient(base)
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.SetRequestURI(srv.URL)
req.Header.SetMethod(http.MethodGet)
resp.StreamBody = true
if err := stream.Do(req, resp); err != nil {
t.Fatalf("Do: %v", err)
}
if resp.StatusCode() != http.StatusOK {
t.Fatalf("status: %d", resp.StatusCode())
}
scanner := bufio.NewScanner(resp.BodyStream())
got := 0
for scanner.Scan() {
if line := scanner.Text(); len(line) >= 5 && line[:5] == "data:" {
got++
}
}
if err := scanner.Err(); err != nil {
t.Fatalf("scanner: %v", err)
}
if got != totalChunks {
t.Errorf("chunks received: got %d, want %d (stream was likely killed early)", got, totalChunks)
}
}
// TestBuildStreamingHTTPClient_ZerosTimeout verifies the net/http streaming
// client has Timeout=0 and shares the base's Transport.
func TestBuildStreamingHTTPClient_ZerosTimeout(t *testing.T) {
transport := &http.Transport{ResponseHeaderTimeout: 10 * time.Second}
base := &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
}
stream := BuildStreamingHTTPClient(base)
if stream.Timeout != 0 {
t.Errorf("Timeout: got %v, want 0", stream.Timeout)
}
if stream.Transport != base.Transport {
t.Error("Transport: streaming client should share base's Transport")
}
if base.Timeout != 30*time.Second {
t.Errorf("base Timeout mutated: got %v, want 30s", base.Timeout)
}
}
// TestBuildStreamingHTTPClient_Nil verifies nil base returns empty client
// (not a panic).
func TestBuildStreamingHTTPClient_Nil(t *testing.T) {
stream := BuildStreamingHTTPClient(nil)
if stream == nil {
t.Fatal("BuildStreamingHTTPClient(nil) returned nil")
}
if stream.Timeout != 0 {
t.Errorf("Timeout: got %v, want 0", stream.Timeout)
}
}
// TestBuildStreamingHTTPClient_LongStreamSurvives verifies that the streaming
// client can read a response body that takes longer than the base client's
// Timeout — proving Timeout=0 actually lifts the whole-request deadline.
func TestBuildStreamingHTTPClient_LongStreamSurvives(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
flusher, _ := w.(http.Flusher)
for i := 0; i < 4; i++ {
fmt.Fprintf(w, "data: chunk-%d\n\n", i)
if flusher != nil {
flusher.Flush()
}
time.Sleep(400 * time.Millisecond)
}
}))
defer srv.Close()
base := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{Timeout: 5 * time.Second}).DialContext,
ResponseHeaderTimeout: 5 * time.Second,
},
Timeout: 500 * time.Millisecond, // would abort the stream without the fix
}
stream := BuildStreamingHTTPClient(base)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
if err != nil {
t.Fatalf("NewRequestWithContext: %v", err)
}
resp, err := stream.Do(req)
if err != nil {
t.Fatalf("Do: %v", err)
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
got := 0
for scanner.Scan() {
if line := scanner.Text(); len(line) >= 5 && line[:5] == "data:" {
got++
}
}
if err := scanner.Err(); err != nil {
t.Fatalf("scanner: %v", err)
}
if got != 4 {
t.Errorf("chunks received: got %d, want 4 (stream was likely killed by Timeout)", got)
}
}

View File

@@ -0,0 +1,274 @@
package utils
import (
"bytes"
"strings"
"github.com/bytedance/sonic"
)
const maxTerminalDetectorBufferBytes = 256 * 1024
var (
sseFrameDelimiterLF = []byte("\n\n")
sseFrameDelimiterCRLF = []byte("\r\n\r\n")
)
// StreamTerminalDetector incrementally parses stream frames and detects
// semantic completion markers such as finishReason or [DONE].
type StreamTerminalDetector struct {
pending bytes.Buffer
}
// ObserveChunk ingests a new raw stream chunk and returns true if a terminal
// marker was detected in a parsed frame payload.
func (d *StreamTerminalDetector) ObserveChunk(chunk []byte) bool {
if len(chunk) == 0 {
return false
}
// Fast path: detect terminal markers when a single chunk already contains
// a complete payload (SSE data line or plain JSON body).
// Skip this when the chunk already contains full SSE frame delimiters,
// because multi-event chunks need frame-by-frame parsing.
if !containsSSEFrameDelimiter(chunk) && d.detectInFrame(chunk) {
return true
}
d.pending.Write(chunk)
for {
data := d.pending.Bytes()
delimIdx, delimLen := findFirstSSEFrameDelimiter(data)
if delimIdx < 0 {
break
}
frame := append([]byte(nil), data[:delimIdx]...)
d.pending.Next(delimIdx + delimLen)
if d.detectInFrame(frame) {
return true
}
}
// Some passthrough streams emit plain JSON chunks (no SSE "\n\n" framing).
// Try parsing the current pending buffer as a whole JSON payload.
if d.detectInUndelimitedPending() {
return true
}
// Keep memory bounded if the upstream never emits a frame delimiter.
if d.pending.Len() > maxTerminalDetectorBufferBytes {
drain := d.pending.Bytes()
if idx, delimLen := findLastSSEFrameDelimiter(drain); idx >= 0 {
d.pending.Next(idx + delimLen)
} else {
trimTo := maxTerminalDetectorBufferBytes / 2
keptPrefix := append([]byte(nil), drain[:trimTo]...)
d.pending.Reset()
d.pending.Write(keptPrefix)
}
}
return false
}
func (d *StreamTerminalDetector) detectInUndelimitedPending() bool {
if d.pending.Len() == 0 {
return false
}
payload := bytes.TrimSpace(d.pending.Bytes())
if len(payload) == 0 {
return false
}
return hasFinishReasonMarker(payload)
}
func (d *StreamTerminalDetector) detectInFrame(frame []byte) bool {
payload := extractSSEDataPayload(frame)
if len(payload) == 0 {
return false
}
text := strings.TrimSpace(string(payload))
if text == "" {
return false
}
if text == "[DONE]" {
return true
}
return hasFinishReasonMarker([]byte(text))
}
func extractSSEDataPayload(frame []byte) []byte {
trimmed := bytes.TrimSpace(frame)
if len(trimmed) == 0 {
return nil
}
if !hasSSEDataLinePrefix(trimmed) {
return trimmed
}
lines := bytes.Split(trimmed, []byte("\n"))
var payload bytes.Buffer
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 || bytes.HasPrefix(line, []byte(":")) {
continue
}
if bytes.HasPrefix(line, []byte("data:")) {
data := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
if len(data) == 0 {
continue
}
if payload.Len() > 0 {
payload.WriteByte('\n')
}
payload.Write(data)
}
}
if payload.Len() == 0 {
return nil
}
return payload.Bytes()
}
func hasSSEDataLinePrefix(frame []byte) bool {
lines := bytes.Split(frame, []byte("\n"))
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 || bytes.HasPrefix(line, []byte(":")) {
continue
}
if bytes.HasPrefix(line, []byte("data:")) {
return true
}
}
return false
}
func hasFinishReasonMarker(payload []byte) bool {
var root any
if err := sonic.Unmarshal(payload, &root); err != nil {
return false
}
switch v := root.(type) {
case map[string]any:
return hasTerminalMarkerInTopLevelObject(v)
case []any:
for _, item := range v {
obj, ok := item.(map[string]any)
if !ok {
continue
}
if hasTerminalMarkerInTopLevelObject(obj) {
return true
}
}
}
return false
}
func hasTerminalMarkerInTopLevelObject(root map[string]any) bool {
if hasValidFinishReasonValue(root) {
return true
}
// Gemini/Vertex streamGenerateContent often signals terminal state in
// top-level candidates[*].finishReason.
if candidatesValue, ok := root["candidates"]; ok {
if candidates, ok := candidatesValue.([]any); ok {
return allCandidatesFinished(candidates)
}
}
// usageMetadata can show up before terminal chunks in long-running streams.
// Treat it as terminal only when all candidates are finished.
if usageMetadata, ok := root["usageMetadata"]; ok {
if usageMap, ok := usageMetadata.(map[string]any); ok && len(usageMap) > 0 {
if candidatesValue, ok := root["candidates"]; ok {
if candidates, ok := candidatesValue.([]any); ok && allCandidatesFinished(candidates) {
return true
}
}
}
}
if promptFeedback, ok := root["promptFeedback"]; ok {
if feedbackMap, ok := promptFeedback.(map[string]any); ok {
if reason, ok := feedbackMap["blockReason"].(string); ok && strings.TrimSpace(reason) != "" {
return true
}
}
}
return false
}
func allCandidatesFinished(candidates []any) bool {
if len(candidates) == 0 {
return false
}
for _, candidate := range candidates {
candidateMap, ok := candidate.(map[string]any)
if !ok {
return false
}
if !hasValidFinishReasonValue(candidateMap) {
return false
}
}
return true
}
func hasValidFinishReasonValue(node map[string]any) bool {
for _, key := range []string{"finishReason", "finish_reason"} {
value, ok := node[key]
if !ok {
continue
}
if str, ok := value.(string); ok && strings.TrimSpace(str) != "" && str != "FINISH_REASON_UNSPECIFIED" {
return true
}
}
return false
}
func findFirstSSEFrameDelimiter(data []byte) (idx int, delimLen int) {
idxLF := bytes.Index(data, sseFrameDelimiterLF)
idxCRLF := bytes.Index(data, sseFrameDelimiterCRLF)
switch {
case idxLF < 0 && idxCRLF < 0:
return -1, 0
case idxLF < 0:
return idxCRLF, len(sseFrameDelimiterCRLF)
case idxCRLF < 0:
return idxLF, len(sseFrameDelimiterLF)
case idxCRLF < idxLF:
return idxCRLF, len(sseFrameDelimiterCRLF)
default:
return idxLF, len(sseFrameDelimiterLF)
}
}
func findLastSSEFrameDelimiter(data []byte) (idx int, delimLen int) {
idxLF := bytes.LastIndex(data, sseFrameDelimiterLF)
idxCRLF := bytes.LastIndex(data, sseFrameDelimiterCRLF)
switch {
case idxLF < 0 && idxCRLF < 0:
return -1, 0
case idxLF < 0:
return idxCRLF, len(sseFrameDelimiterCRLF)
case idxCRLF < 0:
return idxLF, len(sseFrameDelimiterLF)
case idxCRLF > idxLF:
return idxCRLF, len(sseFrameDelimiterCRLF)
default:
return idxLF, len(sseFrameDelimiterLF)
}
}
func containsSSEFrameDelimiter(data []byte) bool {
return bytes.Contains(data, sseFrameDelimiterLF) || bytes.Contains(data, sseFrameDelimiterCRLF)
}

View File

@@ -0,0 +1,155 @@
package utils
import (
"bytes"
"testing"
)
func TestStreamTerminalDetectorObserveChunkSSEFinishReasonAcrossChunks(t *testing.T) {
detector := &StreamTerminalDetector{}
chunks := [][]byte{
[]byte("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"hi\"}]},"),
[]byte("\"finishReason\":\"STOP\"}]}\n\n"),
}
if detector.ObserveChunk(chunks[0]) {
t.Fatalf("unexpected terminal detection on first chunk")
}
if !detector.ObserveChunk(chunks[1]) {
t.Fatalf("expected terminal detection for candidates finishReason")
}
}
func TestStreamTerminalDetectorObserveChunkSSETopLevelFinishReasonAcrossChunks(t *testing.T) {
detector := &StreamTerminalDetector{}
chunks := [][]byte{
[]byte("data: {\"id\":\"abc\","),
[]byte("\"finishReason\":\"STOP\"}\n\n"),
}
if detector.ObserveChunk(chunks[0]) {
t.Fatalf("unexpected terminal detection on first chunk")
}
if !detector.ObserveChunk(chunks[1]) {
t.Fatalf("expected terminal detection for top-level finishReason")
}
}
func TestStreamTerminalDetectorObserveChunkDoneMarker(t *testing.T) {
detector := &StreamTerminalDetector{}
if !detector.ObserveChunk([]byte("data: [DONE]\n\n")) {
t.Fatalf("expected [DONE] marker to be terminal")
}
}
func TestStreamTerminalDetectorObserveChunkDoneMarkerCRLF(t *testing.T) {
detector := &StreamTerminalDetector{}
if !detector.ObserveChunk([]byte("data: [DONE]\r\n\r\n")) {
t.Fatalf("expected [DONE] marker with CRLF delimiter to be terminal")
}
}
func TestStreamTerminalDetectorObserveChunkIgnoresUnspecifiedFinishReason(t *testing.T) {
detector := &StreamTerminalDetector{}
if detector.ObserveChunk([]byte("data: {\"finishReason\":\"FINISH_REASON_UNSPECIFIED\"}\n\n")) {
t.Fatalf("unexpected terminal detection for FINISH_REASON_UNSPECIFIED")
}
}
func TestStreamTerminalDetectorObserveChunkPlainJSONAcrossChunks(t *testing.T) {
detector := &StreamTerminalDetector{}
if detector.ObserveChunk([]byte("{\"content\":\"hello\",")) {
t.Fatalf("unexpected terminal detection for incomplete json")
}
if !detector.ObserveChunk([]byte("\"finishReason\":\"STOP\"}")) {
t.Fatalf("expected terminal detection for plain json stream")
}
}
func TestStreamTerminalDetectorObserveChunkJSONWithDataURITokenAndDelimiter(t *testing.T) {
detector := &StreamTerminalDetector{}
chunk := []byte("{\"finishReason\":\"STOP\",\"content\":{\"parts\":[{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"data:image/png;base64,AAAA\"}}]}}\n\n")
if !detector.ObserveChunk(chunk) {
t.Fatalf("expected terminal detection for delimited JSON containing data URI token")
}
}
func TestStreamTerminalDetectorObserveChunkMultiEventSSEInSingleChunk(t *testing.T) {
detector := &StreamTerminalDetector{}
chunk := []byte("data: {}\n\ndata: {\"finishReason\":\"STOP\"}\n\n")
if !detector.ObserveChunk(chunk) {
t.Fatalf("expected terminal detection for multi-event SSE chunk")
}
}
func TestStreamTerminalDetectorObserveChunkMetadataOnlyFrameIsNotTerminal(t *testing.T) {
detector := &StreamTerminalDetector{}
if detector.ObserveChunk([]byte("data: {\"usageMetadata\":{\"totalTokenCount\":12}}\n\n")) {
t.Fatalf("unexpected terminal detection for metadata-only frame")
}
}
func TestStreamTerminalDetectorObserveChunkMetadataWithFinishedCandidateIsTerminal(t *testing.T) {
detector := &StreamTerminalDetector{}
if !detector.ObserveChunk([]byte("data: {\"usageMetadata\":{\"totalTokenCount\":12},\"candidates\":[{\"finishReason\":\"STOP\"}]}\n\n")) {
t.Fatalf("expected terminal detection for metadata with finished candidate")
}
}
func TestStreamTerminalDetectorObserveChunkMetadataWithUnspecifiedCandidateIsNotTerminal(t *testing.T) {
detector := &StreamTerminalDetector{}
if detector.ObserveChunk([]byte("data: {\"usageMetadata\":{\"totalTokenCount\":12},\"candidates\":[{\"finishReason\":\"FINISH_REASON_UNSPECIFIED\"}]}\n\n")) {
t.Fatalf("unexpected terminal detection for metadata with unfinished candidate")
}
}
func TestStreamTerminalDetectorObserveChunkMetadataWithMixedCandidatesIsNotTerminal(t *testing.T) {
detector := &StreamTerminalDetector{}
if detector.ObserveChunk([]byte("data: {\"usageMetadata\":{\"totalTokenCount\":12},\"candidates\":[{\"finishReason\":\"STOP\"},{}]}\n\n")) {
t.Fatalf("unexpected terminal detection for metadata with mixed finished/unfinished candidates")
}
}
func TestStreamTerminalDetectorObserveChunkTopLevelArrayWithCandidatesFinishReason(t *testing.T) {
detector := &StreamTerminalDetector{}
chunk := []byte("data: [{\"candidates\":[{\"finishReason\":\"STOP\"}]}]\n\n")
if !detector.ObserveChunk(chunk) {
t.Fatalf("expected terminal detection for top-level array payload")
}
}
func TestStreamTerminalDetectorObserveChunkCandidatesRequireAllFinished(t *testing.T) {
detector := &StreamTerminalDetector{}
chunk := []byte("data: {\"candidates\":[{\"finishReason\":\"STOP\"},{\"finishReason\":\"FINISH_REASON_UNSPECIFIED\"}]}\n\n")
if detector.ObserveChunk(chunk) {
t.Fatalf("unexpected terminal detection when not all candidates are finished")
}
}
func TestStreamTerminalDetectorObserveChunkCandidatesAllFinishedIsTerminal(t *testing.T) {
detector := &StreamTerminalDetector{}
chunk := []byte("data: {\"candidates\":[{\"finishReason\":\"STOP\"},{\"finishReason\":\"MAX_TOKENS\"}]}\n\n")
if !detector.ObserveChunk(chunk) {
t.Fatalf("expected terminal detection when all candidates are finished")
}
}
func TestStreamTerminalDetectorObserveChunkUndelimitedOverflowKeepsPrefix(t *testing.T) {
detector := &StreamTerminalDetector{}
originalPrefix := bytes.Repeat([]byte("a"), maxTerminalDetectorBufferBytes/2)
chunk := append([]byte(nil), originalPrefix...)
chunk = append(chunk, bytes.Repeat([]byte("b"), maxTerminalDetectorBufferBytes)...)
if detector.ObserveChunk(chunk) {
t.Fatalf("unexpected terminal detection for non-json buffer")
}
trimTo := maxTerminalDetectorBufferBytes / 2
got := detector.pending.Bytes()
if len(got) != trimTo {
t.Fatalf("expected pending length %d, got %d", trimTo, len(got))
}
if !bytes.Equal(got, originalPrefix) {
t.Fatalf("expected pending buffer to keep original prefix")
}
}

View File

@@ -0,0 +1,171 @@
package utils
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"testing"
"time"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// testLogger is a minimal logger for tests that implements schemas.Logger.
type testLogger struct{}
func (testLogger) Debug(string, ...any) {}
func (testLogger) Info(string, ...any) {}
func (testLogger) Warn(string, ...any) {}
func (testLogger) Error(string, ...any) {}
func (testLogger) Fatal(string, ...any) {}
func (testLogger) SetLevel(schemas.LogLevel) {}
func (testLogger) SetOutputType(schemas.LoggerOutputType) {}
func (testLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
return schemas.NoopLogEvent
}
// validTestCertPEM returns a minimal valid PEM-encoded CA certificate for testing.
func validTestCertPEM(t *testing.T) string {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("generate key: %v", err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "test-ca"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
t.Fatalf("create certificate: %v", err)
}
block := &pem.Block{Type: "CERTIFICATE", Bytes: certDER}
return string(pem.EncodeToMemory(block))
}
func TestConfigureTLS_ReturnsUnchangedWhenNeitherSet(t *testing.T) {
client := &fasthttp.Client{}
logger := testLogger{}
result := ConfigureTLS(client, schemas.NetworkConfig{}, logger)
if result != client {
t.Error("ConfigureTLS should return the same client when neither InsecureSkipVerify nor CACertPEM is set")
}
if client.TLSConfig != nil {
t.Error("TLSConfig should remain nil when no TLS options are set")
}
}
func TestConfigureTLS_SetsInsecureSkipVerify(t *testing.T) {
client := &fasthttp.Client{}
logger := testLogger{}
result := ConfigureTLS(client, schemas.NetworkConfig{InsecureSkipVerify: true}, logger)
if result != client {
t.Error("ConfigureTLS should return the same client")
}
if client.TLSConfig == nil {
t.Fatal("TLSConfig should be set when InsecureSkipVerify is true")
}
if !client.TLSConfig.InsecureSkipVerify {
t.Error("InsecureSkipVerify should be true")
}
}
func TestConfigureTLS_AppliesCACertPEM(t *testing.T) {
client := &fasthttp.Client{}
logger := testLogger{}
caPEM := validTestCertPEM(t)
result := ConfigureTLS(client, schemas.NetworkConfig{CACertPEM: caPEM}, logger)
if result != client {
t.Error("ConfigureTLS should return the same client")
}
if client.TLSConfig == nil {
t.Fatal("TLSConfig should be set when CACertPEM is provided")
}
if client.TLSConfig.RootCAs == nil {
t.Error("RootCAs should be set when CACertPEM is provided")
}
}
func TestConfigureTLS_HandlesInvalidCACertPEM(t *testing.T) {
client := &fasthttp.Client{}
logger := testLogger{}
result := ConfigureTLS(client, schemas.NetworkConfig{CACertPEM: "not-valid-pem"}, logger)
if result != client {
t.Error("ConfigureTLS should return the same client even when CACertPEM is invalid")
}
// Invalid PEM logs warning and skips RootCAs; TLSConfig may still be set with MinVersion
if client.TLSConfig != nil && client.TLSConfig.RootCAs != nil {
t.Error("RootCAs should not be set when CACertPEM is invalid")
}
}
func TestConfigureTLS_MergesWithExistingTLSConfig(t *testing.T) {
// Simulate client that already has TLSConfig from ConfigureProxy
existingRootCAs, _ := x509.SystemCertPool()
if existingRootCAs == nil {
existingRootCAs = x509.NewCertPool()
}
client := &fasthttp.Client{
TLSConfig: &tls.Config{
RootCAs: existingRootCAs,
MinVersion: tls.VersionTLS12,
},
}
logger := testLogger{}
caPEM := validTestCertPEM(t)
result := ConfigureTLS(client, schemas.NetworkConfig{CACertPEM: caPEM}, logger)
if result != client {
t.Error("ConfigureTLS should return the same client")
}
if client.TLSConfig == nil {
t.Fatal("TLSConfig should remain set")
}
if client.TLSConfig.RootCAs == nil {
t.Error("RootCAs should be set (merged with existing)")
}
}
func TestConfigureTLS_InsecureSkipVerifyAndCACertPEM(t *testing.T) {
client := &fasthttp.Client{}
logger := testLogger{}
caPEM := validTestCertPEM(t)
result := ConfigureTLS(client, schemas.NetworkConfig{
InsecureSkipVerify: true,
CACertPEM: caPEM,
}, logger)
if result != client {
t.Error("ConfigureTLS should return the same client")
}
if client.TLSConfig == nil {
t.Fatal("TLSConfig should be set")
}
if !client.TLSConfig.InsecureSkipVerify {
t.Error("InsecureSkipVerify should be true when both options are set")
}
if client.TLSConfig.RootCAs == nil {
t.Error("RootCAs should be set when CACertPEM is provided")
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,122 @@
package utils
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSetJSONField(t *testing.T) {
tests := []struct {
name string
data []byte
path string
value interface{}
expected string
}{
{
name: "set_string_on_empty_object",
data: []byte(`{}`),
path: "name",
value: "test",
expected: `{"name":"test"}`,
},
{
name: "set_nested_path",
data: []byte(`{}`),
path: "file.displayName",
value: "photo.jpg",
expected: `{"file":{"displayName":"photo.jpg"}}`,
},
{
name: "set_boolean",
data: []byte(`{"model":"x"}`),
path: "stream",
value: true,
expected: `{"model":"x","stream":true}`,
},
{
name: "set_string_array",
data: []byte(`{}`),
path: "betas",
value: []string{"a", "b"},
expected: `{"betas":["a","b"]}`,
},
{
name: "preserves_existing_fields",
data: []byte(`{"a":1,"b":2}`),
path: "c",
value: 3,
expected: `{"a":1,"b":2,"c":3}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := SetJSONField(tt.data, tt.path, tt.value)
require.NoError(t, err)
assert.Equal(t, tt.expected, string(result), "exact byte-level ordering must match")
})
}
}
func TestDeleteJSONField(t *testing.T) {
tests := []struct {
name string
data []byte
path string
expected string
wantErr bool
}{
{
name: "delete_existing_field",
data: []byte(`{"a":1,"b":2}`),
path: "a",
expected: `{"b":2}`,
},
{
name: "delete_nonexistent_field",
data: []byte(`{"a":1}`),
path: "b",
expected: `{"a":1}`,
},
{
name: "sequential_deletes",
data: []byte(`{"a":1,"b":2,"c":3}`),
path: "", // handled in validate
expected: `{}`,
},
{
name: "preserves_remaining_order",
data: []byte(`{"x":1,"y":2,"z":3}`),
path: "y",
expected: `{"x":1,"z":3}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.name == "sequential_deletes" {
data := []byte(`{"a":1,"b":2,"c":3}`)
var err error
data, err = DeleteJSONField(data, "a")
require.NoError(t, err)
data, err = DeleteJSONField(data, "b")
require.NoError(t, err)
data, err = DeleteJSONField(data, "c")
require.NoError(t, err)
assert.Equal(t, tt.expected, string(data), "exact byte-level ordering must match")
return
}
result, err := DeleteJSONField(tt.data, tt.path)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expected, string(result), "exact byte-level ordering must match")
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,34 @@
package utils
import (
"net/url"
"strings"
schemas "github.com/maximhq/bifrost/core/schemas"
)
// StripVideoIDProviderSuffix removes ":<provider>" from a video ID if present.
func StripVideoIDProviderSuffix(videoID string, provider schemas.ModelProvider) string {
suffix := ":" + string(provider)
stripped := strings.TrimSuffix(videoID, suffix)
// URL decode the ID to restore original characters (e.g., %2F -> /)
if decoded, err := url.PathUnescape(stripped); err == nil {
return decoded
}
return stripped
}
// AddVideoIDProviderSuffix ensures a video ID is scoped as "<id>:<provider>".
func AddVideoIDProviderSuffix(videoID string, provider schemas.ModelProvider) string {
if videoID == "" {
return videoID
}
suffix := ":" + string(provider)
if strings.HasSuffix(videoID, suffix) {
return videoID
}
// URL-encode the video ID to make it safe for URL paths
// This converts / to %2F and other special characters
escapedVideoID := url.PathEscape(videoID)
return escapedVideoID + suffix
}