Files
bifrost/core/schemas/embedding.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

188 lines
5.0 KiB
Go

package schemas
import (
"fmt"
)
type BifrostEmbeddingRequest struct {
Provider ModelProvider `json:"provider"`
Model string `json:"model"`
Input *EmbeddingInput `json:"input,omitempty"`
Params *EmbeddingParameters `json:"params,omitempty"`
Fallbacks []Fallback `json:"fallbacks,omitempty"`
RawRequestBody []byte `json:"-"` // set bifrost-use-raw-request-body to true in ctx to use the raw request body. Bifrost will directly send this to the downstream provider.
}
func (r *BifrostEmbeddingRequest) GetRawRequestBody() []byte {
return r.RawRequestBody
}
type BifrostEmbeddingResponse struct {
Data []EmbeddingData `json:"data"` // Maps to "data" field in provider responses (e.g., OpenAI embedding format)
Model string `json:"model"`
Object string `json:"object"` // "list"
Usage *BifrostLLMUsage `json:"usage"`
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
}
// EmbeddingInput represents the input for an embedding request.
type EmbeddingInput struct {
Text *string
Texts []string
Embedding []int
Embeddings [][]int
}
func (e *EmbeddingInput) MarshalJSON() ([]byte, error) {
// enforce one-of
set := 0
if e.Text != nil {
set++
}
if e.Texts != nil {
set++
}
if e.Embedding != nil {
set++
}
if e.Embeddings != nil {
set++
}
if set == 0 {
return nil, fmt.Errorf("embedding input is empty")
}
if set > 1 {
return nil, fmt.Errorf("embedding input must set exactly one of: text, texts, embedding, embeddings")
}
if e.Text != nil {
return MarshalSorted(*e.Text)
}
if e.Texts != nil {
return MarshalSorted(e.Texts)
}
if e.Embedding != nil {
return MarshalSorted(e.Embedding)
}
if e.Embeddings != nil {
return MarshalSorted(e.Embeddings)
}
return nil, fmt.Errorf("invalid embedding input")
}
func (e *EmbeddingInput) UnmarshalJSON(data []byte) error {
e.Text = nil
e.Texts = nil
e.Embedding = nil
e.Embeddings = nil
// Try string
var s string
if err := Unmarshal(data, &s); err == nil {
e.Text = &s
return nil
}
// Try []string
var ss []string
if err := Unmarshal(data, &ss); err == nil {
e.Texts = ss
return nil
}
// Try []int
var i []int
if err := Unmarshal(data, &i); err == nil {
e.Embedding = i
return nil
}
// Try [][]int
var i2 [][]int
if err := Unmarshal(data, &i2); err == nil {
e.Embeddings = i2
return nil
}
return fmt.Errorf("unsupported embedding input shape")
}
type EmbeddingParameters struct {
EncodingFormat *string `json:"encoding_format,omitempty"` // Format for embedding output (e.g., "float", "base64")
Dimensions *int `json:"dimensions,omitempty"` // Number of dimensions for embedding output
// Dynamic parameters that can be provider-specific, they are directly
// added to the request as is.
ExtraParams map[string]interface{} `json:"-"`
}
type EmbeddingData struct {
Index int `json:"index"`
Object string `json:"object"` // "embedding"
Embedding EmbeddingStruct `json:"embedding"` // can be string, []float64, [][]float64, []int8, or []int32
}
type EmbeddingStruct struct {
// Embedding responses preserve provider precision in normalized API output.
EmbeddingStr *string
EmbeddingArray []float64
Embedding2DArray [][]float64
EmbeddingInt8Array []int8 // for int8 / binary formats
EmbeddingInt32Array []int32 // for uint8 / ubinary formats
}
func (be EmbeddingStruct) MarshalJSON() ([]byte, error) {
if be.EmbeddingStr != nil {
return MarshalSorted(be.EmbeddingStr)
}
if be.EmbeddingArray != nil {
return MarshalSorted(be.EmbeddingArray)
}
if be.Embedding2DArray != nil {
return MarshalSorted(be.Embedding2DArray)
}
if be.EmbeddingInt8Array != nil {
return Marshal(be.EmbeddingInt8Array)
}
if be.EmbeddingInt32Array != nil {
return Marshal(be.EmbeddingInt32Array)
}
return nil, fmt.Errorf("no embedding found")
}
func (be *EmbeddingStruct) UnmarshalJSON(data []byte) error {
// First, try to unmarshal as a direct string
var stringContent string
if err := Unmarshal(data, &stringContent); err == nil {
be.EmbeddingStr = &stringContent
return nil
}
// Try to unmarshal as a direct array of float64
var arrayContent []float64
if err := Unmarshal(data, &arrayContent); err == nil {
be.EmbeddingArray = arrayContent
return nil
}
// Try to unmarshal as a direct 2D array of float64
var arrayContent2D [][]float64
if err := Unmarshal(data, &arrayContent2D); err == nil {
be.Embedding2DArray = arrayContent2D
return nil
}
// Try to unmarshal as a direct array of int8
var int8Content []int8
if err := Unmarshal(data, &int8Content); err == nil {
be.EmbeddingInt8Array = int8Content
return nil
}
// Try to unmarshal as a direct array of int32
var int32Content []int32
if err := Unmarshal(data, &int32Content); err == nil {
be.EmbeddingInt32Array = int32Content
return nil
}
return fmt.Errorf("embedding field is neither a string, []float64, [][]float64, []int8, nor []int32")
}