Files
bifrost/core/internal/llmtests/text_completion_stream.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

490 lines
15 KiB
Go

package llmtests
import (
"context"
"fmt"
"os"
"strings"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// RunTextCompletionStreamTest executes the text completion streaming test scenario
func RunTextCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.TextCompletionStream {
t.Logf("Text completion stream not supported for provider %s", testConfig.Provider)
return
}
t.Run("TextCompletionStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Create a text completion prompt
prompt := "Write a short story about a robot learning to paint. Keep it under 150 words."
input := &schemas.TextCompletionInput{
PromptStr: &prompt,
}
// Use TextModel if available, otherwise fall back to ChatModel
model := testConfig.TextModel
if model == "" {
model = testConfig.ChatModel
}
request := &schemas.BifrostTextCompletionRequest{
Provider: testConfig.Provider,
Model: model,
Input: input,
Params: &schemas.TextCompletionParameters{
MaxTokens: bifrost.Ptr(150),
},
Fallbacks: testConfig.TextCompletionFallbacks,
}
// Use retry framework for stream requests
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: "TextCompletionStream",
ExpectedBehavior: map[string]interface{}{
"should_stream_content": true,
"should_tell_story": true,
"topic": "robot painting",
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": model,
},
}
// Use proper streaming retry wrapper for the stream request
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TextCompletionStreamRequest(bfCtx, request)
})
// Enhanced error handling
RequireNoError(t, err, "Text completion stream request failed")
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
var fullContent strings.Builder
var responseCount int
var lastResponse *schemas.BifrostStreamChunk
// Create a timeout context for the stream reading
streamCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
t.Logf("📡 Starting to read text completion streaming response...")
// Read streaming responses
for {
select {
case response, ok := <-responseChannel:
if !ok {
// Channel closed, streaming completed
t.Logf("✅ Text completion streaming completed. Total chunks received: %d", responseCount)
goto streamComplete
}
if response == nil {
t.Fatal("Streaming response should not be nil")
}
lastResponse = DeepCopyBifrostStreamChunk(response)
// Basic validation of streaming response structure
if response.BifrostTextCompletionResponse != nil {
if response.BifrostTextCompletionResponse.ExtraFields.Provider != testConfig.Provider {
t.Logf("⚠️ Warning: Provider mismatch - expected %s, got %s", testConfig.Provider, response.BifrostTextCompletionResponse.ExtraFields.Provider)
}
if response.BifrostTextCompletionResponse.ID == "" {
t.Logf("⚠️ Warning: Response ID is empty")
}
// Log latency for each chunk (can be 0 for inter-chunks)
t.Logf("📊 Chunk %d latency: %d ms", responseCount+1, response.BifrostTextCompletionResponse.ExtraFields.Latency)
// Validate text completion response structure
if response.BifrostTextCompletionResponse.Choices == nil {
t.Logf("⚠️ Warning: Choices should not be nil in text completion streaming")
}
// Process each choice in the response (similar to chat completion)
for _, choice := range response.BifrostTextCompletionResponse.Choices {
// For text completion, we expect either streaming deltas or text completion choices
if choice.TextCompletionResponseChoice != nil {
// Handle direct text completion response choice (converted by providers)
if choice.TextCompletionResponseChoice.Text != nil {
fullContent.WriteString(*choice.TextCompletionResponseChoice.Text)
t.Logf("✍️ Text completion: %s", *choice.TextCompletionResponseChoice.Text)
}
// Check finish reason if present
if choice.FinishReason != nil {
t.Logf("🏁 Finish reason: %s", *choice.FinishReason)
}
} else {
t.Logf("⚠️ Warning: Choice %d has no text completion or stream response content", choice.Index)
}
}
}
responseCount++
// Safety check to prevent infinite loops in case of issues
if responseCount > 500 {
t.Fatal("Received too many streaming chunks, something might be wrong")
}
case <-streamCtx.Done():
t.Fatal("Timeout waiting for text completion streaming response")
}
}
streamComplete:
// Validate final streaming response
finalContent := strings.TrimSpace(fullContent.String())
// Create a consolidated response for validation
consolidatedResponse := createConsolidatedTextCompletionResponse(finalContent, lastResponse, testConfig.Provider)
// Enhanced validation expectations for text completion streaming
expectations := GetExpectationsForScenario("TextCompletionStream", testConfig, map[string]interface{}{})
expectations = ModifyExpectationsForProvider(expectations, testConfig.Provider)
expectations.ShouldContainKeywords = append(expectations.ShouldContainKeywords, []string{"robot"}...) // Should include story elements
// Validate the consolidated text completion streaming response
validationResult := ValidateTextCompletionResponse(t, consolidatedResponse, nil, expectations, "TextCompletionStream")
// Basic streaming validation
if responseCount == 0 {
t.Fatal("Should receive at least one streaming response")
}
if finalContent == "" {
t.Fatal("Final content should not be empty")
}
if len(finalContent) < 5 {
t.Fatal("Final content should be substantial")
}
// Validate latency is present in the last chunk (total latency)
if lastResponse != nil && lastResponse.BifrostTextCompletionResponse != nil {
if lastResponse.BifrostTextCompletionResponse.ExtraFields.Latency <= 0 {
t.Fatalf("❌ Last streaming chunk missing latency information (got %d ms)", lastResponse.BifrostTextCompletionResponse.ExtraFields.Latency)
} else {
t.Logf("✅ Total streaming latency: %d ms", lastResponse.BifrostTextCompletionResponse.ExtraFields.Latency)
}
}
if !validationResult.Passed {
t.Fatalf("❌ Text completion streaming validation failed: %v", validationResult.Errors)
}
t.Logf("📊 Text completion streaming metrics: %d chunks, %d chars", responseCount, len(finalContent))
t.Logf("✅ Text completion streaming test completed successfully")
t.Logf("📝 Final content (%d chars): %s", len(finalContent), finalContent)
})
// Test text completion streaming with different prompts
t.Run("TextCompletionStreamVariedPrompts", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Use TextModel if available, otherwise fall back to ChatModel
model := testConfig.TextModel
if model == "" {
model = testConfig.ChatModel
}
testPrompts := []struct {
name string
prompt string
expect string
}{
{
name: "SimpleCompletion",
prompt: "The quick brown fox",
expect: "completion",
},
{
name: "Question",
prompt: "What is artificial intelligence? AI is",
expect: "definition",
},
{
name: "CodeCompletion",
prompt: "def fibonacci(n):\n if n <= 1:",
expect: "code",
},
}
for _, testCase := range testPrompts {
t.Run(testCase.name, func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
input := &schemas.TextCompletionInput{
PromptStr: &testCase.prompt,
}
request := &schemas.BifrostTextCompletionRequest{
Provider: testConfig.Provider,
Model: model,
Input: input,
Params: &schemas.TextCompletionParameters{
MaxTokens: bifrost.Ptr(50),
Temperature: bifrost.Ptr(0.7),
},
Fallbacks: testConfig.TextCompletionFallbacks,
}
// Use retry framework for stream requests
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: fmt.Sprintf("TextCompletionStreamVariedPrompts_%s", testCase.name),
ExpectedBehavior: map[string]interface{}{
"should_stream_content": true,
"prompt_type": testCase.name,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": model,
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TextCompletionStreamRequest(bfCtx, request)
})
RequireNoError(t, err, "Text completion stream with varied prompts failed")
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
var responseCount int
var content strings.Builder
streamCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
defer cancel()
t.Logf("Testing text completion streaming with prompt: %s", testCase.name)
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto variedPromptComplete
}
if response == nil {
t.Fatal("Streaming response should not be nil")
}
responseCount++
// Extract content from choices
if response.BifrostTextCompletionResponse != nil {
for _, choice := range response.BifrostTextCompletionResponse.Choices {
if choice.TextCompletionResponseChoice != nil {
delta := choice.TextCompletionResponseChoice.Text
if delta != nil {
content.WriteString(*delta)
}
}
}
}
if responseCount > 100 {
goto variedPromptComplete
}
case <-streamCtx.Done():
t.Fatal("Timeout waiting for text completion streaming response")
}
}
variedPromptComplete:
finalContent := strings.TrimSpace(content.String())
if responseCount == 0 {
t.Fatal("Should receive at least one streaming response")
}
if finalContent == "" {
t.Logf("⚠️ Warning: No content generated for prompt: %s", testCase.prompt)
} else {
t.Logf("✅ Generated content for %s: %s", testCase.name, finalContent)
}
})
}
})
// Test text completion streaming with different parameters
t.Run("TextCompletionStreamParameters", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
// Use TextModel if available, otherwise fall back to ChatModel
model := testConfig.TextModel
if model == "" {
model = testConfig.ChatModel
}
prompt := "Once upon a time in a distant galaxy"
parameterTests := []struct {
name string
temperature *float64
maxTokens *int
topP *float64
}{
{
name: "HighCreativity",
temperature: bifrost.Ptr(0.9),
maxTokens: bifrost.Ptr(100),
topP: bifrost.Ptr(0.9),
},
{
name: "LowCreativity",
temperature: bifrost.Ptr(0.1),
maxTokens: bifrost.Ptr(50),
topP: bifrost.Ptr(0.5),
},
{
name: "Balanced",
temperature: bifrost.Ptr(0.5),
maxTokens: bifrost.Ptr(75),
topP: bifrost.Ptr(0.8),
},
}
for _, paramTest := range parameterTests {
t.Run(paramTest.name, func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
input := &schemas.TextCompletionInput{
PromptStr: &prompt,
}
request := &schemas.BifrostTextCompletionRequest{
Provider: testConfig.Provider,
Model: model,
Input: input,
Params: &schemas.TextCompletionParameters{
MaxTokens: paramTest.maxTokens,
Temperature: paramTest.temperature,
TopP: paramTest.topP,
},
Fallbacks: testConfig.TextCompletionFallbacks,
}
// Use retry framework for stream requests
retryConfig := StreamingRetryConfig()
retryContext := TestRetryContext{
ScenarioName: fmt.Sprintf("TextCompletionStreamParameters_%s", paramTest.name),
ExpectedBehavior: map[string]interface{}{
"should_stream_content": true,
"parameter_test": paramTest.name,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": model,
},
}
responseChannel, err := WithStreamRetry(t, retryConfig, retryContext, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
return client.TextCompletionStreamRequest(bfCtx, request)
})
RequireNoError(t, err, "Text completion stream with parameters failed")
if responseChannel == nil {
t.Fatal("Response channel should not be nil")
}
var responseCount int
streamCtx, cancel := context.WithTimeout(ctx, 20*time.Second)
defer cancel()
t.Logf("🔧 Testing text completion streaming with parameters: %s", paramTest.name)
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto parameterTestComplete
}
if response != nil {
responseCount++
}
if responseCount > 150 {
goto parameterTestComplete
}
case <-streamCtx.Done():
t.Fatal("Timeout waiting for text completion streaming response")
}
}
parameterTestComplete:
if responseCount == 0 {
t.Fatal("Should receive at least one streaming response")
}
t.Logf("✅ Parameter test %s completed with %d chunks", paramTest.name, responseCount)
})
}
})
}
// createConsolidatedTextCompletionResponse creates a consolidated response for validation
func createConsolidatedTextCompletionResponse(finalContent string, lastResponse *schemas.BifrostStreamChunk, provider schemas.ModelProvider) *schemas.BifrostTextCompletionResponse {
consolidatedResponse := &schemas.BifrostTextCompletionResponse{
Object: "text_completion",
Choices: []schemas.BifrostResponseChoice{
{
Index: 0,
TextCompletionResponseChoice: &schemas.TextCompletionResponseChoice{
Text: &finalContent,
},
},
},
ExtraFields: schemas.BifrostResponseExtraFields{
Provider: provider,
RequestType: schemas.TextCompletionRequest,
},
}
// Copy usage and other metadata from last response if available
if lastResponse != nil && lastResponse.BifrostTextCompletionResponse != nil {
consolidatedResponse.Usage = lastResponse.BifrostTextCompletionResponse.Usage
consolidatedResponse.Model = lastResponse.BifrostTextCompletionResponse.Model
consolidatedResponse.ID = lastResponse.BifrostTextCompletionResponse.ID
// Copy finish reason from last choice if available
if len(lastResponse.BifrostTextCompletionResponse.Choices) > 0 && lastResponse.BifrostTextCompletionResponse.Choices[0].FinishReason != nil {
consolidatedResponse.Choices[0].FinishReason = lastResponse.BifrostTextCompletionResponse.Choices[0].FinishReason
}
consolidatedResponse.ExtraFields = lastResponse.BifrostTextCompletionResponse.ExtraFields
}
return consolidatedResponse
}