743 lines
25 KiB
Go
743 lines
25 KiB
Go
package bedrock
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// mapQualityToBedrock maps quality values to Bedrock format:
|
|
// - "low" and "medium" -> "standard"
|
|
// - "high" -> "premium"
|
|
// - "standard" and "premium" (case-insensitive) -> pass through as lowercase ("standard"/"premium")
|
|
func mapQualityToBedrock(quality *string) *string {
|
|
if quality == nil {
|
|
return nil
|
|
}
|
|
|
|
qualityLower := strings.ToLower(strings.TrimSpace(*quality))
|
|
|
|
switch qualityLower {
|
|
case "low", "medium":
|
|
return schemas.Ptr("standard")
|
|
case "high":
|
|
return schemas.Ptr("premium")
|
|
case "standard":
|
|
return schemas.Ptr("standard")
|
|
case "premium":
|
|
return schemas.Ptr("premium")
|
|
default:
|
|
return quality
|
|
}
|
|
}
|
|
|
|
// isStabilityAIModel returns true if the model is a Stability AI model (contains "stability.")
|
|
func isStabilityAIModel(model string) bool {
|
|
return strings.Contains(strings.ToLower(model), "stability.")
|
|
}
|
|
|
|
// isPromptOnlyImageGenerationModel returns true for image generation models that use a flat
|
|
// {"prompt": "..."} payload (no taskType field). Covers Vertex Imagen and similar models.
|
|
// Stability AI is excluded here — it's handled separately because it also supports image edit.
|
|
func isPromptOnlyImageGenerationModel(model string) bool {
|
|
m := strings.ToLower(model)
|
|
return strings.Contains(m, "image")
|
|
}
|
|
|
|
// ToStabilityAIImageGenerationRequest converts a Bifrost image generation request to the Stability AI
|
|
// flat request format used by Bedrock (stability.stable-image-* models).
|
|
func ToStabilityAIImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*StabilityAIImageGenerationRequest, error) {
|
|
if request == nil {
|
|
return nil, fmt.Errorf("request is nil")
|
|
}
|
|
if request.Input == nil {
|
|
return nil, fmt.Errorf("request input is required")
|
|
}
|
|
|
|
req := &StabilityAIImageGenerationRequest{
|
|
Prompt: request.Input.Prompt,
|
|
}
|
|
|
|
if request.Params != nil {
|
|
if request.Params.AspectRatio != nil {
|
|
req.AspectRatio = request.Params.AspectRatio
|
|
}
|
|
if request.Params.OutputFormat != nil {
|
|
req.OutputFormat = request.Params.OutputFormat
|
|
}
|
|
if request.Params.Seed != nil {
|
|
req.Seed = request.Params.Seed
|
|
}
|
|
if request.Params.NegativePrompt != nil {
|
|
req.NegativePrompt = request.Params.NegativePrompt
|
|
}
|
|
if request.Params.ExtraParams != nil {
|
|
// aspect_ratio may also arrive via ExtraParams if not in knownFields; skip if already set
|
|
if req.AspectRatio == nil {
|
|
if ar, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["aspect_ratio"]); ok {
|
|
delete(request.Params.ExtraParams, "aspect_ratio")
|
|
req.AspectRatio = ar
|
|
}
|
|
}
|
|
req.ExtraParams = request.Params.ExtraParams
|
|
}
|
|
}
|
|
|
|
return req, nil
|
|
}
|
|
|
|
// ToBedrockImageGenerationRequest converts a Bifrost image generation request to a Bedrock image generation request
|
|
func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*BedrockImageGenerationRequest, error) {
|
|
if request == nil {
|
|
return nil, fmt.Errorf("request is nil")
|
|
}
|
|
|
|
if request.Input == nil {
|
|
return nil, fmt.Errorf("request input is required")
|
|
}
|
|
|
|
bedrockReq := &BedrockImageGenerationRequest{
|
|
TaskType: schemas.Ptr(TaskTypeTextImage),
|
|
TextToImageParams: &BedrockTextToImageParams{
|
|
Text: request.Input.Prompt,
|
|
},
|
|
ImageGenerationConfig: &ImageGenerationConfig{},
|
|
}
|
|
|
|
if request.Params != nil {
|
|
if request.Params.N != nil {
|
|
bedrockReq.ImageGenerationConfig.NumberOfImages = request.Params.N
|
|
}
|
|
if request.Params.NegativePrompt != nil {
|
|
bedrockReq.TextToImageParams.NegativeText = request.Params.NegativePrompt
|
|
}
|
|
if request.Params.Seed != nil {
|
|
bedrockReq.ImageGenerationConfig.Seed = request.Params.Seed
|
|
}
|
|
if request.Params.Quality != nil {
|
|
bedrockReq.ImageGenerationConfig.Quality = mapQualityToBedrock(request.Params.Quality)
|
|
}
|
|
if request.Params.Style != nil {
|
|
bedrockReq.TextToImageParams.Style = request.Params.Style
|
|
}
|
|
if request.Params.Size != nil && strings.TrimSpace(strings.ToLower(*request.Params.Size)) != "auto" {
|
|
|
|
size := strings.Split(strings.TrimSpace(strings.ToLower(*request.Params.Size)), "x")
|
|
if len(size) != 2 {
|
|
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *request.Params.Size)
|
|
}
|
|
|
|
width, err := strconv.Atoi(size[0])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid width in size %q: %w", *request.Params.Size, err)
|
|
}
|
|
|
|
height, err := strconv.Atoi(size[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid height in size %q: %w", *request.Params.Size, err)
|
|
}
|
|
|
|
bedrockReq.ImageGenerationConfig.Width = schemas.Ptr(width)
|
|
bedrockReq.ImageGenerationConfig.Height = schemas.Ptr(height)
|
|
}
|
|
if request.Params.ExtraParams != nil {
|
|
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["cfgScale"]); ok {
|
|
delete(request.Params.ExtraParams, "cfgScale")
|
|
bedrockReq.ImageGenerationConfig.CfgScale = cfgScale
|
|
}
|
|
bedrockReq.ExtraParams = request.Params.ExtraParams
|
|
}
|
|
}
|
|
|
|
return bedrockReq, nil
|
|
}
|
|
|
|
// ToStabilityAIImageGenerationResponse converts a BifrostImageGenerationResponse back to
|
|
// the native Bedrock invoke API response format used by Stability AI models.
|
|
// Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas.
|
|
func ToStabilityAIImageGenerationResponse(response *schemas.BifrostImageGenerationResponse) (*BedrockImageGenerationResponse, error) {
|
|
if response == nil {
|
|
return nil, fmt.Errorf("response is nil")
|
|
}
|
|
result := &BedrockImageGenerationResponse{}
|
|
for _, d := range response.Data {
|
|
result.Images = append(result.Images, d.B64JSON)
|
|
}
|
|
if response.ImageGenerationResponseParameters != nil {
|
|
result.FinishReasons = response.ImageGenerationResponseParameters.FinishReasons
|
|
result.Seeds = response.ImageGenerationResponseParameters.Seeds
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// ToBedrockImageVariationRequest converts a Bifrost image variation request to a Bedrock image variation request
|
|
func ToBedrockImageVariationRequest(request *schemas.BifrostImageVariationRequest) (*BedrockImageVariationRequest, error) {
|
|
if request == nil {
|
|
return nil, fmt.Errorf("request is nil")
|
|
}
|
|
|
|
if request.Input == nil || request.Input.Image.Image == nil || len(request.Input.Image.Image) == 0 {
|
|
return nil, fmt.Errorf("request.Input.Image is required")
|
|
}
|
|
|
|
bedrockReq := &BedrockImageVariationRequest{
|
|
TaskType: schemas.Ptr(TaskTypeImageVariation),
|
|
ImageVariationParams: &BedrockImageVariationParams{
|
|
Images: []string{},
|
|
},
|
|
ImageGenerationConfig: &ImageGenerationConfig{},
|
|
}
|
|
|
|
// Convert all images to base64 strings
|
|
// Primary image from Input.Image
|
|
imageBase64 := base64.StdEncoding.EncodeToString(request.Input.Image.Image)
|
|
bedrockReq.ImageVariationParams.Images = append(bedrockReq.ImageVariationParams.Images, imageBase64)
|
|
|
|
// Additional images from ExtraParams (stored as [][]byte)
|
|
if request.Params != nil && request.Params.ExtraParams != nil {
|
|
if additionalImages, ok := request.Params.ExtraParams["images"]; ok {
|
|
delete(request.Params.ExtraParams, "images")
|
|
// Handle array of byte arrays (stored by HTTP handler)
|
|
if imagesArray, ok := additionalImages.([][]byte); ok {
|
|
for _, imgBytes := range imagesArray {
|
|
if len(imgBytes) > 0 {
|
|
additionalBase64 := base64.StdEncoding.EncodeToString(imgBytes)
|
|
bedrockReq.ImageVariationParams.Images = append(bedrockReq.ImageVariationParams.Images, additionalBase64)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Extract optional fields from ExtraParams
|
|
if prompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["prompt"]); ok {
|
|
delete(request.Params.ExtraParams, "prompt")
|
|
bedrockReq.ImageVariationParams.Text = prompt
|
|
}
|
|
if negativeText, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["negativeText"]); ok {
|
|
delete(request.Params.ExtraParams, "negativeText")
|
|
bedrockReq.ImageVariationParams.NegativeText = negativeText
|
|
}
|
|
|
|
if similarityStrength, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["similarityStrength"]); ok {
|
|
delete(request.Params.ExtraParams, "similarityStrength")
|
|
// Validate similarityStrength range (0.2 to 1.0)
|
|
if *similarityStrength < 0.2 || *similarityStrength > 1.0 {
|
|
return nil, fmt.Errorf("similarityStrength must be between 0.2 and 1.0, got %f", *similarityStrength)
|
|
}
|
|
bedrockReq.ImageVariationParams.SimilarityStrength = similarityStrength
|
|
}
|
|
bedrockReq.ExtraParams = request.Params.ExtraParams
|
|
}
|
|
|
|
// Map standard params to ImageGenerationConfig
|
|
if request.Params != nil {
|
|
if request.Params.N != nil {
|
|
bedrockReq.ImageGenerationConfig.NumberOfImages = request.Params.N
|
|
}
|
|
|
|
if request.Params.Size != nil && strings.TrimSpace(strings.ToLower(*request.Params.Size)) != "auto" {
|
|
size := strings.Split(strings.TrimSpace(strings.ToLower(*request.Params.Size)), "x")
|
|
if len(size) != 2 {
|
|
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *request.Params.Size)
|
|
}
|
|
|
|
width, err := strconv.Atoi(size[0])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid width in size %q: %w", *request.Params.Size, err)
|
|
}
|
|
|
|
height, err := strconv.Atoi(size[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid height in size %q: %w", *request.Params.Size, err)
|
|
}
|
|
|
|
bedrockReq.ImageGenerationConfig.Width = schemas.Ptr(width)
|
|
bedrockReq.ImageGenerationConfig.Height = schemas.Ptr(height)
|
|
}
|
|
|
|
// Extract quality and cfgScale from ExtraParams
|
|
if request.Params.ExtraParams != nil {
|
|
if quality, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["quality"]); ok {
|
|
bedrockReq.ImageGenerationConfig.Quality = mapQualityToBedrock(quality)
|
|
}
|
|
|
|
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["cfgScale"]); ok {
|
|
bedrockReq.ImageGenerationConfig.CfgScale = cfgScale
|
|
}
|
|
}
|
|
}
|
|
|
|
return bedrockReq, nil
|
|
}
|
|
|
|
// ToBedrockImageEditRequest converts a Bifrost image edit request to a Bedrock image edit request
|
|
func ToBedrockImageEditRequest(request *schemas.BifrostImageEditRequest) (*BedrockImageEditRequest, error) {
|
|
// Validate request
|
|
if request == nil || request.Input == nil {
|
|
return nil, fmt.Errorf("request or input is nil")
|
|
}
|
|
|
|
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
|
|
return nil, fmt.Errorf("at least one image is required")
|
|
}
|
|
|
|
// Validate and extract type (required)
|
|
if request.Params == nil || request.Params.Type == nil {
|
|
return nil, fmt.Errorf("type field is required (must be inpainting, outpainting, or background_removal)")
|
|
}
|
|
|
|
editType := strings.ToLower(*request.Params.Type)
|
|
|
|
// Convert first image to base64
|
|
imageBase64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
|
|
|
|
bedrockReq := &BedrockImageEditRequest{}
|
|
|
|
switch editType {
|
|
case "inpainting":
|
|
bedrockReq.TaskType = schemas.Ptr(TaskTypeInpainting)
|
|
bedrockReq.InPaintingParams = buildInPaintingParams(imageBase64, request)
|
|
bedrockReq.ImageGenerationConfig = buildImageGenerationConfig(request.Params)
|
|
|
|
case "outpainting":
|
|
bedrockReq.TaskType = schemas.Ptr(TaskTypeOutpainting)
|
|
bedrockReq.OutPaintingParams = buildOutPaintingParams(imageBase64, request)
|
|
bedrockReq.ImageGenerationConfig = buildImageGenerationConfig(request.Params)
|
|
|
|
case "background_removal":
|
|
bedrockReq.TaskType = schemas.Ptr(TaskTypeBackgroundRemoval)
|
|
bedrockReq.BackgroundRemovalParams = &BedrockBackgroundRemovalParams{
|
|
Image: imageBase64,
|
|
}
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unsupported type for Bedrock: %s (must be inpainting, outpainting, or background_removal)", editType)
|
|
}
|
|
|
|
bedrockReq.ExtraParams = request.Params.ExtraParams
|
|
return bedrockReq, nil
|
|
}
|
|
|
|
// Helper functions
|
|
func buildInPaintingParams(imageBase64 string, request *schemas.BifrostImageEditRequest) *BedrockInPaintingParams {
|
|
params := &BedrockInPaintingParams{
|
|
Image: imageBase64,
|
|
Text: request.Input.Prompt,
|
|
}
|
|
|
|
if request.Params.NegativePrompt != nil {
|
|
params.NegativeText = request.Params.NegativePrompt
|
|
}
|
|
|
|
if request.Params.ExtraParams != nil {
|
|
if maskPrompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["mask_prompt"]); ok {
|
|
delete(request.Params.ExtraParams, "mask_prompt")
|
|
params.MaskPrompt = maskPrompt
|
|
}
|
|
if returnMask, ok := schemas.SafeExtractBoolPointer(request.Params.ExtraParams["return_mask"]); ok {
|
|
delete(request.Params.ExtraParams, "return_mask")
|
|
params.ReturnMask = returnMask
|
|
}
|
|
}
|
|
|
|
// Convert mask to base64 if present
|
|
if len(request.Params.Mask) > 0 {
|
|
maskBase64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
|
|
params.MaskImage = &maskBase64
|
|
}
|
|
|
|
return params
|
|
}
|
|
|
|
func buildOutPaintingParams(imageBase64 string, request *schemas.BifrostImageEditRequest) *BedrockOutPaintingParams {
|
|
params := &BedrockOutPaintingParams{
|
|
Text: request.Input.Prompt,
|
|
Image: imageBase64,
|
|
}
|
|
|
|
if request.Params.NegativePrompt != nil {
|
|
params.NegativeText = request.Params.NegativePrompt
|
|
}
|
|
|
|
if request.Params.ExtraParams != nil {
|
|
if maskPrompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["mask_prompt"]); ok {
|
|
delete(request.Params.ExtraParams, "mask_prompt")
|
|
params.MaskPrompt = maskPrompt
|
|
}
|
|
if returnMask, ok := schemas.SafeExtractBoolPointer(request.Params.ExtraParams["return_mask"]); ok {
|
|
delete(request.Params.ExtraParams, "return_mask")
|
|
params.ReturnMask = returnMask
|
|
}
|
|
if outPaintingMode, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["outpainting_mode"]); ok {
|
|
// Validate mode
|
|
mode := strings.ToUpper(*outPaintingMode)
|
|
if mode == "DEFAULT" || mode == "PRECISE" {
|
|
delete(request.Params.ExtraParams, "outpainting_mode")
|
|
params.OutPaintingMode = &mode
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert mask to base64 if present
|
|
if len(request.Params.Mask) > 0 {
|
|
maskBase64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
|
|
params.MaskImage = &maskBase64
|
|
}
|
|
|
|
return params
|
|
}
|
|
|
|
func buildImageGenerationConfig(params *schemas.ImageEditParameters) *ImageGenerationConfig {
|
|
config := &ImageGenerationConfig{}
|
|
|
|
if params.N != nil {
|
|
config.NumberOfImages = params.N
|
|
}
|
|
|
|
// Parse size (reuse logic from image generation)
|
|
if params.Size != nil && strings.TrimSpace(strings.ToLower(*params.Size)) != "auto" {
|
|
size := strings.Split(strings.TrimSpace(strings.ToLower(*params.Size)), "x")
|
|
if len(size) == 2 {
|
|
width, err := strconv.Atoi(size[0])
|
|
if err == nil {
|
|
height, err := strconv.Atoi(size[1])
|
|
if err == nil {
|
|
config.Width = schemas.Ptr(width)
|
|
config.Height = schemas.Ptr(height)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if params.Quality != nil {
|
|
config.Quality = mapQualityToBedrock(params.Quality)
|
|
}
|
|
|
|
if params.Seed != nil {
|
|
config.Seed = params.Seed
|
|
}
|
|
|
|
if params.ExtraParams != nil {
|
|
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(params.ExtraParams["cfgScale"]); ok {
|
|
delete(params.ExtraParams, "cfgScale")
|
|
config.CfgScale = cfgScale
|
|
}
|
|
}
|
|
|
|
return config
|
|
}
|
|
|
|
// getStabilityAITaskTypeFromParams maps the generic BifrostImageEditParameters.Type value
|
|
// to a Stability AI task type string. Returns "" if the value is not a recognized Stability AI task type.
|
|
func getStabilityAITaskTypeFromParams(t string) string {
|
|
switch strings.ToLower(t) {
|
|
case "inpainting", "inpaint":
|
|
return "inpaint"
|
|
case "outpainting", "outpaint":
|
|
return "outpaint"
|
|
case "background_removal", "remove_background":
|
|
return "remove-bg"
|
|
case "erase_object":
|
|
return "erase-object"
|
|
case "upscale_fast":
|
|
return "upscale-fast"
|
|
case "upscale_creative":
|
|
return "upscale-creative"
|
|
case "upscale_conservative":
|
|
return "upscale-conservative"
|
|
case "recolor":
|
|
return "recolor"
|
|
case "search_replace":
|
|
return "search-replace"
|
|
case "control_sketch":
|
|
return "control-sketch"
|
|
case "control_structure":
|
|
return "control-structure"
|
|
case "style_guide":
|
|
return "style-guide"
|
|
case "style_transfer":
|
|
return "style-transfer"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
// getStabilityAIEditTaskType infers the Stability AI edit task from the model name.
|
|
// Returns an error if the model name does not match any known pattern.
|
|
func getStabilityAIEditTaskType(model string) (string, error) {
|
|
m := strings.ToLower(model)
|
|
switch {
|
|
case strings.Contains(m, "stable-creative-upscale"):
|
|
return "upscale-creative", nil
|
|
case strings.Contains(m, "stable-conservative-upscale"):
|
|
return "upscale-conservative", nil
|
|
case strings.Contains(m, "stable-fast-upscale"):
|
|
return "upscale-fast", nil
|
|
case strings.Contains(m, "stable-image-inpaint"):
|
|
return "inpaint", nil
|
|
case strings.Contains(m, "stable-outpaint"):
|
|
return "outpaint", nil
|
|
case strings.Contains(m, "stable-image-search-recolor"):
|
|
return "recolor", nil
|
|
case strings.Contains(m, "stable-image-search-replace"):
|
|
return "search-replace", nil
|
|
case strings.Contains(m, "stable-image-erase-object"):
|
|
return "erase-object", nil
|
|
case strings.Contains(m, "stable-image-remove-background"):
|
|
return "remove-bg", nil
|
|
case strings.Contains(m, "stable-image-control-sketch"):
|
|
return "control-sketch", nil
|
|
case strings.Contains(m, "stable-image-control-structure"):
|
|
return "control-structure", nil
|
|
case strings.Contains(m, "stable-image-style-guide"):
|
|
return "style-guide", nil
|
|
case strings.Contains(m, "stable-style-transfer"):
|
|
return "style-transfer", nil
|
|
default:
|
|
return "", fmt.Errorf("cannot determine task type from stability ai model name %q", model)
|
|
}
|
|
}
|
|
|
|
// ToStabilityAIImageEditRequest converts a Bifrost image edit request to the Stability AI flat request
|
|
// format used by Bedrock edit models. Only fields valid for the detected task type are populated.
|
|
// deployment is the resolved model identifier (after applying any deployment alias mapping); it is
|
|
// used for task-type inference so that alias-mapped models route correctly.
|
|
func ToStabilityAIImageEditRequest(request *schemas.BifrostImageEditRequest, deployment string) (*StabilityAIImageEditRequest, error) {
|
|
if request == nil || request.Input == nil {
|
|
return nil, fmt.Errorf("request or input is nil")
|
|
}
|
|
|
|
var taskType string
|
|
if request.Params != nil && request.Params.Type != nil {
|
|
taskType = getStabilityAITaskTypeFromParams(*request.Params.Type)
|
|
}
|
|
if taskType == "" {
|
|
var err error
|
|
taskType, err = getStabilityAIEditTaskType(deployment)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
req := &StabilityAIImageEditRequest{}
|
|
|
|
// Image sourcing
|
|
if taskType == "style-transfer" {
|
|
if len(request.Input.Images) != 2 {
|
|
return nil, fmt.Errorf("style-transfer requires exactly two images: init_image and style_image")
|
|
}
|
|
if len(request.Input.Images[0].Image) == 0 || len(request.Input.Images[1].Image) == 0 {
|
|
return nil, fmt.Errorf("style-transfer requires non-empty init_image and style_image")
|
|
}
|
|
initB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
|
|
styleB64 := base64.StdEncoding.EncodeToString(request.Input.Images[1].Image)
|
|
req.InitImage = &initB64
|
|
req.StyleImage = &styleB64
|
|
} else {
|
|
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
|
|
return nil, fmt.Errorf("at least one image is required")
|
|
}
|
|
imageB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
|
|
req.Image = &imageB64
|
|
}
|
|
|
|
// Common fields populated based on task allowlist
|
|
prompt := request.Input.Prompt
|
|
switch taskType {
|
|
case "inpaint", "recolor", "search-replace", "control-sketch", "control-structure",
|
|
"style-guide", "upscale-creative", "upscale-conservative", "outpaint", "style-transfer":
|
|
req.Prompt = &prompt
|
|
}
|
|
|
|
// Negative prompt
|
|
if request.Params != nil && request.Params.NegativePrompt != nil {
|
|
switch taskType {
|
|
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
|
|
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
|
|
req.NegativePrompt = request.Params.NegativePrompt
|
|
}
|
|
}
|
|
|
|
// Seed
|
|
if request.Params != nil && request.Params.Seed != nil {
|
|
switch taskType {
|
|
case "inpaint", "outpaint", "recolor", "search-replace", "erase-object", "control-sketch",
|
|
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
|
|
req.Seed = request.Params.Seed
|
|
}
|
|
}
|
|
|
|
// Mask (from Params.Mask bytes)
|
|
if request.Params != nil && len(request.Params.Mask) > 0 {
|
|
switch taskType {
|
|
case "inpaint", "erase-object":
|
|
maskB64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
|
|
req.Mask = &maskB64
|
|
}
|
|
}
|
|
|
|
// ExtraParams
|
|
if request.Params != nil {
|
|
// Typed OutputFormat takes priority over ExtraParams
|
|
if request.Params.OutputFormat != nil {
|
|
req.OutputFormat = request.Params.OutputFormat
|
|
}
|
|
|
|
if request.Params.ExtraParams != nil {
|
|
ep := make(map[string]interface{}, len(request.Params.ExtraParams))
|
|
for k, v := range request.Params.ExtraParams {
|
|
ep[k] = v
|
|
}
|
|
|
|
// output_format — all tasks (fallback if not already set by typed field)
|
|
if req.OutputFormat == nil {
|
|
if v, ok := schemas.SafeExtractStringPointer(ep["output_format"]); ok {
|
|
delete(ep, "output_format")
|
|
req.OutputFormat = v
|
|
}
|
|
}
|
|
|
|
// style_preset
|
|
switch taskType {
|
|
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
|
|
"control-structure", "style-guide", "upscale-creative":
|
|
if v, ok := schemas.SafeExtractStringPointer(ep["style_preset"]); ok {
|
|
delete(ep, "style_preset")
|
|
req.StylePreset = v
|
|
}
|
|
}
|
|
|
|
// grow_mask
|
|
switch taskType {
|
|
case "inpaint", "recolor", "search-replace", "erase-object":
|
|
if v, ok := schemas.SafeExtractIntPointer(ep["grow_mask"]); ok {
|
|
delete(ep, "grow_mask")
|
|
req.GrowMask = v
|
|
}
|
|
}
|
|
|
|
// outpaint directional fields
|
|
if taskType == "outpaint" {
|
|
if v, ok := schemas.SafeExtractIntPointer(ep["left"]); ok {
|
|
delete(ep, "left")
|
|
req.Left = v
|
|
}
|
|
if v, ok := schemas.SafeExtractIntPointer(ep["right"]); ok {
|
|
delete(ep, "right")
|
|
req.Right = v
|
|
}
|
|
if v, ok := schemas.SafeExtractIntPointer(ep["up"]); ok {
|
|
delete(ep, "up")
|
|
req.Up = v
|
|
}
|
|
if v, ok := schemas.SafeExtractIntPointer(ep["down"]); ok {
|
|
delete(ep, "down")
|
|
req.Down = v
|
|
}
|
|
}
|
|
|
|
// creativity
|
|
switch taskType {
|
|
case "upscale-creative", "upscale-conservative", "outpaint":
|
|
if v, ok := schemas.SafeExtractFloat64Pointer(ep["creativity"]); ok {
|
|
delete(ep, "creativity")
|
|
req.Creativity = v
|
|
}
|
|
}
|
|
|
|
// select_prompt (recolor)
|
|
if taskType == "recolor" {
|
|
if v, ok := schemas.SafeExtractStringPointer(ep["select_prompt"]); ok {
|
|
delete(ep, "select_prompt")
|
|
req.SelectPrompt = v
|
|
}
|
|
}
|
|
|
|
// search_prompt (search-replace)
|
|
if taskType == "search-replace" {
|
|
if v, ok := schemas.SafeExtractStringPointer(ep["search_prompt"]); ok {
|
|
delete(ep, "search_prompt")
|
|
req.SearchPrompt = v
|
|
}
|
|
}
|
|
|
|
// control_strength
|
|
switch taskType {
|
|
case "control-sketch", "control-structure":
|
|
if v, ok := schemas.SafeExtractFloat64Pointer(ep["control_strength"]); ok {
|
|
delete(ep, "control_strength")
|
|
req.ControlStrength = v
|
|
}
|
|
}
|
|
|
|
// style-guide fields
|
|
if taskType == "style-guide" {
|
|
if v, ok := schemas.SafeExtractStringPointer(ep["aspect_ratio"]); ok {
|
|
delete(ep, "aspect_ratio")
|
|
req.AspectRatio = v
|
|
}
|
|
if v, ok := schemas.SafeExtractFloat64Pointer(ep["fidelity"]); ok {
|
|
delete(ep, "fidelity")
|
|
req.Fidelity = v
|
|
}
|
|
}
|
|
|
|
// style-transfer fields
|
|
if taskType == "style-transfer" {
|
|
if v, ok := schemas.SafeExtractFloat64Pointer(ep["style_strength"]); ok {
|
|
delete(ep, "style_strength")
|
|
req.StyleStrength = v
|
|
}
|
|
if v, ok := schemas.SafeExtractFloat64Pointer(ep["composition_fidelity"]); ok {
|
|
delete(ep, "composition_fidelity")
|
|
req.CompositionFidelity = v
|
|
}
|
|
if v, ok := schemas.SafeExtractFloat64Pointer(ep["change_strength"]); ok {
|
|
delete(ep, "change_strength")
|
|
req.ChangeStrength = v
|
|
}
|
|
}
|
|
|
|
req.ExtraParams = ep
|
|
}
|
|
}
|
|
|
|
// Validate required per-task fields
|
|
if taskType == "recolor" && (req.SelectPrompt == nil || *req.SelectPrompt == "") {
|
|
return nil, fmt.Errorf("select_prompt is required for stability ai recolor task")
|
|
}
|
|
if taskType == "search-replace" && (req.SearchPrompt == nil || *req.SearchPrompt == "") {
|
|
return nil, fmt.Errorf("search_prompt is required for stability ai search-replace task")
|
|
}
|
|
|
|
return req, nil
|
|
}
|
|
|
|
// ToBifrostImageGenerationResponse converts a Bedrock image generation response to a Bifrost image generation response
|
|
func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) *schemas.BifrostImageGenerationResponse {
|
|
if response == nil {
|
|
return nil
|
|
}
|
|
|
|
bifrostResponse := &schemas.BifrostImageGenerationResponse{}
|
|
|
|
if len(response.FinishReasons) > 0 || len(response.Seeds) > 0 {
|
|
bifrostResponse.ImageGenerationResponseParameters = &schemas.ImageGenerationResponseParameters{
|
|
FinishReasons: append([]*string(nil), response.FinishReasons...),
|
|
Seeds: append([]int(nil), response.Seeds...),
|
|
}
|
|
}
|
|
|
|
for index, image := range response.Images {
|
|
bifrostResponse.Data = append(bifrostResponse.Data, schemas.ImageData{
|
|
B64JSON: image,
|
|
Index: index,
|
|
})
|
|
}
|
|
|
|
return bifrostResponse
|
|
}
|