188 lines
5.0 KiB
Go
188 lines
5.0 KiB
Go
package schemas
|
|
|
|
import (
|
|
"fmt"
|
|
)
|
|
|
|
type BifrostEmbeddingRequest struct {
|
|
Provider ModelProvider `json:"provider"`
|
|
Model string `json:"model"`
|
|
Input *EmbeddingInput `json:"input,omitempty"`
|
|
Params *EmbeddingParameters `json:"params,omitempty"`
|
|
Fallbacks []Fallback `json:"fallbacks,omitempty"`
|
|
RawRequestBody []byte `json:"-"` // set bifrost-use-raw-request-body to true in ctx to use the raw request body. Bifrost will directly send this to the downstream provider.
|
|
}
|
|
|
|
func (r *BifrostEmbeddingRequest) GetRawRequestBody() []byte {
|
|
return r.RawRequestBody
|
|
}
|
|
|
|
type BifrostEmbeddingResponse struct {
|
|
Data []EmbeddingData `json:"data"` // Maps to "data" field in provider responses (e.g., OpenAI embedding format)
|
|
Model string `json:"model"`
|
|
Object string `json:"object"` // "list"
|
|
Usage *BifrostLLMUsage `json:"usage"`
|
|
ExtraFields BifrostResponseExtraFields `json:"extra_fields"`
|
|
}
|
|
|
|
// EmbeddingInput represents the input for an embedding request.
|
|
type EmbeddingInput struct {
|
|
Text *string
|
|
Texts []string
|
|
Embedding []int
|
|
Embeddings [][]int
|
|
}
|
|
|
|
func (e *EmbeddingInput) MarshalJSON() ([]byte, error) {
|
|
// enforce one-of
|
|
set := 0
|
|
if e.Text != nil {
|
|
set++
|
|
}
|
|
if e.Texts != nil {
|
|
set++
|
|
}
|
|
if e.Embedding != nil {
|
|
set++
|
|
}
|
|
if e.Embeddings != nil {
|
|
set++
|
|
}
|
|
if set == 0 {
|
|
return nil, fmt.Errorf("embedding input is empty")
|
|
}
|
|
if set > 1 {
|
|
return nil, fmt.Errorf("embedding input must set exactly one of: text, texts, embedding, embeddings")
|
|
}
|
|
|
|
if e.Text != nil {
|
|
return MarshalSorted(*e.Text)
|
|
}
|
|
if e.Texts != nil {
|
|
return MarshalSorted(e.Texts)
|
|
}
|
|
if e.Embedding != nil {
|
|
return MarshalSorted(e.Embedding)
|
|
}
|
|
if e.Embeddings != nil {
|
|
return MarshalSorted(e.Embeddings)
|
|
}
|
|
|
|
return nil, fmt.Errorf("invalid embedding input")
|
|
}
|
|
|
|
func (e *EmbeddingInput) UnmarshalJSON(data []byte) error {
|
|
e.Text = nil
|
|
e.Texts = nil
|
|
e.Embedding = nil
|
|
e.Embeddings = nil
|
|
// Try string
|
|
var s string
|
|
if err := Unmarshal(data, &s); err == nil {
|
|
e.Text = &s
|
|
return nil
|
|
}
|
|
// Try []string
|
|
var ss []string
|
|
if err := Unmarshal(data, &ss); err == nil {
|
|
e.Texts = ss
|
|
return nil
|
|
}
|
|
// Try []int
|
|
var i []int
|
|
if err := Unmarshal(data, &i); err == nil {
|
|
e.Embedding = i
|
|
return nil
|
|
}
|
|
// Try [][]int
|
|
var i2 [][]int
|
|
if err := Unmarshal(data, &i2); err == nil {
|
|
e.Embeddings = i2
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("unsupported embedding input shape")
|
|
}
|
|
|
|
type EmbeddingParameters struct {
|
|
EncodingFormat *string `json:"encoding_format,omitempty"` // Format for embedding output (e.g., "float", "base64")
|
|
Dimensions *int `json:"dimensions,omitempty"` // Number of dimensions for embedding output
|
|
|
|
// Dynamic parameters that can be provider-specific, they are directly
|
|
// added to the request as is.
|
|
ExtraParams map[string]interface{} `json:"-"`
|
|
}
|
|
|
|
type EmbeddingData struct {
|
|
Index int `json:"index"`
|
|
Object string `json:"object"` // "embedding"
|
|
Embedding EmbeddingStruct `json:"embedding"` // can be string, []float64, [][]float64, []int8, or []int32
|
|
}
|
|
|
|
type EmbeddingStruct struct {
|
|
// Embedding responses preserve provider precision in normalized API output.
|
|
EmbeddingStr *string
|
|
EmbeddingArray []float64
|
|
Embedding2DArray [][]float64
|
|
EmbeddingInt8Array []int8 // for int8 / binary formats
|
|
EmbeddingInt32Array []int32 // for uint8 / ubinary formats
|
|
}
|
|
|
|
func (be EmbeddingStruct) MarshalJSON() ([]byte, error) {
|
|
if be.EmbeddingStr != nil {
|
|
return MarshalSorted(be.EmbeddingStr)
|
|
}
|
|
if be.EmbeddingArray != nil {
|
|
return MarshalSorted(be.EmbeddingArray)
|
|
}
|
|
if be.Embedding2DArray != nil {
|
|
return MarshalSorted(be.Embedding2DArray)
|
|
}
|
|
if be.EmbeddingInt8Array != nil {
|
|
return Marshal(be.EmbeddingInt8Array)
|
|
}
|
|
if be.EmbeddingInt32Array != nil {
|
|
return Marshal(be.EmbeddingInt32Array)
|
|
}
|
|
return nil, fmt.Errorf("no embedding found")
|
|
}
|
|
|
|
func (be *EmbeddingStruct) UnmarshalJSON(data []byte) error {
|
|
// First, try to unmarshal as a direct string
|
|
var stringContent string
|
|
if err := Unmarshal(data, &stringContent); err == nil {
|
|
be.EmbeddingStr = &stringContent
|
|
return nil
|
|
}
|
|
|
|
// Try to unmarshal as a direct array of float64
|
|
var arrayContent []float64
|
|
if err := Unmarshal(data, &arrayContent); err == nil {
|
|
be.EmbeddingArray = arrayContent
|
|
return nil
|
|
}
|
|
|
|
// Try to unmarshal as a direct 2D array of float64
|
|
var arrayContent2D [][]float64
|
|
if err := Unmarshal(data, &arrayContent2D); err == nil {
|
|
be.Embedding2DArray = arrayContent2D
|
|
return nil
|
|
}
|
|
|
|
// Try to unmarshal as a direct array of int8
|
|
var int8Content []int8
|
|
if err := Unmarshal(data, &int8Content); err == nil {
|
|
be.EmbeddingInt8Array = int8Content
|
|
return nil
|
|
}
|
|
|
|
// Try to unmarshal as a direct array of int32
|
|
var int32Content []int32
|
|
if err := Unmarshal(data, &int32Content); err == nil {
|
|
be.EmbeddingInt32Array = int32Content
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("embedding field is neither a string, []float64, [][]float64, []int8, nor []int32")
|
|
}
|