first commit
This commit is contained in:
247
core/providers/gemini/embedding.go
Normal file
247
core/providers/gemini/embedding.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToGeminiEmbeddingRequest converts a BifrostRequest with embedding input to Gemini's batch embedding request format
|
||||
// GeminiGenerationRequest contains requests array for batch embed content endpoint
|
||||
func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *GeminiBatchEmbeddingRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
|
||||
return nil
|
||||
}
|
||||
|
||||
embeddingInput := bifrostReq.Input
|
||||
|
||||
// Collect all texts to embed
|
||||
var texts []string
|
||||
if embeddingInput.Text != nil {
|
||||
texts = append(texts, *embeddingInput.Text)
|
||||
}
|
||||
if len(embeddingInput.Texts) > 0 {
|
||||
texts = append(texts, embeddingInput.Texts...)
|
||||
}
|
||||
|
||||
if len(texts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create batch embedding request with one request per text
|
||||
batchRequest := &GeminiBatchEmbeddingRequest{
|
||||
Requests: make([]GeminiEmbeddingRequest, len(texts)),
|
||||
}
|
||||
if bifrostReq.Params != nil {
|
||||
batchRequest.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
// Create individual embedding requests for each text
|
||||
for i, text := range texts {
|
||||
embeddingReq := GeminiEmbeddingRequest{
|
||||
Model: "models/" + bifrostReq.Model,
|
||||
Content: &Content{
|
||||
Parts: []*Part{
|
||||
{
|
||||
Text: text,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add parameters if available
|
||||
if bifrostReq.Params != nil {
|
||||
if bifrostReq.Params.Dimensions != nil {
|
||||
embeddingReq.OutputDimensionality = bifrostReq.Params.Dimensions
|
||||
}
|
||||
|
||||
// Handle extra parameters
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
if taskType, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["taskType"]); ok {
|
||||
delete(batchRequest.ExtraParams, "taskType")
|
||||
embeddingReq.TaskType = taskType
|
||||
}
|
||||
if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok {
|
||||
delete(batchRequest.ExtraParams, "title")
|
||||
embeddingReq.Title = title
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
batchRequest.Requests[i] = embeddingReq
|
||||
}
|
||||
|
||||
return batchRequest
|
||||
}
|
||||
|
||||
// ToGeminiEmbeddingResponse converts a BifrostResponse with embedding data to Gemini's embedding response format
|
||||
func ToGeminiEmbeddingResponse(bifrostResp *schemas.BifrostEmbeddingResponse) *GeminiEmbeddingResponse {
|
||||
if bifrostResp == nil || len(bifrostResp.Data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResp := &GeminiEmbeddingResponse{
|
||||
Embeddings: make([]GeminiEmbedding, len(bifrostResp.Data)),
|
||||
}
|
||||
|
||||
// Convert each embedding from Bifrost format to Gemini format
|
||||
for i, embedding := range bifrostResp.Data {
|
||||
var values []float64
|
||||
|
||||
// Extract embedding values from BifrostEmbeddingResponse
|
||||
if embedding.Embedding.EmbeddingArray != nil {
|
||||
values = append([]float64(nil), embedding.Embedding.EmbeddingArray...)
|
||||
} else if len(embedding.Embedding.Embedding2DArray) > 0 {
|
||||
// If it's a 2D array, take the first array
|
||||
values = append([]float64(nil), embedding.Embedding.Embedding2DArray[0]...)
|
||||
}
|
||||
|
||||
geminiEmbedding := GeminiEmbedding{
|
||||
Values: values,
|
||||
}
|
||||
|
||||
// Add statistics if available (token count from usage metadata)
|
||||
if bifrostResp.Usage != nil {
|
||||
geminiEmbedding.Statistics = &ContentEmbeddingStatistics{
|
||||
TokenCount: int32(bifrostResp.Usage.PromptTokens),
|
||||
}
|
||||
}
|
||||
|
||||
geminiResp.Embeddings[i] = geminiEmbedding
|
||||
}
|
||||
|
||||
// Set metadata if available (for Vertex API compatibility)
|
||||
if bifrostResp.Usage != nil {
|
||||
geminiResp.Metadata = &EmbedContentMetadata{
|
||||
BillableCharacterCount: int32(bifrostResp.Usage.PromptTokens),
|
||||
}
|
||||
}
|
||||
|
||||
return geminiResp
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingResponse converts a Gemini embedding response to BifrostEmbeddingResponse format
|
||||
func ToBifrostEmbeddingResponse(geminiResp *GeminiEmbeddingResponse, model string) *schemas.BifrostEmbeddingResponse {
|
||||
if geminiResp == nil || len(geminiResp.Embeddings) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResp := &schemas.BifrostEmbeddingResponse{
|
||||
Data: make([]schemas.EmbeddingData, len(geminiResp.Embeddings)),
|
||||
Model: model,
|
||||
Object: "list",
|
||||
}
|
||||
|
||||
// Convert each embedding from Gemini format to Bifrost format
|
||||
for i, geminiEmbedding := range geminiResp.Embeddings {
|
||||
embeddingData := schemas.EmbeddingData{
|
||||
Index: i,
|
||||
Object: "embedding",
|
||||
Embedding: schemas.EmbeddingStruct{
|
||||
EmbeddingArray: geminiEmbedding.Values,
|
||||
},
|
||||
}
|
||||
|
||||
bifrostResp.Data[i] = embeddingData
|
||||
}
|
||||
|
||||
// Convert usage metadata if available
|
||||
if geminiResp.Metadata != nil || (len(geminiResp.Embeddings) > 0 && geminiResp.Embeddings[0].Statistics != nil) {
|
||||
bifrostResp.Usage = &schemas.BifrostLLMUsage{}
|
||||
|
||||
// Use statistics from the first embedding if available
|
||||
if geminiResp.Embeddings[0].Statistics != nil {
|
||||
bifrostResp.Usage.PromptTokens = int(geminiResp.Embeddings[0].Statistics.TokenCount)
|
||||
} else if geminiResp.Metadata != nil {
|
||||
// Fall back to metadata if statistics are not available
|
||||
bifrostResp.Usage.PromptTokens = int(geminiResp.Metadata.BillableCharacterCount)
|
||||
}
|
||||
|
||||
// Set total tokens same as prompt tokens for embeddings
|
||||
bifrostResp.Usage.TotalTokens = bifrostResp.Usage.PromptTokens
|
||||
}
|
||||
|
||||
return bifrostResp
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingRequest converts a GeminiGenerationRequest to BifrostEmbeddingRequest format
|
||||
func (request *GeminiGenerationRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest {
|
||||
if request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
|
||||
|
||||
// Create the embedding request
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
|
||||
}
|
||||
|
||||
// SDK batch embedding request contains multiple embedding requests with same parameters but different text fields.
|
||||
if len(request.Requests) > 0 {
|
||||
var texts []string
|
||||
for _, req := range request.Requests {
|
||||
if req.Content != nil && len(req.Content.Parts) > 0 {
|
||||
for _, part := range req.Content.Parts {
|
||||
if part != nil && part.Text != "" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(texts) > 0 {
|
||||
bifrostReq.Input = &schemas.EmbeddingInput{}
|
||||
if len(texts) == 1 {
|
||||
bifrostReq.Input.Text = &texts[0]
|
||||
} else {
|
||||
bifrostReq.Input.Texts = texts
|
||||
}
|
||||
}
|
||||
|
||||
embeddingRequest := request.Requests[0]
|
||||
|
||||
// Convert parameters
|
||||
if embeddingRequest.OutputDimensionality != nil || embeddingRequest.TaskType != nil || embeddingRequest.Title != nil {
|
||||
bifrostReq.Params = &schemas.EmbeddingParameters{}
|
||||
|
||||
if embeddingRequest.OutputDimensionality != nil {
|
||||
bifrostReq.Params.Dimensions = embeddingRequest.OutputDimensionality
|
||||
}
|
||||
|
||||
// Handle extra parameters
|
||||
if embeddingRequest.TaskType != nil || embeddingRequest.Title != nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
if embeddingRequest.TaskType != nil {
|
||||
bifrostReq.Params.ExtraParams["taskType"] = embeddingRequest.TaskType
|
||||
}
|
||||
if embeddingRequest.Title != nil {
|
||||
bifrostReq.Params.ExtraParams["title"] = embeddingRequest.Title
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generation-style requests (e.g., non-Imagen :predict) carry text in contents[].parts[].
|
||||
// If no SDK requests[] were provided, derive embedding input from contents.
|
||||
if bifrostReq.Input == nil {
|
||||
var texts []string
|
||||
for _, content := range request.Contents {
|
||||
for _, part := range content.Parts {
|
||||
if part != nil && part.Text != "" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(texts) > 0 {
|
||||
bifrostReq.Input = &schemas.EmbeddingInput{}
|
||||
if len(texts) == 1 {
|
||||
bifrostReq.Input.Text = &texts[0]
|
||||
} else {
|
||||
bifrostReq.Input.Texts = texts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
Reference in New Issue
Block a user