248 lines
7.5 KiB
Go
248 lines
7.5 KiB
Go
package gemini
|
|
|
|
import (
|
|
"github.com/maximhq/bifrost/core/providers/utils"
|
|
"github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
// ToGeminiEmbeddingRequest converts a BifrostRequest with embedding input to Gemini's batch embedding request format
|
|
// GeminiGenerationRequest contains requests array for batch embed content endpoint
|
|
func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *GeminiBatchEmbeddingRequest {
|
|
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
|
|
return nil
|
|
}
|
|
|
|
embeddingInput := bifrostReq.Input
|
|
|
|
// Collect all texts to embed
|
|
var texts []string
|
|
if embeddingInput.Text != nil {
|
|
texts = append(texts, *embeddingInput.Text)
|
|
}
|
|
if len(embeddingInput.Texts) > 0 {
|
|
texts = append(texts, embeddingInput.Texts...)
|
|
}
|
|
|
|
if len(texts) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Create batch embedding request with one request per text
|
|
batchRequest := &GeminiBatchEmbeddingRequest{
|
|
Requests: make([]GeminiEmbeddingRequest, len(texts)),
|
|
}
|
|
if bifrostReq.Params != nil {
|
|
batchRequest.ExtraParams = bifrostReq.Params.ExtraParams
|
|
}
|
|
|
|
// Create individual embedding requests for each text
|
|
for i, text := range texts {
|
|
embeddingReq := GeminiEmbeddingRequest{
|
|
Model: "models/" + bifrostReq.Model,
|
|
Content: &Content{
|
|
Parts: []*Part{
|
|
{
|
|
Text: text,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
// Add parameters if available
|
|
if bifrostReq.Params != nil {
|
|
if bifrostReq.Params.Dimensions != nil {
|
|
embeddingReq.OutputDimensionality = bifrostReq.Params.Dimensions
|
|
}
|
|
|
|
// Handle extra parameters
|
|
if bifrostReq.Params.ExtraParams != nil {
|
|
if taskType, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["taskType"]); ok {
|
|
delete(batchRequest.ExtraParams, "taskType")
|
|
embeddingReq.TaskType = taskType
|
|
}
|
|
if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok {
|
|
delete(batchRequest.ExtraParams, "title")
|
|
embeddingReq.Title = title
|
|
}
|
|
}
|
|
}
|
|
|
|
batchRequest.Requests[i] = embeddingReq
|
|
}
|
|
|
|
return batchRequest
|
|
}
|
|
|
|
// ToGeminiEmbeddingResponse converts a BifrostResponse with embedding data to Gemini's embedding response format
|
|
func ToGeminiEmbeddingResponse(bifrostResp *schemas.BifrostEmbeddingResponse) *GeminiEmbeddingResponse {
|
|
if bifrostResp == nil || len(bifrostResp.Data) == 0 {
|
|
return nil
|
|
}
|
|
|
|
geminiResp := &GeminiEmbeddingResponse{
|
|
Embeddings: make([]GeminiEmbedding, len(bifrostResp.Data)),
|
|
}
|
|
|
|
// Convert each embedding from Bifrost format to Gemini format
|
|
for i, embedding := range bifrostResp.Data {
|
|
var values []float64
|
|
|
|
// Extract embedding values from BifrostEmbeddingResponse
|
|
if embedding.Embedding.EmbeddingArray != nil {
|
|
values = append([]float64(nil), embedding.Embedding.EmbeddingArray...)
|
|
} else if len(embedding.Embedding.Embedding2DArray) > 0 {
|
|
// If it's a 2D array, take the first array
|
|
values = append([]float64(nil), embedding.Embedding.Embedding2DArray[0]...)
|
|
}
|
|
|
|
geminiEmbedding := GeminiEmbedding{
|
|
Values: values,
|
|
}
|
|
|
|
// Add statistics if available (token count from usage metadata)
|
|
if bifrostResp.Usage != nil {
|
|
geminiEmbedding.Statistics = &ContentEmbeddingStatistics{
|
|
TokenCount: int32(bifrostResp.Usage.PromptTokens),
|
|
}
|
|
}
|
|
|
|
geminiResp.Embeddings[i] = geminiEmbedding
|
|
}
|
|
|
|
// Set metadata if available (for Vertex API compatibility)
|
|
if bifrostResp.Usage != nil {
|
|
geminiResp.Metadata = &EmbedContentMetadata{
|
|
BillableCharacterCount: int32(bifrostResp.Usage.PromptTokens),
|
|
}
|
|
}
|
|
|
|
return geminiResp
|
|
}
|
|
|
|
// ToBifrostEmbeddingResponse converts a Gemini embedding response to BifrostEmbeddingResponse format
|
|
func ToBifrostEmbeddingResponse(geminiResp *GeminiEmbeddingResponse, model string) *schemas.BifrostEmbeddingResponse {
|
|
if geminiResp == nil || len(geminiResp.Embeddings) == 0 {
|
|
return nil
|
|
}
|
|
|
|
bifrostResp := &schemas.BifrostEmbeddingResponse{
|
|
Data: make([]schemas.EmbeddingData, len(geminiResp.Embeddings)),
|
|
Model: model,
|
|
Object: "list",
|
|
}
|
|
|
|
// Convert each embedding from Gemini format to Bifrost format
|
|
for i, geminiEmbedding := range geminiResp.Embeddings {
|
|
embeddingData := schemas.EmbeddingData{
|
|
Index: i,
|
|
Object: "embedding",
|
|
Embedding: schemas.EmbeddingStruct{
|
|
EmbeddingArray: geminiEmbedding.Values,
|
|
},
|
|
}
|
|
|
|
bifrostResp.Data[i] = embeddingData
|
|
}
|
|
|
|
// Convert usage metadata if available
|
|
if geminiResp.Metadata != nil || (len(geminiResp.Embeddings) > 0 && geminiResp.Embeddings[0].Statistics != nil) {
|
|
bifrostResp.Usage = &schemas.BifrostLLMUsage{}
|
|
|
|
// Use statistics from the first embedding if available
|
|
if geminiResp.Embeddings[0].Statistics != nil {
|
|
bifrostResp.Usage.PromptTokens = int(geminiResp.Embeddings[0].Statistics.TokenCount)
|
|
} else if geminiResp.Metadata != nil {
|
|
// Fall back to metadata if statistics are not available
|
|
bifrostResp.Usage.PromptTokens = int(geminiResp.Metadata.BillableCharacterCount)
|
|
}
|
|
|
|
// Set total tokens same as prompt tokens for embeddings
|
|
bifrostResp.Usage.TotalTokens = bifrostResp.Usage.PromptTokens
|
|
}
|
|
|
|
return bifrostResp
|
|
}
|
|
|
|
// ToBifrostEmbeddingRequest converts a GeminiGenerationRequest to BifrostEmbeddingRequest format
|
|
func (request *GeminiGenerationRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest {
|
|
if request == nil {
|
|
return nil
|
|
}
|
|
|
|
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
|
|
|
|
// Create the embedding request
|
|
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
|
Provider: provider,
|
|
Model: model,
|
|
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
|
|
}
|
|
|
|
// SDK batch embedding request contains multiple embedding requests with same parameters but different text fields.
|
|
if len(request.Requests) > 0 {
|
|
var texts []string
|
|
for _, req := range request.Requests {
|
|
if req.Content != nil && len(req.Content.Parts) > 0 {
|
|
for _, part := range req.Content.Parts {
|
|
if part != nil && part.Text != "" {
|
|
texts = append(texts, part.Text)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if len(texts) > 0 {
|
|
bifrostReq.Input = &schemas.EmbeddingInput{}
|
|
if len(texts) == 1 {
|
|
bifrostReq.Input.Text = &texts[0]
|
|
} else {
|
|
bifrostReq.Input.Texts = texts
|
|
}
|
|
}
|
|
|
|
embeddingRequest := request.Requests[0]
|
|
|
|
// Convert parameters
|
|
if embeddingRequest.OutputDimensionality != nil || embeddingRequest.TaskType != nil || embeddingRequest.Title != nil {
|
|
bifrostReq.Params = &schemas.EmbeddingParameters{}
|
|
|
|
if embeddingRequest.OutputDimensionality != nil {
|
|
bifrostReq.Params.Dimensions = embeddingRequest.OutputDimensionality
|
|
}
|
|
|
|
// Handle extra parameters
|
|
if embeddingRequest.TaskType != nil || embeddingRequest.Title != nil {
|
|
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
|
if embeddingRequest.TaskType != nil {
|
|
bifrostReq.Params.ExtraParams["taskType"] = embeddingRequest.TaskType
|
|
}
|
|
if embeddingRequest.Title != nil {
|
|
bifrostReq.Params.ExtraParams["title"] = embeddingRequest.Title
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Generation-style requests (e.g., non-Imagen :predict) carry text in contents[].parts[].
|
|
// If no SDK requests[] were provided, derive embedding input from contents.
|
|
if bifrostReq.Input == nil {
|
|
var texts []string
|
|
for _, content := range request.Contents {
|
|
for _, part := range content.Parts {
|
|
if part != nil && part.Text != "" {
|
|
texts = append(texts, part.Text)
|
|
}
|
|
}
|
|
}
|
|
if len(texts) > 0 {
|
|
bifrostReq.Input = &schemas.EmbeddingInput{}
|
|
if len(texts) == 1 {
|
|
bifrostReq.Input.Text = &texts[0]
|
|
} else {
|
|
bifrostReq.Input.Texts = texts
|
|
}
|
|
}
|
|
}
|
|
|
|
return bifrostReq
|
|
}
|