Files
bifrost/core/providers/bedrock/images.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

743 lines
25 KiB
Go

package bedrock
import (
"encoding/base64"
"fmt"
"strconv"
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
// mapQualityToBedrock maps quality values to Bedrock format:
// - "low" and "medium" -> "standard"
// - "high" -> "premium"
// - "standard" and "premium" (case-insensitive) -> pass through as lowercase ("standard"/"premium")
func mapQualityToBedrock(quality *string) *string {
if quality == nil {
return nil
}
qualityLower := strings.ToLower(strings.TrimSpace(*quality))
switch qualityLower {
case "low", "medium":
return schemas.Ptr("standard")
case "high":
return schemas.Ptr("premium")
case "standard":
return schemas.Ptr("standard")
case "premium":
return schemas.Ptr("premium")
default:
return quality
}
}
// isStabilityAIModel returns true if the model is a Stability AI model (contains "stability.")
func isStabilityAIModel(model string) bool {
return strings.Contains(strings.ToLower(model), "stability.")
}
// isPromptOnlyImageGenerationModel returns true for image generation models that use a flat
// {"prompt": "..."} payload (no taskType field). Covers Vertex Imagen and similar models.
// Stability AI is excluded here — it's handled separately because it also supports image edit.
func isPromptOnlyImageGenerationModel(model string) bool {
m := strings.ToLower(model)
return strings.Contains(m, "image")
}
// ToStabilityAIImageGenerationRequest converts a Bifrost image generation request to the Stability AI
// flat request format used by Bedrock (stability.stable-image-* models).
func ToStabilityAIImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*StabilityAIImageGenerationRequest, error) {
if request == nil {
return nil, fmt.Errorf("request is nil")
}
if request.Input == nil {
return nil, fmt.Errorf("request input is required")
}
req := &StabilityAIImageGenerationRequest{
Prompt: request.Input.Prompt,
}
if request.Params != nil {
if request.Params.AspectRatio != nil {
req.AspectRatio = request.Params.AspectRatio
}
if request.Params.OutputFormat != nil {
req.OutputFormat = request.Params.OutputFormat
}
if request.Params.Seed != nil {
req.Seed = request.Params.Seed
}
if request.Params.NegativePrompt != nil {
req.NegativePrompt = request.Params.NegativePrompt
}
if request.Params.ExtraParams != nil {
// aspect_ratio may also arrive via ExtraParams if not in knownFields; skip if already set
if req.AspectRatio == nil {
if ar, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["aspect_ratio"]); ok {
delete(request.Params.ExtraParams, "aspect_ratio")
req.AspectRatio = ar
}
}
req.ExtraParams = request.Params.ExtraParams
}
}
return req, nil
}
// ToBedrockImageGenerationRequest converts a Bifrost image generation request to a Bedrock image generation request
func ToBedrockImageGenerationRequest(request *schemas.BifrostImageGenerationRequest) (*BedrockImageGenerationRequest, error) {
if request == nil {
return nil, fmt.Errorf("request is nil")
}
if request.Input == nil {
return nil, fmt.Errorf("request input is required")
}
bedrockReq := &BedrockImageGenerationRequest{
TaskType: schemas.Ptr(TaskTypeTextImage),
TextToImageParams: &BedrockTextToImageParams{
Text: request.Input.Prompt,
},
ImageGenerationConfig: &ImageGenerationConfig{},
}
if request.Params != nil {
if request.Params.N != nil {
bedrockReq.ImageGenerationConfig.NumberOfImages = request.Params.N
}
if request.Params.NegativePrompt != nil {
bedrockReq.TextToImageParams.NegativeText = request.Params.NegativePrompt
}
if request.Params.Seed != nil {
bedrockReq.ImageGenerationConfig.Seed = request.Params.Seed
}
if request.Params.Quality != nil {
bedrockReq.ImageGenerationConfig.Quality = mapQualityToBedrock(request.Params.Quality)
}
if request.Params.Style != nil {
bedrockReq.TextToImageParams.Style = request.Params.Style
}
if request.Params.Size != nil && strings.TrimSpace(strings.ToLower(*request.Params.Size)) != "auto" {
size := strings.Split(strings.TrimSpace(strings.ToLower(*request.Params.Size)), "x")
if len(size) != 2 {
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *request.Params.Size)
}
width, err := strconv.Atoi(size[0])
if err != nil {
return nil, fmt.Errorf("invalid width in size %q: %w", *request.Params.Size, err)
}
height, err := strconv.Atoi(size[1])
if err != nil {
return nil, fmt.Errorf("invalid height in size %q: %w", *request.Params.Size, err)
}
bedrockReq.ImageGenerationConfig.Width = schemas.Ptr(width)
bedrockReq.ImageGenerationConfig.Height = schemas.Ptr(height)
}
if request.Params.ExtraParams != nil {
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["cfgScale"]); ok {
delete(request.Params.ExtraParams, "cfgScale")
bedrockReq.ImageGenerationConfig.CfgScale = cfgScale
}
bedrockReq.ExtraParams = request.Params.ExtraParams
}
}
return bedrockReq, nil
}
// ToStabilityAIImageGenerationResponse converts a BifrostImageGenerationResponse back to
// the native Bedrock invoke API response format used by Stability AI models.
// Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas.
func ToStabilityAIImageGenerationResponse(response *schemas.BifrostImageGenerationResponse) (*BedrockImageGenerationResponse, error) {
if response == nil {
return nil, fmt.Errorf("response is nil")
}
result := &BedrockImageGenerationResponse{}
for _, d := range response.Data {
result.Images = append(result.Images, d.B64JSON)
}
if response.ImageGenerationResponseParameters != nil {
result.FinishReasons = response.ImageGenerationResponseParameters.FinishReasons
result.Seeds = response.ImageGenerationResponseParameters.Seeds
}
return result, nil
}
// ToBedrockImageVariationRequest converts a Bifrost image variation request to a Bedrock image variation request
func ToBedrockImageVariationRequest(request *schemas.BifrostImageVariationRequest) (*BedrockImageVariationRequest, error) {
if request == nil {
return nil, fmt.Errorf("request is nil")
}
if request.Input == nil || request.Input.Image.Image == nil || len(request.Input.Image.Image) == 0 {
return nil, fmt.Errorf("request.Input.Image is required")
}
bedrockReq := &BedrockImageVariationRequest{
TaskType: schemas.Ptr(TaskTypeImageVariation),
ImageVariationParams: &BedrockImageVariationParams{
Images: []string{},
},
ImageGenerationConfig: &ImageGenerationConfig{},
}
// Convert all images to base64 strings
// Primary image from Input.Image
imageBase64 := base64.StdEncoding.EncodeToString(request.Input.Image.Image)
bedrockReq.ImageVariationParams.Images = append(bedrockReq.ImageVariationParams.Images, imageBase64)
// Additional images from ExtraParams (stored as [][]byte)
if request.Params != nil && request.Params.ExtraParams != nil {
if additionalImages, ok := request.Params.ExtraParams["images"]; ok {
delete(request.Params.ExtraParams, "images")
// Handle array of byte arrays (stored by HTTP handler)
if imagesArray, ok := additionalImages.([][]byte); ok {
for _, imgBytes := range imagesArray {
if len(imgBytes) > 0 {
additionalBase64 := base64.StdEncoding.EncodeToString(imgBytes)
bedrockReq.ImageVariationParams.Images = append(bedrockReq.ImageVariationParams.Images, additionalBase64)
}
}
}
}
// Extract optional fields from ExtraParams
if prompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["prompt"]); ok {
delete(request.Params.ExtraParams, "prompt")
bedrockReq.ImageVariationParams.Text = prompt
}
if negativeText, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["negativeText"]); ok {
delete(request.Params.ExtraParams, "negativeText")
bedrockReq.ImageVariationParams.NegativeText = negativeText
}
if similarityStrength, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["similarityStrength"]); ok {
delete(request.Params.ExtraParams, "similarityStrength")
// Validate similarityStrength range (0.2 to 1.0)
if *similarityStrength < 0.2 || *similarityStrength > 1.0 {
return nil, fmt.Errorf("similarityStrength must be between 0.2 and 1.0, got %f", *similarityStrength)
}
bedrockReq.ImageVariationParams.SimilarityStrength = similarityStrength
}
bedrockReq.ExtraParams = request.Params.ExtraParams
}
// Map standard params to ImageGenerationConfig
if request.Params != nil {
if request.Params.N != nil {
bedrockReq.ImageGenerationConfig.NumberOfImages = request.Params.N
}
if request.Params.Size != nil && strings.TrimSpace(strings.ToLower(*request.Params.Size)) != "auto" {
size := strings.Split(strings.TrimSpace(strings.ToLower(*request.Params.Size)), "x")
if len(size) != 2 {
return nil, fmt.Errorf("invalid size format: expected 'WIDTHxHEIGHT', got %q", *request.Params.Size)
}
width, err := strconv.Atoi(size[0])
if err != nil {
return nil, fmt.Errorf("invalid width in size %q: %w", *request.Params.Size, err)
}
height, err := strconv.Atoi(size[1])
if err != nil {
return nil, fmt.Errorf("invalid height in size %q: %w", *request.Params.Size, err)
}
bedrockReq.ImageGenerationConfig.Width = schemas.Ptr(width)
bedrockReq.ImageGenerationConfig.Height = schemas.Ptr(height)
}
// Extract quality and cfgScale from ExtraParams
if request.Params.ExtraParams != nil {
if quality, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["quality"]); ok {
bedrockReq.ImageGenerationConfig.Quality = mapQualityToBedrock(quality)
}
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(request.Params.ExtraParams["cfgScale"]); ok {
bedrockReq.ImageGenerationConfig.CfgScale = cfgScale
}
}
}
return bedrockReq, nil
}
// ToBedrockImageEditRequest converts a Bifrost image edit request to a Bedrock image edit request
func ToBedrockImageEditRequest(request *schemas.BifrostImageEditRequest) (*BedrockImageEditRequest, error) {
// Validate request
if request == nil || request.Input == nil {
return nil, fmt.Errorf("request or input is nil")
}
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
return nil, fmt.Errorf("at least one image is required")
}
// Validate and extract type (required)
if request.Params == nil || request.Params.Type == nil {
return nil, fmt.Errorf("type field is required (must be inpainting, outpainting, or background_removal)")
}
editType := strings.ToLower(*request.Params.Type)
// Convert first image to base64
imageBase64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
bedrockReq := &BedrockImageEditRequest{}
switch editType {
case "inpainting":
bedrockReq.TaskType = schemas.Ptr(TaskTypeInpainting)
bedrockReq.InPaintingParams = buildInPaintingParams(imageBase64, request)
bedrockReq.ImageGenerationConfig = buildImageGenerationConfig(request.Params)
case "outpainting":
bedrockReq.TaskType = schemas.Ptr(TaskTypeOutpainting)
bedrockReq.OutPaintingParams = buildOutPaintingParams(imageBase64, request)
bedrockReq.ImageGenerationConfig = buildImageGenerationConfig(request.Params)
case "background_removal":
bedrockReq.TaskType = schemas.Ptr(TaskTypeBackgroundRemoval)
bedrockReq.BackgroundRemovalParams = &BedrockBackgroundRemovalParams{
Image: imageBase64,
}
default:
return nil, fmt.Errorf("unsupported type for Bedrock: %s (must be inpainting, outpainting, or background_removal)", editType)
}
bedrockReq.ExtraParams = request.Params.ExtraParams
return bedrockReq, nil
}
// Helper functions
func buildInPaintingParams(imageBase64 string, request *schemas.BifrostImageEditRequest) *BedrockInPaintingParams {
params := &BedrockInPaintingParams{
Image: imageBase64,
Text: request.Input.Prompt,
}
if request.Params.NegativePrompt != nil {
params.NegativeText = request.Params.NegativePrompt
}
if request.Params.ExtraParams != nil {
if maskPrompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["mask_prompt"]); ok {
delete(request.Params.ExtraParams, "mask_prompt")
params.MaskPrompt = maskPrompt
}
if returnMask, ok := schemas.SafeExtractBoolPointer(request.Params.ExtraParams["return_mask"]); ok {
delete(request.Params.ExtraParams, "return_mask")
params.ReturnMask = returnMask
}
}
// Convert mask to base64 if present
if len(request.Params.Mask) > 0 {
maskBase64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
params.MaskImage = &maskBase64
}
return params
}
func buildOutPaintingParams(imageBase64 string, request *schemas.BifrostImageEditRequest) *BedrockOutPaintingParams {
params := &BedrockOutPaintingParams{
Text: request.Input.Prompt,
Image: imageBase64,
}
if request.Params.NegativePrompt != nil {
params.NegativeText = request.Params.NegativePrompt
}
if request.Params.ExtraParams != nil {
if maskPrompt, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["mask_prompt"]); ok {
delete(request.Params.ExtraParams, "mask_prompt")
params.MaskPrompt = maskPrompt
}
if returnMask, ok := schemas.SafeExtractBoolPointer(request.Params.ExtraParams["return_mask"]); ok {
delete(request.Params.ExtraParams, "return_mask")
params.ReturnMask = returnMask
}
if outPaintingMode, ok := schemas.SafeExtractStringPointer(request.Params.ExtraParams["outpainting_mode"]); ok {
// Validate mode
mode := strings.ToUpper(*outPaintingMode)
if mode == "DEFAULT" || mode == "PRECISE" {
delete(request.Params.ExtraParams, "outpainting_mode")
params.OutPaintingMode = &mode
}
}
}
// Convert mask to base64 if present
if len(request.Params.Mask) > 0 {
maskBase64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
params.MaskImage = &maskBase64
}
return params
}
func buildImageGenerationConfig(params *schemas.ImageEditParameters) *ImageGenerationConfig {
config := &ImageGenerationConfig{}
if params.N != nil {
config.NumberOfImages = params.N
}
// Parse size (reuse logic from image generation)
if params.Size != nil && strings.TrimSpace(strings.ToLower(*params.Size)) != "auto" {
size := strings.Split(strings.TrimSpace(strings.ToLower(*params.Size)), "x")
if len(size) == 2 {
width, err := strconv.Atoi(size[0])
if err == nil {
height, err := strconv.Atoi(size[1])
if err == nil {
config.Width = schemas.Ptr(width)
config.Height = schemas.Ptr(height)
}
}
}
}
if params.Quality != nil {
config.Quality = mapQualityToBedrock(params.Quality)
}
if params.Seed != nil {
config.Seed = params.Seed
}
if params.ExtraParams != nil {
if cfgScale, ok := schemas.SafeExtractFloat64Pointer(params.ExtraParams["cfgScale"]); ok {
delete(params.ExtraParams, "cfgScale")
config.CfgScale = cfgScale
}
}
return config
}
// getStabilityAITaskTypeFromParams maps the generic BifrostImageEditParameters.Type value
// to a Stability AI task type string. Returns "" if the value is not a recognized Stability AI task type.
func getStabilityAITaskTypeFromParams(t string) string {
switch strings.ToLower(t) {
case "inpainting", "inpaint":
return "inpaint"
case "outpainting", "outpaint":
return "outpaint"
case "background_removal", "remove_background":
return "remove-bg"
case "erase_object":
return "erase-object"
case "upscale_fast":
return "upscale-fast"
case "upscale_creative":
return "upscale-creative"
case "upscale_conservative":
return "upscale-conservative"
case "recolor":
return "recolor"
case "search_replace":
return "search-replace"
case "control_sketch":
return "control-sketch"
case "control_structure":
return "control-structure"
case "style_guide":
return "style-guide"
case "style_transfer":
return "style-transfer"
default:
return ""
}
}
// getStabilityAIEditTaskType infers the Stability AI edit task from the model name.
// Returns an error if the model name does not match any known pattern.
func getStabilityAIEditTaskType(model string) (string, error) {
m := strings.ToLower(model)
switch {
case strings.Contains(m, "stable-creative-upscale"):
return "upscale-creative", nil
case strings.Contains(m, "stable-conservative-upscale"):
return "upscale-conservative", nil
case strings.Contains(m, "stable-fast-upscale"):
return "upscale-fast", nil
case strings.Contains(m, "stable-image-inpaint"):
return "inpaint", nil
case strings.Contains(m, "stable-outpaint"):
return "outpaint", nil
case strings.Contains(m, "stable-image-search-recolor"):
return "recolor", nil
case strings.Contains(m, "stable-image-search-replace"):
return "search-replace", nil
case strings.Contains(m, "stable-image-erase-object"):
return "erase-object", nil
case strings.Contains(m, "stable-image-remove-background"):
return "remove-bg", nil
case strings.Contains(m, "stable-image-control-sketch"):
return "control-sketch", nil
case strings.Contains(m, "stable-image-control-structure"):
return "control-structure", nil
case strings.Contains(m, "stable-image-style-guide"):
return "style-guide", nil
case strings.Contains(m, "stable-style-transfer"):
return "style-transfer", nil
default:
return "", fmt.Errorf("cannot determine task type from stability ai model name %q", model)
}
}
// ToStabilityAIImageEditRequest converts a Bifrost image edit request to the Stability AI flat request
// format used by Bedrock edit models. Only fields valid for the detected task type are populated.
// deployment is the resolved model identifier (after applying any deployment alias mapping); it is
// used for task-type inference so that alias-mapped models route correctly.
func ToStabilityAIImageEditRequest(request *schemas.BifrostImageEditRequest, deployment string) (*StabilityAIImageEditRequest, error) {
if request == nil || request.Input == nil {
return nil, fmt.Errorf("request or input is nil")
}
var taskType string
if request.Params != nil && request.Params.Type != nil {
taskType = getStabilityAITaskTypeFromParams(*request.Params.Type)
}
if taskType == "" {
var err error
taskType, err = getStabilityAIEditTaskType(deployment)
if err != nil {
return nil, err
}
}
req := &StabilityAIImageEditRequest{}
// Image sourcing
if taskType == "style-transfer" {
if len(request.Input.Images) != 2 {
return nil, fmt.Errorf("style-transfer requires exactly two images: init_image and style_image")
}
if len(request.Input.Images[0].Image) == 0 || len(request.Input.Images[1].Image) == 0 {
return nil, fmt.Errorf("style-transfer requires non-empty init_image and style_image")
}
initB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
styleB64 := base64.StdEncoding.EncodeToString(request.Input.Images[1].Image)
req.InitImage = &initB64
req.StyleImage = &styleB64
} else {
if len(request.Input.Images) == 0 || len(request.Input.Images[0].Image) == 0 {
return nil, fmt.Errorf("at least one image is required")
}
imageB64 := base64.StdEncoding.EncodeToString(request.Input.Images[0].Image)
req.Image = &imageB64
}
// Common fields populated based on task allowlist
prompt := request.Input.Prompt
switch taskType {
case "inpaint", "recolor", "search-replace", "control-sketch", "control-structure",
"style-guide", "upscale-creative", "upscale-conservative", "outpaint", "style-transfer":
req.Prompt = &prompt
}
// Negative prompt
if request.Params != nil && request.Params.NegativePrompt != nil {
switch taskType {
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
req.NegativePrompt = request.Params.NegativePrompt
}
}
// Seed
if request.Params != nil && request.Params.Seed != nil {
switch taskType {
case "inpaint", "outpaint", "recolor", "search-replace", "erase-object", "control-sketch",
"control-structure", "style-guide", "upscale-creative", "upscale-conservative", "style-transfer":
req.Seed = request.Params.Seed
}
}
// Mask (from Params.Mask bytes)
if request.Params != nil && len(request.Params.Mask) > 0 {
switch taskType {
case "inpaint", "erase-object":
maskB64 := base64.StdEncoding.EncodeToString(request.Params.Mask)
req.Mask = &maskB64
}
}
// ExtraParams
if request.Params != nil {
// Typed OutputFormat takes priority over ExtraParams
if request.Params.OutputFormat != nil {
req.OutputFormat = request.Params.OutputFormat
}
if request.Params.ExtraParams != nil {
ep := make(map[string]interface{}, len(request.Params.ExtraParams))
for k, v := range request.Params.ExtraParams {
ep[k] = v
}
// output_format — all tasks (fallback if not already set by typed field)
if req.OutputFormat == nil {
if v, ok := schemas.SafeExtractStringPointer(ep["output_format"]); ok {
delete(ep, "output_format")
req.OutputFormat = v
}
}
// style_preset
switch taskType {
case "inpaint", "outpaint", "recolor", "search-replace", "control-sketch",
"control-structure", "style-guide", "upscale-creative":
if v, ok := schemas.SafeExtractStringPointer(ep["style_preset"]); ok {
delete(ep, "style_preset")
req.StylePreset = v
}
}
// grow_mask
switch taskType {
case "inpaint", "recolor", "search-replace", "erase-object":
if v, ok := schemas.SafeExtractIntPointer(ep["grow_mask"]); ok {
delete(ep, "grow_mask")
req.GrowMask = v
}
}
// outpaint directional fields
if taskType == "outpaint" {
if v, ok := schemas.SafeExtractIntPointer(ep["left"]); ok {
delete(ep, "left")
req.Left = v
}
if v, ok := schemas.SafeExtractIntPointer(ep["right"]); ok {
delete(ep, "right")
req.Right = v
}
if v, ok := schemas.SafeExtractIntPointer(ep["up"]); ok {
delete(ep, "up")
req.Up = v
}
if v, ok := schemas.SafeExtractIntPointer(ep["down"]); ok {
delete(ep, "down")
req.Down = v
}
}
// creativity
switch taskType {
case "upscale-creative", "upscale-conservative", "outpaint":
if v, ok := schemas.SafeExtractFloat64Pointer(ep["creativity"]); ok {
delete(ep, "creativity")
req.Creativity = v
}
}
// select_prompt (recolor)
if taskType == "recolor" {
if v, ok := schemas.SafeExtractStringPointer(ep["select_prompt"]); ok {
delete(ep, "select_prompt")
req.SelectPrompt = v
}
}
// search_prompt (search-replace)
if taskType == "search-replace" {
if v, ok := schemas.SafeExtractStringPointer(ep["search_prompt"]); ok {
delete(ep, "search_prompt")
req.SearchPrompt = v
}
}
// control_strength
switch taskType {
case "control-sketch", "control-structure":
if v, ok := schemas.SafeExtractFloat64Pointer(ep["control_strength"]); ok {
delete(ep, "control_strength")
req.ControlStrength = v
}
}
// style-guide fields
if taskType == "style-guide" {
if v, ok := schemas.SafeExtractStringPointer(ep["aspect_ratio"]); ok {
delete(ep, "aspect_ratio")
req.AspectRatio = v
}
if v, ok := schemas.SafeExtractFloat64Pointer(ep["fidelity"]); ok {
delete(ep, "fidelity")
req.Fidelity = v
}
}
// style-transfer fields
if taskType == "style-transfer" {
if v, ok := schemas.SafeExtractFloat64Pointer(ep["style_strength"]); ok {
delete(ep, "style_strength")
req.StyleStrength = v
}
if v, ok := schemas.SafeExtractFloat64Pointer(ep["composition_fidelity"]); ok {
delete(ep, "composition_fidelity")
req.CompositionFidelity = v
}
if v, ok := schemas.SafeExtractFloat64Pointer(ep["change_strength"]); ok {
delete(ep, "change_strength")
req.ChangeStrength = v
}
}
req.ExtraParams = ep
}
}
// Validate required per-task fields
if taskType == "recolor" && (req.SelectPrompt == nil || *req.SelectPrompt == "") {
return nil, fmt.Errorf("select_prompt is required for stability ai recolor task")
}
if taskType == "search-replace" && (req.SearchPrompt == nil || *req.SearchPrompt == "") {
return nil, fmt.Errorf("search_prompt is required for stability ai search-replace task")
}
return req, nil
}
// ToBifrostImageGenerationResponse converts a Bedrock image generation response to a Bifrost image generation response
func ToBifrostImageGenerationResponse(response *BedrockImageGenerationResponse) *schemas.BifrostImageGenerationResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostImageGenerationResponse{}
if len(response.FinishReasons) > 0 || len(response.Seeds) > 0 {
bifrostResponse.ImageGenerationResponseParameters = &schemas.ImageGenerationResponseParameters{
FinishReasons: append([]*string(nil), response.FinishReasons...),
Seeds: append([]int(nil), response.Seeds...),
}
}
for index, image := range response.Images {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.ImageData{
B64JSON: image,
Index: index,
})
}
return bifrostResponse
}