Files
bifrost/transports/bifrost-http/lib/streamreader_test.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

715 lines
17 KiB
Go

package lib
import (
"fmt"
"io"
"sync"
"testing"
)
func TestSSEStreamReaderSingleEventPerRead(t *testing.T) {
r := NewSSEStreamReader()
events := [][]byte{
[]byte("data: {\"chunk\":1}\n\n"),
[]byte("data: {\"chunk\":2}\n\n"),
[]byte("data: {\"chunk\":3}\n\n"),
}
errCh := make(chan error, 1)
go func() {
for _, e := range events {
if !r.Send(e) {
select {
case errCh <- fmt.Errorf("Send returned false unexpectedly"):
default:
}
return
}
}
r.Done()
}()
buf := make([]byte, 4096)
for i, want := range events {
n, err := r.Read(buf)
if err != nil {
t.Fatalf("event %d: unexpected error: %v", i, err)
}
got := string(buf[:n])
if got != string(want) {
t.Errorf("event %d: got %q, want %q", i, got, want)
}
}
// Next read should return EOF
n, err := r.Read(buf)
if err != io.EOF {
t.Errorf("expected io.EOF, got err=%v n=%d", err, n)
}
select {
case err := <-errCh:
t.Error(err)
default:
}
}
func TestSSEStreamReaderPartialRead(t *testing.T) {
r := NewSSEStreamReader()
event := []byte("data: {\"content\":\"hello world\"}\n\n")
go func() {
r.Send(event)
r.Done()
}()
// Read with a small buffer (5 bytes at a time)
var result []byte
buf := make([]byte, 5)
for {
n, err := r.Read(buf)
result = append(result, buf[:n]...)
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
if string(result) != string(event) {
t.Errorf("reassembled data: got %q, want %q", result, event)
}
}
func TestSSEStreamReaderEOFOnDone(t *testing.T) {
r := NewSSEStreamReader()
r.Done() // Close immediately
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != io.EOF {
t.Errorf("expected io.EOF, got err=%v n=%d", err, n)
}
}
func TestSSEStreamReaderCloseSignalsProducer(t *testing.T) {
r := NewSSEStreamReader()
r.Close()
if r.Send([]byte("data: test\n\n")) {
t.Error("Send should return false after Close")
}
}
func TestSSEStreamReaderIdempotentClose(t *testing.T) {
r := NewSSEStreamReader()
// Should not panic
r.Close()
r.Close()
r.Close()
}
func TestSSEStreamReaderConcurrent(t *testing.T) {
r := NewSSEStreamReader()
const numEvents = 100
var wg sync.WaitGroup
wg.Add(1)
// Producer
go func() {
for i := 0; i < numEvents; i++ {
if !r.Send([]byte("data: event\n\n")) {
break
}
}
r.Done()
}()
// Consumer
errCh := make(chan error, 2)
go func() {
defer wg.Done()
buf := make([]byte, 4096)
count := 0
for {
_, err := r.Read(buf)
if err == io.EOF {
break
}
if err != nil {
select {
case errCh <- fmt.Errorf("unexpected error: %v", err):
default:
}
break
}
count++
}
if count != numEvents {
select {
case errCh <- fmt.Errorf("got %d events, want %d", count, numEvents):
default:
}
}
}()
wg.Wait()
close(errCh)
for err := range errCh {
t.Error(err)
}
}
func TestSSEStreamReaderSendEvent(t *testing.T) {
tests := []struct {
name string
eventType string
data []byte
want string
}{
{
name: "data only",
eventType: "",
data: []byte(`{"chunk":1}`),
want: "data: {\"chunk\":1}\n\n",
},
{
name: "with event type",
eventType: "response.delta",
data: []byte(`{"delta":"hi"}`),
want: "event: response.delta\ndata: {\"delta\":\"hi\"}\n\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendEvent(tt.eventType, tt.data)
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := string(buf[:n]); got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
})
}
}
func TestSSEStreamReaderSendError(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendError([]byte(`{"error":"bad"}`))
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := "event: error\ndata: {\"error\":\"bad\"}\n\n"
if got := string(buf[:n]); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestSSEStreamReaderSendDone(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendDone()
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := "data: [DONE]\n\n"
if got := string(buf[:n]); got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestSSEStreamReaderSendEventAfterClose(t *testing.T) {
r := NewSSEStreamReader()
r.Close()
if r.SendEvent("test", []byte("data")) {
t.Error("SendEvent should return false after Close")
}
if r.SendError([]byte("err")) {
t.Error("SendError should return false after Close")
}
if r.SendDone() {
t.Error("SendDone should return false after Close")
}
}
// TestSSEStreamReaderSendEventByteAccuracy verifies that SendEvent produces
// the exact same bytes that the old manual buffer assembly in the handlers did.
func TestSSEStreamReaderSendEventByteAccuracy(t *testing.T) {
tests := []struct {
name string
eventType string
data []byte
want []byte
}{
{
name: "standard SSE data (old inference.go pattern)",
eventType: "",
data: []byte(`{"id":"chatcmpl-123","choices":[{"delta":{"content":"Hello"}}]}`),
want: func() []byte {
// Old code: buf = append(buf, "data: "...); buf = append(buf, chunkJSON...); buf = append(buf, '\n', '\n')
data := []byte(`{"id":"chatcmpl-123","choices":[{"delta":{"content":"Hello"}}]}`)
buf := make([]byte, 0, len(data)+8)
buf = append(buf, "data: "...)
buf = append(buf, data...)
buf = append(buf, '\n', '\n')
return buf
}(),
},
{
name: "OpenAI responses format with event type (old inference.go pattern)",
eventType: "response.output_item.added",
data: []byte(`{"type":"response.output_item.added","item":{"id":"item_1"}}`),
want: func() []byte {
// Old code: buf = append(buf, "event: "...); buf = append(buf, eventType...); buf = append(buf, "\ndata: "...); ...
eventType := "response.output_item.added"
data := []byte(`{"type":"response.output_item.added","item":{"id":"item_1"}}`)
buf := make([]byte, 0, len(eventType)+len(data)+16)
buf = append(buf, "event: "...)
buf = append(buf, eventType...)
buf = append(buf, "\ndata: "...)
buf = append(buf, data...)
buf = append(buf, '\n', '\n')
return buf
}(),
},
{
name: "error event (old interceptor pattern)",
eventType: "error",
data: []byte(`{"error":"stream interrupted"}`),
want: func() []byte {
data := []byte(`{"error":"stream interrupted"}`)
buf := make([]byte, 0, len(data)+24)
buf = append(buf, "event: error\ndata: "...)
buf = append(buf, data...)
buf = append(buf, '\n', '\n')
return buf
}(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendEvent(tt.eventType, tt.data)
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
got := buf[:n]
if string(got) != string(tt.want) {
t.Errorf("byte mismatch:\n got: %q\n want: %q", got, tt.want)
}
})
}
}
// TestSSEStreamReaderSendErrorByteAccuracy verifies SendError matches
// the old "event: error\ndata: ..." manual assembly.
func TestSSEStreamReaderSendErrorByteAccuracy(t *testing.T) {
r := NewSSEStreamReader()
errorJSON := []byte(`{"error":{"type":"internal_error","message":"An error occurred"}}`)
go func() {
r.SendError(errorJSON)
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Must match the old pattern exactly:
// buf = append(buf, "event: error\ndata: "...)
// buf = append(buf, errorJSON...)
// buf = append(buf, '\n', '\n')
want := "event: error\ndata: " + string(errorJSON) + "\n\n"
if got := string(buf[:n]); got != want {
t.Errorf("byte mismatch:\n got: %q\n want: %q", got, want)
}
}
// TestSSEStreamReaderMixedMethodStream simulates a realistic stream
// that uses multiple methods (like router.go does): data events,
// typed events, error, and done marker.
func TestSSEStreamReaderMixedMethodStream(t *testing.T) {
r := NewSSEStreamReader()
expected := []string{
"data: {\"chunk\":1}\n\n",
"event: response.delta\ndata: {\"delta\":\"hi\"}\n\n",
"data: {\"chunk\":2}\n\n",
"event: error\ndata: {\"error\":\"timeout\"}\n\n",
"data: [DONE]\n\n",
}
go func() {
r.SendEvent("", []byte(`{"chunk":1}`))
r.SendEvent("response.delta", []byte(`{"delta":"hi"}`))
r.SendEvent("", []byte(`{"chunk":2}`))
r.SendError([]byte(`{"error":"timeout"}`))
r.SendDone()
r.Done()
}()
buf := make([]byte, 4096)
for i, want := range expected {
n, err := r.Read(buf)
if err != nil {
t.Fatalf("event %d: unexpected error: %v", i, err)
}
if got := string(buf[:n]); got != want {
t.Errorf("event %d:\n got: %q\n want: %q", i, got, want)
}
}
// Should be EOF after all events
n, err := r.Read(buf)
if err != io.EOF {
t.Errorf("expected EOF, got err=%v n=%d", err, n)
}
}
// TestSSEStreamReaderRawAndWrapperMixed simulates the router.go pattern
// where raw Send (for Bedrock/passthrough) is mixed with wrapper methods.
func TestSSEStreamReaderRawAndWrapperMixed(t *testing.T) {
r := NewSSEStreamReader()
// Simulate: Bedrock binary event (raw), followed by SSE events, then done
bedrockBinary := []byte{0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00, 0x2A} // fake binary
preformattedSSE := []byte("event: content_block_delta\ndata: {\"delta\":\"test\"}\n\n")
expected := [][]byte{
bedrockBinary,
preformattedSSE,
[]byte("data: {\"final\":true}\n\n"),
}
go func() {
r.Send(bedrockBinary) // raw binary passthrough
r.Send(preformattedSSE) // pre-formatted SSE string
r.SendEvent("", []byte(`{"final":true}`)) // wrapper method
r.Done()
}()
buf := make([]byte, 4096)
for i, want := range expected {
n, err := r.Read(buf)
if err != nil {
t.Fatalf("event %d: unexpected error: %v", i, err)
}
if string(buf[:n]) != string(want) {
t.Errorf("event %d:\n got: %q\n want: %q", i, buf[:n], want)
}
}
}
// TestSSEStreamReaderSendEventEmptyData verifies behavior with empty data payload.
func TestSSEStreamReaderSendEventEmptyData(t *testing.T) {
tests := []struct {
name string
eventType string
want string
}{
{
name: "empty data no event type",
eventType: "",
want: "data: \n\n",
},
{
name: "empty data with event type",
eventType: "heartbeat",
want: "event: heartbeat\ndata: \n\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendEvent(tt.eventType, []byte{})
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := string(buf[:n]); got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
})
}
}
// TestSSEStreamReaderSendEventNilData verifies behavior with nil data payload.
func TestSSEStreamReaderSendEventNilData(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendEvent("", nil)
r.Done()
}()
buf := make([]byte, 4096)
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// nil data should produce same as empty: "data: \n\n"
if got := string(buf[:n]); got != "data: \n\n" {
t.Errorf("got %q, want %q", got, "data: \n\n")
}
}
// TestSSEStreamReaderSendEventLargePayload verifies no corruption with large JSON payloads.
func TestSSEStreamReaderSendEventLargePayload(t *testing.T) {
r := NewSSEStreamReader()
// Build a large JSON payload (~64KB, larger than typical ReadBufferSize)
largeContent := make([]byte, 65536)
for i := range largeContent {
largeContent[i] = 'A' + byte(i%26)
}
data := append([]byte(`{"content":"`), largeContent...)
data = append(data, '"', '}')
go func() {
r.SendEvent("response.delta", data)
r.Done()
}()
// Read the entire event using small buffer to exercise partial reads
var result []byte
buf := make([]byte, 1024)
for {
n, err := r.Read(buf)
result = append(result, buf[:n]...)
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
want := "event: response.delta\ndata: " + string(data) + "\n\n"
if string(result) != want {
t.Errorf("large payload mismatch: got len=%d, want len=%d", len(result), len(want))
// Check prefix and suffix for debugging
if len(result) > 40 {
t.Errorf(" got prefix: %q", result[:40])
t.Errorf(" want prefix: %q", want[:40])
}
}
}
// TestSSEStreamReaderMidStreamDisconnect simulates a client disconnecting
// mid-stream while the producer is using SendEvent.
func TestSSEStreamReaderMidStreamDisconnect(t *testing.T) {
r := NewSSEStreamReader()
producerDone := make(chan int) // reports how many events were sent
go func() {
sent := 0
for i := 0; i < 100; i++ {
if !r.SendEvent("", []byte(fmt.Sprintf(`{"chunk":%d}`, i))) {
break
}
sent++
}
close(producerDone)
}()
// Read a few events then simulate client disconnect
buf := make([]byte, 4096)
for i := 0; i < 3; i++ {
_, err := r.Read(buf)
if err != nil {
t.Fatalf("event %d: unexpected error: %v", i, err)
}
}
// Client disconnects
r.Close()
// Producer should stop promptly
<-producerDone
}
// TestSSEStreamReaderSendErrorThenDone verifies the handler pattern
// of sending an error event and immediately closing the stream.
func TestSSEStreamReaderSendErrorThenDone(t *testing.T) {
r := NewSSEStreamReader()
go func() {
// Send a few normal events
r.SendEvent("", []byte(`{"chunk":1}`))
r.SendEvent("", []byte(`{"chunk":2}`))
// Error occurs, send error and stop
r.SendError([]byte(`{"error":"rate_limit"}`))
r.Done()
}()
buf := make([]byte, 4096)
expected := []string{
"data: {\"chunk\":1}\n\n",
"data: {\"chunk\":2}\n\n",
"event: error\ndata: {\"error\":\"rate_limit\"}\n\n",
}
for i, want := range expected {
n, err := r.Read(buf)
if err != nil {
t.Fatalf("event %d: unexpected error: %v", i, err)
}
if got := string(buf[:n]); got != want {
t.Errorf("event %d: got %q, want %q", i, got, want)
}
}
// Should be EOF (stream ended after error, no [DONE] marker)
n, err := r.Read(buf)
if err != io.EOF {
t.Errorf("expected EOF after error event, got err=%v n=%d data=%q", err, n, buf[:n])
}
}
// TestSSEStreamReaderSendDoneByteExact verifies SendDone produces
// exactly "data: [DONE]\n\n" — the standard OpenAI SSE terminator.
func TestSSEStreamReaderSendDoneByteExact(t *testing.T) {
r := NewSSEStreamReader()
go func() {
r.SendDone()
r.Done()
}()
// Use exact-size buffer to verify no extra bytes
want := []byte("data: [DONE]\n\n")
buf := make([]byte, len(want))
n, err := r.Read(buf)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if n != len(want) {
t.Errorf("expected %d bytes, got %d", len(want), n)
}
if string(buf[:n]) != string(want) {
t.Errorf("got %q, want %q", buf[:n], want)
}
}
// TestSSEStreamReaderConcurrentSendEvent verifies thread safety of SendEvent
// with multiple concurrent producers (not a real pattern but validates safety).
func TestSSEStreamReaderConcurrentSendEvent(t *testing.T) {
r := NewSSEStreamReader()
const numProducers = 5
const eventsPerProducer = 20
var wg sync.WaitGroup
wg.Add(numProducers)
// Launch multiple producers
for p := 0; p < numProducers; p++ {
go func(id int) {
defer wg.Done()
for i := 0; i < eventsPerProducer; i++ {
if !r.SendEvent("", []byte(fmt.Sprintf(`{"p":%d,"i":%d}`, id, i))) {
return
}
}
}(p)
}
// Close after all producers finish
go func() {
wg.Wait()
r.Done()
}()
// Consume all events
buf := make([]byte, 4096)
count := 0
for {
n, err := r.Read(buf)
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Every event must be a valid SSE data line
got := string(buf[:n])
if len(got) < 8 || got[:6] != "data: " || got[len(got)-2:] != "\n\n" {
t.Errorf("event %d: invalid SSE format: %q", count, got)
}
count++
}
if count != numProducers*eventsPerProducer {
t.Errorf("got %d events, want %d", count, numProducers*eventsPerProducer)
}
}
func TestSSEStreamReaderCloseUnblocksProducer(t *testing.T) {
r := NewSSEStreamReader()
done := make(chan struct{})
errCh := make(chan error, 1)
go func() {
defer close(done)
// Fill the channel buffer (cap=1)
r.Send([]byte("data: first\n\n"))
// This Send should block until Close is called
r.Send([]byte("data: second\n\n"))
// After Close, the next Send should return false
if r.Send([]byte("data: third\n\n")) {
select {
case errCh <- fmt.Errorf("Send should return false after Close"):
default:
}
}
}()
// Close unblocks the blocked Send
r.Close()
<-done
select {
case err := <-errCh:
t.Error(err)
default:
}
}