Files
bifrost/core/internal/llmtests/image_edit.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

558 lines
18 KiB
Go

package llmtests
import (
"bytes"
"context"
"encoding/base64"
"fmt"
"image"
"image/color"
"image/jpeg"
"image/png"
"os"
"strings"
"testing"
"time"
bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
)
// createMaskImageForAzureOpenAI creates a PNG mask image with transparent background for Azure and OpenAI
// Creates a white rectangle in the center on transparent background (typical inpainting mask pattern)
// PNG format with alpha channel is required by Azure and OpenAI
func createMaskImageForAzureOpenAI(width, height int) ([]byte, error) {
// Create an RGBA image with alpha channel support
img := image.NewRGBA(image.Rect(0, 0, width, height))
// Create a white rectangle in the center (typical mask pattern for inpainting)
// White areas with full alpha indicate regions to edit
// Transparent areas indicate regions to preserve
centerX, centerY := width/2, height/2
maskWidth, maskHeight := width/3, height/3
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
// Check if pixel is within the mask rectangle
if x >= centerX-maskWidth/2 && x < centerX+maskWidth/2 &&
y >= centerY-maskHeight/2 && y < centerY+maskHeight/2 {
// White with full alpha = edit area
img.Set(x, y, color.RGBA{R: 255, G: 255, B: 255, A: 255})
} else {
// Transparent (alpha=0) = preserve area
img.Set(x, y, color.RGBA{R: 0, G: 0, B: 0, A: 0})
}
}
}
// Encode as PNG to preserve alpha channel (required by Azure and OpenAI)
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return nil, fmt.Errorf("failed to encode mask image: %w", err)
}
return buf.Bytes(), nil
}
// createSimpleMaskImage creates a simple JPEG mask image for testing (no transparency)
// Creates a white rectangle in the center on black background (typical inpainting mask pattern)
// JPEG format doesn't support transparency, so this works with providers that don't require alpha channel
func createSimpleMaskImage(width, height int) ([]byte, error) {
// Create an RGB image (no alpha channel)
img := image.NewRGBA(image.Rect(0, 0, width, height))
// Create a white rectangle in the center (typical mask pattern for inpainting)
// White areas indicate regions to edit, black areas are preserved
centerX, centerY := width/2, height/2
maskWidth, maskHeight := width/3, height/3
for y := 0; y < height; y++ {
for x := 0; x < width; x++ {
// Check if pixel is within the mask rectangle
if x >= centerX-maskWidth/2 && x < centerX+maskWidth/2 &&
y >= centerY-maskHeight/2 && y < centerY+maskHeight/2 {
img.Set(x, y, color.RGBA{R: 255, G: 255, B: 255, A: 255}) // White (edit area)
} else {
img.Set(x, y, color.RGBA{R: 0, G: 0, B: 0, A: 255}) // Black (preserve area)
}
}
}
// Encode as JPEG (no transparency support, so it works with all providers)
var buf bytes.Buffer
if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: 95}); err != nil {
return nil, fmt.Errorf("failed to encode mask image: %w", err)
}
return buf.Bytes(), nil
}
// createMaskImage creates a mask image based on the provider requirements
// Azure and OpenAI require PNG with transparent background (alpha channel)
// Other providers use JPEG with opaque background
func createMaskImage(provider schemas.ModelProvider, width, height int) ([]byte, error) {
if provider == schemas.Azure || provider == schemas.OpenAI {
return createMaskImageForAzureOpenAI(width, height)
}
return createSimpleMaskImage(width, height)
}
// convertImageToPNG converts any image format to PNG (supports transparency)
// This ensures compatibility with providers that require PNG format
// Returns the converted image bytes and its dimensions
func convertImageToPNG(imageBytes []byte) ([]byte, int, int, error) {
// Decode the image (supports PNG, JPEG, GIF, etc.)
img, format, err := image.Decode(bytes.NewReader(imageBytes))
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to decode image: %w", err)
}
bounds := img.Bounds()
width := bounds.Dx()
height := bounds.Dy()
// If it's already PNG, return as-is
if format == "png" {
return imageBytes, width, height, nil
}
// Convert to RGBA to preserve color information
rgbaImg := image.NewRGBA(bounds)
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
rgbaImg.Set(x, y, img.At(x, y))
}
}
// Encode as PNG (supports transparency)
var buf bytes.Buffer
if err := png.Encode(&buf, rgbaImg); err != nil {
return nil, 0, 0, fmt.Errorf("failed to encode image as PNG: %w", err)
}
return buf.Bytes(), width, height, nil
}
// convertImageToJPEG converts any image format to JPEG (no transparency)
// This ensures compatibility with providers that don't support transparency
// Returns the converted image bytes and its dimensions
func convertImageToJPEG(imageBytes []byte) ([]byte, int, int, error) {
// Decode the image (supports PNG, JPEG, GIF, etc.)
img, format, err := image.Decode(bytes.NewReader(imageBytes))
if err != nil {
return nil, 0, 0, fmt.Errorf("failed to decode image: %w", err)
}
bounds := img.Bounds()
width := bounds.Dx()
height := bounds.Dy()
// If it's already JPEG, return as-is
if format == "jpeg" || format == "jpg" {
return imageBytes, width, height, nil
}
// Convert to RGBA to ensure no transparency
rgbaImg := image.NewRGBA(bounds)
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
rgbaImg.Set(x, y, img.At(x, y))
}
}
// Encode as JPEG (no transparency support)
var buf bytes.Buffer
if err := jpeg.Encode(&buf, rgbaImg, &jpeg.Options{Quality: 95}); err != nil {
return nil, 0, 0, fmt.Errorf("failed to encode image as JPEG: %w", err)
}
return buf.Bytes(), width, height, nil
}
// convertImageForProvider converts an image to the appropriate format based on provider requirements
// OpenAI requires PNG format, other providers use JPEG
// Returns the converted image bytes and its dimensions
func convertImageForProvider(provider schemas.ModelProvider, imageBytes []byte) ([]byte, int, int, error) {
if provider == schemas.OpenAI {
return convertImageToPNG(imageBytes)
}
return convertImageToJPEG(imageBytes)
}
// decodeBase64ImageToBytes converts a base64 data URL string to []byte
// Handles both "data:image/png;base64,<data>" and plain base64 strings
func decodeBase64ImageToBytes(base64Str string) ([]byte, error) {
// Remove data URL prefix if present
if strings.HasPrefix(base64Str, "data:") {
parts := strings.SplitN(base64Str, ",", 2)
if len(parts) != 2 {
return nil, fmt.Errorf("invalid data URL format")
}
base64Str = parts[1]
}
// Decode base64 string
decoded, err := base64.StdEncoding.DecodeString(base64Str)
if err != nil {
return nil, fmt.Errorf("failed to decode base64: %w", err)
}
if len(decoded) == 0 {
return nil, fmt.Errorf("decoded image data is empty")
}
return decoded, nil
}
// RunImageEditTest executes the end-to-end image edit test (non-streaming)
func RunImageEditTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if testConfig.ImageEditModel == "" {
t.Logf("Image edit not configured for provider %s", testConfig.Provider)
return
}
if !testConfig.Scenarios.ImageEdit {
t.Logf("Image edit not supported for provider %s", testConfig.Provider)
return
}
t.Run("ImageEdit", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
retryConfig := GetTestRetryConfigForScenario("ImageEdit", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ImageEdit",
ExpectedBehavior: map[string]interface{}{},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ImageEditModel,
},
}
expectations := GetExpectationsForScenario("ImageEdit", testConfig, map[string]interface{}{
"min_images": 1,
"expected_size": "1024x1024",
})
imageEditRetryConfig := ImageGenerationRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ImageGenerationRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}
// Load test image
lionBase64, err := GetLionBase64Image()
if err != nil {
t.Fatalf("Failed to load test image: %v", err)
}
// Convert base64 to bytes
imageBytes, err := decodeBase64ImageToBytes(lionBase64)
if err != nil {
t.Fatalf("Failed to decode image: %v", err)
}
// Convert input image to JPEG (no transparency) to avoid provider compatibility issues
imageBytes, imgWidth, imgHeight, err := convertImageToJPEG(imageBytes)
if err != nil {
t.Fatalf("Failed to convert image to JPEG: %v", err)
}
// Create mask image based on provider requirements
// Azure and OpenAI require PNG with transparent background, others use JPEG
maskBytes, err := createMaskImage(testConfig.Provider, imgWidth, imgHeight)
if err != nil {
t.Fatalf("Failed to create mask image: %v", err)
}
// Test basic image edit (inpainting)
imageEditOperation := func() (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
request := &schemas.BifrostImageEditRequest{
Provider: testConfig.Provider,
Model: testConfig.ImageEditModel,
Input: &schemas.ImageEditInput{
Images: []schemas.ImageInput{
{
Image: imageBytes,
},
},
Prompt: "Add a beautiful sunset in the background",
},
Params: &schemas.ImageEditParameters{
Size: bifrost.Ptr("1024x1024"),
N: bifrost.Ptr(1),
Type: bifrost.Ptr("inpainting"),
Mask: maskBytes,
},
Fallbacks: testConfig.ImageEditFallbacks,
}
response, err := client.ImageEditRequest(schemas.NewBifrostContext(ctx, schemas.NoDeadline), request)
if err != nil {
return nil, err
}
if response != nil {
return response, nil
}
return nil, &schemas.BifrostError{
IsBifrostError: true,
Error: &schemas.ErrorField{
Message: "No image edit response returned",
},
}
}
imageEditResponse, imageEditError := WithImageGenerationRetry(t, imageEditRetryConfig, retryContext, expectations, "ImageEdit", imageEditOperation)
if imageEditError != nil {
t.Fatalf("❌ Image edit failed: %v", GetErrorMessage(imageEditError))
}
// Validate response
if imageEditResponse == nil {
t.Fatal("❌ Image edit returned nil response")
}
if len(imageEditResponse.Data) == 0 {
t.Fatal("❌ Image edit returned no image data")
}
// Validate first image
imageData := imageEditResponse.Data[0]
if imageData.B64JSON == "" && imageData.URL == "" {
t.Fatal("❌ Image data missing both b64_json and URL")
}
// Validate base64 if present
if imageData.B64JSON != "" {
// Decode base64 image data
decoded, err := base64.StdEncoding.DecodeString(imageData.B64JSON)
if err != nil {
t.Fatalf("❌ Failed to decode base64 image data: %v", err)
}
if len(decoded) == 0 {
t.Fatalf("❌ Decoded image data is empty")
}
// Decode image config to validate dimensions
reader := bytes.NewReader(decoded)
config, format, err := image.DecodeConfig(reader)
if err != nil {
t.Fatalf("❌ Failed to decode image config: %v (format: %s)", err, format)
}
// Validate dimensions are reasonable (at least 256x256)
if config.Width < 256 || config.Height < 256 {
t.Errorf("❌ Image dimensions too small: got %dx%d, expected at least 256x256", config.Width, config.Height)
}
}
// Validate usage if present
if imageEditResponse.Usage != nil {
if imageEditResponse.Usage.TotalTokens == 0 {
t.Logf("⚠️ Usage total_tokens is 0 (may be provider-specific)")
}
}
// Validate extra fields
if imageEditResponse.ExtraFields.Provider == "" {
t.Error("❌ ExtraFields.Provider is empty")
}
if imageEditResponse.ExtraFields.OriginalModelRequested == "" {
t.Error("❌ ExtraFields.OriginalModelRequested is empty")
}
// Validate RequestType is ImageEditRequest
if imageEditResponse.ExtraFields.RequestType != schemas.ImageEditRequest {
t.Errorf("❌ ExtraFields.RequestType mismatch: got %s, expected %s", imageEditResponse.ExtraFields.RequestType, schemas.ImageEditRequest)
}
t.Logf("✅ Image edit successful: ID=%s, Provider=%s, Model=%s, Images=%d",
imageEditResponse.ID, imageEditResponse.ExtraFields.Provider, imageEditResponse.ExtraFields.OriginalModelRequested, len(imageEditResponse.Data))
})
}
// RunImageEditStreamTest executes the end-to-end streaming image edit test
func RunImageEditStreamTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, testConfig ComprehensiveTestConfig) {
if !testConfig.Scenarios.ImageEditStream {
t.Logf("Image edit streaming not supported for provider %s", testConfig.Provider)
return
}
if testConfig.ImageEditModel == "" {
t.Logf("Image edit streaming not configured for provider %s", testConfig.Provider)
return
}
t.Run("ImageEditStream", func(t *testing.T) {
if os.Getenv("SKIP_PARALLEL_TESTS") != "true" {
t.Parallel()
}
retryConfig := GetTestRetryConfigForScenario("ImageEditStream", testConfig)
retryContext := TestRetryContext{
ScenarioName: "ImageEditStream",
ExpectedBehavior: map[string]interface{}{
"should_generate_images": true,
},
TestMetadata: map[string]interface{}{
"provider": testConfig.Provider,
"model": testConfig.ImageEditModel,
},
}
// Load test image
lionBase64, err := GetLionBase64Image()
if err != nil {
t.Fatalf("Failed to load test image: %v", err)
}
// Convert base64 to bytes
imageBytes, err := decodeBase64ImageToBytes(lionBase64)
if err != nil {
t.Fatalf("Failed to decode image: %v", err)
}
// Convert input image to JPEG (no transparency) to avoid provider compatibility issues
imageBytes, imgWidth, imgHeight, err := convertImageToJPEG(imageBytes)
if err != nil {
t.Fatalf("Failed to convert image to JPEG: %v", err)
}
// Create mask image based on provider requirements
// Azure and OpenAI require PNG with transparent background, others use JPEG
maskBytes, err := createMaskImage(testConfig.Provider, imgWidth, imgHeight)
if err != nil {
t.Fatalf("Failed to create mask image: %v", err)
}
request := &schemas.BifrostImageEditRequest{
Provider: testConfig.Provider,
Model: testConfig.ImageEditModel,
Input: &schemas.ImageEditInput{
Images: []schemas.ImageInput{
{
Image: imageBytes,
},
},
Prompt: "Add a futuristic cityscape in the background",
},
Params: &schemas.ImageEditParameters{
Size: bifrost.Ptr("1024x1024"),
Quality: bifrost.Ptr("low"),
Type: bifrost.Ptr("inpainting"),
Mask: maskBytes,
},
Fallbacks: testConfig.ImageEditFallbacks,
}
streamCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
validationResult := WithImageGenerationStreamRetry(
t,
retryConfig,
retryContext,
func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return client.ImageEditStreamRequest(schemas.NewBifrostContext(streamCtx, schemas.NoDeadline), request)
},
func(responseChannel chan *schemas.BifrostStreamChunk) ImageGenerationStreamValidationResult {
// Validate stream content
var receivedData bool
var streamErrors []string
var validationErrors []string
hasCompleted := false
for {
select {
case response, ok := <-responseChannel:
if !ok {
goto streamComplete
}
if response == nil {
streamErrors = append(streamErrors, "Received nil stream response")
continue
}
if response.BifrostError != nil {
streamErrors = append(streamErrors, fmt.Sprintf("Error in stream: %s", GetErrorMessage(response.BifrostError)))
continue
}
if response.BifrostImageGenerationStreamResponse != nil {
receivedData = true
imgResp := response.BifrostImageGenerationStreamResponse
// Check for completion event (can be ImageGenerationEventTypeCompleted or ImageEditEventTypeCompleted)
if imgResp.Type == schemas.ImageGenerationEventTypeCompleted || imgResp.Type == schemas.ImageEditEventTypeCompleted {
hasCompleted = true
// Validate that completed images have actual data
if imgResp.URL == "" && imgResp.B64JSON == "" {
validationErrors = append(validationErrors, "Completion chunk received but image has no URL or B64JSON data")
}
}
}
case <-streamCtx.Done():
validationErrors = append(validationErrors, "Stream validation timed out")
drainCtx, drainCancel := context.WithTimeout(context.Background(), 5*time.Second)
go func() {
defer drainCancel()
for {
select {
case _, ok := <-responseChannel:
if !ok {
return
}
case <-drainCtx.Done():
return
}
}
}()
goto streamComplete
}
}
streamComplete:
// Stream errors should cause the test to fail - convert them to validation errors
if len(streamErrors) > 0 {
validationErrors = append(validationErrors, fmt.Sprintf("Stream errors encountered: %s", strings.Join(streamErrors, "; ")))
}
// Test passes only if: data received, completion received, and no errors (including stream errors)
passed := receivedData && hasCompleted && len(validationErrors) == 0
if !receivedData {
validationErrors = append(validationErrors, "No stream data received")
}
if !hasCompleted {
validationErrors = append(validationErrors, "No completion chunk received")
}
return ImageGenerationStreamValidationResult{
Passed: passed,
Errors: validationErrors,
ReceivedData: receivedData,
StreamErrors: streamErrors,
}
},
)
if !validationResult.Passed {
allErrors := append(validationResult.Errors, validationResult.StreamErrors...)
t.Fatalf("❌ Image edit stream validation failed: %s", strings.Join(allErrors, "; "))
}
if !validationResult.ReceivedData {
t.Fatal("❌ No stream data received")
}
t.Logf("✅ Image edit stream successful: ReceivedData=%v, Errors=%d, StreamErrors=%d",
validationResult.ReceivedData, len(validationResult.Errors), len(validationResult.StreamErrors))
})
}