Files
bifrost/core/providers/bedrock/embedding.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

272 lines
8.0 KiB
Go

package bedrock
import (
"encoding/json"
"fmt"
"strings"
"github.com/maximhq/bifrost/core/schemas"
)
// ToBedrockTitanEmbeddingRequest converts a Bifrost embedding request to Bedrock Titan format
func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockTitanEmbeddingRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrost embedding request is nil")
}
// Validate that only single text input is provided for Titan models
if bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0 {
return nil, fmt.Errorf("no input text provided for embedding")
}
titanReq := &BedrockTitanEmbeddingRequest{}
// Set input text
if bifrostReq.Input.Text != nil {
titanReq.InputText = *bifrostReq.Input.Text
} else if len(bifrostReq.Input.Texts) > 0 {
var embeddingText string
for _, text := range bifrostReq.Input.Texts {
embeddingText += text + " \n"
}
titanReq.InputText = embeddingText
}
if bifrostReq.Params != nil {
titanReq.Dimensions = bifrostReq.Params.Dimensions
if normalize, ok := bifrostReq.Params.ExtraParams["normalize"]; ok {
if b, ok := normalize.(bool); ok {
titanReq.Normalize = &b
}
}
// Forward remaining extra params (excluding normalize which is now a first-class field)
if len(bifrostReq.Params.ExtraParams) > 0 {
extra := make(map[string]interface{})
for k, v := range bifrostReq.Params.ExtraParams {
if k != "normalize" {
extra[k] = v
}
}
if len(extra) > 0 {
titanReq.ExtraParams = extra
}
}
}
return titanReq, nil
}
// ToBifrostEmbeddingResponse converts a Bedrock Titan embedding response to Bifrost format
func (response *BedrockTitanEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostEmbeddingResponse{
Object: "list",
Data: []schemas.EmbeddingData{
{
Index: 0,
Object: "embedding",
Embedding: schemas.EmbeddingStruct{
EmbeddingArray: response.Embedding,
},
},
},
Usage: &schemas.BifrostLLMUsage{
PromptTokens: response.InputTextTokenCount,
TotalTokens: response.InputTextTokenCount,
},
}
return bifrostResponse
}
// ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format.
// Unlike the direct Cohere API, Bedrock does not accept a "model" field in the request body.
func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockCohereEmbeddingRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrost embedding request is nil")
}
if bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0) {
return nil, fmt.Errorf("no input provided for embedding")
}
req := &BedrockCohereEmbeddingRequest{}
// Map texts
if bifrostReq.Input.Text != nil {
req.Texts = []string{*bifrostReq.Input.Text}
} else if len(bifrostReq.Input.Texts) > 0 {
req.Texts = bifrostReq.Input.Texts
}
if bifrostReq.Params != nil {
extra := make(map[string]interface{}, len(bifrostReq.Params.ExtraParams))
for k, v := range bifrostReq.Params.ExtraParams {
extra[k] = v
}
if v, ok := extra["input_type"]; ok {
if s, ok := v.(string); ok {
req.InputType = s
delete(extra, "input_type")
}
}
if v, ok := extra["truncate"]; ok {
if s, ok := v.(string); ok {
req.Truncate = &s
delete(extra, "truncate")
}
}
if v, ok := extra["embedding_types"]; ok {
if ss, ok := v.([]string); ok {
req.EmbeddingTypes = ss
delete(extra, "embedding_types")
}
}
if v, ok := extra["images"]; ok {
if ss, ok := v.([]string); ok {
req.Images = ss
delete(extra, "images")
}
}
if v, ok := extra["inputs"]; ok {
if inputs, ok := v.([]BedrockCohereEmbeddingInput); ok {
req.Inputs = inputs
delete(extra, "inputs")
}
}
if v, ok := extra["max_tokens"]; ok {
switch n := v.(type) {
case int:
req.MaxTokens = &n
delete(extra, "max_tokens")
case float64:
i := int(n)
req.MaxTokens = &i
delete(extra, "max_tokens")
}
}
if bifrostReq.Params.Dimensions != nil {
req.OutputDimension = bifrostReq.Params.Dimensions
}
if len(extra) > 0 {
req.ExtraParams = extra
}
}
return req, nil
}
// DetermineEmbeddingModelType determines the embedding model type from the model name
func DetermineEmbeddingModelType(model string) (string, error) {
switch {
case strings.Contains(model, "amazon.titan-embed-text"):
return "titan", nil
case strings.Contains(model, "cohere.embed"):
return "cohere", nil
default:
return "", fmt.Errorf("unsupported embedding model: %s", model)
}
}
// ToBifrostEmbeddingResponse converts a BedrockCohereEmbeddingResponse to Bifrost format.
// Bedrock returns embeddings as a raw [][]float32 when response_type is "embeddings_floats"
// (the default, when no embedding_types are requested), and as a typed object when
// response_type is "embeddings_by_type".
func (r *BedrockCohereEmbeddingResponse) ToBifrostEmbeddingResponse() (*schemas.BifrostEmbeddingResponse, error) {
if r == nil {
return nil, fmt.Errorf("nil Bedrock Cohere embedding response")
}
bifrostResponse := &schemas.BifrostEmbeddingResponse{Object: "list"}
switch r.ResponseType {
case "embeddings_by_type":
// Object form: {"float": [[...]], "int8": [[...]], "uint8": [[...]], "binary": [[...]], "ubinary": [[...]], "base64": [...]}
var typed struct {
Float [][]float32 `json:"float"`
Base64 []string `json:"base64"`
Int8 [][]int8 `json:"int8"`
Uint8 [][]int32 `json:"uint8"` // int32 avoids []byte→base64 JSON issue
Binary [][]int8 `json:"binary"`
Ubinary [][]int32 `json:"ubinary"` // int32 avoids []byte→base64 JSON issue
}
if err := json.Unmarshal(r.Embeddings, &typed); err != nil {
return nil, fmt.Errorf("error parsing embeddings_by_type: %w", err)
}
if typed.Float != nil {
for i, emb := range typed.Float {
float64Emb := make([]float64, len(emb))
for j, v := range emb {
float64Emb[j] = float64(v)
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb},
})
}
}
if typed.Base64 != nil {
for i, emb := range typed.Base64 {
e := emb
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingStr: &e},
})
}
}
for i, emb := range typed.Int8 {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb},
})
}
for i, emb := range typed.Binary {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingInt8Array: emb},
})
}
for i, emb := range typed.Uint8 {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb},
})
}
for i, emb := range typed.Ubinary {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingInt32Array: emb},
})
}
default:
// Default / "embeddings_floats": raw array form [[...], [...]]
var floats [][]float32
if err := json.Unmarshal(r.Embeddings, &floats); err != nil {
return nil, fmt.Errorf("error parsing embeddings_floats: %w", err)
}
for i, emb := range floats {
float64Emb := make([]float64, len(emb))
for j, v := range emb {
float64Emb[j] = float64(v)
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.EmbeddingData{
Object: "embedding",
Index: i,
Embedding: schemas.EmbeddingStruct{EmbeddingArray: float64Emb},
})
}
}
return bifrostResponse, nil
}