first commit
This commit is contained in:
300
core/internal/llmtests/image_generation.go
Normal file
300
core/internal/llmtests/image_generation.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package llmtests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
bifrost "github.com/maximhq/bifrost/core"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// RunImageGenerationTest executes the end-to-end image generation test (non-streaming)
|
||||
func RunImageGenerationTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ImageGeneration {
|
||||
t.Logf("Image generation not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if testConfig.ImageGenerationModel == "" {
|
||||
t.Logf("Image generation not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ImageGeneration", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("ImageGeneration", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ImageGeneration",
|
||||
ExpectedBehavior: map[string]interface{}{},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ImageGenerationModel,
|
||||
},
|
||||
}
|
||||
|
||||
expectations := GetExpectationsForScenario("ImageGeneration", testConfig, map[string]interface{}{
|
||||
"min_images": 1,
|
||||
"expected_size": "1024x1024",
|
||||
})
|
||||
|
||||
imageGenerationRetryConfig := ImageGenerationRetryConfig{
|
||||
MaxAttempts: retryConfig.MaxAttempts,
|
||||
BaseDelay: retryConfig.BaseDelay,
|
||||
MaxDelay: retryConfig.MaxDelay,
|
||||
Conditions: []ImageGenerationRetryCondition{},
|
||||
OnRetry: retryConfig.OnRetry,
|
||||
OnFinalFail: retryConfig.OnFinalFail,
|
||||
}
|
||||
// Test basic image generation
|
||||
imageGenerationOperation := func() (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
||||
request := &schemas.BifrostImageGenerationRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ImageGenerationModel,
|
||||
Input: &schemas.ImageGenerationInput{
|
||||
Prompt: "A serene Japanese garden with cherry blossoms in spring",
|
||||
},
|
||||
Params: &schemas.ImageGenerationParameters{
|
||||
Size: bifrost.Ptr("1024x1024"),
|
||||
N: bifrost.Ptr(1),
|
||||
},
|
||||
Fallbacks: testConfig.ImageGenerationFallbacks,
|
||||
}
|
||||
|
||||
response, err := client.ImageGenerationRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if response != nil {
|
||||
return response, nil
|
||||
}
|
||||
return nil, &schemas.BifrostError{
|
||||
IsBifrostError: true,
|
||||
Error: &schemas.ErrorField{
|
||||
Message: "No image generation response returned",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
imageGenerationResponse, imageGenerationError := WithImageGenerationRetry(t, imageGenerationRetryConfig, retryContext, expectations, "ImageGeneration", imageGenerationOperation)
|
||||
|
||||
if imageGenerationError != nil {
|
||||
t.Fatalf("❌ Image generation failed: %v", GetErrorMessage(imageGenerationError))
|
||||
}
|
||||
|
||||
// Validate response
|
||||
if imageGenerationResponse == nil {
|
||||
t.Fatal("❌ Image generation returned nil response")
|
||||
}
|
||||
|
||||
if len(imageGenerationResponse.Data) == 0 {
|
||||
t.Fatal("❌ Image generation returned no image data")
|
||||
}
|
||||
|
||||
// Validate first image
|
||||
imageData := imageGenerationResponse.Data[0]
|
||||
if imageData.B64JSON == "" && imageData.URL == "" {
|
||||
t.Fatal("❌ Image data missing both b64_json and URL")
|
||||
}
|
||||
|
||||
// Validate base64 if present
|
||||
if imageData.B64JSON != "" {
|
||||
// Decode base64 image data
|
||||
decoded, err := base64.StdEncoding.DecodeString(imageData.B64JSON)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to decode base64 image data: %v", err)
|
||||
}
|
||||
if len(decoded) == 0 {
|
||||
t.Fatalf("❌ Decoded image data is empty")
|
||||
}
|
||||
|
||||
// Decode image config to validate dimensions
|
||||
reader := bytes.NewReader(decoded)
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("❌ Failed to decode image config: %v (format: %s)", err, format)
|
||||
}
|
||||
|
||||
// Validate dimensions are 1024x1024 as requested
|
||||
expectedWidth, expectedHeight := 1024, 1024
|
||||
if config.Width != expectedWidth || config.Height != expectedHeight {
|
||||
t.Errorf("❌ Image dimensions mismatch: got %dx%d, expected %dx%d", config.Width, config.Height, expectedWidth, expectedHeight)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate usage if present
|
||||
if imageGenerationResponse.Usage != nil {
|
||||
if imageGenerationResponse.Usage.TotalTokens == 0 {
|
||||
t.Logf("⚠️ Usage total_tokens is 0 (may be provider-specific)")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate extra fields
|
||||
if imageGenerationResponse.ExtraFields.Provider == "" {
|
||||
t.Error("❌ ExtraFields.Provider is empty")
|
||||
}
|
||||
|
||||
if imageGenerationResponse.ExtraFields.OriginalModelRequested == "" {
|
||||
t.Error("❌ ExtraFields.OriginalModelRequested is empty")
|
||||
}
|
||||
|
||||
t.Logf("✅ Image generation successful: ID=%s, Provider=%s, Model=%s, Images=%d",
|
||||
imageGenerationResponse.ID, imageGenerationResponse.ExtraFields.Provider, imageGenerationResponse.ExtraFields.OriginalModelRequested, len(imageGenerationResponse.Data))
|
||||
})
|
||||
}
|
||||
|
||||
// RunImageGenerationStreamTest executes the end-to-end streaming image generation test
|
||||
func RunImageGenerationStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
||||
if !testConfig.Scenarios.ImageGenerationStream {
|
||||
t.Logf("Image generation streaming not supported for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
if testConfig.ImageGenerationModel == "" {
|
||||
t.Logf("Image generation streaming not configured for provider %s", testConfig.Provider)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("ImageGenerationStream", func(t *testing.T) {
|
||||
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
||||
t.Parallel()
|
||||
}
|
||||
|
||||
retryConfig := GetTestRetryConfigForScenario("ImageGenerationStream", testConfig)
|
||||
retryContext := TestRetryContext{
|
||||
ScenarioName: "ImageGenerationStream",
|
||||
ExpectedBehavior: map[string]interface{}{
|
||||
"should_generate_images": true,
|
||||
},
|
||||
TestMetadata: map[string]interface{}{
|
||||
"provider": testConfig.Provider,
|
||||
"model": testConfig.ImageGenerationModel,
|
||||
},
|
||||
}
|
||||
|
||||
request := &schemas.BifrostImageGenerationRequest{
|
||||
Provider: testConfig.Provider,
|
||||
Model: testConfig.ImageGenerationModel,
|
||||
Input: &schemas.ImageGenerationInput{
|
||||
Prompt: "A futuristic cityscape at sunset with flying cars",
|
||||
},
|
||||
Params: &schemas.ImageGenerationParameters{
|
||||
Size: bifrost.Ptr("1024x1024"),
|
||||
Quality: bifrost.Ptr("low"),
|
||||
},
|
||||
Fallbacks: testConfig.ImageGenerationFallbacks,
|
||||
}
|
||||
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
validationResult := WithImageGenerationStreamRetry(
|
||||
t,
|
||||
retryConfig,
|
||||
retryContext,
|
||||
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
||||
return client.ImageGenerationStreamRequest(schemas.NewBifrostContext(streamCtx, schemas.NoDeadline), request)
|
||||
},
|
||||
func(responseChannel chan *schemas.BifrostStreamChunk) ImageGenerationStreamValidationResult {
|
||||
// Validate stream content
|
||||
var receivedData bool
|
||||
var streamErrors []string
|
||||
var validationErrors []string
|
||||
hasCompleted := false
|
||||
|
||||
for {
|
||||
select {
|
||||
case response, ok := <-responseChannel:
|
||||
if !ok {
|
||||
goto streamComplete
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
streamErrors = append(streamErrors, "Received nil stream response")
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostError != nil {
|
||||
streamErrors = append(streamErrors, fmt.Sprintf("Error in stream: %s", GetErrorMessage(response.BifrostError)))
|
||||
continue
|
||||
}
|
||||
|
||||
if response.BifrostImageGenerationStreamResponse != nil {
|
||||
receivedData = true
|
||||
imgResp := response.BifrostImageGenerationStreamResponse
|
||||
|
||||
if imgResp.Type == schemas.ImageGenerationEventTypeCompleted {
|
||||
hasCompleted = true
|
||||
// Validate that completed images have actual data
|
||||
if imgResp.URL == "" && imgResp.B64JSON == "" {
|
||||
validationErrors = append(validationErrors, "Completion chunk received but image has no URL or B64JSON data")
|
||||
}
|
||||
}
|
||||
}
|
||||
case <-streamCtx.Done():
|
||||
validationErrors = append(validationErrors, "Stream validation timed out")
|
||||
drainCtx, drainCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
go func() {
|
||||
defer drainCancel()
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-responseChannel:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case <-drainCtx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
goto streamComplete
|
||||
}
|
||||
}
|
||||
streamComplete:
|
||||
|
||||
// Stream errors should cause the test to fail - convert them to validation errors
|
||||
if len(streamErrors) > 0 {
|
||||
validationErrors = append(validationErrors, fmt.Sprintf("Stream errors encountered: %s", strings.Join(streamErrors, "; ")))
|
||||
}
|
||||
|
||||
// Test passes only if: data received, completion received, and no errors (including stream errors)
|
||||
passed := receivedData && hasCompleted && len(validationErrors) == 0
|
||||
if !receivedData {
|
||||
validationErrors = append(validationErrors, "No stream data received")
|
||||
}
|
||||
if !hasCompleted {
|
||||
validationErrors = append(validationErrors, "No completion chunk received")
|
||||
}
|
||||
|
||||
return ImageGenerationStreamValidationResult{
|
||||
Passed: passed,
|
||||
Errors: validationErrors,
|
||||
ReceivedData: receivedData,
|
||||
StreamErrors: streamErrors,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if !validationResult.Passed {
|
||||
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
|
||||
t.Fatalf("❌ Image generation stream validation failed: %s", strings.Join(allErrors, "; "))
|
||||
}
|
||||
|
||||
if !validationResult.ReceivedData {
|
||||
t.Fatal("❌ No stream data received")
|
||||
}
|
||||
|
||||
t.Logf("✅ Image generation stream successful: ReceivedData=%v, Errors=%d, StreamErrors=%d",
|
||||
validationResult.ReceivedData, len(validationResult.Errors), len(validationResult.StreamErrors))
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user