Files
bifrost/core/providers/gemini/batch.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

371 lines
11 KiB
Go

package gemini
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
// ToBifrostBatchStatus converts Gemini batch job state to Bifrost status.
func ToBifrostBatchStatus(geminiState string) schemas.BatchStatus {
switch geminiState {
case GeminiBatchStatePending, GeminiBatchStateRunning:
return schemas.BatchStatusInProgress
case GeminiBatchStateSucceeded:
return schemas.BatchStatusCompleted
case GeminiBatchStateFailed:
return schemas.BatchStatusFailed
case GeminiBatchStateCancelling:
return schemas.BatchStatusCancelling
case GeminiBatchStateCancelled:
return schemas.BatchStatusCancelled
case GeminiBatchStateExpired:
return schemas.BatchStatusExpired
default:
return schemas.BatchStatus(geminiState)
}
}
// ToGeminiBatchStatus converts Bifrost batch status to Gemini batch job state.
func ToGeminiBatchStatus(status schemas.BatchStatus) string {
switch status {
case schemas.BatchStatusValidating, schemas.BatchStatusInProgress:
return GeminiBatchStateRunning
case schemas.BatchStatusFinalizing:
return GeminiBatchStateRunning
case schemas.BatchStatusCompleted, schemas.BatchStatusEnded:
return GeminiBatchStateSucceeded
case schemas.BatchStatusFailed:
return GeminiBatchStateFailed
case schemas.BatchStatusCancelling:
return GeminiBatchStateCancelling
case schemas.BatchStatusCancelled:
return GeminiBatchStateCancelled
case schemas.BatchStatusExpired:
return GeminiBatchStateExpired
default:
return GeminiBatchStateUnspecified
}
}
// ToGeminiBatchJobResponse converts Bifrost batch create response to Gemini batch job response format.
func ToGeminiBatchJobResponse(resp *schemas.BifrostBatchCreateResponse) *GeminiBatchJobResponse {
if resp == nil {
return nil
}
succeededCount := resp.RequestCounts.Succeeded
if succeededCount == 0 {
succeededCount = resp.RequestCounts.Completed
}
geminiResp := &GeminiBatchJobResponse{
Name: resp.ID,
Metadata: &GeminiBatchMetadata{
Name: resp.ID,
Type: "type.googleapis.com/google.ai.generativelanguage.v1beta.BatchPredictionJob",
CreateTime: formatGeminiTimestamp(resp.CreatedAt),
UpdateTime: formatGeminiTimestamp(resp.CreatedAt),
State: ToGeminiBatchStatus(resp.Status),
BatchStats: &GeminiBatchStats{
RequestCount: resp.RequestCounts.Total,
PendingRequestCount: max(0, resp.RequestCounts.Total-succeededCount-resp.RequestCounts.Failed),
SuccessfulRequestCount: succeededCount,
},
},
}
if resp.OperationName != nil && *resp.OperationName != "" {
geminiResp.Metadata.Name = *resp.OperationName
geminiResp.Name = *resp.OperationName
}
if resp.InputFileID != "" {
geminiResp.Metadata.InputConfig = &GeminiBatchMetadataInputConfig{
FileName: resp.InputFileID,
}
}
if resp.OutputFileID != nil && *resp.OutputFileID != "" {
geminiResp.Dest = &GeminiBatchDest{
FileName: *resp.OutputFileID,
}
geminiResp.Metadata.Output = &GeminiBatchMetadataOutputConfig{
ResponsesFile: *resp.OutputFileID,
}
}
if resp.Status == schemas.BatchStatusCompleted ||
resp.Status == schemas.BatchStatusEnded ||
resp.Status == schemas.BatchStatusFailed ||
resp.Status == schemas.BatchStatusExpired ||
resp.Status == schemas.BatchStatusCancelled {
geminiResp.Done = true
}
return geminiResp
}
// ToGeminiBatchRetrieveResponse converts a Bifrost batch retrieve response to Gemini batch job response format.
func ToGeminiBatchRetrieveResponse(resp *schemas.BifrostBatchRetrieveResponse) *GeminiBatchJobResponse {
if resp == nil {
return nil
}
succeededCount := resp.RequestCounts.Succeeded
if succeededCount == 0 {
succeededCount = resp.RequestCounts.Completed
}
pendingCount := resp.RequestCounts.Pending
if pendingCount == 0 && resp.RequestCounts.Total > 0 {
processedCount := resp.RequestCounts.Completed
if processedCount == 0 {
processedCount = succeededCount
}
pendingCount = resp.RequestCounts.Total - processedCount - resp.RequestCounts.Failed
if pendingCount < 0 {
pendingCount = 0
}
}
geminiResp := &GeminiBatchJobResponse{
Name: resp.ID,
Metadata: &GeminiBatchMetadata{
Name: resp.ID,
Type: "type.googleapis.com/google.ai.generativelanguage.v1beta.BatchPredictionJob",
CreateTime: formatGeminiTimestamp(resp.CreatedAt),
UpdateTime: formatGeminiTimestamp(resp.CreatedAt),
State: ToGeminiBatchStatus(resp.Status),
BatchStats: &GeminiBatchStats{
RequestCount: resp.RequestCounts.Total,
PendingRequestCount: pendingCount,
SuccessfulRequestCount: succeededCount,
},
},
}
if resp.OperationName != nil && *resp.OperationName != "" {
geminiResp.Metadata.Name = *resp.OperationName
geminiResp.Name = *resp.OperationName
}
if resp.Done != nil {
geminiResp.Done = *resp.Done
} else {
geminiResp.Done = resp.Status == schemas.BatchStatusCompleted ||
resp.Status == schemas.BatchStatusEnded ||
resp.Status == schemas.BatchStatusFailed ||
resp.Status == schemas.BatchStatusExpired ||
resp.Status == schemas.BatchStatusCancelled
}
if resp.InputFileID != "" {
geminiResp.Metadata.InputConfig = &GeminiBatchMetadataInputConfig{
FileName: resp.InputFileID,
}
}
if resp.OutputFileID != nil && *resp.OutputFileID != "" {
geminiResp.Dest = &GeminiBatchDest{
FileName: *resp.OutputFileID,
}
geminiResp.Metadata.Output = &GeminiBatchMetadataOutputConfig{
ResponsesFile: *resp.OutputFileID,
}
}
// Set end time from the most relevant terminal timestamp
var endTime int64
if resp.CompletedAt != nil {
endTime = *resp.CompletedAt
} else if resp.FailedAt != nil {
endTime = *resp.FailedAt
} else if resp.ExpiredAt != nil {
endTime = *resp.ExpiredAt
} else if resp.CancelledAt != nil {
endTime = *resp.CancelledAt
}
if endTime > 0 {
geminiResp.Metadata.EndTime = formatGeminiTimestamp(endTime)
}
return geminiResp
}
// ToGeminiBatchListResponse converts a Bifrost batch list response to Gemini format.
func ToGeminiBatchListResponse(resp *schemas.BifrostBatchListResponse) *GeminiBatchListResponse {
if resp == nil {
return nil
}
operations := make([]GeminiBatchJobResponse, 0, len(resp.Data))
for i := range resp.Data {
if geminiResp := ToGeminiBatchRetrieveResponse(&resp.Data[i]); geminiResp != nil {
operations = append(operations, *geminiResp)
}
}
geminiListResp := &GeminiBatchListResponse{
Operations: operations,
}
if resp.NextCursor != nil {
geminiListResp.NextPageToken = *resp.NextCursor
}
return geminiListResp
}
// parseGeminiTimestamp converts Gemini RFC3339 timestamp to Unix timestamp.
func parseGeminiTimestamp(timestamp string) int64 {
if timestamp == "" {
return 0
}
t, err := time.Parse(time.RFC3339, timestamp)
if err != nil {
return 0
}
return t.Unix()
}
// extractBatchIDFromName extracts the batch ID from the full resource name.
// e.g., "batches/abc123" -> "abc123"
func extractBatchIDFromName(name string) string {
if name == "" {
return ""
}
parts := strings.Split(name, "/")
return parts[len(parts)-1]
}
// downloadBatchResultsFile downloads and parses a batch results file from Gemini.
// Returns the parsed result items from the JSONL file and any parse errors encountered.
func (provider *GeminiProvider) downloadBatchResultsFile(ctx context.Context, key schemas.Key, fileName string) ([]schemas.BatchResultItem, []schemas.BatchError, *schemas.BifrostError) {
// Create request to download the file
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Build download URL - use the download endpoint with alt=media
// The base URL is like https://generativelanguage.googleapis.com/v1beta
// We need to change it to https://generativelanguage.googleapis.com/download/v1beta
baseURL := strings.Replace(provider.networkConfig.BaseURL, "/v1beta", "/download/v1beta", 1)
// Ensure fileName has proper format
fileID := fileName
if !strings.HasPrefix(fileID, "files/") {
fileID = "files/" + fileID
}
url := fmt.Sprintf("%s/%s:download?alt=media", baseURL, fileID)
provider.logger.Debug("gemini batch results file download url: " + url)
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(url)
req.Header.SetMethod(http.MethodGet)
if key.Value.GetValue() != "" {
req.Header.Set("x-goog-api-key", key.Value.GetValue())
}
// Make request
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, nil, bifrostErr
}
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
return nil, nil, parseGeminiError(resp)
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
// Parse JSONL content - each line is a separate JSON object
// Use streaming parser to avoid string conversion and collect parse errors
results := make([]schemas.BatchResultItem, 0)
parseResult := providerUtils.ParseJSONL(body, func(line []byte) error {
var resultLine GeminiBatchFileResultLine
if err := sonic.Unmarshal(line, &resultLine); err != nil {
provider.logger.Warn("gemini batch results file parse error: " + err.Error())
return err
}
customID := resultLine.Key
if customID == "" {
customID = fmt.Sprintf("request-%d", len(results))
}
resultItem := schemas.BatchResultItem{
CustomID: customID,
}
if resultLine.Error != nil {
resultItem.Error = &schemas.BatchResultError{
Code: fmt.Sprintf("%d", resultLine.Error.Code),
Message: resultLine.Error.Message,
}
} else if resultLine.Response != nil {
// Convert the response to a map for the Body field
respBody := make(map[string]interface{})
if len(resultLine.Response.Candidates) > 0 {
candidate := resultLine.Response.Candidates[0]
if candidate.Content != nil && len(candidate.Content.Parts) > 0 {
var textParts []string
for _, part := range candidate.Content.Parts {
if part.Text != "" {
textParts = append(textParts, part.Text)
}
}
if len(textParts) > 0 {
respBody["text"] = strings.Join(textParts, "")
}
}
respBody["finish_reason"] = string(candidate.FinishReason)
}
if resultLine.Response.UsageMetadata != nil {
respBody["usage"] = map[string]interface{}{
"prompt_tokens": resultLine.Response.UsageMetadata.PromptTokenCount,
"completion_tokens": resultLine.Response.UsageMetadata.CandidatesTokenCount,
"total_tokens": resultLine.Response.UsageMetadata.TotalTokenCount,
}
}
resultItem.Response = &schemas.BatchResultResponse{
StatusCode: 200,
Body: respBody,
}
}
results = append(results, resultItem)
return nil
})
return results, parseResult.Errors, nil
}
// extractGeminiUsageMetadata extracts usage metadata (as ints) from Gemini response
func extractGeminiUsageMetadata(geminiResponse *GenerateContentResponse) (int, int, int) {
var inputTokens, outputTokens, totalTokens int
if geminiResponse.UsageMetadata != nil {
usageMetadata := geminiResponse.UsageMetadata
inputTokens = int(usageMetadata.PromptTokenCount)
outputTokens = int(usageMetadata.CandidatesTokenCount)
totalTokens = int(usageMetadata.TotalTokenCount)
}
return inputTokens, outputTokens, totalTokens
}