first commit
This commit is contained in:
370
core/providers/gemini/batch.go
Normal file
370
core/providers/gemini/batch.go
Normal file
@@ -0,0 +1,370 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ToBifrostBatchStatus converts Gemini batch job state to Bifrost status.
|
||||
func ToBifrostBatchStatus(geminiState string) schemas.BatchStatus {
|
||||
switch geminiState {
|
||||
case GeminiBatchStatePending, GeminiBatchStateRunning:
|
||||
return schemas.BatchStatusInProgress
|
||||
case GeminiBatchStateSucceeded:
|
||||
return schemas.BatchStatusCompleted
|
||||
case GeminiBatchStateFailed:
|
||||
return schemas.BatchStatusFailed
|
||||
case GeminiBatchStateCancelling:
|
||||
return schemas.BatchStatusCancelling
|
||||
case GeminiBatchStateCancelled:
|
||||
return schemas.BatchStatusCancelled
|
||||
case GeminiBatchStateExpired:
|
||||
return schemas.BatchStatusExpired
|
||||
default:
|
||||
return schemas.BatchStatus(geminiState)
|
||||
}
|
||||
}
|
||||
|
||||
// ToGeminiBatchStatus converts Bifrost batch status to Gemini batch job state.
|
||||
func ToGeminiBatchStatus(status schemas.BatchStatus) string {
|
||||
switch status {
|
||||
case schemas.BatchStatusValidating, schemas.BatchStatusInProgress:
|
||||
return GeminiBatchStateRunning
|
||||
case schemas.BatchStatusFinalizing:
|
||||
return GeminiBatchStateRunning
|
||||
case schemas.BatchStatusCompleted, schemas.BatchStatusEnded:
|
||||
return GeminiBatchStateSucceeded
|
||||
case schemas.BatchStatusFailed:
|
||||
return GeminiBatchStateFailed
|
||||
case schemas.BatchStatusCancelling:
|
||||
return GeminiBatchStateCancelling
|
||||
case schemas.BatchStatusCancelled:
|
||||
return GeminiBatchStateCancelled
|
||||
case schemas.BatchStatusExpired:
|
||||
return GeminiBatchStateExpired
|
||||
default:
|
||||
return GeminiBatchStateUnspecified
|
||||
}
|
||||
}
|
||||
|
||||
// ToGeminiBatchJobResponse converts Bifrost batch create response to Gemini batch job response format.
|
||||
func ToGeminiBatchJobResponse(resp *schemas.BifrostBatchCreateResponse) *GeminiBatchJobResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
succeededCount := resp.RequestCounts.Succeeded
|
||||
if succeededCount == 0 {
|
||||
succeededCount = resp.RequestCounts.Completed
|
||||
}
|
||||
|
||||
geminiResp := &GeminiBatchJobResponse{
|
||||
Name: resp.ID,
|
||||
Metadata: &GeminiBatchMetadata{
|
||||
Name: resp.ID,
|
||||
Type: "type.googleapis.com/google.ai.generativelanguage.v1beta.BatchPredictionJob",
|
||||
CreateTime: formatGeminiTimestamp(resp.CreatedAt),
|
||||
UpdateTime: formatGeminiTimestamp(resp.CreatedAt),
|
||||
State: ToGeminiBatchStatus(resp.Status),
|
||||
BatchStats: &GeminiBatchStats{
|
||||
RequestCount: resp.RequestCounts.Total,
|
||||
PendingRequestCount: max(0, resp.RequestCounts.Total-succeededCount-resp.RequestCounts.Failed),
|
||||
SuccessfulRequestCount: succeededCount,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if resp.OperationName != nil && *resp.OperationName != "" {
|
||||
geminiResp.Metadata.Name = *resp.OperationName
|
||||
geminiResp.Name = *resp.OperationName
|
||||
}
|
||||
|
||||
if resp.InputFileID != "" {
|
||||
geminiResp.Metadata.InputConfig = &GeminiBatchMetadataInputConfig{
|
||||
FileName: resp.InputFileID,
|
||||
}
|
||||
}
|
||||
|
||||
if resp.OutputFileID != nil && *resp.OutputFileID != "" {
|
||||
geminiResp.Dest = &GeminiBatchDest{
|
||||
FileName: *resp.OutputFileID,
|
||||
}
|
||||
geminiResp.Metadata.Output = &GeminiBatchMetadataOutputConfig{
|
||||
ResponsesFile: *resp.OutputFileID,
|
||||
}
|
||||
}
|
||||
|
||||
if resp.Status == schemas.BatchStatusCompleted ||
|
||||
resp.Status == schemas.BatchStatusEnded ||
|
||||
resp.Status == schemas.BatchStatusFailed ||
|
||||
resp.Status == schemas.BatchStatusExpired ||
|
||||
resp.Status == schemas.BatchStatusCancelled {
|
||||
geminiResp.Done = true
|
||||
}
|
||||
|
||||
return geminiResp
|
||||
}
|
||||
|
||||
// ToGeminiBatchRetrieveResponse converts a Bifrost batch retrieve response to Gemini batch job response format.
|
||||
func ToGeminiBatchRetrieveResponse(resp *schemas.BifrostBatchRetrieveResponse) *GeminiBatchJobResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
succeededCount := resp.RequestCounts.Succeeded
|
||||
if succeededCount == 0 {
|
||||
succeededCount = resp.RequestCounts.Completed
|
||||
}
|
||||
|
||||
pendingCount := resp.RequestCounts.Pending
|
||||
if pendingCount == 0 && resp.RequestCounts.Total > 0 {
|
||||
processedCount := resp.RequestCounts.Completed
|
||||
if processedCount == 0 {
|
||||
processedCount = succeededCount
|
||||
}
|
||||
pendingCount = resp.RequestCounts.Total - processedCount - resp.RequestCounts.Failed
|
||||
if pendingCount < 0 {
|
||||
pendingCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
geminiResp := &GeminiBatchJobResponse{
|
||||
Name: resp.ID,
|
||||
Metadata: &GeminiBatchMetadata{
|
||||
Name: resp.ID,
|
||||
Type: "type.googleapis.com/google.ai.generativelanguage.v1beta.BatchPredictionJob",
|
||||
CreateTime: formatGeminiTimestamp(resp.CreatedAt),
|
||||
UpdateTime: formatGeminiTimestamp(resp.CreatedAt),
|
||||
State: ToGeminiBatchStatus(resp.Status),
|
||||
BatchStats: &GeminiBatchStats{
|
||||
RequestCount: resp.RequestCounts.Total,
|
||||
PendingRequestCount: pendingCount,
|
||||
SuccessfulRequestCount: succeededCount,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if resp.OperationName != nil && *resp.OperationName != "" {
|
||||
geminiResp.Metadata.Name = *resp.OperationName
|
||||
geminiResp.Name = *resp.OperationName
|
||||
}
|
||||
|
||||
if resp.Done != nil {
|
||||
geminiResp.Done = *resp.Done
|
||||
} else {
|
||||
geminiResp.Done = resp.Status == schemas.BatchStatusCompleted ||
|
||||
resp.Status == schemas.BatchStatusEnded ||
|
||||
resp.Status == schemas.BatchStatusFailed ||
|
||||
resp.Status == schemas.BatchStatusExpired ||
|
||||
resp.Status == schemas.BatchStatusCancelled
|
||||
}
|
||||
|
||||
if resp.InputFileID != "" {
|
||||
geminiResp.Metadata.InputConfig = &GeminiBatchMetadataInputConfig{
|
||||
FileName: resp.InputFileID,
|
||||
}
|
||||
}
|
||||
|
||||
if resp.OutputFileID != nil && *resp.OutputFileID != "" {
|
||||
geminiResp.Dest = &GeminiBatchDest{
|
||||
FileName: *resp.OutputFileID,
|
||||
}
|
||||
geminiResp.Metadata.Output = &GeminiBatchMetadataOutputConfig{
|
||||
ResponsesFile: *resp.OutputFileID,
|
||||
}
|
||||
}
|
||||
|
||||
// Set end time from the most relevant terminal timestamp
|
||||
var endTime int64
|
||||
if resp.CompletedAt != nil {
|
||||
endTime = *resp.CompletedAt
|
||||
} else if resp.FailedAt != nil {
|
||||
endTime = *resp.FailedAt
|
||||
} else if resp.ExpiredAt != nil {
|
||||
endTime = *resp.ExpiredAt
|
||||
} else if resp.CancelledAt != nil {
|
||||
endTime = *resp.CancelledAt
|
||||
}
|
||||
if endTime > 0 {
|
||||
geminiResp.Metadata.EndTime = formatGeminiTimestamp(endTime)
|
||||
}
|
||||
|
||||
return geminiResp
|
||||
}
|
||||
|
||||
// ToGeminiBatchListResponse converts a Bifrost batch list response to Gemini format.
|
||||
func ToGeminiBatchListResponse(resp *schemas.BifrostBatchListResponse) *GeminiBatchListResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
operations := make([]GeminiBatchJobResponse, 0, len(resp.Data))
|
||||
for i := range resp.Data {
|
||||
if geminiResp := ToGeminiBatchRetrieveResponse(&resp.Data[i]); geminiResp != nil {
|
||||
operations = append(operations, *geminiResp)
|
||||
}
|
||||
}
|
||||
|
||||
geminiListResp := &GeminiBatchListResponse{
|
||||
Operations: operations,
|
||||
}
|
||||
|
||||
if resp.NextCursor != nil {
|
||||
geminiListResp.NextPageToken = *resp.NextCursor
|
||||
}
|
||||
|
||||
return geminiListResp
|
||||
}
|
||||
|
||||
// parseGeminiTimestamp converts Gemini RFC3339 timestamp to Unix timestamp.
|
||||
func parseGeminiTimestamp(timestamp string) int64 {
|
||||
if timestamp == "" {
|
||||
return 0
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, timestamp)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
// extractBatchIDFromName extracts the batch ID from the full resource name.
|
||||
// e.g., "batches/abc123" -> "abc123"
|
||||
func extractBatchIDFromName(name string) string {
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
parts := strings.Split(name, "/")
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
|
||||
// downloadBatchResultsFile downloads and parses a batch results file from Gemini.
|
||||
// Returns the parsed result items from the JSONL file and any parse errors encountered.
|
||||
func (provider *GeminiProvider) downloadBatchResultsFile(ctx context.Context, key schemas.Key, fileName string) ([]schemas.BatchResultItem, []schemas.BatchError, *schemas.BifrostError) {
|
||||
// Create request to download the file
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// Build download URL - use the download endpoint with alt=media
|
||||
// The base URL is like https://generativelanguage.googleapis.com/v1beta
|
||||
// We need to change it to https://generativelanguage.googleapis.com/download/v1beta
|
||||
baseURL := strings.Replace(provider.networkConfig.BaseURL, "/v1beta", "/download/v1beta", 1)
|
||||
|
||||
// Ensure fileName has proper format
|
||||
fileID := fileName
|
||||
if !strings.HasPrefix(fileID, "files/") {
|
||||
fileID = "files/" + fileID
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s:download?alt=media", baseURL, fileID)
|
||||
|
||||
provider.logger.Debug("gemini batch results file download url: " + url)
|
||||
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
||||
req.SetRequestURI(url)
|
||||
req.Header.SetMethod(http.MethodGet)
|
||||
if key.Value.GetValue() != "" {
|
||||
req.Header.Set("x-goog-api-key", key.Value.GetValue())
|
||||
}
|
||||
|
||||
// Make request
|
||||
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
||||
defer wait()
|
||||
if bifrostErr != nil {
|
||||
return nil, nil, bifrostErr
|
||||
}
|
||||
|
||||
// Handle error response
|
||||
if resp.StatusCode() != fasthttp.StatusOK {
|
||||
return nil, nil, parseGeminiError(resp)
|
||||
}
|
||||
|
||||
body, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
}
|
||||
|
||||
// Parse JSONL content - each line is a separate JSON object
|
||||
// Use streaming parser to avoid string conversion and collect parse errors
|
||||
results := make([]schemas.BatchResultItem, 0)
|
||||
|
||||
parseResult := providerUtils.ParseJSONL(body, func(line []byte) error {
|
||||
var resultLine GeminiBatchFileResultLine
|
||||
if err := sonic.Unmarshal(line, &resultLine); err != nil {
|
||||
provider.logger.Warn("gemini batch results file parse error: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
customID := resultLine.Key
|
||||
if customID == "" {
|
||||
customID = fmt.Sprintf("request-%d", len(results))
|
||||
}
|
||||
|
||||
resultItem := schemas.BatchResultItem{
|
||||
CustomID: customID,
|
||||
}
|
||||
|
||||
if resultLine.Error != nil {
|
||||
resultItem.Error = &schemas.BatchResultError{
|
||||
Code: fmt.Sprintf("%d", resultLine.Error.Code),
|
||||
Message: resultLine.Error.Message,
|
||||
}
|
||||
} else if resultLine.Response != nil {
|
||||
// Convert the response to a map for the Body field
|
||||
respBody := make(map[string]interface{})
|
||||
if len(resultLine.Response.Candidates) > 0 {
|
||||
candidate := resultLine.Response.Candidates[0]
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
var textParts []string
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
textParts = append(textParts, part.Text)
|
||||
}
|
||||
}
|
||||
if len(textParts) > 0 {
|
||||
respBody["text"] = strings.Join(textParts, "")
|
||||
}
|
||||
}
|
||||
respBody["finish_reason"] = string(candidate.FinishReason)
|
||||
}
|
||||
if resultLine.Response.UsageMetadata != nil {
|
||||
respBody["usage"] = map[string]interface{}{
|
||||
"prompt_tokens": resultLine.Response.UsageMetadata.PromptTokenCount,
|
||||
"completion_tokens": resultLine.Response.UsageMetadata.CandidatesTokenCount,
|
||||
"total_tokens": resultLine.Response.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
}
|
||||
|
||||
resultItem.Response = &schemas.BatchResultResponse{
|
||||
StatusCode: 200,
|
||||
Body: respBody,
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, resultItem)
|
||||
return nil
|
||||
})
|
||||
|
||||
return results, parseResult.Errors, nil
|
||||
}
|
||||
|
||||
// extractGeminiUsageMetadata extracts usage metadata (as ints) from Gemini response
|
||||
func extractGeminiUsageMetadata(geminiResponse *GenerateContentResponse) (int, int, int) {
|
||||
var inputTokens, outputTokens, totalTokens int
|
||||
if geminiResponse.UsageMetadata != nil {
|
||||
usageMetadata := geminiResponse.UsageMetadata
|
||||
inputTokens = int(usageMetadata.PromptTokenCount)
|
||||
outputTokens = int(usageMetadata.CandidatesTokenCount)
|
||||
totalTokens = int(usageMetadata.TotalTokenCount)
|
||||
}
|
||||
return inputTokens, outputTokens, totalTokens
|
||||
}
|
||||
558
core/providers/gemini/chat.go
Normal file
558
core/providers/gemini/chat.go
Normal file
@@ -0,0 +1,558 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToGeminiChatCompletionRequest converts a BifrostChatRequest to Gemini's generation request format for chat completion
|
||||
func ToGeminiChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*GeminiGenerationRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Create the base Gemini generation request
|
||||
geminiReq := &GeminiGenerationRequest{
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
|
||||
// Convert parameters to generation config
|
||||
if bifrostReq.Params != nil {
|
||||
geminiReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
var err error
|
||||
geminiReq.GenerationConfig, err = convertParamsToGenerationConfig(bifrostReq.Params, []string{}, bifrostReq.Model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Handle tool-related parameters
|
||||
if len(bifrostReq.Params.Tools) > 0 {
|
||||
geminiReq.Tools = convertBifrostToolsToGemini(bifrostReq.Params.Tools)
|
||||
|
||||
// Convert tool choice to tool config
|
||||
if bifrostReq.Params.ToolChoice != nil {
|
||||
geminiReq.ToolConfig = convertToolChoiceToToolConfig(bifrostReq.Params.ToolChoice)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle extra parameters
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
// Safety settings
|
||||
if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok {
|
||||
delete(geminiReq.ExtraParams, "safety_settings")
|
||||
if settings, ok := SafeExtractSafetySettings(safetySettings); ok {
|
||||
geminiReq.SafetySettings = settings
|
||||
}
|
||||
}
|
||||
|
||||
// Cached content
|
||||
if cachedContent, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["cached_content"]); ok {
|
||||
delete(geminiReq.ExtraParams, "cached_content")
|
||||
geminiReq.CachedContent = cachedContent
|
||||
}
|
||||
|
||||
// Labels
|
||||
if labels, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "labels"); ok {
|
||||
delete(geminiReq.ExtraParams, "labels")
|
||||
if labelMap, ok := schemas.SafeExtractStringMap(labels); ok {
|
||||
geminiReq.Labels = labelMap
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Convert chat completion messages to Gemini format
|
||||
contents, systemInstruction := convertBifrostMessagesToGemini(bifrostReq.Input)
|
||||
if systemInstruction != nil {
|
||||
geminiReq.SystemInstruction = systemInstruction
|
||||
}
|
||||
geminiReq.Contents = contents
|
||||
return geminiReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostChatResponse converts a GenerateContentResponse to a BifrostChatResponse
|
||||
func (response *GenerateContentResponse) ToBifrostChatResponse() *schemas.BifrostChatResponse {
|
||||
bifrostResp := &schemas.BifrostChatResponse{
|
||||
ID: response.ResponseID,
|
||||
Model: response.ModelVersion,
|
||||
Object: "chat.completion",
|
||||
}
|
||||
|
||||
// Set creation timestamp if available
|
||||
if !response.CreateTime.IsZero() {
|
||||
bifrostResp.Created = int(response.CreateTime.Unix())
|
||||
}
|
||||
|
||||
// Handle empty candidates (filtered/malformed responses)
|
||||
if len(response.Candidates) == 0 {
|
||||
finishReason := ConvertGeminiFinishReasonToBifrost(FinishReasonMalformedFunctionCall)
|
||||
return createErrorResponse(response, finishReason, false)
|
||||
}
|
||||
|
||||
candidate := response.Candidates[0]
|
||||
|
||||
// Check for filtered finish reasons that indicate errors
|
||||
if isErrorFinishReason(candidate.FinishReason) {
|
||||
finishReason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason)
|
||||
return createErrorResponse(response, finishReason, false)
|
||||
}
|
||||
|
||||
// Collect all content and tool calls into a single message
|
||||
var toolCalls []schemas.ChatAssistantMessageToolCall
|
||||
var contentBlocks []schemas.ChatContentBlock
|
||||
var reasoningDetails []schemas.ChatReasoningDetails
|
||||
var contentStr *string
|
||||
|
||||
// Process candidate content to extract text, tool calls, and reasoning
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
// Handle thought/reasoning text separately - add to reasoning details
|
||||
if part.Text != "" && part.Thought {
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: &part.Text,
|
||||
})
|
||||
continue
|
||||
}
|
||||
// Handle regular text
|
||||
if part.Text != "" {
|
||||
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: &part.Text,
|
||||
})
|
||||
// Add thought signature to reasoning details if present with text
|
||||
if len(part.ThoughtSignature) > 0 {
|
||||
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
|
||||
Signature: &thoughtSig,
|
||||
})
|
||||
}
|
||||
}
|
||||
if part.FunctionCall != nil {
|
||||
function := schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: &part.FunctionCall.Name,
|
||||
}
|
||||
|
||||
if len(part.FunctionCall.Args) > 0 {
|
||||
function.Arguments = string(part.FunctionCall.Args)
|
||||
}
|
||||
|
||||
callID := part.FunctionCall.Name
|
||||
if part.FunctionCall.ID != "" {
|
||||
callID = part.FunctionCall.ID
|
||||
}
|
||||
|
||||
// Embed thought signature into CallID if present (matches responses.go pattern)
|
||||
if len(part.ThoughtSignature) > 0 && !strings.Contains(callID, thoughtSignatureSeparator) {
|
||||
encoded := base64.RawURLEncoding.EncodeToString(part.ThoughtSignature)
|
||||
callID = fmt.Sprintf("%s%s%s", callID, thoughtSignatureSeparator, encoded)
|
||||
}
|
||||
|
||||
toolCall := schemas.ChatAssistantMessageToolCall{
|
||||
Index: uint16(len(toolCalls)),
|
||||
Type: schemas.Ptr(string(schemas.ChatToolChoiceTypeFunction)),
|
||||
ID: &callID,
|
||||
Function: function,
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
|
||||
// Also add to reasoning details for backward compatibility
|
||||
if len(part.ThoughtSignature) > 0 {
|
||||
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
|
||||
// Extract base ID without signature for reasoning detail lookup
|
||||
baseCallID := callID
|
||||
if strings.Contains(callID, thoughtSignatureSeparator) {
|
||||
parts := strings.SplitN(callID, thoughtSignatureSeparator, 2)
|
||||
if len(parts) == 2 {
|
||||
baseCallID = parts[0]
|
||||
}
|
||||
}
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
|
||||
Signature: &thoughtSig,
|
||||
ID: schemas.Ptr(fmt.Sprintf("tool_call_%s", baseCallID)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if part.FunctionResponse != nil {
|
||||
// Extract the output from the response
|
||||
output := extractFunctionResponseOutput(part.FunctionResponse)
|
||||
|
||||
// Add as text content block
|
||||
if output != "" {
|
||||
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: &output,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle code execution results
|
||||
if part.CodeExecutionResult != nil {
|
||||
output := part.CodeExecutionResult.Output
|
||||
if part.CodeExecutionResult.Outcome != OutcomeOK {
|
||||
output = "Error: " + output
|
||||
}
|
||||
if output != "" {
|
||||
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: &output,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle executable code
|
||||
if part.ExecutableCode != nil {
|
||||
codeContent := "```" + part.ExecutableCode.Language + "\n" + part.ExecutableCode.Code + "\n```"
|
||||
contentBlocks = append(contentBlocks, schemas.ChatContentBlock{
|
||||
Type: schemas.ChatContentBlockTypeText,
|
||||
Text: &codeContent,
|
||||
})
|
||||
}
|
||||
|
||||
// Handle standalone thought signature (not associated with function call or text)
|
||||
if len(part.ThoughtSignature) > 0 && part.FunctionCall == nil && part.Text == "" {
|
||||
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
|
||||
Signature: &thoughtSig,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Build the choice with message
|
||||
message := &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
}
|
||||
|
||||
if len(contentBlocks) == 1 && contentBlocks[0].Type == schemas.ChatContentBlockTypeText {
|
||||
contentStr = contentBlocks[0].Text
|
||||
contentBlocks = nil
|
||||
}
|
||||
|
||||
message.Content = &schemas.ChatMessageContent{
|
||||
ContentStr: contentStr,
|
||||
ContentBlocks: contentBlocks,
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 || len(reasoningDetails) > 0 {
|
||||
message.ChatAssistantMessage = &schemas.ChatAssistantMessage{
|
||||
ToolCalls: toolCalls,
|
||||
ReasoningDetails: reasoningDetails,
|
||||
}
|
||||
}
|
||||
|
||||
// Convert finish reason to Bifrost format.
|
||||
// Gemini uses "STOP" for both normal text completions and tool call responses —
|
||||
// it has no dedicated finish reason for tool calls. Override to "tool_calls" when
|
||||
// tool calls are present so downstream consumers see a uniform signal.
|
||||
finishReason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason)
|
||||
if len(toolCalls) > 0 && finishReason == "stop" {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
bifrostResp.Choices = append(bifrostResp.Choices, schemas.BifrostResponseChoice{
|
||||
Index: 0,
|
||||
FinishReason: &finishReason,
|
||||
LogProbs: ConvertGeminiLogprobsResultToBifrost(candidate.LogprobsResult),
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Set usage information
|
||||
bifrostResp.Usage = ConvertGeminiUsageMetadataToChatUsage(response.UsageMetadata)
|
||||
|
||||
return bifrostResp
|
||||
}
|
||||
|
||||
// GeminiStreamState tracks tool-call index across streaming chunks.
|
||||
type GeminiStreamState struct {
|
||||
nextToolCallIndex int
|
||||
hadToolCalls bool // true if any tool calls were seen in this stream
|
||||
}
|
||||
|
||||
// NewGeminiStreamState returns initialised stream state for one streaming response.
|
||||
func NewGeminiStreamState() *GeminiStreamState {
|
||||
return &GeminiStreamState{}
|
||||
}
|
||||
|
||||
// ToBifrostChatCompletionStream converts a Gemini streaming response to a Bifrost Chat Completion Stream response
|
||||
// Returns the response, error (if any), and a boolean indicating if this is the last chunk
|
||||
func (response *GenerateContentResponse) ToBifrostChatCompletionStream(state *GeminiStreamState) (*schemas.BifrostChatResponse, *schemas.BifrostError, bool) {
|
||||
if response == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
if state == nil {
|
||||
state = NewGeminiStreamState()
|
||||
}
|
||||
|
||||
// Handle empty candidates (filtered/malformed responses)
|
||||
if len(response.Candidates) == 0 {
|
||||
finishReason := ConvertGeminiFinishReasonToBifrost(FinishReasonMalformedFunctionCall)
|
||||
return createErrorResponse(response, finishReason, true), nil, true
|
||||
}
|
||||
|
||||
candidate := response.Candidates[0]
|
||||
|
||||
// Check for filtered finish reasons that indicate errors
|
||||
if isErrorFinishReason(candidate.FinishReason) {
|
||||
finishReason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason)
|
||||
return createErrorResponse(response, finishReason, true), nil, true
|
||||
}
|
||||
|
||||
// Determine if this is the last chunk based on finish reason and usage metadata
|
||||
isLastChunk := candidate.FinishReason != "" && response.UsageMetadata != nil
|
||||
|
||||
// Create the streaming response
|
||||
streamResponse := &schemas.BifrostChatResponse{
|
||||
ID: response.ResponseID,
|
||||
Model: response.ModelVersion,
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
|
||||
// Set creation timestamp if available
|
||||
if !response.CreateTime.IsZero() {
|
||||
streamResponse.Created = int(response.CreateTime.Unix())
|
||||
}
|
||||
|
||||
// Build delta content
|
||||
delta := &schemas.ChatStreamResponseChoiceDelta{}
|
||||
|
||||
// Process content parts
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
// Set role from the first chunk (Gemini uses "model" for assistant)
|
||||
if candidate.Content.Role != "" {
|
||||
role := candidate.Content.Role
|
||||
if role == string(RoleModel) {
|
||||
role = string(schemas.ChatMessageRoleAssistant)
|
||||
}
|
||||
delta.Role = &role
|
||||
}
|
||||
|
||||
var textContent string
|
||||
var toolCalls []schemas.ChatAssistantMessageToolCall
|
||||
var reasoningDetails []schemas.ChatReasoningDetails
|
||||
|
||||
for _, part := range candidate.Content.Parts {
|
||||
switch {
|
||||
case part.Text != "" && part.Thought:
|
||||
// Thought/reasoning content - add to reasoning details
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeText,
|
||||
Text: &part.Text,
|
||||
})
|
||||
|
||||
case part.Text != "":
|
||||
// Regular text content
|
||||
textContent += part.Text
|
||||
|
||||
case part.FunctionCall != nil:
|
||||
// Function call
|
||||
jsonArgs := ""
|
||||
if len(part.FunctionCall.Args) > 0 {
|
||||
jsonArgs = string(part.FunctionCall.Args)
|
||||
}
|
||||
|
||||
// Use ID if available, otherwise use function name
|
||||
callID := part.FunctionCall.Name
|
||||
if part.FunctionCall.ID != "" {
|
||||
callID = part.FunctionCall.ID
|
||||
}
|
||||
|
||||
// Embed thought signature into CallID if present
|
||||
if len(part.ThoughtSignature) > 0 && !strings.Contains(callID, thoughtSignatureSeparator) {
|
||||
encoded := base64.RawURLEncoding.EncodeToString(part.ThoughtSignature)
|
||||
callID = fmt.Sprintf("%s%s%s", callID, thoughtSignatureSeparator, encoded)
|
||||
}
|
||||
|
||||
toolCallIdx := state.nextToolCallIndex
|
||||
state.nextToolCallIndex++
|
||||
|
||||
toolCall := schemas.ChatAssistantMessageToolCall{
|
||||
Index: uint16(toolCallIdx),
|
||||
Type: schemas.Ptr(string(schemas.ChatToolTypeFunction)),
|
||||
ID: &callID,
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: &part.FunctionCall.Name,
|
||||
Arguments: jsonArgs,
|
||||
},
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
|
||||
// Also add thought signature to reasoning details if present
|
||||
if len(part.ThoughtSignature) > 0 {
|
||||
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
|
||||
// Extract base ID without signature for reasoning detail lookup
|
||||
baseCallID := callID
|
||||
if strings.Contains(callID, thoughtSignatureSeparator) {
|
||||
parts := strings.SplitN(callID, thoughtSignatureSeparator, 2)
|
||||
if len(parts) == 2 {
|
||||
baseCallID = parts[0]
|
||||
}
|
||||
}
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
|
||||
Signature: &thoughtSig,
|
||||
ID: schemas.Ptr(fmt.Sprintf("tool_call_%s", baseCallID)),
|
||||
})
|
||||
}
|
||||
|
||||
case part.FunctionResponse != nil:
|
||||
// Extract the output from the response and add to text content
|
||||
output := extractFunctionResponseOutput(part.FunctionResponse)
|
||||
if output != "" {
|
||||
textContent += output
|
||||
}
|
||||
case part.CodeExecutionResult != nil:
|
||||
output := part.CodeExecutionResult.Output
|
||||
if part.CodeExecutionResult.Outcome != OutcomeOK {
|
||||
output = "Error: " + output
|
||||
}
|
||||
if output != "" {
|
||||
textContent += output
|
||||
}
|
||||
case part.ExecutableCode != nil:
|
||||
codeContent := "```" + part.ExecutableCode.Language + "\n" + part.ExecutableCode.Code + "\n```"
|
||||
textContent += codeContent
|
||||
}
|
||||
|
||||
// Handle thought signature separately (not part of the switch since it can co-exist with other types)
|
||||
if len(part.ThoughtSignature) > 0 && part.FunctionCall == nil {
|
||||
thoughtSig := base64.StdEncoding.EncodeToString(part.ThoughtSignature)
|
||||
reasoningDetails = append(reasoningDetails, schemas.ChatReasoningDetails{
|
||||
Index: len(reasoningDetails),
|
||||
Type: schemas.BifrostReasoningDetailsTypeEncrypted,
|
||||
Signature: &thoughtSig,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Set text content if present
|
||||
if textContent != "" {
|
||||
delta.Content = &textContent
|
||||
}
|
||||
|
||||
// Set reasoning details if present
|
||||
if len(reasoningDetails) > 0 {
|
||||
delta.ReasoningDetails = reasoningDetails
|
||||
}
|
||||
|
||||
// Set tool calls if present
|
||||
if len(toolCalls) > 0 {
|
||||
delta.ToolCalls = toolCalls
|
||||
state.hadToolCalls = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if delta has any content - if not and it's not the last chunk, skip it
|
||||
hasDeltaContent := delta.Role != nil || delta.Content != nil || len(delta.ToolCalls) > 0 || len(delta.ReasoningDetails) > 0
|
||||
if !hasDeltaContent && !isLastChunk {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
// Build the choice
|
||||
var finishReason *string
|
||||
if isLastChunk && candidate.FinishReason != "" {
|
||||
reason := ConvertGeminiFinishReasonToBifrost(candidate.FinishReason)
|
||||
// Gemini uses "STOP" for both text completions and tool call responses.
|
||||
// Override to "tool_calls" when tool calls were seen in this stream for uniformity.
|
||||
if (len(delta.ToolCalls) > 0 || state.hadToolCalls) && reason == "stop" {
|
||||
reason = "tool_calls"
|
||||
}
|
||||
finishReason = &reason
|
||||
}
|
||||
|
||||
choice := schemas.BifrostResponseChoice{
|
||||
Index: int(candidate.Index),
|
||||
FinishReason: finishReason,
|
||||
LogProbs: ConvertGeminiLogprobsResultToBifrost(candidate.LogprobsResult),
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: delta,
|
||||
},
|
||||
}
|
||||
|
||||
streamResponse.Choices = []schemas.BifrostResponseChoice{choice}
|
||||
|
||||
// Add usage information if this is the last chunk
|
||||
if isLastChunk && response.UsageMetadata != nil {
|
||||
streamResponse.Usage = ConvertGeminiUsageMetadataToChatUsage(response.UsageMetadata)
|
||||
}
|
||||
|
||||
return streamResponse, nil, isLastChunk
|
||||
}
|
||||
|
||||
// isErrorFinishReason checks if a finish reason indicates a filtered or error response
|
||||
func isErrorFinishReason(reason FinishReason) bool {
|
||||
return reason == FinishReasonSafety ||
|
||||
reason == FinishReasonRecitation ||
|
||||
reason == FinishReasonMalformedFunctionCall ||
|
||||
reason == FinishReasonBlocklist ||
|
||||
reason == FinishReasonProhibitedContent ||
|
||||
reason == FinishReasonSPII ||
|
||||
reason == FinishReasonImageSafety ||
|
||||
reason == FinishReasonUnexpectedToolCall ||
|
||||
reason == FinishReasonMissingThoughtSignature ||
|
||||
reason == FinishReasonMalformedResponse ||
|
||||
reason == FinishReasonImageProhibitedContent ||
|
||||
reason == FinishReasonImageRecitation ||
|
||||
reason == FinishReasonTooManyToolCalls ||
|
||||
reason == FinishReasonNoImage
|
||||
}
|
||||
|
||||
// createErrorResponse creates a complete BifrostChatResponse for error cases
|
||||
func createErrorResponse(response *GenerateContentResponse, finishReason string, isStream bool) *schemas.BifrostChatResponse {
|
||||
var choice schemas.BifrostResponseChoice
|
||||
if isStream {
|
||||
choice = schemas.BifrostResponseChoice{
|
||||
Index: 0,
|
||||
FinishReason: &finishReason,
|
||||
ChatStreamResponseChoice: &schemas.ChatStreamResponseChoice{
|
||||
Delta: &schemas.ChatStreamResponseChoiceDelta{},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
choice = schemas.BifrostResponseChoice{
|
||||
Index: 0,
|
||||
FinishReason: &finishReason,
|
||||
ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{
|
||||
Message: &schemas.ChatMessage{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
Content: &schemas.ChatMessageContent{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
objectType := "chat.completion"
|
||||
if isStream {
|
||||
objectType = "chat.completion.chunk"
|
||||
}
|
||||
|
||||
errorResp := &schemas.BifrostChatResponse{
|
||||
ID: response.ResponseID,
|
||||
Model: response.ModelVersion,
|
||||
Object: objectType,
|
||||
Choices: []schemas.BifrostResponseChoice{choice},
|
||||
Usage: ConvertGeminiUsageMetadataToChatUsage(response.UsageMetadata),
|
||||
}
|
||||
|
||||
if !response.CreateTime.IsZero() {
|
||||
errorResp.Created = int(response.CreateTime.Unix())
|
||||
}
|
||||
|
||||
return errorResp
|
||||
}
|
||||
80
core/providers/gemini/count_tokens.go
Normal file
80
core/providers/gemini/count_tokens.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostCountTokensResponse converts a Gemini count tokens response to Bifrost format.
|
||||
func (resp *GeminiCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sum prompt tokens and map modality-specific counts
|
||||
inputTokens := 0
|
||||
inputDetails := &schemas.ResponsesResponseInputTokens{}
|
||||
|
||||
for _, m := range resp.PromptTokensDetails {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
inputTokens += int(m.TokenCount)
|
||||
mod := strings.ToLower(string(m.Modality))
|
||||
// handle audio modality
|
||||
if strings.Contains(mod, "audio") {
|
||||
inputDetails.AudioTokens += int(m.TokenCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Set cached tokens from top-level field if present
|
||||
if resp.CachedContentTokenCount != 0 {
|
||||
inputDetails.CachedReadTokens = int(resp.CachedContentTokenCount)
|
||||
} else if resp.CacheTokensDetails != nil {
|
||||
// If cache tokens details present, sum them
|
||||
cachedSum := 0
|
||||
for _, m := range resp.CacheTokensDetails {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
cachedSum += int(m.TokenCount)
|
||||
if strings.Contains(strings.ToLower(string(m.Modality)), "audio") {
|
||||
// also populate audio tokens from cache into AudioTokens (additive)
|
||||
inputDetails.AudioTokens += int(m.TokenCount)
|
||||
}
|
||||
}
|
||||
inputDetails.CachedReadTokens = cachedSum
|
||||
}
|
||||
|
||||
total := int(resp.TotalTokens)
|
||||
|
||||
return &schemas.BifrostCountTokensResponse{
|
||||
Model: model,
|
||||
Object: "response.input_tokens",
|
||||
InputTokens: inputTokens,
|
||||
InputTokensDetails: inputDetails,
|
||||
TotalTokens: &total,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{},
|
||||
}
|
||||
}
|
||||
|
||||
// ToGeminiCountTokensResponse converts a Bifrost count tokens response to Gemini format.
|
||||
func ToGeminiCountTokensResponse(bifrostResp *schemas.BifrostCountTokensResponse) *GeminiCountTokensResponse {
|
||||
if bifrostResp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
response := &GeminiCountTokensResponse{
|
||||
TotalTokens: int32(bifrostResp.InputTokens),
|
||||
}
|
||||
|
||||
// Map cached content token count if available
|
||||
if bifrostResp.InputTokensDetails != nil && bifrostResp.InputTokensDetails.CachedReadTokens > 0 {
|
||||
response.CachedContentTokenCount = int32(bifrostResp.InputTokensDetails.CachedReadTokens)
|
||||
} else {
|
||||
response.CachedContentTokenCount = 0
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
247
core/providers/gemini/embedding.go
Normal file
247
core/providers/gemini/embedding.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToGeminiEmbeddingRequest converts a BifrostRequest with embedding input to Gemini's batch embedding request format
|
||||
// GeminiGenerationRequest contains requests array for batch embed content endpoint
|
||||
func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *GeminiBatchEmbeddingRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
|
||||
return nil
|
||||
}
|
||||
|
||||
embeddingInput := bifrostReq.Input
|
||||
|
||||
// Collect all texts to embed
|
||||
var texts []string
|
||||
if embeddingInput.Text != nil {
|
||||
texts = append(texts, *embeddingInput.Text)
|
||||
}
|
||||
if len(embeddingInput.Texts) > 0 {
|
||||
texts = append(texts, embeddingInput.Texts...)
|
||||
}
|
||||
|
||||
if len(texts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create batch embedding request with one request per text
|
||||
batchRequest := &GeminiBatchEmbeddingRequest{
|
||||
Requests: make([]GeminiEmbeddingRequest, len(texts)),
|
||||
}
|
||||
if bifrostReq.Params != nil {
|
||||
batchRequest.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
|
||||
// Create individual embedding requests for each text
|
||||
for i, text := range texts {
|
||||
embeddingReq := GeminiEmbeddingRequest{
|
||||
Model: "models/" + bifrostReq.Model,
|
||||
Content: &Content{
|
||||
Parts: []*Part{
|
||||
{
|
||||
Text: text,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add parameters if available
|
||||
if bifrostReq.Params != nil {
|
||||
if bifrostReq.Params.Dimensions != nil {
|
||||
embeddingReq.OutputDimensionality = bifrostReq.Params.Dimensions
|
||||
}
|
||||
|
||||
// Handle extra parameters
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
if taskType, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["taskType"]); ok {
|
||||
delete(batchRequest.ExtraParams, "taskType")
|
||||
embeddingReq.TaskType = taskType
|
||||
}
|
||||
if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok {
|
||||
delete(batchRequest.ExtraParams, "title")
|
||||
embeddingReq.Title = title
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
batchRequest.Requests[i] = embeddingReq
|
||||
}
|
||||
|
||||
return batchRequest
|
||||
}
|
||||
|
||||
// ToGeminiEmbeddingResponse converts a BifrostResponse with embedding data to Gemini's embedding response format
|
||||
func ToGeminiEmbeddingResponse(bifrostResp *schemas.BifrostEmbeddingResponse) *GeminiEmbeddingResponse {
|
||||
if bifrostResp == nil || len(bifrostResp.Data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResp := &GeminiEmbeddingResponse{
|
||||
Embeddings: make([]GeminiEmbedding, len(bifrostResp.Data)),
|
||||
}
|
||||
|
||||
// Convert each embedding from Bifrost format to Gemini format
|
||||
for i, embedding := range bifrostResp.Data {
|
||||
var values []float64
|
||||
|
||||
// Extract embedding values from BifrostEmbeddingResponse
|
||||
if embedding.Embedding.EmbeddingArray != nil {
|
||||
values = append([]float64(nil), embedding.Embedding.EmbeddingArray...)
|
||||
} else if len(embedding.Embedding.Embedding2DArray) > 0 {
|
||||
// If it's a 2D array, take the first array
|
||||
values = append([]float64(nil), embedding.Embedding.Embedding2DArray[0]...)
|
||||
}
|
||||
|
||||
geminiEmbedding := GeminiEmbedding{
|
||||
Values: values,
|
||||
}
|
||||
|
||||
// Add statistics if available (token count from usage metadata)
|
||||
if bifrostResp.Usage != nil {
|
||||
geminiEmbedding.Statistics = &ContentEmbeddingStatistics{
|
||||
TokenCount: int32(bifrostResp.Usage.PromptTokens),
|
||||
}
|
||||
}
|
||||
|
||||
geminiResp.Embeddings[i] = geminiEmbedding
|
||||
}
|
||||
|
||||
// Set metadata if available (for Vertex API compatibility)
|
||||
if bifrostResp.Usage != nil {
|
||||
geminiResp.Metadata = &EmbedContentMetadata{
|
||||
BillableCharacterCount: int32(bifrostResp.Usage.PromptTokens),
|
||||
}
|
||||
}
|
||||
|
||||
return geminiResp
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingResponse converts a Gemini embedding response to BifrostEmbeddingResponse format
|
||||
func ToBifrostEmbeddingResponse(geminiResp *GeminiEmbeddingResponse, model string) *schemas.BifrostEmbeddingResponse {
|
||||
if geminiResp == nil || len(geminiResp.Embeddings) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResp := &schemas.BifrostEmbeddingResponse{
|
||||
Data: make([]schemas.EmbeddingData, len(geminiResp.Embeddings)),
|
||||
Model: model,
|
||||
Object: "list",
|
||||
}
|
||||
|
||||
// Convert each embedding from Gemini format to Bifrost format
|
||||
for i, geminiEmbedding := range geminiResp.Embeddings {
|
||||
embeddingData := schemas.EmbeddingData{
|
||||
Index: i,
|
||||
Object: "embedding",
|
||||
Embedding: schemas.EmbeddingStruct{
|
||||
EmbeddingArray: geminiEmbedding.Values,
|
||||
},
|
||||
}
|
||||
|
||||
bifrostResp.Data[i] = embeddingData
|
||||
}
|
||||
|
||||
// Convert usage metadata if available
|
||||
if geminiResp.Metadata != nil || (len(geminiResp.Embeddings) > 0 && geminiResp.Embeddings[0].Statistics != nil) {
|
||||
bifrostResp.Usage = &schemas.BifrostLLMUsage{}
|
||||
|
||||
// Use statistics from the first embedding if available
|
||||
if geminiResp.Embeddings[0].Statistics != nil {
|
||||
bifrostResp.Usage.PromptTokens = int(geminiResp.Embeddings[0].Statistics.TokenCount)
|
||||
} else if geminiResp.Metadata != nil {
|
||||
// Fall back to metadata if statistics are not available
|
||||
bifrostResp.Usage.PromptTokens = int(geminiResp.Metadata.BillableCharacterCount)
|
||||
}
|
||||
|
||||
// Set total tokens same as prompt tokens for embeddings
|
||||
bifrostResp.Usage.TotalTokens = bifrostResp.Usage.PromptTokens
|
||||
}
|
||||
|
||||
return bifrostResp
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingRequest converts a GeminiGenerationRequest to BifrostEmbeddingRequest format
|
||||
func (request *GeminiGenerationRequest) ToBifrostEmbeddingRequest(ctx *schemas.BifrostContext) *schemas.BifrostEmbeddingRequest {
|
||||
if request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
|
||||
|
||||
// Create the embedding request
|
||||
bifrostReq := &schemas.BifrostEmbeddingRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Fallbacks: schemas.ParseFallbacks(request.Fallbacks),
|
||||
}
|
||||
|
||||
// SDK batch embedding request contains multiple embedding requests with same parameters but different text fields.
|
||||
if len(request.Requests) > 0 {
|
||||
var texts []string
|
||||
for _, req := range request.Requests {
|
||||
if req.Content != nil && len(req.Content.Parts) > 0 {
|
||||
for _, part := range req.Content.Parts {
|
||||
if part != nil && part.Text != "" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(texts) > 0 {
|
||||
bifrostReq.Input = &schemas.EmbeddingInput{}
|
||||
if len(texts) == 1 {
|
||||
bifrostReq.Input.Text = &texts[0]
|
||||
} else {
|
||||
bifrostReq.Input.Texts = texts
|
||||
}
|
||||
}
|
||||
|
||||
embeddingRequest := request.Requests[0]
|
||||
|
||||
// Convert parameters
|
||||
if embeddingRequest.OutputDimensionality != nil || embeddingRequest.TaskType != nil || embeddingRequest.Title != nil {
|
||||
bifrostReq.Params = &schemas.EmbeddingParameters{}
|
||||
|
||||
if embeddingRequest.OutputDimensionality != nil {
|
||||
bifrostReq.Params.Dimensions = embeddingRequest.OutputDimensionality
|
||||
}
|
||||
|
||||
// Handle extra parameters
|
||||
if embeddingRequest.TaskType != nil || embeddingRequest.Title != nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
if embeddingRequest.TaskType != nil {
|
||||
bifrostReq.Params.ExtraParams["taskType"] = embeddingRequest.TaskType
|
||||
}
|
||||
if embeddingRequest.Title != nil {
|
||||
bifrostReq.Params.ExtraParams["title"] = embeddingRequest.Title
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generation-style requests (e.g., non-Imagen :predict) carry text in contents[].parts[].
|
||||
// If no SDK requests[] were provided, derive embedding input from contents.
|
||||
if bifrostReq.Input == nil {
|
||||
var texts []string
|
||||
for _, content := range request.Contents {
|
||||
for _, part := range content.Parts {
|
||||
if part != nil && part.Text != "" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(texts) > 0 {
|
||||
bifrostReq.Input = &schemas.EmbeddingInput{}
|
||||
if len(texts) == 1 {
|
||||
bifrostReq.Input.Text = &texts[0]
|
||||
} else {
|
||||
bifrostReq.Input.Texts = texts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
79
core/providers/gemini/errors.go
Normal file
79
core/providers/gemini/errors.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ToGeminiError derives a GeminiGenerationError from a BifrostError
|
||||
func ToGeminiError(bifrostErr *schemas.BifrostError) *GeminiGenerationError {
|
||||
if bifrostErr == nil {
|
||||
return nil
|
||||
}
|
||||
code := 500
|
||||
status := ""
|
||||
if bifrostErr.Error != nil && bifrostErr.Error.Type != nil {
|
||||
status = *bifrostErr.Error.Type
|
||||
}
|
||||
message := ""
|
||||
if bifrostErr.Error != nil && bifrostErr.Error.Message != "" {
|
||||
message = bifrostErr.Error.Message
|
||||
}
|
||||
if bifrostErr.StatusCode != nil {
|
||||
code = *bifrostErr.StatusCode
|
||||
}
|
||||
return &GeminiGenerationError{
|
||||
Error: &GeminiGenerationErrorStruct{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Status: status,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// parseGeminiError parses Gemini error responses
|
||||
func parseGeminiError(resp *fasthttp.Response) *schemas.BifrostError {
|
||||
// Try to parse as []GeminiGenerationError
|
||||
var errorResps []GeminiGenerationError
|
||||
bifrostErr := providerUtils.HandleProviderAPIError(resp, &errorResps)
|
||||
if len(errorResps) > 0 {
|
||||
var message string
|
||||
var firstError *GeminiGenerationErrorStruct
|
||||
for _, errorResp := range errorResps {
|
||||
if errorResp.Error != nil {
|
||||
if firstError == nil {
|
||||
firstError = errorResp.Error
|
||||
}
|
||||
message = message + errorResp.Error.Message + "\n"
|
||||
}
|
||||
}
|
||||
// Trim trailing newline
|
||||
message = strings.TrimSuffix(message, "\n")
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
// Set Code from first error if available
|
||||
if firstError != nil {
|
||||
bifrostErr.Error.Code = schemas.Ptr(strconv.Itoa(firstError.Code))
|
||||
}
|
||||
// Set Message to trimmed concatenated message
|
||||
bifrostErr.Error.Message = message
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
// Try to parse as GeminiGenerationError
|
||||
var errorResp GeminiGenerationError
|
||||
bifrostErr = providerUtils.HandleProviderAPIError(resp, &errorResp)
|
||||
if errorResp.Error != nil {
|
||||
if bifrostErr.Error == nil {
|
||||
bifrostErr.Error = &schemas.ErrorField{}
|
||||
}
|
||||
bifrostErr.Error.Code = schemas.Ptr(strconv.Itoa(errorResp.Error.Code))
|
||||
bifrostErr.Error.Message = errorResp.Error.Message
|
||||
}
|
||||
return bifrostErr
|
||||
}
|
||||
145
core/providers/gemini/files.go
Normal file
145
core/providers/gemini/files.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// Gemini Files API types
|
||||
// The Gemini Files API allows uploading files for use with multimodal models.
|
||||
|
||||
// GeminiFileResponse represents a file object from Gemini's API.
|
||||
type GeminiFileResponse struct {
|
||||
Name string `json:"name"` // Resource name (e.g., "files/abc123")
|
||||
DisplayName string `json:"displayName"` // User-provided display name
|
||||
MimeType string `json:"mimeType"` // MIME type of the file
|
||||
SizeBytes string `json:"sizeBytes"` // Size in bytes (as string)
|
||||
CreateTime string `json:"createTime"` // RFC3339 timestamp
|
||||
UpdateTime string `json:"updateTime"` // RFC3339 timestamp
|
||||
ExpirationTime string `json:"expirationTime,omitempty"` // RFC3339 timestamp when file will be deleted
|
||||
SHA256Hash string `json:"sha256Hash"` // Base64 encoded SHA256 hash
|
||||
URI string `json:"uri"` // URI for accessing the file
|
||||
State string `json:"state"` // "PROCESSING", "ACTIVE", "FAILED"
|
||||
VideoMetadata *GeminiFileVideoMetadata `json:"videoMetadata,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiFileVideoMetadata contains video-specific metadata.
|
||||
type GeminiFileVideoMetadata struct {
|
||||
VideoDuration string `json:"videoDuration"` // Duration in seconds
|
||||
}
|
||||
|
||||
// GeminiFileListResponse represents the response from listing files.
|
||||
type GeminiFileListResponse struct {
|
||||
Files []GeminiFileResponse `json:"files"`
|
||||
NextPageToken string `json:"nextPageToken,omitempty"`
|
||||
}
|
||||
|
||||
// ToBifrostFileStatus converts Gemini file state to Bifrost status.
|
||||
func ToBifrostFileStatus(state string) schemas.FileStatus {
|
||||
switch state {
|
||||
case "PROCESSING":
|
||||
return schemas.FileStatusProcessing
|
||||
case "ACTIVE":
|
||||
return schemas.FileStatusProcessed
|
||||
case "FAILED":
|
||||
return schemas.FileStatusError
|
||||
default:
|
||||
return schemas.FileStatus(strings.ToLower(state))
|
||||
}
|
||||
}
|
||||
|
||||
// ToGeminiFileListResponse converts a Bifrost file list response to Gemini format.
|
||||
func ToGeminiFileListResponse(resp *schemas.BifrostFileListResponse) *GeminiFileListResponse {
|
||||
files := make([]GeminiFileResponse, len(resp.Data))
|
||||
for i, f := range resp.Data {
|
||||
updateAt := f.UpdatedAt
|
||||
if updateAt == 0 {
|
||||
updateAt = f.CreatedAt
|
||||
}
|
||||
files[i] = GeminiFileResponse{
|
||||
Name: f.ID,
|
||||
DisplayName: f.Filename,
|
||||
SizeBytes: fmt.Sprintf("%d", f.Bytes),
|
||||
CreateTime: formatGeminiTimestamp(f.CreatedAt),
|
||||
UpdateTime: formatGeminiTimestamp(updateAt),
|
||||
State: toGeminiFileState(f.Status),
|
||||
ExpirationTime: formatGeminiTimestamp(safeDerefInt64(f.ExpiresAt)),
|
||||
}
|
||||
}
|
||||
result := &GeminiFileListResponse{Files: files}
|
||||
if resp.After != nil && *resp.After != "" {
|
||||
result.NextPageToken = *resp.After
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ToGeminiFileRetrieveResponse converts a Bifrost file retrieve response to Gemini format.
|
||||
func ToGeminiFileRetrieveResponse(resp *schemas.BifrostFileRetrieveResponse) *GeminiFileResponse {
|
||||
updateAt := resp.UpdatedAt
|
||||
if updateAt == 0 {
|
||||
updateAt = resp.CreatedAt
|
||||
}
|
||||
return &GeminiFileResponse{
|
||||
Name: resp.ID,
|
||||
DisplayName: resp.Filename,
|
||||
SizeBytes: fmt.Sprintf("%d", resp.Bytes),
|
||||
CreateTime: formatGeminiTimestamp(resp.CreatedAt),
|
||||
UpdateTime: formatGeminiTimestamp(updateAt),
|
||||
State: toGeminiFileState(resp.Status),
|
||||
URI: resp.StorageURI,
|
||||
ExpirationTime: formatGeminiTimestamp(safeDerefInt64(resp.ExpiresAt)),
|
||||
}
|
||||
}
|
||||
|
||||
// toGeminiFileState converts Bifrost file status to Gemini state.
|
||||
func toGeminiFileState(status schemas.FileStatus) string {
|
||||
switch status {
|
||||
case schemas.FileStatusProcessing:
|
||||
return "PROCESSING"
|
||||
case schemas.FileStatusProcessed:
|
||||
return "ACTIVE"
|
||||
case schemas.FileStatusError:
|
||||
return "FAILED"
|
||||
default:
|
||||
return strings.ToUpper(string(status))
|
||||
}
|
||||
}
|
||||
|
||||
// formatGeminiTimestamp converts Unix timestamp to Gemini RFC3339 format.
|
||||
func formatGeminiTimestamp(unixTime int64) string {
|
||||
if unixTime == 0 {
|
||||
return ""
|
||||
}
|
||||
return time.Unix(unixTime, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// safeDerefInt64 safely dereferences an int64 pointer.
|
||||
func safeDerefInt64(ptr *int64) int64 {
|
||||
if ptr == nil {
|
||||
return 0
|
||||
}
|
||||
return *ptr
|
||||
}
|
||||
|
||||
// ToGeminiFileUploadResponse converts a Bifrost file upload response to Gemini format.
|
||||
func ToGeminiFileUploadResponse(resp *schemas.BifrostFileUploadResponse) map[string]interface{} {
|
||||
file := map[string]interface{}{
|
||||
"name": resp.ID,
|
||||
"displayName": resp.Filename,
|
||||
"mimeType": "application/octet-stream",
|
||||
"sizeBytes": fmt.Sprintf("%d", resp.Bytes),
|
||||
"createTime": formatGeminiTimestamp(resp.CreatedAt),
|
||||
"updateTime": formatGeminiTimestamp(resp.CreatedAt),
|
||||
"state": toGeminiFileState(resp.Status),
|
||||
"uri": resp.StorageURI,
|
||||
}
|
||||
if exp := formatGeminiTimestamp(safeDerefInt64(resp.ExpiresAt)); exp != "" {
|
||||
file["expirationTime"] = exp
|
||||
}
|
||||
return map[string]interface{}{
|
||||
"file": file,
|
||||
}
|
||||
}
|
||||
4306
core/providers/gemini/gemini.go
Normal file
4306
core/providers/gemini/gemini.go
Normal file
File diff suppressed because it is too large
Load Diff
62
core/providers/gemini/gemini_stream_reader_test.go
Normal file
62
core/providers/gemini/gemini_stream_reader_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadNextSSEDataLine_SkipInlineDataOnGzipReader(t *testing.T) {
|
||||
var compressed bytes.Buffer
|
||||
gz := gzip.NewWriter(&compressed)
|
||||
payload := "data: {\"candidates\":[{\"content\":{\"parts\":[{\"inlineData\":{\"data\":\"abc\"}}]}}]}\n" +
|
||||
"data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]}}]}\n"
|
||||
if _, err := gz.Write([]byte(payload)); err != nil {
|
||||
t.Fatalf("failed to write gzip payload: %v", err)
|
||||
}
|
||||
if err := gz.Close(); err != nil {
|
||||
t.Fatalf("failed to close gzip writer: %v", err)
|
||||
}
|
||||
|
||||
reader, err := gzip.NewReader(bytes.NewReader(compressed.Bytes()))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create gzip reader: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
line, err := readNextSSEDataLine(bufio.NewReaderSize(reader, 64*1024), true)
|
||||
if err != nil {
|
||||
t.Fatalf("expected next non-inline line, got error: %v", err)
|
||||
}
|
||||
|
||||
want := []byte(`{"candidates":[{"content":{"parts":[{"text":"ok"}]}}]}`)
|
||||
if !bytes.Equal(line, want) {
|
||||
t.Fatalf("expected %q, got %q", string(want), string(line))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadNextSSEDataLine_SkipInlineDataContinuedLine(t *testing.T) {
|
||||
longInline := bytes.Repeat([]byte("x"), 70*1024)
|
||||
var stream bytes.Buffer
|
||||
stream.WriteString("data: {\"candidates\":[{\"content\":{\"parts\":[{\"inlineData\":{\"data\":\"")
|
||||
stream.Write(longInline)
|
||||
stream.WriteString("\"}}]}}]}\n")
|
||||
stream.WriteString("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"}]}}]}\n")
|
||||
|
||||
line, err := readNextSSEDataLine(bufio.NewReaderSize(bytes.NewReader(stream.Bytes()), 64*1024), true)
|
||||
if err != nil {
|
||||
t.Fatalf("expected next non-inline line, got error: %v", err)
|
||||
}
|
||||
|
||||
want := []byte(`{"candidates":[{"content":{"parts":[{"text":"ok"}]}}]}`)
|
||||
if !bytes.Equal(line, want) {
|
||||
t.Fatalf("expected %q, got %q", string(want), string(line))
|
||||
}
|
||||
|
||||
_, err = readNextSSEDataLine(bufio.NewReaderSize(bytes.NewReader(nil), 64*1024), true)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF on empty reader, got %v", err)
|
||||
}
|
||||
}
|
||||
2988
core/providers/gemini/gemini_test.go
Normal file
2988
core/providers/gemini/gemini_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1095
core/providers/gemini/images.go
Normal file
1095
core/providers/gemini/images.go
Normal file
File diff suppressed because it is too large
Load Diff
61
core/providers/gemini/list_models_single_payload_test.go
Normal file
61
core/providers/gemini/list_models_single_payload_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testNoopLogger struct{}
|
||||
|
||||
func (testNoopLogger) Debug(string, ...any) {}
|
||||
func (testNoopLogger) Info(string, ...any) {}
|
||||
func (testNoopLogger) Warn(string, ...any) {}
|
||||
func (testNoopLogger) Error(string, ...any) {}
|
||||
func (testNoopLogger) Fatal(string, ...any) {}
|
||||
func (testNoopLogger) SetLevel(schemas.LogLevel) {}
|
||||
func (testNoopLogger) SetOutputType(schemas.LoggerOutputType) {}
|
||||
func (testNoopLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder {
|
||||
return schemas.NoopLogEvent
|
||||
}
|
||||
|
||||
func TestListModelsByKey_ParsesSingleModelPayload(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.URL.Path != "/models/gemini-2.5-pro" {
|
||||
http.Error(w, "unexpected path", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"name":"models/gemini-2.5-pro","displayName":"Gemini 2.5 Pro","description":"test","inputTokenLimit":1048576,"outputTokenLimit":8192,"supportedGenerationMethods":["generateContent"]}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
provider := NewGeminiProvider(&schemas.ProviderConfig{
|
||||
NetworkConfig: schemas.NetworkConfig{BaseURL: ts.URL},
|
||||
}, testNoopLogger{})
|
||||
|
||||
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
ctx.SetValue(schemas.BifrostContextKeyURLPath, "/models/gemini-2.5-pro")
|
||||
|
||||
key := schemas.Key{Value: *schemas.NewEnvVar("dummy-key")}
|
||||
// Unfiltered=true bypasses the allowed/alias/blacklist filter pipeline so
|
||||
// this test can focus on the single-model-payload parsing code path in
|
||||
// listModelsByKey (gemini.go:215-220).
|
||||
resp, err := provider.listModelsByKey(ctx, key, &schemas.BifrostListModelsRequest{Provider: schemas.Gemini, Unfiltered: true})
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Data, 1)
|
||||
assert.Equal(t, "gemini/gemini-2.5-pro", resp.Data[0].ID)
|
||||
require.NotNil(t, resp.Data[0].Name)
|
||||
assert.Equal(t, "Gemini 2.5 Pro", *resp.Data[0].Name)
|
||||
}
|
||||
105
core/providers/gemini/models.go
Normal file
105
core/providers/gemini/models.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func toGeminiModelResourceName(modelID string) string {
|
||||
if strings.HasPrefix(modelID, "models/") {
|
||||
return modelID
|
||||
}
|
||||
if idx := strings.Index(modelID, "/"); idx >= 0 && idx+1 < len(modelID) {
|
||||
return "models/" + modelID[idx+1:]
|
||||
}
|
||||
return "models/" + modelID
|
||||
}
|
||||
|
||||
func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.Models)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: providerKey,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.Models {
|
||||
contextLength := model.InputTokenLimit + model.OutputTokenLimit
|
||||
// Gemini returns model names with a "models/" prefix — strip it before filtering
|
||||
// so that allowedModels entries like "gemini-1.5-pro" match correctly.
|
||||
modelName := strings.TrimPrefix(model.Name, "models/")
|
||||
|
||||
for _, result := range pipeline.FilterModel(modelName) {
|
||||
entry := schemas.Model{
|
||||
ID: string(providerKey) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(model.DisplayName),
|
||||
Description: schemas.Ptr(model.Description),
|
||||
ContextLength: schemas.Ptr(int(contextLength)),
|
||||
MaxInputTokens: schemas.Ptr(model.InputTokenLimit),
|
||||
MaxOutputTokens: schemas.Ptr(model.OutputTokenLimit),
|
||||
SupportedMethods: model.SupportedGenerationMethods,
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
entry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, entry)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
func ToGeminiListModelsResponse(resp *schemas.BifrostListModelsResponse) *GeminiListModelsResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
geminiResponse := &GeminiListModelsResponse{
|
||||
Models: make([]GeminiModel, 0, len(resp.Data)),
|
||||
NextPageToken: resp.NextPageToken,
|
||||
}
|
||||
|
||||
for _, model := range resp.Data {
|
||||
geminiModel := GeminiModel{
|
||||
Name: toGeminiModelResourceName(model.ID),
|
||||
SupportedGenerationMethods: model.SupportedMethods,
|
||||
}
|
||||
if model.Name != nil {
|
||||
geminiModel.DisplayName = *model.Name
|
||||
}
|
||||
if model.Description != nil {
|
||||
geminiModel.Description = *model.Description
|
||||
}
|
||||
if model.MaxInputTokens != nil {
|
||||
geminiModel.InputTokenLimit = *model.MaxInputTokens
|
||||
}
|
||||
if model.MaxOutputTokens != nil {
|
||||
geminiModel.OutputTokenLimit = *model.MaxOutputTokens
|
||||
}
|
||||
|
||||
geminiResponse.Models = append(geminiResponse.Models, geminiModel)
|
||||
}
|
||||
|
||||
return geminiResponse
|
||||
}
|
||||
41
core/providers/gemini/models_test.go
Normal file
41
core/providers/gemini/models_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestToGeminiModelResourceName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{name: "already native", input: "models/gemini-2.5-pro", want: "models/gemini-2.5-pro"},
|
||||
{name: "provider prefixed", input: "gemini/gemini-2.5-pro", want: "models/gemini-2.5-pro"},
|
||||
{name: "bare model", input: "gemini-2.5-pro", want: "models/gemini-2.5-pro"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, toGeminiModelResourceName(tc.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToGeminiListModelsResponse_UsesNativeModelResourceName(t *testing.T) {
|
||||
resp := &schemas.BifrostListModelsResponse{
|
||||
Data: []schemas.Model{
|
||||
{ID: "gemini/gemini-2.5-pro"},
|
||||
{ID: "models/gemini-2.5-flash"},
|
||||
},
|
||||
}
|
||||
|
||||
converted := ToGeminiListModelsResponse(resp)
|
||||
if assert.Len(t, converted.Models, 2) {
|
||||
assert.Equal(t, "models/gemini-2.5-pro", converted.Models[0].Name)
|
||||
assert.Equal(t, "models/gemini-2.5-flash", converted.Models[1].Name)
|
||||
}
|
||||
}
|
||||
56
core/providers/gemini/payload_ordering_test.go
Normal file
56
core/providers/gemini/payload_ordering_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
schemas "github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPayloadOrdering_GeminiGenerationRequest(t *testing.T) {
|
||||
req := &GeminiGenerationRequest{
|
||||
Model: "gemini-2.0-flash",
|
||||
Contents: []Content{
|
||||
{
|
||||
Parts: []*Part{{Text: "hello"}},
|
||||
Role: "user",
|
||||
},
|
||||
},
|
||||
GenerationConfig: GenerationConfig{
|
||||
Temperature: schemas.Ptr(float64(0.7)),
|
||||
},
|
||||
Tools: []Tool{
|
||||
{
|
||||
FunctionDeclarations: []*FunctionDeclaration{
|
||||
{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: &Schema{
|
||||
Type: "OBJECT",
|
||||
Properties: map[string]*Schema{
|
||||
"location": {Type: "STRING"},
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
golden := `{"model":"gemini-2.0-flash","contents":[{"parts":[{"text":"hello"}],"role":"user"}],"generationConfig":{"temperature":0.7},"tools":[{"functionDeclarations":[{"description":"Get weather","name":"get_weather","parameters":{"properties":{"location":{"type":"STRING"}},"required":["location"],"type":"OBJECT"}}]}]}`
|
||||
|
||||
assert.Equal(t, golden, string(result), "payload field ordering changed — if intentional, update the golden string")
|
||||
|
||||
// Determinism: 100 iterations must produce identical bytes
|
||||
for i := 0; i < 100; i++ {
|
||||
iter, err := providerUtils.MarshalSorted(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(result), string(iter), "non-deterministic marshal output on iteration %d", i)
|
||||
}
|
||||
}
|
||||
3652
core/providers/gemini/responses.go
Normal file
3652
core/providers/gemini/responses.go
Normal file
File diff suppressed because it is too large
Load Diff
200
core/providers/gemini/speech.go
Normal file
200
core/providers/gemini/speech.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostSpeechRequest converts a GeminiGenerationRequest to a BifrostSpeechRequest
|
||||
func (request *GeminiGenerationRequest) ToBifrostSpeechRequest(ctx *schemas.BifrostContext) *schemas.BifrostSpeechRequest {
|
||||
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
|
||||
|
||||
bifrostReq := &schemas.BifrostSpeechRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}
|
||||
|
||||
// Extract text input from contents
|
||||
var textInput string
|
||||
for _, content := range request.Contents {
|
||||
for _, part := range content.Parts {
|
||||
if part.Text != "" {
|
||||
textInput += part.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bifrostReq.Input = &schemas.SpeechInput{
|
||||
Input: textInput,
|
||||
}
|
||||
|
||||
// Convert generation config to parameters
|
||||
if request.GenerationConfig.SpeechConfig != nil || len(request.GenerationConfig.ResponseModalities) > 0 {
|
||||
bifrostReq.Params = &schemas.SpeechParameters{}
|
||||
|
||||
// Extract voice config from speech config
|
||||
if request.GenerationConfig.SpeechConfig != nil {
|
||||
// Handle single-speaker voice config
|
||||
if request.GenerationConfig.SpeechConfig.VoiceConfig != nil {
|
||||
bifrostReq.Params.VoiceConfig = &schemas.SpeechVoiceInput{}
|
||||
|
||||
if request.GenerationConfig.SpeechConfig.VoiceConfig.PrebuiltVoiceConfig != nil {
|
||||
voiceName := request.GenerationConfig.SpeechConfig.VoiceConfig.PrebuiltVoiceConfig.VoiceName
|
||||
bifrostReq.Params.VoiceConfig.Voice = &voiceName
|
||||
}
|
||||
} else if request.GenerationConfig.SpeechConfig.MultiSpeakerVoiceConfig != nil {
|
||||
// Handle multi-speaker voice config
|
||||
// Convert to Bifrost's MultiVoiceConfig format
|
||||
if len(request.GenerationConfig.SpeechConfig.MultiSpeakerVoiceConfig.SpeakerVoiceConfigs) > 0 {
|
||||
bifrostReq.Params.VoiceConfig = &schemas.SpeechVoiceInput{}
|
||||
multiVoiceConfig := make([]schemas.VoiceConfig, 0, len(request.GenerationConfig.SpeechConfig.MultiSpeakerVoiceConfig.SpeakerVoiceConfigs))
|
||||
|
||||
for _, speakerConfig := range request.GenerationConfig.SpeechConfig.MultiSpeakerVoiceConfig.SpeakerVoiceConfigs {
|
||||
if speakerConfig.VoiceConfig != nil && speakerConfig.VoiceConfig.PrebuiltVoiceConfig != nil {
|
||||
multiVoiceConfig = append(multiVoiceConfig, schemas.VoiceConfig{
|
||||
Speaker: speakerConfig.Speaker,
|
||||
Voice: speakerConfig.VoiceConfig.PrebuiltVoiceConfig.VoiceName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
bifrostReq.Params.VoiceConfig.MultiVoiceConfig = multiVoiceConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store response modalities in extra params if needed
|
||||
if len(request.GenerationConfig.ResponseModalities) > 0 {
|
||||
if bifrostReq.Params.ExtraParams == nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
}
|
||||
modalities := make([]string, len(request.GenerationConfig.ResponseModalities))
|
||||
for i, mod := range request.GenerationConfig.ResponseModalities {
|
||||
modalities[i] = string(mod)
|
||||
}
|
||||
bifrostReq.Params.ExtraParams["response_modalities"] = modalities
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
// ToGeminiSpeechRequest converts a BifrostSpeechRequest to a GeminiGenerationRequest
|
||||
func ToGeminiSpeechRequest(bifrostReq *schemas.BifrostSpeechRequest) (*GeminiGenerationRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrostReq is nil")
|
||||
}
|
||||
// Here we confirm if the response_format is wav or empty string
|
||||
// If its anything else, we will return an error
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.ResponseFormat != "" && bifrostReq.Params.ResponseFormat != "wav" {
|
||||
return nil, fmt.Errorf("gemini does not support response_format: %s. Only wav or empty string is supported which defaults to wav", bifrostReq.Params.ResponseFormat)
|
||||
}
|
||||
// Create the base Gemini generation request
|
||||
geminiReq := &GeminiGenerationRequest{
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
// Convert parameters to generation config
|
||||
geminiReq.GenerationConfig.ResponseModalities = []Modality{ModalityAudio}
|
||||
// Convert speech input to Gemini format
|
||||
if bifrostReq.Input != nil && bifrostReq.Input.Input != "" {
|
||||
geminiReq.Contents = []Content{
|
||||
{
|
||||
Parts: []*Part{
|
||||
{
|
||||
Text: bifrostReq.Input.Input,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// Add speech config to generation config if voice config is provided
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.VoiceConfig != nil {
|
||||
// Handle both single voice and multi-voice configurations
|
||||
if bifrostReq.Params.VoiceConfig.Voice != nil || len(bifrostReq.Params.VoiceConfig.MultiVoiceConfig) > 0 {
|
||||
addSpeechConfigToGenerationConfig(&geminiReq.GenerationConfig, bifrostReq.Params.VoiceConfig)
|
||||
}
|
||||
geminiReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
}
|
||||
return geminiReq, nil
|
||||
}
|
||||
|
||||
// ToBifrostSpeechResponse converts a GenerateContentResponse to a BifrostSpeechResponse
|
||||
func (response *GenerateContentResponse) ToBifrostSpeechResponse(ctx context.Context) (*schemas.BifrostSpeechResponse, error) {
|
||||
bifrostResp := &schemas.BifrostSpeechResponse{}
|
||||
|
||||
// Process candidates to extract audio content
|
||||
if len(response.Candidates) > 0 {
|
||||
candidate := response.Candidates[0]
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
var audioData []byte
|
||||
// Extract audio data from all parts
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && len(part.InlineData.Data) > 0 {
|
||||
// Check if this is audio data
|
||||
if strings.HasPrefix(part.InlineData.MIMEType, "audio/") {
|
||||
decodedData, err := decodeBase64StringToBytes(part.InlineData.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode base64 audio data: %v", err)
|
||||
}
|
||||
audioData = append(audioData, decodedData...)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(audioData) > 0 {
|
||||
responseFormat := ctx.Value(BifrostContextKeyResponseFormat).(string)
|
||||
// Gemini returns PCM audio (s16le, 24000 Hz, mono)
|
||||
// Convert to WAV for standard playable output format
|
||||
if responseFormat == "wav" {
|
||||
wavData, err := utils.ConvertPCMToWAV(audioData, utils.DefaultGeminiPCMConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert PCM to WAV: %v", err)
|
||||
}
|
||||
bifrostResp.Audio = wavData
|
||||
} else {
|
||||
bifrostResp.Audio = audioData
|
||||
}
|
||||
}
|
||||
|
||||
// Set usage information
|
||||
if response.UsageMetadata != nil {
|
||||
bifrostResp.Usage = convertGeminiUsageMetadataToSpeechUsage(response.UsageMetadata)
|
||||
}
|
||||
}
|
||||
}
|
||||
return bifrostResp, nil
|
||||
}
|
||||
|
||||
// ToGeminiSpeechResponse converts a BifrostSpeechResponse to Gemini's GenerateContentResponse
|
||||
func ToGeminiSpeechResponse(bifrostResp *schemas.BifrostSpeechResponse) *GenerateContentResponse {
|
||||
if bifrostResp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
genaiResp := &GenerateContentResponse{}
|
||||
|
||||
candidate := &Candidate{
|
||||
Content: &Content{
|
||||
Parts: []*Part{
|
||||
{
|
||||
InlineData: &Blob{
|
||||
Data: encodeBytesToBase64String(bifrostResp.Audio),
|
||||
MIMEType: utils.DetectAudioMimeType(bifrostResp.Audio),
|
||||
},
|
||||
},
|
||||
},
|
||||
Role: string(RoleModel),
|
||||
},
|
||||
}
|
||||
|
||||
// Set usage metadata if present
|
||||
if bifrostResp.Usage != nil {
|
||||
genaiResp.UsageMetadata = convertBifrostSpeechUsageToGeminiUsageMetadata(bifrostResp.Usage)
|
||||
}
|
||||
|
||||
genaiResp.Candidates = []*Candidate{candidate}
|
||||
return genaiResp
|
||||
}
|
||||
233
core/providers/gemini/transcription.go
Normal file
233
core/providers/gemini/transcription.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToBifrostTranscriptionRequest converts a GeminiGenerationRequest to a BifrostTranscriptionRequest
|
||||
func (request *GeminiGenerationRequest) ToBifrostTranscriptionRequest(ctx *schemas.BifrostContext) (*schemas.BifrostTranscriptionRequest, error) {
|
||||
provider, model := schemas.ParseModelString(request.Model, utils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
|
||||
|
||||
bifrostReq := &schemas.BifrostTranscriptionRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}
|
||||
|
||||
// Extract audio data and prompt from contents
|
||||
var promptText string
|
||||
var audioData []byte
|
||||
var audioMimeType string
|
||||
|
||||
for _, content := range request.Contents {
|
||||
for _, part := range content.Parts {
|
||||
// Extract text prompt
|
||||
if part.Text != "" {
|
||||
if promptText != "" {
|
||||
promptText += " "
|
||||
}
|
||||
promptText += part.Text
|
||||
}
|
||||
|
||||
// Extract audio data from inline data
|
||||
if part.InlineData != nil && strings.HasPrefix(strings.ToLower(part.InlineData.MIMEType), "audio/") {
|
||||
decodedData, err := decodeBase64StringToBytes(part.InlineData.Data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode base64 audio data: %v", err)
|
||||
}
|
||||
audioData = append(audioData, decodedData...)
|
||||
if audioMimeType == "" {
|
||||
audioMimeType = part.InlineData.MIMEType
|
||||
}
|
||||
}
|
||||
|
||||
// Extract audio data from file data (would need to be fetched separately in real scenario)
|
||||
// For now, we just note the file URI in extra params
|
||||
if part.FileData != nil && strings.HasPrefix(strings.ToLower(part.FileData.MIMEType), "audio/") {
|
||||
if bifrostReq.Params == nil {
|
||||
bifrostReq.Params = &schemas.TranscriptionParameters{}
|
||||
}
|
||||
if bifrostReq.Params.ExtraParams == nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
}
|
||||
bifrostReq.Params.ExtraParams["file_uri"] = part.FileData.FileURI
|
||||
if audioMimeType == "" {
|
||||
audioMimeType = part.FileData.MIMEType
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set the audio input
|
||||
bifrostReq.Input = &schemas.TranscriptionInput{
|
||||
File: audioData,
|
||||
}
|
||||
|
||||
// Set parameters
|
||||
if bifrostReq.Params == nil {
|
||||
bifrostReq.Params = &schemas.TranscriptionParameters{}
|
||||
}
|
||||
|
||||
// Set prompt if provided
|
||||
if promptText != "" {
|
||||
bifrostReq.Params.Prompt = &promptText
|
||||
}
|
||||
|
||||
// Handle safety settings from request
|
||||
if len(request.SafetySettings) > 0 {
|
||||
if bifrostReq.Params.ExtraParams == nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
}
|
||||
bifrostReq.Params.ExtraParams["safety_settings"] = request.SafetySettings
|
||||
}
|
||||
|
||||
// Handle cached content
|
||||
if request.CachedContent != "" {
|
||||
if bifrostReq.Params.ExtraParams == nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
}
|
||||
bifrostReq.Params.ExtraParams["cached_content"] = request.CachedContent
|
||||
}
|
||||
|
||||
// Handle labels
|
||||
if len(request.Labels) > 0 {
|
||||
if bifrostReq.Params.ExtraParams == nil {
|
||||
bifrostReq.Params.ExtraParams = make(map[string]interface{})
|
||||
}
|
||||
bifrostReq.Params.ExtraParams["labels"] = request.Labels
|
||||
}
|
||||
|
||||
return bifrostReq, nil
|
||||
}
|
||||
|
||||
func ToGeminiTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionRequest) *GeminiGenerationRequest {
|
||||
if bifrostReq == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create the base Gemini generation request
|
||||
geminiReq := &GeminiGenerationRequest{
|
||||
Model: bifrostReq.Model,
|
||||
}
|
||||
|
||||
// Convert parameters to generation config
|
||||
if bifrostReq.Params != nil {
|
||||
geminiReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
// Handle extra parameters
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
// Safety settings
|
||||
if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok {
|
||||
delete(geminiReq.ExtraParams, "safety_settings")
|
||||
if settings, ok := SafeExtractSafetySettings(safetySettings); ok {
|
||||
geminiReq.SafetySettings = settings
|
||||
}
|
||||
}
|
||||
|
||||
// Cached content
|
||||
if cachedContent, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["cached_content"]); ok {
|
||||
delete(geminiReq.ExtraParams, "cached_content")
|
||||
geminiReq.CachedContent = cachedContent
|
||||
}
|
||||
|
||||
// Labels
|
||||
if labels, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "labels"); ok {
|
||||
if labelMap, ok := schemas.SafeExtractStringMap(labels); ok {
|
||||
delete(geminiReq.ExtraParams, "labels")
|
||||
geminiReq.Labels = labelMap
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the prompt text
|
||||
var prompt string
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.Prompt != nil {
|
||||
prompt = *bifrostReq.Params.Prompt
|
||||
} else {
|
||||
prompt = "Generate a transcript of the speech."
|
||||
}
|
||||
|
||||
// Create parts for the transcription request
|
||||
parts := []*Part{
|
||||
{
|
||||
Text: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
// Add audio file if present
|
||||
if len(bifrostReq.Input.File) > 0 {
|
||||
parts = append(parts, &Part{
|
||||
InlineData: &Blob{
|
||||
MIMEType: utils.DetectAudioMimeType(bifrostReq.Input.File),
|
||||
Data: encodeBytesToBase64String(bifrostReq.Input.File),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
geminiReq.Contents = []Content{
|
||||
{
|
||||
Parts: parts,
|
||||
},
|
||||
}
|
||||
|
||||
return geminiReq
|
||||
}
|
||||
|
||||
// ToBifrostTranscriptionResponse converts a GenerateContentResponse to a BifrostTranscriptionResponse
|
||||
func (response *GenerateContentResponse) ToBifrostTranscriptionResponse() *schemas.BifrostTranscriptionResponse {
|
||||
bifrostResp := &schemas.BifrostTranscriptionResponse{}
|
||||
|
||||
// Process candidates to extract text content
|
||||
if len(response.Candidates) > 0 {
|
||||
candidate := response.Candidates[0]
|
||||
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
|
||||
var textContent string
|
||||
|
||||
// Extract text content from all parts
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
textContent += part.Text
|
||||
}
|
||||
}
|
||||
|
||||
if textContent != "" {
|
||||
bifrostResp.Text = textContent
|
||||
bifrostResp.Task = schemas.Ptr("transcribe")
|
||||
|
||||
// Set usage information with modality details
|
||||
bifrostResp.Usage = convertGeminiUsageMetadataToTranscriptionUsage(response.UsageMetadata)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostResp
|
||||
}
|
||||
|
||||
// ToGeminiTranscriptionResponse converts a BifrostTranscriptionResponse to Gemini's GenerateContentResponse
|
||||
func ToGeminiTranscriptionResponse(bifrostResp *schemas.BifrostTranscriptionResponse) *GenerateContentResponse {
|
||||
if bifrostResp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
genaiResp := &GenerateContentResponse{}
|
||||
|
||||
candidate := &Candidate{
|
||||
Content: &Content{
|
||||
Parts: []*Part{
|
||||
{
|
||||
Text: bifrostResp.Text,
|
||||
},
|
||||
},
|
||||
Role: string(RoleModel),
|
||||
},
|
||||
}
|
||||
|
||||
// Set usage metadata from transcription usage with modality details
|
||||
genaiResp.UsageMetadata = convertBifrostTranscriptionUsageToGeminiUsageMetadata(bifrostResp.Usage)
|
||||
|
||||
genaiResp.Candidates = []*Candidate{candidate}
|
||||
return genaiResp
|
||||
}
|
||||
2745
core/providers/gemini/types.go
Normal file
2745
core/providers/gemini/types.go
Normal file
File diff suppressed because it is too large
Load Diff
2523
core/providers/gemini/utils.go
Normal file
2523
core/providers/gemini/utils.go
Normal file
File diff suppressed because it is too large
Load Diff
595
core/providers/gemini/videos.go
Normal file
595
core/providers/gemini/videos.go
Normal file
@@ -0,0 +1,595 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
const defaultVideoContentType = "video/mp4"
|
||||
|
||||
// sizeToAspectRatio converts OpenAI-style size strings to Gemini aspect ratios.
|
||||
// Gemini supports 16:9 and 9:16. Returns default value if no mapping exists.
|
||||
func sizeToAspectRatio(size string) string {
|
||||
switch size {
|
||||
case "1280x720", "1792x1024":
|
||||
return "16:9"
|
||||
case "720x1280", "1024x1792":
|
||||
return "9:16"
|
||||
default:
|
||||
return "16:9"
|
||||
}
|
||||
}
|
||||
|
||||
func addVideoURLOutput(uri, contentType string) *schemas.VideoOutput {
|
||||
if uri == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(contentType) == "" {
|
||||
contentType = defaultVideoContentType
|
||||
}
|
||||
return &schemas.VideoOutput{
|
||||
Type: schemas.VideoOutputTypeURL,
|
||||
URL: schemas.Ptr(uri),
|
||||
ContentType: contentType,
|
||||
}
|
||||
}
|
||||
|
||||
func addVideoBase64Output(base64Value, contentType string) *schemas.VideoOutput {
|
||||
if base64Value == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(contentType) == "" {
|
||||
contentType = defaultVideoContentType
|
||||
}
|
||||
return &schemas.VideoOutput{
|
||||
Type: schemas.VideoOutputTypeBase64,
|
||||
Base64Data: schemas.Ptr(base64Value),
|
||||
ContentType: contentType,
|
||||
}
|
||||
}
|
||||
|
||||
func parseVideoDataURL(data string) (mimeType string, base64Payload string, ok bool) {
|
||||
if !strings.HasPrefix(data, "data:") {
|
||||
return "", "", false
|
||||
}
|
||||
parts := strings.SplitN(data, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", false
|
||||
}
|
||||
header := parts[0]
|
||||
payload := parts[1]
|
||||
if payload == "" {
|
||||
return "", "", false
|
||||
}
|
||||
header = strings.TrimPrefix(header, "data:")
|
||||
if before, _, found := strings.Cut(header, ";"); found {
|
||||
return before, payload, true
|
||||
}
|
||||
return header, payload, true
|
||||
}
|
||||
|
||||
// ToGeminiVideoGenerationRequest converts a Bifrost video generation request to Gemini REST API format
|
||||
// This creates the request body for POST /models/{model}:predictLongRunning
|
||||
func ToGeminiVideoGenerationRequest(bifrostReq *schemas.BifrostVideoGenerationRequest) (*GeminiVideoGenerationRequest, error) {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil {
|
||||
return nil, fmt.Errorf("bifrost request or input is nil")
|
||||
}
|
||||
|
||||
// Create the instance with prompt
|
||||
instance := &GeminiVideoGenerationInstance{
|
||||
Prompt: bifrostReq.Input.Prompt,
|
||||
}
|
||||
|
||||
// Handle input reference (image for image-to-video)
|
||||
if bifrostReq.Input.InputReference != nil && *bifrostReq.Input.InputReference != "" {
|
||||
// extract mime type and base64 string from input reference
|
||||
sanitizedURL, err := schemas.SanitizeImageURL(*bifrostReq.Input.InputReference)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid input reference: %w", err)
|
||||
}
|
||||
urlInfo := schemas.ExtractURLTypeInfo(sanitizedURL)
|
||||
|
||||
image := &VideoImageData{}
|
||||
|
||||
if urlInfo.DataURLWithoutPrefix != nil {
|
||||
image.BytesBase64Encoded = urlInfo.DataURLWithoutPrefix
|
||||
}
|
||||
image.MimeType = schemas.Ptr("image/png")
|
||||
if urlInfo.MediaType != nil {
|
||||
image.MimeType = urlInfo.MediaType
|
||||
}
|
||||
|
||||
instance.Image = image
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.VideoURI != nil {
|
||||
instance.Video = &VideoGenerationVideoInput{
|
||||
URI: bifrostReq.Params.VideoURI,
|
||||
}
|
||||
}
|
||||
|
||||
req := &GeminiVideoGenerationRequest{
|
||||
Instances: []GeminiVideoGenerationInstance{*instance},
|
||||
}
|
||||
|
||||
// Map parameters if provided
|
||||
if bifrostReq.Params != nil {
|
||||
params := &VideoGenerationParameters{}
|
||||
|
||||
// Extract all video generation parameters from ExtraParams
|
||||
if bifrostReq.Params.NegativePrompt != nil {
|
||||
params.NegativePrompt = bifrostReq.Params.NegativePrompt
|
||||
}
|
||||
|
||||
if bifrostReq.Params.Seconds != nil {
|
||||
seconds, err := strconv.Atoi(*bifrostReq.Params.Seconds)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid seconds value: %w", err)
|
||||
}
|
||||
params.DurationSeconds = &seconds
|
||||
}
|
||||
|
||||
if bifrostReq.Params.Seed != nil {
|
||||
params.Seed = bifrostReq.Params.Seed
|
||||
}
|
||||
|
||||
if bifrostReq.Params.Audio != nil {
|
||||
params.GenerateAudio = bifrostReq.Params.Audio
|
||||
}
|
||||
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
req.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
if aspectRatio, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["aspectRatio"]); ok {
|
||||
params.AspectRatio = aspectRatio
|
||||
}
|
||||
if resolution, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["resolution"]); ok {
|
||||
params.Resolution = resolution
|
||||
}
|
||||
|
||||
if sampleCount, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["sampleCount"]); ok {
|
||||
params.SampleCount = sampleCount
|
||||
}
|
||||
|
||||
if personGeneration, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["personGeneration"]); ok {
|
||||
params.PersonGeneration = personGeneration
|
||||
}
|
||||
|
||||
if numberOfVideos, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["numberOfVideos"]); ok {
|
||||
params.NumberOfVideos = numberOfVideos
|
||||
}
|
||||
if storageURI, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["storageURI"]); ok {
|
||||
params.StorageURI = storageURI
|
||||
}
|
||||
if compressionQuality, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["compressionQuality"]); ok {
|
||||
params.CompressionQuality = compressionQuality
|
||||
}
|
||||
if enhancePrompt, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["enhancePrompt"]); ok {
|
||||
params.EnhancePrompt = enhancePrompt
|
||||
}
|
||||
if resizeMode, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["resizeMode"]); ok {
|
||||
params.ResizeMode = resizeMode
|
||||
}
|
||||
if referenceImages, ok := bifrostReq.Params.ExtraParams["referenceImages"]; ok {
|
||||
if referenceImages, ok := referenceImages.([]VideoReferenceImage); ok && referenceImages != nil {
|
||||
params.ReferenceImages = referenceImages
|
||||
} else if data, err := providerUtils.MarshalSorted(referenceImages); err == nil {
|
||||
var referenceImages []VideoReferenceImage
|
||||
if sonic.Unmarshal(data, &referenceImages) == nil {
|
||||
params.ReferenceImages = referenceImages
|
||||
}
|
||||
}
|
||||
}
|
||||
if lastFrame, ok := bifrostReq.Params.ExtraParams["lastFrame"]; ok {
|
||||
if lastFrame, ok := lastFrame.(*VideoImageData); ok {
|
||||
params.LastFrame = lastFrame
|
||||
} else if data, err := providerUtils.MarshalSorted(lastFrame); err == nil {
|
||||
var lastFrame VideoImageData
|
||||
if sonic.Unmarshal(data, &lastFrame) == nil {
|
||||
params.LastFrame = &lastFrame
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert size to aspect ratio if size is provided and aspect ratio is not already set
|
||||
if params.AspectRatio == nil && bifrostReq.Params.Size != "" {
|
||||
aspectRatio := sizeToAspectRatio(bifrostReq.Params.Size)
|
||||
if aspectRatio != "" {
|
||||
params.AspectRatio = &aspectRatio
|
||||
}
|
||||
}
|
||||
|
||||
req.Parameters = params
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// ToBifrostVideoGenerationResponse converts Gemini operation response to Bifrost format
|
||||
func ToBifrostVideoGenerationResponse(operation *GenerateVideosOperation, model string) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
||||
if operation == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("operation is nil", nil)
|
||||
}
|
||||
|
||||
response := &schemas.BifrostVideoGenerationResponse{
|
||||
ID: operation.Name,
|
||||
Object: "video",
|
||||
CreatedAt: time.Now().Unix(),
|
||||
}
|
||||
if model != "" {
|
||||
response.Model = model
|
||||
}
|
||||
|
||||
// Set status based on operation state
|
||||
if !operation.Done {
|
||||
response.Status = schemas.VideoStatusInProgress
|
||||
if operation.Metadata != nil {
|
||||
if p := providerUtils.GetJSONField([]byte(operation.Metadata), "progress"); p.Exists() {
|
||||
progress := p.Float()
|
||||
response.Progress = &progress
|
||||
}
|
||||
}
|
||||
} else if operation.Error != nil {
|
||||
response.Status = schemas.VideoStatusFailed
|
||||
code := providerUtils.GetJSONField(operation.Error, "code").String()
|
||||
message := providerUtils.GetJSONField(operation.Error, "message").String()
|
||||
if code == "" {
|
||||
code = "video_generation_failed"
|
||||
}
|
||||
if message == "" {
|
||||
message = string(operation.Error)
|
||||
}
|
||||
response.Error = &schemas.VideoCreateError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
} else if operation.Response != nil {
|
||||
// Check new response format with content filtering support
|
||||
if genVideoResp := operation.Response.GenerateVideoResponse; genVideoResp != nil {
|
||||
// Check for content filtering
|
||||
if genVideoResp.RAIMediaFilteredCount > 0 {
|
||||
response.Status = schemas.VideoStatusFailed
|
||||
response.ContentFilter = &schemas.ContentFilterInfo{
|
||||
FilteredCount: int(genVideoResp.RAIMediaFilteredCount),
|
||||
Reasons: genVideoResp.RAIMediaFilteredReasons,
|
||||
}
|
||||
errorMsg := "Content filtered by safety policies"
|
||||
if len(genVideoResp.RAIMediaFilteredReasons) > 0 {
|
||||
errorMsg = genVideoResp.RAIMediaFilteredReasons[0]
|
||||
}
|
||||
response.Error = &schemas.VideoCreateError{
|
||||
Code: "content_filtered",
|
||||
Message: errorMsg,
|
||||
}
|
||||
} else {
|
||||
response.Status = schemas.VideoStatusCompleted
|
||||
|
||||
// Collect all generated videos from multiple possible locations.
|
||||
var videos []schemas.VideoOutput
|
||||
|
||||
// Priority 1: GeneratedSamples
|
||||
if len(genVideoResp.GeneratedSamples) > 0 {
|
||||
for _, sample := range genVideoResp.GeneratedSamples {
|
||||
if sample == nil || sample.Video == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if sample.Video.URI != "" {
|
||||
videoOutput := addVideoURLOutput(sample.Video.URI, sample.Video.MIMEType)
|
||||
if videoOutput != nil {
|
||||
videos = append(videos, *videoOutput)
|
||||
}
|
||||
}
|
||||
if len(sample.Video.VideoBytes) > 0 {
|
||||
videoOutput := addVideoBase64Output(
|
||||
base64.StdEncoding.EncodeToString(sample.Video.VideoBytes),
|
||||
sample.Video.MIMEType,
|
||||
)
|
||||
if videoOutput != nil {
|
||||
videos = append(videos, *videoOutput)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(videos) > 0 {
|
||||
response.Videos = videos
|
||||
}
|
||||
}
|
||||
} else if len(operation.Response.GeneratedVideos) > 0 {
|
||||
// Backward compatibility for older response shapes
|
||||
response.Status = schemas.VideoStatusCompleted
|
||||
var videos []schemas.VideoOutput
|
||||
for _, genVideo := range operation.Response.GeneratedVideos {
|
||||
if genVideo == nil || genVideo.Video == nil {
|
||||
continue
|
||||
}
|
||||
if genVideo.Video.URI != "" {
|
||||
videoOutput := addVideoURLOutput(genVideo.Video.URI, genVideo.Video.MIMEType)
|
||||
if videoOutput != nil {
|
||||
videos = append(videos, *videoOutput)
|
||||
}
|
||||
}
|
||||
if len(genVideo.Video.VideoBytes) > 0 {
|
||||
videoOutput := addVideoBase64Output(
|
||||
base64.StdEncoding.EncodeToString(genVideo.Video.VideoBytes),
|
||||
genVideo.Video.MIMEType,
|
||||
)
|
||||
if videoOutput != nil {
|
||||
videos = append(videos, *videoOutput)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(videos) > 0 {
|
||||
response.Videos = videos
|
||||
}
|
||||
} else if len(operation.Response.Videos) > 0 {
|
||||
response.Status = schemas.VideoStatusCompleted
|
||||
var videos []schemas.VideoOutput
|
||||
for _, video := range operation.Response.Videos {
|
||||
if video.GCSURI != nil && *video.GCSURI != "" {
|
||||
mimeType := defaultVideoContentType
|
||||
if video.MIMEType != nil && *video.MIMEType != "" {
|
||||
mimeType = *video.MIMEType
|
||||
}
|
||||
videoOutput := addVideoURLOutput(*video.GCSURI, mimeType)
|
||||
if videoOutput != nil {
|
||||
videos = append(videos, *videoOutput)
|
||||
}
|
||||
} else if video.BytesBase64Encoded != nil && *video.BytesBase64Encoded != "" {
|
||||
mimeType := defaultVideoContentType
|
||||
if video.MIMEType != nil && *video.MIMEType != "" {
|
||||
mimeType = *video.MIMEType
|
||||
}
|
||||
videoOutput := addVideoBase64Output(*video.BytesBase64Encoded, mimeType)
|
||||
if videoOutput != nil {
|
||||
videos = append(videos, *videoOutput)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(videos) > 0 {
|
||||
response.Videos = videos
|
||||
}
|
||||
} else {
|
||||
response.Status = schemas.VideoStatusCompleted
|
||||
}
|
||||
} else {
|
||||
response.Status = schemas.VideoStatusCompleted
|
||||
}
|
||||
|
||||
// Try to extract timestamps from metadata
|
||||
if operation.Metadata != nil {
|
||||
if ct := providerUtils.GetJSONField([]byte(operation.Metadata), "createTime"); ct.Exists() {
|
||||
if t, err := time.Parse(time.RFC3339, ct.String()); err == nil {
|
||||
response.CreatedAt = t.Unix()
|
||||
}
|
||||
}
|
||||
if ut := providerUtils.GetJSONField([]byte(operation.Metadata), "updateTime"); ut.Exists() {
|
||||
if t, err := time.Parse(time.RFC3339, ut.String()); err == nil && operation.Done {
|
||||
response.CompletedAt = schemas.Ptr(t.Unix())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (request *GeminiVideoGenerationRequest) ToBifrostVideoGenerationRequest(ctx *schemas.BifrostContext) (*schemas.BifrostVideoGenerationRequest, error) {
|
||||
if request == nil || len(request.Instances) == 0 {
|
||||
return nil, fmt.Errorf("request is nil or has no instances")
|
||||
}
|
||||
|
||||
// Use the first instance for the main input
|
||||
instance := request.Instances[0]
|
||||
|
||||
provider, model := schemas.ParseModelString(request.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Gemini))
|
||||
|
||||
bifrostReq := &schemas.BifrostVideoGenerationRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Input: &schemas.VideoGenerationInput{
|
||||
Prompt: instance.Prompt,
|
||||
},
|
||||
}
|
||||
|
||||
// Handle image input for image-to-video
|
||||
if instance.Image != nil && instance.Image.BytesBase64Encoded != nil && *instance.Image.BytesBase64Encoded != "" {
|
||||
// attach mime type and base64 string to input reference
|
||||
mimeType := "image/png"
|
||||
if instance.Image.MimeType != nil && *instance.Image.MimeType != "" {
|
||||
mimeType = *instance.Image.MimeType
|
||||
}
|
||||
bifrostReq.Input.InputReference = schemas.Ptr(fmt.Sprintf("data:%s;base64,%s", mimeType, *instance.Image.BytesBase64Encoded))
|
||||
}
|
||||
|
||||
// Helper to ensure params are initialized
|
||||
ensureParams := func() {
|
||||
if bifrostReq.Params == nil {
|
||||
bifrostReq.Params = &schemas.VideoGenerationParameters{
|
||||
ExtraParams: make(map[string]any),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle reference images
|
||||
if len(instance.ReferenceImages) > 0 {
|
||||
ensureParams()
|
||||
bifrostReq.Params.ExtraParams["referenceImages"] = instance.ReferenceImages
|
||||
}
|
||||
|
||||
// Handle video URI
|
||||
if instance.Video != nil && instance.Video.URI != nil {
|
||||
ensureParams()
|
||||
bifrostReq.Params.VideoURI = instance.Video.URI
|
||||
}
|
||||
|
||||
// Handle last frame
|
||||
if instance.LastFrame != nil {
|
||||
ensureParams()
|
||||
bifrostReq.Params.ExtraParams["lastFrame"] = instance.LastFrame
|
||||
}
|
||||
|
||||
// Map parameters if provided
|
||||
if request.Parameters != nil {
|
||||
ensureParams()
|
||||
params := bifrostReq.Params
|
||||
|
||||
if request.Parameters.NegativePrompt != nil {
|
||||
params.NegativePrompt = request.Parameters.NegativePrompt
|
||||
}
|
||||
if request.Parameters.DurationSeconds != nil {
|
||||
seconds := strconv.Itoa(*request.Parameters.DurationSeconds)
|
||||
params.Seconds = &seconds
|
||||
}
|
||||
if request.Parameters.Seed != nil {
|
||||
params.Seed = request.Parameters.Seed
|
||||
}
|
||||
if request.Parameters.GenerateAudio != nil {
|
||||
params.Audio = request.Parameters.GenerateAudio
|
||||
}
|
||||
if request.Parameters.AspectRatio != nil {
|
||||
params.ExtraParams["aspectRatio"] = *request.Parameters.AspectRatio
|
||||
}
|
||||
if request.Parameters.Resolution != nil {
|
||||
params.ExtraParams["resolution"] = *request.Parameters.Resolution
|
||||
}
|
||||
if request.Parameters.SampleCount != nil {
|
||||
params.ExtraParams["sampleCount"] = *request.Parameters.SampleCount
|
||||
}
|
||||
if request.Parameters.PersonGeneration != nil {
|
||||
params.ExtraParams["personGeneration"] = *request.Parameters.PersonGeneration
|
||||
}
|
||||
if request.Parameters.NumberOfVideos != nil {
|
||||
params.ExtraParams["numberOfVideos"] = *request.Parameters.NumberOfVideos
|
||||
}
|
||||
if request.Parameters.StorageURI != nil {
|
||||
params.ExtraParams["storageURI"] = *request.Parameters.StorageURI
|
||||
}
|
||||
if request.Parameters.CompressionQuality != nil {
|
||||
params.ExtraParams["compressionQuality"] = *request.Parameters.CompressionQuality
|
||||
}
|
||||
if request.Parameters.EnhancePrompt != nil {
|
||||
params.ExtraParams["enhancePrompt"] = *request.Parameters.EnhancePrompt
|
||||
}
|
||||
if request.Parameters.ResizeMode != nil {
|
||||
params.ExtraParams["resizeMode"] = *request.Parameters.ResizeMode
|
||||
}
|
||||
}
|
||||
|
||||
return bifrostReq, nil
|
||||
}
|
||||
|
||||
func ToGeminiVideoGenerationResponse(response *schemas.BifrostVideoGenerationResponse) *GenerateVideosOperation {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
decodedID := response.ID
|
||||
if decoded, err := url.PathUnescape(decodedID); err == nil {
|
||||
decodedID = decoded
|
||||
}
|
||||
|
||||
// if id is in gemini or vertex format, set name in format models/model/operations/operation_id:provider
|
||||
// else make the id in gemini format
|
||||
if !(strings.HasPrefix(decodedID, "models/") && strings.Contains(decodedID, response.Model) && strings.Contains(decodedID, "operations/")) {
|
||||
// url encode model
|
||||
encodedModel := url.PathEscape(response.Model)
|
||||
decodedID = "models/" + encodedModel + "/operations/" + decodedID
|
||||
}
|
||||
operation := &GenerateVideosOperation{
|
||||
Name: decodedID,
|
||||
}
|
||||
|
||||
switch response.Status {
|
||||
case schemas.VideoStatusCompleted:
|
||||
operation.Done = true
|
||||
if len(response.Videos) > 0 {
|
||||
generatedSamples := make([]*GeneratedVideo, 0, len(response.Videos))
|
||||
for _, output := range response.Videos {
|
||||
var video *Video
|
||||
|
||||
switch output.Type {
|
||||
case schemas.VideoOutputTypeURL:
|
||||
if output.URL == nil || *output.URL == "" {
|
||||
continue
|
||||
}
|
||||
video = &Video{
|
||||
URI: *output.URL,
|
||||
}
|
||||
if output.ContentType != "" {
|
||||
video.MIMEType = output.ContentType
|
||||
}
|
||||
case schemas.VideoOutputTypeBase64:
|
||||
if output.Base64Data == nil || *output.Base64Data == "" {
|
||||
continue
|
||||
}
|
||||
base64Payload := *output.Base64Data
|
||||
mimeType := output.ContentType
|
||||
if parsedMimeType, payload, ok := parseVideoDataURL(*output.Base64Data); ok {
|
||||
base64Payload = payload
|
||||
if mimeType == "" {
|
||||
mimeType = parsedMimeType
|
||||
}
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(base64Payload)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if mimeType == "" {
|
||||
mimeType = defaultVideoContentType
|
||||
}
|
||||
video = &Video{
|
||||
VideoBytes: decoded,
|
||||
MIMEType: mimeType,
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if video == nil {
|
||||
continue
|
||||
}
|
||||
generatedSamples = append(generatedSamples, &GeneratedVideo{
|
||||
Video: video,
|
||||
})
|
||||
}
|
||||
if len(generatedSamples) > 0 {
|
||||
operation.Response = &GenerateVideosOperationResponse{
|
||||
GenerateVideoResponse: &GenerateVideoResponse{
|
||||
GeneratedSamples: generatedSamples,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
case schemas.VideoStatusFailed:
|
||||
operation.Done = true
|
||||
// Check if this is a content filtering case
|
||||
if response.ContentFilter != nil && response.ContentFilter.FilteredCount > 0 {
|
||||
operation.Response = &GenerateVideosOperationResponse{
|
||||
GenerateVideoResponse: &GenerateVideoResponse{
|
||||
RAIMediaFilteredCount: int32(response.ContentFilter.FilteredCount),
|
||||
RAIMediaFilteredReasons: response.ContentFilter.Reasons,
|
||||
},
|
||||
}
|
||||
} else if response.Error != nil {
|
||||
errBytes, _ := providerUtils.MarshalSorted(map[string]any{
|
||||
"message": response.Error.Message,
|
||||
"code": response.Error.Code,
|
||||
})
|
||||
operation.Error = json.RawMessage(errBytes)
|
||||
}
|
||||
default:
|
||||
operation.Done = false
|
||||
}
|
||||
|
||||
return operation
|
||||
}
|
||||
Reference in New Issue
Block a user