558 lines
18 KiB
Go
558 lines
18 KiB
Go
package llmtests
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"image"
|
|
"image/color"
|
|
"image/jpeg"
|
|
"image/png"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
bifrost "github.com/maximhq/bifrost/core"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// createMaskImageForAzureOpenAI creates a PNG mask image with transparent background for Azure and OpenAI
|
|
// Creates a white rectangle in the center on transparent background (typical inpainting mask pattern)
|
|
// PNG format with alpha channel is required by Azure and OpenAI
|
|
func createMaskImageForAzureOpenAI(width, height int) ([]byte, error) {
|
|
// Create an RGBA image with alpha channel support
|
|
img := image.NewRGBA(image.Rect(0, 0, width, height))
|
|
|
|
// Create a white rectangle in the center (typical mask pattern for inpainting)
|
|
// White areas with full alpha indicate regions to edit
|
|
// Transparent areas indicate regions to preserve
|
|
centerX, centerY := width/2, height/2
|
|
maskWidth, maskHeight := width/3, height/3
|
|
|
|
for y := 0; y < height; y++ {
|
|
for x := 0; x < width; x++ {
|
|
// Check if pixel is within the mask rectangle
|
|
if x >= centerX-maskWidth/2 && x < centerX+maskWidth/2 &&
|
|
y >= centerY-maskHeight/2 && y < centerY+maskHeight/2 {
|
|
// White with full alpha = edit area
|
|
img.Set(x, y, color.RGBA{R: 255, G: 255, B: 255, A: 255})
|
|
} else {
|
|
// Transparent (alpha=0) = preserve area
|
|
img.Set(x, y, color.RGBA{R: 0, G: 0, B: 0, A: 0})
|
|
}
|
|
}
|
|
}
|
|
|
|
// Encode as PNG to preserve alpha channel (required by Azure and OpenAI)
|
|
var buf bytes.Buffer
|
|
if err := png.Encode(&buf, img); err != nil {
|
|
return nil, fmt.Errorf("failed to encode mask image: %w", err)
|
|
}
|
|
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
// createSimpleMaskImage creates a simple JPEG mask image for testing (no transparency)
|
|
// Creates a white rectangle in the center on black background (typical inpainting mask pattern)
|
|
// JPEG format doesn't support transparency, so this works with providers that don't require alpha channel
|
|
func createSimpleMaskImage(width, height int) ([]byte, error) {
|
|
// Create an RGB image (no alpha channel)
|
|
img := image.NewRGBA(image.Rect(0, 0, width, height))
|
|
|
|
// Create a white rectangle in the center (typical mask pattern for inpainting)
|
|
// White areas indicate regions to edit, black areas are preserved
|
|
centerX, centerY := width/2, height/2
|
|
maskWidth, maskHeight := width/3, height/3
|
|
|
|
for y := 0; y < height; y++ {
|
|
for x := 0; x < width; x++ {
|
|
// Check if pixel is within the mask rectangle
|
|
if x >= centerX-maskWidth/2 && x < centerX+maskWidth/2 &&
|
|
y >= centerY-maskHeight/2 && y < centerY+maskHeight/2 {
|
|
img.Set(x, y, color.RGBA{R: 255, G: 255, B: 255, A: 255}) // White (edit area)
|
|
} else {
|
|
img.Set(x, y, color.RGBA{R: 0, G: 0, B: 0, A: 255}) // Black (preserve area)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Encode as JPEG (no transparency support, so it works with all providers)
|
|
var buf bytes.Buffer
|
|
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 95}); err != nil {
|
|
return nil, fmt.Errorf("failed to encode mask image: %w", err)
|
|
}
|
|
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
// createMaskImage creates a mask image based on the provider requirements
|
|
// Azure and OpenAI require PNG with transparent background (alpha channel)
|
|
// Other providers use JPEG with opaque background
|
|
func createMaskImage(provider schemas.ModelProvider, width, height int) ([]byte, error) {
|
|
if provider == schemas.Azure || provider == schemas.OpenAI {
|
|
return createMaskImageForAzureOpenAI(width, height)
|
|
}
|
|
return createSimpleMaskImage(width, height)
|
|
}
|
|
|
|
// convertImageToPNG converts any image format to PNG (supports transparency)
|
|
// This ensures compatibility with providers that require PNG format
|
|
// Returns the converted image bytes and its dimensions
|
|
func convertImageToPNG(imageBytes []byte) ([]byte, int, int, error) {
|
|
// Decode the image (supports PNG, JPEG, GIF, etc.)
|
|
img, format, err := image.Decode(bytes.NewReader(imageBytes))
|
|
if err != nil {
|
|
return nil, 0, 0, fmt.Errorf("failed to decode image: %w", err)
|
|
}
|
|
|
|
bounds := img.Bounds()
|
|
width := bounds.Dx()
|
|
height := bounds.Dy()
|
|
|
|
// If it's already PNG, return as-is
|
|
if format == "png" {
|
|
return imageBytes, width, height, nil
|
|
}
|
|
|
|
// Convert to RGBA to preserve color information
|
|
rgbaImg := image.NewRGBA(bounds)
|
|
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
|
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
|
rgbaImg.Set(x, y, img.At(x, y))
|
|
}
|
|
}
|
|
|
|
// Encode as PNG (supports transparency)
|
|
var buf bytes.Buffer
|
|
if err := png.Encode(&buf, rgbaImg); err != nil {
|
|
return nil, 0, 0, fmt.Errorf("failed to encode image as PNG: %w", err)
|
|
}
|
|
|
|
return buf.Bytes(), width, height, nil
|
|
}
|
|
|
|
// convertImageToJPEG converts any image format to JPEG (no transparency)
|
|
// This ensures compatibility with providers that don't support transparency
|
|
// Returns the converted image bytes and its dimensions
|
|
func convertImageToJPEG(imageBytes []byte) ([]byte, int, int, error) {
|
|
// Decode the image (supports PNG, JPEG, GIF, etc.)
|
|
img, format, err := image.Decode(bytes.NewReader(imageBytes))
|
|
if err != nil {
|
|
return nil, 0, 0, fmt.Errorf("failed to decode image: %w", err)
|
|
}
|
|
|
|
bounds := img.Bounds()
|
|
width := bounds.Dx()
|
|
height := bounds.Dy()
|
|
|
|
// If it's already JPEG, return as-is
|
|
if format == "jpeg" || format == "jpg" {
|
|
return imageBytes, width, height, nil
|
|
}
|
|
|
|
// Convert to RGBA to ensure no transparency
|
|
rgbaImg := image.NewRGBA(bounds)
|
|
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
|
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
|
rgbaImg.Set(x, y, img.At(x, y))
|
|
}
|
|
}
|
|
|
|
// Encode as JPEG (no transparency support)
|
|
var buf bytes.Buffer
|
|
if err := jpeg.Encode(&buf, rgbaImg, &jpeg.Options{Quality: 95}); err != nil {
|
|
return nil, 0, 0, fmt.Errorf("failed to encode image as JPEG: %w", err)
|
|
}
|
|
|
|
return buf.Bytes(), width, height, nil
|
|
}
|
|
|
|
// convertImageForProvider converts an image to the appropriate format based on provider requirements
|
|
// OpenAI requires PNG format, other providers use JPEG
|
|
// Returns the converted image bytes and its dimensions
|
|
func convertImageForProvider(provider schemas.ModelProvider, imageBytes []byte) ([]byte, int, int, error) {
|
|
if provider == schemas.OpenAI {
|
|
return convertImageToPNG(imageBytes)
|
|
}
|
|
return convertImageToJPEG(imageBytes)
|
|
}
|
|
|
|
// decodeBase64ImageToBytes converts a base64 data URL string to []byte
|
|
// Handles both "data:image/png;base64,<data>" and plain base64 strings
|
|
func decodeBase64ImageToBytes(base64Str string) ([]byte, error) {
|
|
// Remove data URL prefix if present
|
|
if strings.HasPrefix(base64Str, "data:") {
|
|
parts := strings.SplitN(base64Str, ",", 2)
|
|
if len(parts) != 2 {
|
|
return nil, fmt.Errorf("invalid data URL format")
|
|
}
|
|
base64Str = parts[1]
|
|
}
|
|
|
|
// Decode base64 string
|
|
decoded, err := base64.StdEncoding.DecodeString(base64Str)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decode base64: %w", err)
|
|
}
|
|
|
|
if len(decoded) == 0 {
|
|
return nil, fmt.Errorf("decoded image data is empty")
|
|
}
|
|
|
|
return decoded, nil
|
|
}
|
|
|
|
// RunImageEditTest executes the end-to-end image edit test (non-streaming)
|
|
func RunImageEditTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
|
if testConfig.ImageEditModel == "" {
|
|
t.Logf("Image edit not configured for provider %s", testConfig.Provider)
|
|
return
|
|
}
|
|
|
|
if !testConfig.Scenarios.ImageEdit {
|
|
t.Logf("Image edit not supported for provider %s", testConfig.Provider)
|
|
return
|
|
}
|
|
|
|
t.Run("ImageEdit", func(t *testing.T) {
|
|
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
|
t.Parallel()
|
|
}
|
|
|
|
retryConfig := GetTestRetryConfigForScenario("ImageEdit", testConfig)
|
|
retryContext := TestRetryContext{
|
|
ScenarioName: "ImageEdit",
|
|
ExpectedBehavior: map[string]interface{}{},
|
|
TestMetadata: map[string]interface{}{
|
|
"provider": testConfig.Provider,
|
|
"model": testConfig.ImageEditModel,
|
|
},
|
|
}
|
|
|
|
expectations := GetExpectationsForScenario("ImageEdit", testConfig, map[string]interface{}{
|
|
"min_images": 1,
|
|
"expected_size": "1024x1024",
|
|
})
|
|
|
|
imageEditRetryConfig := ImageGenerationRetryConfig{
|
|
MaxAttempts: retryConfig.MaxAttempts,
|
|
BaseDelay: retryConfig.BaseDelay,
|
|
MaxDelay: retryConfig.MaxDelay,
|
|
Conditions: []ImageGenerationRetryCondition{},
|
|
OnRetry: retryConfig.OnRetry,
|
|
OnFinalFail: retryConfig.OnFinalFail,
|
|
}
|
|
|
|
// Load test image
|
|
lionBase64, err := GetLionBase64Image()
|
|
if err != nil {
|
|
t.Fatalf("Failed to load test image: %v", err)
|
|
}
|
|
|
|
// Convert base64 to bytes
|
|
imageBytes, err := decodeBase64ImageToBytes(lionBase64)
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode image: %v", err)
|
|
}
|
|
|
|
// Convert input image to JPEG (no transparency) to avoid provider compatibility issues
|
|
imageBytes, imgWidth, imgHeight, err := convertImageToJPEG(imageBytes)
|
|
if err != nil {
|
|
t.Fatalf("Failed to convert image to JPEG: %v", err)
|
|
}
|
|
|
|
// Create mask image based on provider requirements
|
|
// Azure and OpenAI require PNG with transparent background, others use JPEG
|
|
maskBytes, err := createMaskImage(testConfig.Provider, imgWidth, imgHeight)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create mask image: %v", err)
|
|
}
|
|
|
|
// Test basic image edit (inpainting)
|
|
imageEditOperation := func() (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
request := &schemas.BifrostImageEditRequest{
|
|
Provider: testConfig.Provider,
|
|
Model: testConfig.ImageEditModel,
|
|
Input: &schemas.ImageEditInput{
|
|
Images: []schemas.ImageInput{
|
|
{
|
|
Image: imageBytes,
|
|
},
|
|
},
|
|
Prompt: "Add a beautiful sunset in the background",
|
|
},
|
|
Params: &schemas.ImageEditParameters{
|
|
Size: bifrost.Ptr("1024x1024"),
|
|
N: bifrost.Ptr(1),
|
|
Type: bifrost.Ptr("inpainting"),
|
|
Mask: maskBytes,
|
|
},
|
|
Fallbacks: testConfig.ImageEditFallbacks,
|
|
}
|
|
|
|
response, err := client.ImageEditRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if response != nil {
|
|
return response, nil
|
|
}
|
|
return nil, &schemas.BifrostError{
|
|
IsBifrostError: true,
|
|
Error: &schemas.ErrorField{
|
|
Message: "No image edit response returned",
|
|
},
|
|
}
|
|
}
|
|
|
|
imageEditResponse, imageEditError := WithImageGenerationRetry(t, imageEditRetryConfig, retryContext, expectations, "ImageEdit", imageEditOperation)
|
|
|
|
if imageEditError != nil {
|
|
t.Fatalf("❌ Image edit failed: %v", GetErrorMessage(imageEditError))
|
|
}
|
|
|
|
// Validate response
|
|
if imageEditResponse == nil {
|
|
t.Fatal("❌ Image edit returned nil response")
|
|
}
|
|
|
|
if len(imageEditResponse.Data) == 0 {
|
|
t.Fatal("❌ Image edit returned no image data")
|
|
}
|
|
|
|
// Validate first image
|
|
imageData := imageEditResponse.Data[0]
|
|
if imageData.B64JSON == "" && imageData.URL == "" {
|
|
t.Fatal("❌ Image data missing both b64_json and URL")
|
|
}
|
|
|
|
// Validate base64 if present
|
|
if imageData.B64JSON != "" {
|
|
// Decode base64 image data
|
|
decoded, err := base64.StdEncoding.DecodeString(imageData.B64JSON)
|
|
if err != nil {
|
|
t.Fatalf("❌ Failed to decode base64 image data: %v", err)
|
|
}
|
|
if len(decoded) == 0 {
|
|
t.Fatalf("❌ Decoded image data is empty")
|
|
}
|
|
|
|
// Decode image config to validate dimensions
|
|
reader := bytes.NewReader(decoded)
|
|
config, format, err := image.DecodeConfig(reader)
|
|
if err != nil {
|
|
t.Fatalf("❌ Failed to decode image config: %v (format: %s)", err, format)
|
|
}
|
|
|
|
// Validate dimensions are reasonable (at least 256x256)
|
|
if config.Width < 256 || config.Height < 256 {
|
|
t.Errorf("❌ Image dimensions too small: got %dx%d, expected at least 256x256", config.Width, config.Height)
|
|
}
|
|
}
|
|
|
|
// Validate usage if present
|
|
if imageEditResponse.Usage != nil {
|
|
if imageEditResponse.Usage.TotalTokens == 0 {
|
|
t.Logf("⚠️ Usage total_tokens is 0 (may be provider-specific)")
|
|
}
|
|
}
|
|
|
|
// Validate extra fields
|
|
if imageEditResponse.ExtraFields.Provider == "" {
|
|
t.Error("❌ ExtraFields.Provider is empty")
|
|
}
|
|
|
|
if imageEditResponse.ExtraFields.OriginalModelRequested == "" {
|
|
t.Error("❌ ExtraFields.OriginalModelRequested is empty")
|
|
}
|
|
|
|
// Validate RequestType is ImageEditRequest
|
|
if imageEditResponse.ExtraFields.RequestType != schemas.ImageEditRequest {
|
|
t.Errorf("❌ ExtraFields.RequestType mismatch: got %s, expected %s", imageEditResponse.ExtraFields.RequestType, schemas.ImageEditRequest)
|
|
}
|
|
|
|
t.Logf("✅ Image edit successful: ID=%s, Provider=%s, Model=%s, Images=%d",
|
|
imageEditResponse.ID, imageEditResponse.ExtraFields.Provider, imageEditResponse.ExtraFields.OriginalModelRequested, len(imageEditResponse.Data))
|
|
})
|
|
}
|
|
|
|
// RunImageEditStreamTest executes the end-to-end streaming image edit test
|
|
func RunImageEditStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
|
|
if !testConfig.Scenarios.ImageEditStream {
|
|
t.Logf("Image edit streaming not supported for provider %s", testConfig.Provider)
|
|
return
|
|
}
|
|
|
|
if testConfig.ImageEditModel == "" {
|
|
t.Logf("Image edit streaming not configured for provider %s", testConfig.Provider)
|
|
return
|
|
}
|
|
|
|
t.Run("ImageEditStream", func(t *testing.T) {
|
|
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
|
|
t.Parallel()
|
|
}
|
|
|
|
retryConfig := GetTestRetryConfigForScenario("ImageEditStream", testConfig)
|
|
retryContext := TestRetryContext{
|
|
ScenarioName: "ImageEditStream",
|
|
ExpectedBehavior: map[string]interface{}{
|
|
"should_generate_images": true,
|
|
},
|
|
TestMetadata: map[string]interface{}{
|
|
"provider": testConfig.Provider,
|
|
"model": testConfig.ImageEditModel,
|
|
},
|
|
}
|
|
|
|
// Load test image
|
|
lionBase64, err := GetLionBase64Image()
|
|
if err != nil {
|
|
t.Fatalf("Failed to load test image: %v", err)
|
|
}
|
|
|
|
// Convert base64 to bytes
|
|
imageBytes, err := decodeBase64ImageToBytes(lionBase64)
|
|
if err != nil {
|
|
t.Fatalf("Failed to decode image: %v", err)
|
|
}
|
|
|
|
// Convert input image to JPEG (no transparency) to avoid provider compatibility issues
|
|
imageBytes, imgWidth, imgHeight, err := convertImageToJPEG(imageBytes)
|
|
if err != nil {
|
|
t.Fatalf("Failed to convert image to JPEG: %v", err)
|
|
}
|
|
|
|
// Create mask image based on provider requirements
|
|
// Azure and OpenAI require PNG with transparent background, others use JPEG
|
|
maskBytes, err := createMaskImage(testConfig.Provider, imgWidth, imgHeight)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create mask image: %v", err)
|
|
}
|
|
|
|
request := &schemas.BifrostImageEditRequest{
|
|
Provider: testConfig.Provider,
|
|
Model: testConfig.ImageEditModel,
|
|
Input: &schemas.ImageEditInput{
|
|
Images: []schemas.ImageInput{
|
|
{
|
|
Image: imageBytes,
|
|
},
|
|
},
|
|
Prompt: "Add a futuristic cityscape in the background",
|
|
},
|
|
Params: &schemas.ImageEditParameters{
|
|
Size: bifrost.Ptr("1024x1024"),
|
|
Quality: bifrost.Ptr("low"),
|
|
Type: bifrost.Ptr("inpainting"),
|
|
Mask: maskBytes,
|
|
},
|
|
Fallbacks: testConfig.ImageEditFallbacks,
|
|
}
|
|
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
|
|
defer cancel()
|
|
|
|
validationResult := WithImageGenerationStreamRetry(
|
|
t,
|
|
retryConfig,
|
|
retryContext,
|
|
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
return client.ImageEditStreamRequest(schemas.NewBifrostContext(streamCtx, schemas.NoDeadline), request)
|
|
},
|
|
func(responseChannel chan *schemas.BifrostStreamChunk) ImageGenerationStreamValidationResult {
|
|
// Validate stream content
|
|
var receivedData bool
|
|
var streamErrors []string
|
|
var validationErrors []string
|
|
hasCompleted := false
|
|
|
|
for {
|
|
select {
|
|
case response, ok := <-responseChannel:
|
|
if !ok {
|
|
goto streamComplete
|
|
}
|
|
|
|
if response == nil {
|
|
streamErrors = append(streamErrors, "Received nil stream response")
|
|
continue
|
|
}
|
|
|
|
if response.BifrostError != nil {
|
|
streamErrors = append(streamErrors, fmt.Sprintf("Error in stream: %s", GetErrorMessage(response.BifrostError)))
|
|
continue
|
|
}
|
|
|
|
if response.BifrostImageGenerationStreamResponse != nil {
|
|
receivedData = true
|
|
imgResp := response.BifrostImageGenerationStreamResponse
|
|
|
|
// Check for completion event (can be ImageGenerationEventTypeCompleted or ImageEditEventTypeCompleted)
|
|
if imgResp.Type == schemas.ImageGenerationEventTypeCompleted || imgResp.Type == schemas.ImageEditEventTypeCompleted {
|
|
hasCompleted = true
|
|
// Validate that completed images have actual data
|
|
if imgResp.URL == "" && imgResp.B64JSON == "" {
|
|
validationErrors = append(validationErrors, "Completion chunk received but image has no URL or B64JSON data")
|
|
}
|
|
}
|
|
}
|
|
case <-streamCtx.Done():
|
|
validationErrors = append(validationErrors, "Stream validation timed out")
|
|
drainCtx, drainCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
go func() {
|
|
defer drainCancel()
|
|
for {
|
|
select {
|
|
case _, ok := <-responseChannel:
|
|
if !ok {
|
|
return
|
|
}
|
|
case <-drainCtx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
goto streamComplete
|
|
}
|
|
}
|
|
streamComplete:
|
|
|
|
// Stream errors should cause the test to fail - convert them to validation errors
|
|
if len(streamErrors) > 0 {
|
|
validationErrors = append(validationErrors, fmt.Sprintf("Stream errors encountered: %s", strings.Join(streamErrors, "; ")))
|
|
}
|
|
|
|
// Test passes only if: data received, completion received, and no errors (including stream errors)
|
|
passed := receivedData && hasCompleted && len(validationErrors) == 0
|
|
if !receivedData {
|
|
validationErrors = append(validationErrors, "No stream data received")
|
|
}
|
|
if !hasCompleted {
|
|
validationErrors = append(validationErrors, "No completion chunk received")
|
|
}
|
|
|
|
return ImageGenerationStreamValidationResult{
|
|
Passed: passed,
|
|
Errors: validationErrors,
|
|
ReceivedData: receivedData,
|
|
StreamErrors: streamErrors,
|
|
}
|
|
},
|
|
)
|
|
|
|
if !validationResult.Passed {
|
|
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
|
|
t.Fatalf("❌ Image edit stream validation failed: %s", strings.Join(allErrors, "; "))
|
|
}
|
|
|
|
if !validationResult.ReceivedData {
|
|
t.Fatal("❌ No stream data received")
|
|
}
|
|
|
|
t.Logf("✅ Image edit stream successful: ReceivedData=%v, Errors=%d, StreamErrors=%d",
|
|
validationResult.ReceivedData, len(validationResult.Errors), len(validationResult.StreamErrors))
|
|
})
|
|
}
|