Files
bifrost/core/providers/vertex/vertex.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

3214 lines
133 KiB
Go

package vertex
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"time"
"github.com/valyala/fasthttp"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"github.com/bytedance/sonic"
"github.com/maximhq/bifrost/core/providers/anthropic"
"github.com/maximhq/bifrost/core/providers/gemini"
"github.com/maximhq/bifrost/core/providers/openai"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
schemas "github.com/maximhq/bifrost/core/schemas"
)
type VertexError struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
} `json:"error"`
}
// vertexClientPool provides a pool/cache for authenticated Vertex HTTP clients.
// This avoids creating and authenticating clients for every request.
// Uses sync.Map for atomic operations without explicit locking.
var vertexClientPool sync.Map
// vertexLocationsPathRe matches /locations/{region} in Vertex API paths for region replacement.
var vertexLocationsPathRe = regexp.MustCompile(`/locations/[^/]+`)
var vertexProjectsPathRe = regexp.MustCompile(`/projects/[^/]+`)
const maxStreamPassthroughCaptureBytes = 1024 * 1024
// vertexBodyProjectsRe matches projects/{project} in body JSON values,
// where the path may appear as "projects/X (after a JSON quote) or /projects/X (mid-path).
var vertexBodyProjectsRe = regexp.MustCompile(`(["/])projects/[^/"]+`)
// vertexShortModelRe matches short-form model names like "models/X" in JSON bodies
// that need expanding to the full Vertex resource path.
var vertexShortModelRe = regexp.MustCompile(`"(models/[^/"]+)"`)
// getClientKey generates a unique key for caching authenticated clients.
// It uses a hash of the auth credentials for security.
func getClientKey(authCredentials string) string {
hash := sha256.Sum256([]byte(authCredentials))
return hex.EncodeToString(hash[:])
}
// removeVertexClient removes a specific client from the pool.
// This should be called when:
// - API returns authentication/authorization errors (401, 403)
// - Auth client creation fails
// - Network errors that might indicate credential issues
// This ensures we don't keep using potentially invalid clients.
func removeVertexClient(authCredentials string) {
clientKey := getClientKey(authCredentials)
vertexClientPool.Delete(clientKey)
}
// VertexProvider implements the Provider interface for Google's Vertex AI API.
type VertexProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
}
// NewVertexProvider creates a new Vertex provider instance.
// It initializes the HTTP client with the provided configuration and sets up response pools.
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
func NewVertexProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*VertexProvider, error) {
config.CheckAndSetDefaults()
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
client := &fasthttp.Client{
ReadTimeout: requestTimeout,
WriteTimeout: requestTimeout,
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
MaxIdleConnDuration: 30 * time.Second,
MaxConnWaitTimeout: requestTimeout,
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
ConnPoolStrategy: fasthttp.FIFO,
}
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
client = providerUtils.ConfigureDialer(client)
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
streamingClient := providerUtils.BuildStreamingClient(client)
return &VertexProvider{
logger: logger,
client: client,
streamingClient: streamingClient,
networkConfig: config.NetworkConfig,
sendBackRawRequest: config.SendBackRawRequest,
sendBackRawResponse: config.SendBackRawResponse,
}, nil
}
const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
// getAuthTokenSource returns an authenticated token source for Vertex AI API requests.
// It uses the default credentials if no auth credentials are provided.
// It uses the JWT config if auth credentials are provided.
// It returns an error if the token source creation fails.
func getAuthTokenSource(key schemas.Key) (oauth2.TokenSource, error) {
authCredentials := key.VertexKeyConfig.AuthCredentials
var tokenSource oauth2.TokenSource
if authCredentials.GetValue() == "" {
creds, err := google.FindDefaultCredentials(context.Background(), cloudPlatformScope)
if err != nil {
return nil, fmt.Errorf("failed to find default credentials in environment: %w", err)
}
tokenSource = creds.TokenSource
} else {
jsonData := []byte(authCredentials.GetValue())
// Peek at the JSON to detect the "type" field
var meta struct {
Type string `json:"type"`
}
if err := sonic.Unmarshal(jsonData, &meta); err != nil {
return nil, fmt.Errorf("failed to parse auth credentials JSON: %w", err)
}
// Map string to google.CredentialsType with a security whitelist
var credType google.CredentialsType
switch meta.Type {
case string(google.ServiceAccount):
credType = google.ServiceAccount
case string(google.ImpersonatedServiceAccount):
credType = google.ImpersonatedServiceAccount
case string(google.AuthorizedUser):
credType = google.AuthorizedUser
case string(google.ExternalAccount):
credType = google.ExternalAccount
case string(google.ExternalAccountAuthorizedUser):
credType = google.ExternalAccountAuthorizedUser
case "":
return nil, fmt.Errorf("invalid google auth credentials: missing 'type'")
default:
return nil, fmt.Errorf("unsupported or restricted credential type: %s", meta.Type)
}
conf, err := google.CredentialsFromJSONWithType(context.Background(), jsonData, credType, cloudPlatformScope)
if err != nil {
return nil, fmt.Errorf("failed to create credentials from auth credentials JSON: %w", err)
}
tokenSource = conf.TokenSource
}
return tokenSource, nil
}
// GetProviderKey returns the provider identifier for Vertex.
func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider {
return schemas.Vertex
}
// listModelsByKey performs a list models request for a single key.
// Returns the response and latency, or an error if the request fails.
//
// The logic is:
// 1. If deployments or allowedModels are configured, return those (no API call needed)
// 2. Otherwise, fetch from the publishers.models.list API endpoint (Model Garden)
func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
deployments := key.Aliases
allowedModels := key.Models
if !request.Unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || key.BlacklistedModels.IsBlockAll()) {
return &schemas.BifrostListModelsResponse{Data: make([]schemas.Model, 0)}, nil
}
// If deployments or allowedModels are configured, return those directly without API call
// Skip this fast path when Unfiltered is set so the full Vertex catalog can be retrieved
if !request.Unfiltered && (len(deployments) > 0 || allowedModels.IsRestricted()) {
return buildResponseFromConfig(deployments, allowedModels, key.BlacklistedModels), nil
}
// No deployments configured - fetch from Model Garden API
var host string
if region == "global" {
host = "aiplatform.googleapis.com"
} else {
host = fmt.Sprintf("%s-aiplatform.googleapis.com", region)
}
// Accumulate all publisher models from paginated requests
var allPublisherModels []VertexPublisherModel
var rawRequests []interface{}
var rawResponses []interface{}
pageToken := ""
// Getting oauth2 token
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source (api key auth not supported for list models)", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token (api key auth not supported for list models)", err)
}
// Iterate over all supported Vertex publishers to include Google, Anthropic, and Mistral models
publishers := []string{"google", "anthropic", "mistralai"}
for _, publisher := range publishers {
pageToken = ""
// Loop through all pages until no nextPageToken is returned
for {
// Build URL for publishers.models.list endpoint (Model Garden)
// Format: https://{region}-aiplatform.googleapis.com/v1beta1/publishers/{publisher}/models
requestURL := fmt.Sprintf("https://%s/v1beta1/publishers/%s/models?pageSize=%d", host, publisher, MaxPageSize)
if pageToken != "" {
requestURL = fmt.Sprintf("%s&pageToken=%s", requestURL, url.QueryEscape(pageToken))
}
// Create HTTP request for listing models
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
req.Header.SetMethod(http.MethodGet)
req.SetRequestURI(requestURL)
req.Header.SetContentType("application/json")
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
if bifrostErr != nil {
wait()
respBody := append([]byte(nil), resp.Body()...)
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
// Non-Google publishers may not be available in all regions; skip on error
if publisher != "google" {
break
}
return nil, providerUtils.EnrichError(ctx, bifrostErr, nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
// Non-Google publishers may not be available in all regions;
// skip only on 403/404 which indicate regional unavailability.
// Surface other errors (401, 429, 5xx) so they aren't silently swallowed.
if publisher != "google" && (resp.StatusCode() == fasthttp.StatusForbidden || resp.StatusCode() == fasthttp.StatusNotFound) {
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
break
}
respBody := append([]byte(nil), resp.Body()...)
statusCode := resp.StatusCode()
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
var errorResp VertexError
if err := sonic.Unmarshal(respBody, &errorResp); err != nil {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorResp.Error.Message, nil, statusCode, nil, nil), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Parse Vertex's publisher models response
var vertexResponse VertexListPublisherModelsResponse
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(resp.Body(), &vertexResponse, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
respBody := append([]byte(nil), resp.Body()...)
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
return nil, providerUtils.EnrichError(ctx, bifrostErr, nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
rawRequests = append(rawRequests, rawRequest)
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
rawResponses = append(rawResponses, rawResponse)
}
// Accumulate models from this page
allPublisherModels = append(allPublisherModels, vertexResponse.PublisherModels...)
wait()
fasthttp.ReleaseRequest(req)
fasthttp.ReleaseResponse(resp)
// Check if there are more pages
if vertexResponse.NextPageToken == "" {
break
}
pageToken = vertexResponse.NextPageToken
}
}
// Create aggregated response from all pages
aggregatedResponse := &VertexListPublisherModelsResponse{
PublisherModels: allPublisherModels,
}
response := aggregatedResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered)
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequests
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponses
}
return response, nil
}
// ListModels performs a list models request to Vertex's API.
// Requests are made concurrently for improved performance.
func (provider *VertexProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
finalResponse, bifrostErr := providerUtils.HandleMultipleListModelsRequests(
ctx,
keys,
request,
provider.listModelsByKey,
)
if bifrostErr != nil {
return nil, bifrostErr
}
return finalResponse, nil
}
// TextCompletion is not supported by the Vertex provider.
// Returns an error indicating that text completion is not available.
func (provider *VertexProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey())
}
// TextCompletionStream performs a streaming text completion request to Vertex's API.
// It formats the request, sends it to Vertex, and processes the response.
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
func (provider *VertexProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey())
}
// ChatCompletion performs a chat completion request to the Vertex API.
// It supports both text and image content in messages.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
// Format messages for Vertex API, preserving key order for prompt caching
var rawBody []byte
var extraParams map[string]interface{}
var err error
if schemas.IsAnthropicModel(request.Model) {
// Use centralized Anthropic converter
reqBody, convErr := anthropic.ToAnthropicChatRequest(ctx, request)
if convErr != nil {
return nil, convErr
}
if reqBody == nil {
return nil, fmt.Errorf("chat completion input is not provided")
}
extraParams = reqBody.GetExtraParams()
// Add provider-aware beta headers for Vertex
anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex)
// Marshal to JSON bytes, preserving struct field order
rawBody, err = providerUtils.MarshalSorted(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
// Add anthropic_version if not present (using sjson to preserve order)
if !providerUtils.JSONFieldExists(rawBody, "anthropic_version") {
rawBody, err = providerUtils.SetJSONField(rawBody, "anthropic_version", DefaultVertexAnthropicVersion)
if err != nil {
return nil, fmt.Errorf("failed to set anthropic_version: %w", err)
}
}
// Inject beta headers into body as anthropic_beta (Vertex uses body field, not HTTP header)
if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(provider.networkConfig.ExtraHeaders, ctx), schemas.Vertex, provider.networkConfig.BetaHeaderOverrides); len(betaHeaders) > 0 {
rawBody, err = providerUtils.SetJSONField(rawBody, "anthropic_beta", betaHeaders)
if err != nil {
return nil, fmt.Errorf("failed to set anthropic_beta: %w", err)
}
}
// Remove model field (it's in URL for Vertex)
rawBody, err = providerUtils.DeleteJSONField(rawBody, "model")
if err != nil {
return nil, fmt.Errorf("failed to delete model field: %w", err)
}
} else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
reqBody, err := gemini.ToGeminiChatCompletionRequest(request)
if err != nil {
return nil, err
}
if reqBody == nil {
return nil, fmt.Errorf("chat completion input is not provided")
}
extraParams = reqBody.GetExtraParams()
// Strip unsupported fields for Vertex Gemini
stripVertexGeminiUnsupportedFields(reqBody)
// Marshal to JSON bytes
rawBody, err = providerUtils.MarshalSorted(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
} else {
// Use centralized OpenAI converter for non-Claude models
reqBody := openai.ToOpenAIChatRequest(ctx, request)
if reqBody == nil {
return nil, fmt.Errorf("chat completion input is not provided")
}
extraParams = reqBody.GetExtraParams()
// Marshal to JSON bytes
rawBody, err = providerUtils.MarshalSorted(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
}
// Remove region field if present
rawBody, err = providerUtils.DeleteJSONField(rawBody, "region")
if err != nil {
return nil, fmt.Errorf("failed to delete region field: %w", err)
}
return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil
},
)
if bifrostErr != nil {
return nil, bifrostErr
}
projectID := key.VertexKeyConfig.ProjectID.GetValue()
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
// Remap unsupported tool versions for Vertex (handles raw passthrough bodies)
if schemas.IsAnthropicModel(request.Model) && jsonBody != nil {
remappedBody, remapErr := anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex)
if remapErr != nil {
return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil)
}
jsonBody = remappedBody
}
// Auth query is used for fine-tuned models to pass the API key in the query string
authQuery := ""
// Determine the URL based on model type
var completeURL string
if schemas.IsAllDigitsASCII(request.Model) {
// Custom Fine-tuned models use OpenAPI endpoint
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
if projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
}
if key.Value.GetValue() != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
}
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model)
}
} else if schemas.IsAnthropicModel(request.Model) {
// Claude models use Anthropic publisher
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model)
}
} else if schemas.IsMistralModel(request.Model) {
// Mistral models use mistralai publisher with rawPredict
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:rawPredict", projectID, request.Model)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:rawPredict", region, projectID, region, request.Model)
}
} else if schemas.IsGeminiModel(request.Model) {
// Gemini models support api key
if key.Value.GetValue() != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
}
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model)
}
} else {
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/openapi/chat/completions", projectID)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region)
}
}
// Create HTTP request for streaming
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
respOwned := true
defer func() {
if respOwned {
fasthttp.ReleaseResponse(resp)
}
}()
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
// Skip anthropic-beta from context headers — Anthropic models on Vertex use the
// anthropic_beta body field instead, and other model families don't use it.
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, []string{anthropic.AnthropicBetaHeader})
// If auth query is set, add it to the URL
// Otherwise, get the oauth2 token and set the Authorization header
if authQuery != "" {
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
} else {
// Getting oauth2 token
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
}
req.SetRequestURI(completeURL)
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
if !usedLargePayloadBody {
req.SetBody(jsonBody)
}
// Make the request with optional large response streaming
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
defer wait()
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if usedLargePayloadBody {
providerUtils.DrainLargePayloadRemainder(ctx)
}
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
if resp.StatusCode() != fasthttp.StatusOK {
providerUtils.MaterializeStreamErrorBody(ctx, resp)
// Remove client from pool for authentication/authorization errors
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
if decodeErr != nil {
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if isLargeResp {
respOwned = false
return &schemas.BifrostChatResponse{
Model: request.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
},
}, nil
}
if schemas.IsAnthropicModel(request.Model) {
// Create response object from pool
anthropicResponse := anthropic.AcquireAnthropicMessageResponse()
defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse)
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Create final response
response := anthropicResponse.ToBifrostChatResponse(ctx)
response.ExtraFields = schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
}
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
} else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
geminiResponse := gemini.GenerateContentResponse{}
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response := geminiResponse.ToBifrostChatResponse()
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
} else {
response := &schemas.BifrostChatResponse{}
// Use enhanced response handler with pre-allocated response
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
}
// ChatCompletionStream performs a streaming chat completion request to the Vertex API.
// It supports both OpenAI-style streaming (for non-Claude models) and Anthropic-style streaming (for Claude models).
// Returns a channel of BifrostStreamChunk objects for streaming results or an error if the request fails.
func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
providerName := provider.GetProviderKey()
projectID := key.VertexKeyConfig.ProjectID.GetValue()
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
if schemas.IsAnthropicModel(request.Model) {
// Use Anthropic-style streaming for Claude models
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
var extraParams map[string]interface{}
reqBody, convErr := anthropic.ToAnthropicChatRequest(ctx, request)
if convErr != nil {
return nil, convErr
}
if reqBody == nil {
return nil, fmt.Errorf("chat completion input is not provided")
}
extraParams = reqBody.GetExtraParams()
reqBody.Stream = new(true)
// Add provider-aware beta headers for Vertex
anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex)
// Marshal to JSON bytes, preserving struct field order for prompt caching
rawBody, err := providerUtils.MarshalSorted(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
// Add anthropic_version if not present (using sjson to preserve order)
if !providerUtils.JSONFieldExists(rawBody, "anthropic_version") {
rawBody, err = providerUtils.SetJSONField(rawBody, "anthropic_version", DefaultVertexAnthropicVersion)
if err != nil {
return nil, fmt.Errorf("failed to set anthropic_version: %w", err)
}
}
// Inject beta headers into body as anthropic_beta (Vertex uses body field, not HTTP header)
if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(provider.networkConfig.ExtraHeaders, ctx), schemas.Vertex, provider.networkConfig.BetaHeaderOverrides); len(betaHeaders) > 0 {
rawBody, err = providerUtils.SetJSONField(rawBody, "anthropic_beta", betaHeaders)
if err != nil {
return nil, fmt.Errorf("failed to set anthropic_beta: %w", err)
}
}
// Remove model and region fields (using sjson to preserve order)
rawBody, err = providerUtils.DeleteJSONField(rawBody, "model")
if err != nil {
return nil, fmt.Errorf("failed to delete model field: %w", err)
}
rawBody, err = providerUtils.DeleteJSONField(rawBody, "region")
if err != nil {
return nil, fmt.Errorf("failed to delete region field: %w", err)
}
return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil
},
)
if bifrostErr != nil {
return nil, bifrostErr
}
// Remap unsupported tool versions for Vertex streaming (handles raw passthrough bodies)
if jsonData != nil {
var remapErr error
jsonData, remapErr = anthropic.RemapRawToolVersionsForProvider(jsonData, schemas.Vertex)
if remapErr != nil {
return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil)
}
}
var completeURL string
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model)
}
// Prepare headers for Vertex Anthropic
headers := map[string]string{
"Content-Type": "application/json",
"Accept": "text/event-stream",
"Cache-Control": "no-cache",
}
// Adding authorization header
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
headers["Authorization"] = "Bearer " + token.AccessToken
// Use shared Anthropic streaming logic
return anthropic.HandleAnthropicChatCompletionStreaming(
ctx,
provider.streamingClient,
completeURL,
jsonData,
headers,
provider.networkConfig.ExtraHeaders,
provider.networkConfig.BetaHeaderOverrides,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
providerName,
postHookRunner,
nil,
provider.logger,
postHookSpanFinalizer,
)
} else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
// Use Gemini-style streaming for Gemini models
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
reqBody, err := gemini.ToGeminiChatCompletionRequest(request)
if err != nil {
return nil, err
}
if reqBody == nil {
return nil, fmt.Errorf("chat completion input is not provided")
}
// Strip unsupported fields for Vertex Gemini
stripVertexGeminiUnsupportedFields(reqBody)
return reqBody, nil
},
)
if bifrostErr != nil {
return nil, bifrostErr
}
// Auth query is used to pass the API key in the query string
authQuery := ""
if key.Value.GetValue() != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
}
// For custom/fine-tuned models, validate projectNumber is set
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
}
// Construct the URL for Gemini streaming
completeURL := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":streamGenerateContent")
// Add alt=sse parameter
if authQuery != "" {
completeURL = fmt.Sprintf("%s?alt=sse&%s", completeURL, authQuery)
} else {
completeURL = fmt.Sprintf("%s?alt=sse", completeURL)
}
// Prepare headers for Vertex Gemini
headers := map[string]string{
"Accept": "text/event-stream",
"Cache-Control": "no-cache",
}
// If no auth query, use OAuth2 token
if authQuery == "" {
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
headers["Authorization"] = "Bearer " + token.AccessToken
}
// Use shared streaming logic from Gemini
return gemini.HandleGeminiChatCompletionStream(
ctx,
provider.streamingClient,
completeURL,
jsonData,
headers,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
request.Model,
postHookRunner,
nil,
provider.logger,
postHookSpanFinalizer,
)
} else {
var authHeader map[string]string
// Auth query is used for fine-tuned models to pass the API key in the query string
authQuery := ""
// Determine the URL based on model type
var completeURL string
if schemas.IsMistralModel(request.Model) {
// Mistral models use mistralai publisher with streamRawPredict
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:streamRawPredict", projectID, request.Model)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:streamRawPredict", region, projectID, region, request.Model)
}
} else {
// Other models use OpenAPI endpoint for gemini models
if key.Value.GetValue() != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
}
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/openapi/chat/completions", projectID)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region)
}
}
if authQuery != "" {
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
} else {
// Getting oauth2 token
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
authHeader = map[string]string{
"Authorization": "Bearer " + token.AccessToken,
}
}
// Use shared OpenAI streaming logic
return openai.HandleOpenAIChatCompletionStreaming(
ctx,
provider.streamingClient,
completeURL,
request,
authHeader,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
providerName,
postHookRunner,
nil,
nil,
nil,
nil,
nil,
provider.logger,
postHookSpanFinalizer,
)
}
}
// Responses performs a responses request to the Vertex API.
func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
if schemas.IsAnthropicModel(request.Model) {
jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, false, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders)
if bifrostErr != nil {
return nil, bifrostErr
}
projectID := key.VertexKeyConfig.ProjectID.GetValue()
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
// Claude models use Anthropic publisher
var url string
if region == "global" {
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model)
} else {
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model)
}
// Create HTTP request for streaming
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
respOwned := true
defer func() {
if respOwned {
fasthttp.ReleaseResponse(resp)
}
}()
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, []string{anthropic.AnthropicBetaHeader})
// Getting oauth2 token
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
req.SetRequestURI(url)
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
if !usedLargePayloadBody {
req.SetBody(jsonBody)
}
// Make the request with optional large response streaming
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
defer wait()
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if usedLargePayloadBody {
providerUtils.DrainLargePayloadRemainder(ctx)
}
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
if resp.StatusCode() != fasthttp.StatusOK {
providerUtils.MaterializeStreamErrorBody(ctx, resp)
// Remove client from pool for authentication/authorization errors
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
if decodeErr != nil {
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if isLargeResp {
respOwned = false
return &schemas.BifrostResponsesResponse{
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
},
}, nil
}
// Create response object from pool
anthropicResponse := anthropic.AcquireAnthropicMessageResponse()
defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse)
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Create final response
response := anthropicResponse.ToBifrostResponsesResponse(ctx)
response.ExtraFields = schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
}
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
// Set raw request if enabled
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
} else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
reqBody, err := gemini.ToGeminiResponsesRequest(request)
if err != nil {
return nil, err
}
if reqBody == nil {
return nil, fmt.Errorf("responses input is not provided")
}
// Strip unsupported fields for Vertex Gemini
stripVertexGeminiUnsupportedFields(reqBody)
return reqBody, nil
},
)
if bifrostErr != nil {
return nil, bifrostErr
}
projectID := key.VertexKeyConfig.ProjectID.GetValue()
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
authQuery := ""
if key.Value.GetValue() != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
}
// For custom/fine-tuned models, validate projectNumber is set
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
}
url := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":generateContent")
// Create HTTP request for streaming
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
respOwned := true
defer func() {
if respOwned {
fasthttp.ReleaseResponse(resp)
}
}()
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
// If auth query is set, add it to the URL
// Otherwise, get the oauth2 token and set the Authorization header
if authQuery != "" {
url = fmt.Sprintf("%s?%s", url, authQuery)
} else {
// Getting oauth2 token
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
}
req.SetRequestURI(url)
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
if !usedLargePayloadBody {
req.SetBody(jsonBody)
}
// Make the request with optional large response streaming
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
defer wait()
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if usedLargePayloadBody {
providerUtils.DrainLargePayloadRemainder(ctx)
}
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
if resp.StatusCode() != fasthttp.StatusOK {
providerUtils.MaterializeStreamErrorBody(ctx, resp)
// Remove client from pool for authentication/authorization errors
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
if decodeErr != nil {
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if isLargeResp {
respOwned = false
return &schemas.BifrostResponsesResponse{
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
},
}, nil
}
geminiResponse := &gemini.GenerateContentResponse{}
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response := geminiResponse.ToResponsesBifrostResponsesResponse()
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
return response, nil
} else {
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
if err != nil {
return nil, err
}
response := chatResponse.ToBifrostResponsesResponse()
return response, nil
}
}
// ResponsesStream performs a streaming responses request to the Vertex API.
func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
if schemas.IsAnthropicModel(request.Model) {
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
projectID := key.VertexKeyConfig.ProjectID.GetValue()
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, true, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders)
if bifrostErr != nil {
return nil, bifrostErr
}
var url string
if region == "global" {
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model)
} else {
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model)
}
// Prepare headers for Vertex Anthropic
headers := map[string]string{
"Content-Type": "application/json",
"Accept": "text/event-stream",
"Cache-Control": "no-cache",
}
// Adding authorization header
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
headers["Authorization"] = "Bearer " + token.AccessToken
// Use shared streaming logic from Anthropic
return anthropic.HandleAnthropicResponsesStream(
ctx,
provider.streamingClient,
url,
jsonBody,
headers,
provider.networkConfig.ExtraHeaders,
provider.networkConfig.BetaHeaderOverrides,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
postHookRunner,
nil,
provider.logger,
postHookSpanFinalizer,
)
} else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
projectID := key.VertexKeyConfig.ProjectID.GetValue()
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
// Use Gemini-style streaming for Gemini models
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
reqBody, err := gemini.ToGeminiResponsesRequest(request)
if err != nil {
return nil, err
}
if reqBody == nil {
return nil, fmt.Errorf("responses input is not provided")
}
// Strip unsupported fields for Vertex Gemini
stripVertexGeminiUnsupportedFields(reqBody)
return reqBody, nil
},
)
if bifrostErr != nil {
return nil, bifrostErr
}
// Auth query is used to pass the API key in the query string
authQuery := ""
if key.Value.GetValue() != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
}
// For custom/fine-tuned models, validate projectNumber is set
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
}
// Construct the URL for Gemini streaming
completeURL := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":streamGenerateContent")
// Add alt=sse parameter
if authQuery != "" {
completeURL = fmt.Sprintf("%s?alt=sse&%s", completeURL, authQuery)
} else {
completeURL = fmt.Sprintf("%s?alt=sse", completeURL)
}
// Prepare headers for Vertex Gemini
headers := map[string]string{
"Accept": "text/event-stream",
"Cache-Control": "no-cache",
}
// If no auth query, use OAuth2 token
if authQuery == "" {
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
headers["Authorization"] = "Bearer " + token.AccessToken
}
// Use shared streaming logic from Gemini
return gemini.HandleGeminiResponsesStream(
ctx,
provider.streamingClient,
completeURL,
jsonData,
headers,
provider.networkConfig.ExtraHeaders,
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
provider.GetProviderKey(),
request.Model,
postHookRunner,
nil,
provider.logger,
postHookSpanFinalizer,
)
} else {
ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
return provider.ChatCompletionStream(
ctx,
postHookRunner,
postHookSpanFinalizer,
key,
request.ToChatRequest(),
)
}
}
// Embedding generates embeddings for the given input text(s) using Vertex AI.
// All Vertex AI embedding models use the same response format regardless of the model type.
// Returns a BifrostResponse containing the embedding(s) and any error that occurred.
func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
projectID := key.VertexKeyConfig.ProjectID.GetValue()
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToVertexEmbeddingRequest(request), nil
},
)
if bifrostErr != nil {
return nil, bifrostErr
}
// For custom/fine-tuned models, validate projectNumber is set
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
}
// Build the native Vertex embedding API endpoint
url := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":predict")
// Create HTTP request for streaming
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
respOwned := true
defer func() {
if respOwned {
fasthttp.ReleaseResponse(resp)
}
}()
req.Header.SetMethod(http.MethodPost)
req.SetRequestURI(url)
req.Header.SetContentType("application/json")
// Set any extra headers from network config
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
// Getting oauth2 token
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
if !usedLargePayloadBody {
req.SetBody(jsonBody)
}
// Make the request with optional large response streaming
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
defer wait()
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if usedLargePayloadBody {
providerUtils.DrainLargePayloadRemainder(ctx)
}
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
if resp.StatusCode() != fasthttp.StatusOK {
providerUtils.MaterializeStreamErrorBody(ctx, resp)
// Remove client from pool for authentication/authorization errors
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
errBody := resp.Body()
// Extract error message from Vertex's error format
errorMessage := "Unknown error"
if len(errBody) > 0 {
// Try to parse Vertex's error format
var vertexError map[string]interface{}
if err := sonic.Unmarshal(errBody, &vertexError); err != nil {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if errorObj, exists := vertexError["error"]; exists {
if errorMap, ok := errorObj.(map[string]interface{}); ok {
if message, exists := errorMap["message"]; exists {
if msgStr, ok := message.(string); ok {
errorMessage = msgStr
}
}
}
}
}
return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), nil, nil), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
if decodeErr != nil {
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if isLargeResp {
respOwned = false
return &schemas.BifrostEmbeddingResponse{
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
},
}, nil
}
// Parse Vertex's native embedding response using typed response
var vertexResponse VertexEmbeddingResponse
if err := sonic.Unmarshal(responseBody, &vertexResponse); err != nil {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Use centralized Vertex converter
bifrostResponse := vertexResponse.ToBifrostEmbeddingResponse()
// Set ExtraFields
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
// Set raw response if enabled
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
var rawResponseMap map[string]interface{}
if err := sonic.Unmarshal(resp.Body(), &rawResponseMap); err != nil {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err), jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse)
}
bifrostResponse.ExtraFields.RawResponse = rawResponseMap
}
return bifrostResponse, nil
}
// Speech is not supported by the Vertex provider.
func (provider *VertexProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
}
// Rerank performs a rerank request using Vertex Discovery Engine ranking API.
func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue())
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
options, err := getVertexRerankOptions(projectID, request.Params)
if err != nil {
return nil, providerUtils.NewConfigurationError(err.Error())
}
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToVertexRankRequest(request, options)
},
)
if bifrostErr != nil {
return nil, bifrostErr
}
completeURL := fmt.Sprintf("https://discoveryengine.googleapis.com/v1/%s:rank", options.RankingConfig)
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
respOwned := true
defer func() {
if respOwned {
fasthttp.ReleaseResponse(resp)
}
}()
req.Header.SetMethod(http.MethodPost)
req.SetRequestURI(completeURL)
req.Header.SetContentType("application/json")
req.Header.Set("X-Goog-User-Project", projectID)
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
if !usedLargePayloadBody {
req.SetBody(jsonBody)
}
// Make the request with optional large response streaming
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
defer wait()
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if usedLargePayloadBody {
providerUtils.DrainLargePayloadRemainder(ctx)
}
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
if resp.StatusCode() != fasthttp.StatusOK {
providerUtils.MaterializeStreamErrorBody(ctx, resp)
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
errorMessage := parseDiscoveryEngineErrorMessage(resp.Body())
parsedError := parseVertexError(resp)
if strings.TrimSpace(errorMessage) != "" {
shouldOverride := parsedError == nil ||
parsedError.Error == nil ||
strings.TrimSpace(parsedError.Error.Message) == "" ||
parsedError.Error.Message == "Unknown error" ||
parsedError.Error.Message == schemas.ErrProviderResponseUnmarshal
if shouldOverride {
parsedError = providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), nil, nil)
}
}
return nil, providerUtils.EnrichError(ctx, parsedError, jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse)
}
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
if decodeErr != nil {
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if isLargeResp {
respOwned = false
return &schemas.BifrostRerankResponse{
Model: request.Model,
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
},
}, nil
}
vertexResponse := &VertexRankResponse{}
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, vertexResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
returnDocuments := request.Params != nil && request.Params.ReturnDocuments != nil && *request.Params.ReturnDocuments
bifrostResponse, err := vertexResponse.ToBifrostRerankResponse(request.Documents, returnDocuments)
if err != nil {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error converting rerank response", err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
bifrostResponse.ExtraFields.RawRequest = rawRequest
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
bifrostResponse.ExtraFields.RawResponse = rawResponse
}
return bifrostResponse, nil
}
// OCR is not supported by the Vertex provider.
func (provider *VertexProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
}
// SpeechStream is not supported by the Vertex provider.
func (provider *VertexProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
}
// Transcription is not supported by the Vertex provider.
func (provider *VertexProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey())
}
// TranscriptionStream is not supported by the Vertex provider.
func (provider *VertexProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
}
func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
// Validate model type before processing
if !schemas.IsGeminiModel(request.Model) && !schemas.IsAllDigitsASCII(request.Model) && !schemas.IsImagenModel(request.Model) {
return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image generation is only supported for Gemini and Imagen models, got: %s", request.Model))
}
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
var rawBody []byte
var extraParams map[string]interface{}
var err error
if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
reqBody := gemini.ToGeminiImageGenerationRequest(request)
if reqBody == nil {
return nil, fmt.Errorf("image generation input is not provided")
}
extraParams = reqBody.GetExtraParams()
// Strip unsupported fields for Vertex Gemini
stripVertexGeminiUnsupportedFields(reqBody)
// Marshal to JSON bytes, preserving key order
rawBody, err = providerUtils.MarshalSorted(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
} else if schemas.IsImagenModel(request.Model) {
reqBody := gemini.ToImagenImageGenerationRequest(request)
if reqBody == nil {
return nil, fmt.Errorf("image generation input is not provided")
}
extraParams = reqBody.GetExtraParams()
// Marshal to JSON bytes, preserving key order
rawBody, err = providerUtils.MarshalSorted(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
}
// Remove region field if present
rawBody, err = providerUtils.DeleteJSONField(rawBody, "region")
if err != nil {
return nil, fmt.Errorf("failed to delete region field: %w", err)
}
return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil
},
)
if bifrostErr != nil {
return nil, bifrostErr
}
projectID := key.VertexKeyConfig.ProjectID.GetValue()
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
// Auth query is used for fine-tuned models to pass the API key in the query string
authQuery := ""
// Determine the URL based on model type
var completeURL string
if schemas.IsAllDigitsASCII(request.Model) {
// Custom Fine-tuned models use OpenAPI endpoint
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
if projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
}
if value := key.Value.GetValue(); value != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value))
}
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model)
}
} else if schemas.IsImagenModel(request.Model) {
// Imagen models are published models, use publishers/google/models path
if value := key.Value.GetValue(); value != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value))
}
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, request.Model)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, request.Model)
}
} else if schemas.IsGeminiModel(request.Model) {
if value := key.Value.GetValue(); value != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value))
}
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model)
}
}
// Create HTTP request for image generation
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
respOwned := true
defer func() {
if respOwned {
fasthttp.ReleaseResponse(resp)
}
}()
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
// If auth query is set, add it to the URL
// Otherwise, get the oauth2 token and set the Authorization header
if authQuery != "" {
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
} else {
// Getting oauth2 token
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
}
req.SetRequestURI(completeURL)
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
if !usedLargePayloadBody {
req.SetBody(jsonBody)
}
// Make the request with optional large response streaming
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
defer wait()
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if usedLargePayloadBody {
providerUtils.DrainLargePayloadRemainder(ctx)
}
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
if resp.StatusCode() != fasthttp.StatusOK {
providerUtils.MaterializeStreamErrorBody(ctx, resp)
// Remove client from pool for authentication/authorization errors
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
if decodeErr != nil {
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if isLargeResp {
respOwned = false
return &schemas.BifrostImageGenerationResponse{
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
},
}, nil
}
if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
geminiResponse := gemini.GenerateContentResponse{}
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response, err := geminiResponse.ToBifrostImageGenerationResponse()
if err != nil {
return nil, providerUtils.EnrichError(ctx, err, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
} else {
// Handle Imagen responses
imagenResponse := gemini.GeminiImagenResponse{}
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &imagenResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response := imagenResponse.ToBifrostImageGenerationResponse()
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
}
// ImageGenerationStream is not supported by the Vertex provider.
func (provider *VertexProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
}
// ImageEdit edits images for the given input text(s) using Vertex AI.
// Returns a BifrostResponse containing the images and any error that occurred.
func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
// Validate model type before processing
if !schemas.IsGeminiModel(request.Model) && !schemas.IsAllDigitsASCII(request.Model) && !schemas.IsImagenModel(request.Model) {
return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image edit is only supported for Gemini and Imagen models, got: %s", request.Model))
}
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
var rawBody []byte
var extraParams map[string]interface{}
var err error
if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
reqBody := gemini.ToGeminiImageEditRequest(request)
if reqBody == nil {
return nil, fmt.Errorf("image edit input is not provided")
}
extraParams = reqBody.GetExtraParams()
// Strip unsupported fields for Vertex Gemini
stripVertexGeminiUnsupportedFields(reqBody)
// Marshal to JSON bytes, preserving key order
rawBody, err = providerUtils.MarshalSorted(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
} else if schemas.IsImagenModel(request.Model) {
reqBody := gemini.ToImagenImageEditRequest(request)
if reqBody == nil {
return nil, fmt.Errorf("image edit input is not provided")
}
extraParams = reqBody.GetExtraParams()
// Marshal to JSON bytes, preserving key order
rawBody, err = providerUtils.MarshalSorted(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
}
// Remove region field if present
rawBody, err = providerUtils.DeleteJSONField(rawBody, "region")
if err != nil {
return nil, fmt.Errorf("failed to delete region field: %w", err)
}
return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil
},
)
if bifrostErr != nil {
return nil, bifrostErr
}
projectID := key.VertexKeyConfig.ProjectID.GetValue()
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
authQuery := ""
if value := key.Value.GetValue(); value != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value))
}
var completeURL string
if schemas.IsAllDigitsASCII(request.Model) {
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
if projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
}
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model)
}
} else if schemas.IsImagenModel(request.Model) {
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, request.Model)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, request.Model)
}
} else if schemas.IsGeminiModel(request.Model) {
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model)
}
}
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
respOwned := true
defer func() {
if respOwned {
fasthttp.ReleaseResponse(resp)
}
}()
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
// If auth query is set, add it to the URL
// Otherwise, get the oauth2 token and set the Authorization header
if authQuery != "" {
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
} else {
// Getting oauth2 token
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
}
req.SetRequestURI(completeURL)
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
if !usedLargePayloadBody {
req.SetBody(jsonBody)
}
// Make the request with optional large response streaming
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
defer wait()
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if usedLargePayloadBody {
providerUtils.DrainLargePayloadRemainder(ctx)
}
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
if resp.StatusCode() != fasthttp.StatusOK {
providerUtils.MaterializeStreamErrorBody(ctx, resp)
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
if decodeErr != nil {
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if isLargeResp {
respOwned = false
return &schemas.BifrostImageGenerationResponse{
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
},
}, nil
}
if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
geminiResponse := gemini.GenerateContentResponse{}
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response, err := geminiResponse.ToBifrostImageGenerationResponse()
if err != nil {
return nil, providerUtils.EnrichError(ctx, err, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
} else {
// Handle Imagen responses
imagenResponse := gemini.GeminiImagenResponse{}
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &imagenResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response := imagenResponse.ToBifrostImageGenerationResponse()
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
}
// ImageEditStream is not supported by the Vertex provider.
func (provider *VertexProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
}
// ImageVariation is not supported by the Vertex provider.
func (provider *VertexProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
}
// VideoGeneration generates a video using Vertex AI's Gemini models.
// Only Gemini models support video generation in Vertex AI.
// Uses the predictLongRunning endpoint for async video generation.
func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
providerName := provider.GetProviderKey()
// Only Gemini models support video generation in Vertex
if !schemas.IsVeoModel(bifrostReq.Model) && !schemas.IsAllDigitsASCII(bifrostReq.Model) {
return nil, providerUtils.NewConfigurationError(fmt.Sprintf("video generation is only supported for Veo models in Vertex, got: %s", bifrostReq.Model))
}
// Convert Bifrost request to Gemini format (reusing Gemini converters)
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
bifrostReq,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return gemini.ToGeminiVideoGenerationRequest(bifrostReq)
},
)
if bifrostErr != nil {
return nil, bifrostErr
}
projectID := key.VertexKeyConfig.ProjectID.GetValue()
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
// Auth query is used to pass the API key in the query string
authQuery := ""
if key.Value.GetValue() != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
}
// For custom/fine-tuned models, validate projectNumber is set
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
if schemas.IsAllDigitsASCII(bifrostReq.Model) && projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
}
// Construct the URL for Gemini video generation using predictLongRunning
completeURL := getCompleteURLForGeminiEndpoint(bifrostReq.Model, region, projectID, projectNumber, ":predictLongRunning")
// Create HTTP request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
// If auth query is set, add it to the URL
// Otherwise, get the oauth2 token and set the Authorization header
if authQuery != "" {
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
} else {
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
}
req.SetRequestURI(completeURL)
req.SetBody(jsonData)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
// Parse response
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
var operation gemini.GenerateVideosOperation
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &operation, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, bifrostErr
}
// Convert to Bifrost response using Gemini converter
bifrostResp, bifrostErr := gemini.ToBifrostVideoGenerationResponse(&operation, bifrostReq.Model)
if bifrostErr != nil {
return nil, bifrostErr
}
bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName)
bifrostResp.ExtraFields.Latency = latency.Milliseconds()
bifrostResp.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
bifrostResp.ExtraFields.RawRequest = rawRequest
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
bifrostResp.ExtraFields.RawResponse = rawResponse
}
return bifrostResp, nil
}
// VideoRetrieve retrieves the status of a video generation operation.
// Uses the fetchPredictOperation endpoint for Vertex AI.
func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
// Construct base URL based on region
var baseURL string
if region == "global" {
baseURL = "https://aiplatform.googleapis.com/v1"
} else {
baseURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1", region)
}
// Construct the URL for fetching the operation status
// The operation name (bifrostReq.ID) already contains the full path:
// projects/PROJECT_ID/locations/REGION/publishers/google/models/MODEL_ID/operations/OPERATION_ID
// We need to extract the model path from it to construct the fetchPredictOperation endpoint
// Extract: projects/.../models/MODEL_ID from the operation name
taskID := providerUtils.StripVideoIDProviderSuffix(bifrostReq.ID, provider.GetProviderKey())
var modelPath string
if idx := strings.Index(taskID, "/operations/"); idx != -1 {
modelPath = taskID[:idx]
} else {
return nil, providerUtils.NewBifrostOperationError("invalid operation ID format", nil)
}
// Construct the URL: https://REGION-aiplatform.googleapis.com/v1/{modelPath}:fetchPredictOperation
completeURL := fmt.Sprintf("%s/%s:fetchPredictOperation", baseURL, modelPath)
// Auth query is used to pass the API key in the query string
authQuery := ""
if key.Value.GetValue() != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
}
// Create request body with operation name (using sjson to avoid map marshaling)
jsonBody, err := providerUtils.SetJSONField([]byte(`{}`), "operationName", taskID)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to marshal request", err)
}
// Create HTTP request
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
// If auth query is set, add it to the URL
// Otherwise, get the oauth2 token and set the Authorization header
if authQuery != "" {
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
} else {
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
}
req.SetRequestURI(completeURL)
req.SetBody(jsonBody)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
// Handle error response
if resp.StatusCode() != fasthttp.StatusOK {
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
}
// Parse response
var operation gemini.GenerateVideosOperation
_, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(resp.Body(), &operation, jsonBody, sendBackRawRequest, sendBackRawResponse)
if bifrostErr != nil {
return nil, bifrostErr
}
bifrostResp, bifrostErr := gemini.ToBifrostVideoGenerationResponse(&operation, "")
if bifrostErr != nil {
return nil, bifrostErr
}
bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey())
bifrostResp.ExtraFields.Latency = latency.Milliseconds()
bifrostResp.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
if sendBackRawResponse {
bifrostResp.ExtraFields.RawResponse = rawResponse
}
return bifrostResp, nil
}
// VideoDownload downloads the generated video content.
// First retrieves the video status to get the URL, then downloads the content.
// Handles both regular URLs and data URLs (base64-encoded videos).
func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
if request == nil || request.ID == "" {
return nil, providerUtils.NewBifrostOperationError("video_id is required", nil)
}
// Retrieve operation first to get the video URL
bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{
Provider: request.Provider,
ID: request.ID,
}
videoResp, bifrostErr := provider.VideoRetrieve(ctx, key, bifrostVideoRetrieveRequest)
if bifrostErr != nil {
return nil, bifrostErr
}
if videoResp.Status != schemas.VideoStatusCompleted {
return nil, providerUtils.NewBifrostOperationError(
fmt.Sprintf("video not ready, current status: %s", videoResp.Status),
nil)
}
if len(videoResp.Videos) == 0 {
return nil, providerUtils.NewBifrostOperationError("video URL not available", nil)
}
var content []byte
var latency time.Duration
var providerResponseHeaders map[string]string
contentType := "video/mp4"
// Check if it's a data URL (base64-encoded video)
if videoResp.Videos[0].Type == schemas.VideoOutputTypeBase64 && videoResp.Videos[0].Base64Data != nil {
// Decode base64 content
startTime := time.Now()
decoded, err := base64.StdEncoding.DecodeString(*videoResp.Videos[0].Base64Data)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err)
}
content = decoded
contentType = videoResp.Videos[0].ContentType
latency = time.Since(startTime)
} else if videoResp.Videos[0].Type == schemas.VideoOutputTypeURL && videoResp.Videos[0].URL != nil {
// Regular URL - fetch from HTTP endpoint
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
req.SetRequestURI(*videoResp.Videos[0].URL)
req.Header.SetMethod(http.MethodGet)
// Add authentication for Vertex video downloads
authQuery := ""
if key.Value.GetValue() != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
}
if authQuery != "" {
uri := *videoResp.Videos[0].URL
if strings.Contains(uri, "?") {
uri += "&" + authQuery
} else {
uri += "?" + authQuery
}
req.SetRequestURI(uri)
} else {
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
}
var bifrostErr *schemas.BifrostError
var wait func()
latency, bifrostErr, wait = providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
if resp.StatusCode() != fasthttp.StatusOK {
return nil, providerUtils.NewBifrostOperationError(
fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()),
nil)
}
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
}
contentType = string(resp.Header.ContentType())
content = append([]byte(nil), body...)
providerResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
} else {
return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil)
}
bifrostResp := &schemas.BifrostVideoDownloadResponse{
VideoID: request.ID,
Content: content,
ContentType: contentType,
}
bifrostResp.ExtraFields.Latency = latency.Milliseconds()
bifrostResp.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
return bifrostResp, nil
}
// VideoDelete is not supported by the Vertex provider.
func (provider *VertexProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
}
// VideoList is not supported by the Vertex provider.
func (provider *VertexProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
}
// VideoRemix is not supported by the Vertex provider.
func (provider *VertexProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
}
// stripVertexGeminiUnsupportedFields removes fields that are not supported by Vertex AI's Gemini API.
// Specifically, it removes the "id" field from function_call and function_response objects in contents.
func stripVertexGeminiUnsupportedFields(requestBody *gemini.GeminiGenerationRequest) {
for _, content := range requestBody.Contents {
for _, part := range content.Parts {
// Remove id from function_call
if part.FunctionCall != nil {
part.FunctionCall.ID = ""
}
// Remove id from function_response
if part.FunctionResponse != nil {
part.FunctionResponse.ID = ""
}
}
}
}
// BatchCreate is not supported by Vertex AI provider.
func (provider *VertexProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
}
// BatchList is not supported by Vertex AI provider.
func (provider *VertexProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
}
// BatchRetrieve is not supported by Vertex AI provider.
func (provider *VertexProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
}
// BatchCancel is not supported by Vertex AI provider.
func (provider *VertexProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
}
// BatchDelete is not supported by Vertex AI provider.
func (provider *VertexProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
}
// BatchResults is not supported by Vertex AI provider.
func (provider *VertexProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
}
// FileUpload is not yet implemented for Vertex AI provider.
// Vertex AI uses Google Cloud Storage (GCS) for batch input/output files.
func (provider *VertexProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
}
// FileList is not yet implemented for Vertex AI provider.
func (provider *VertexProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
}
// FileRetrieve is not yet implemented for Vertex AI provider.
func (provider *VertexProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
}
// FileDelete is not yet implemented for Vertex AI provider.
func (provider *VertexProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
}
// FileContent is not yet implemented for Vertex AI provider.
func (provider *VertexProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
}
// CountTokens counts the number of tokens in the provided content using Vertex AI's countTokens endpoint.
// Supports Gemini models with both text and image content.
func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
var (
jsonBody []byte
bifrostErr *schemas.BifrostError
)
if schemas.IsAnthropicModel(request.Model) {
jsonBody, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, request.Model, false, true, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders)
if bifrostErr != nil {
return nil, bifrostErr
}
} else {
jsonBody, bifrostErr = providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return gemini.ToGeminiResponsesRequest(request)
},
)
if bifrostErr != nil {
return nil, bifrostErr
}
// Skip field-stripping when large payload mode is active — jsonBody is nil
// and the raw body will stream directly from the ingress reader.
if jsonBody != nil {
// Use sjson to delete fields directly from JSON bytes, preserving key ordering
jsonBody, _ = providerUtils.DeleteJSONField(jsonBody, "toolConfig")
jsonBody, _ = providerUtils.DeleteJSONField(jsonBody, "generationConfig")
jsonBody, _ = providerUtils.DeleteJSONField(jsonBody, "systemInstruction")
}
}
projectID := key.VertexKeyConfig.ProjectID.GetValue()
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
region := key.VertexKeyConfig.Region.GetValue()
if region == "" {
return nil, providerUtils.NewConfigurationError("region is not set in key config")
}
authQuery := ""
var completeURL string
if schemas.IsAnthropicModel(request.Model) {
if region == "global" {
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/count-tokens:rawPredict", projectID)
} else {
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/count-tokens:rawPredict", region, projectID, region)
}
} else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
if key.Value.GetValue() != "" {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
}
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" {
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
}
completeURL = getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":countTokens")
}
if completeURL == "" {
return nil, providerUtils.NewConfigurationError(fmt.Sprintf("count tokens is not supported for model: %s", request.Model))
}
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
respOwned := true
defer func() {
if respOwned {
fasthttp.ReleaseResponse(resp)
}
}()
req.Header.SetMethod(http.MethodPost)
req.Header.SetContentType("application/json")
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, []string{anthropic.AnthropicBetaHeader})
if authQuery != "" {
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
} else {
tokenSource, err := getAuthTokenSource(key)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
}
req.SetRequestURI(completeURL)
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
if !usedLargePayloadBody {
req.SetBody(jsonBody)
}
// Make the request with optional large response streaming
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
defer wait()
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if usedLargePayloadBody {
providerUtils.DrainLargePayloadRemainder(ctx)
}
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
if resp.StatusCode() != fasthttp.StatusOK {
providerUtils.MaterializeStreamErrorBody(ctx, resp)
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
if decodeErr != nil {
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
if isLargeResp {
respOwned = false
return &schemas.BifrostCountTokensResponse{
ExtraFields: schemas.BifrostResponseExtraFields{
Latency: latency.Milliseconds(),
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
},
}, nil
}
if schemas.IsAnthropicModel(request.Model) {
anthropicResponse := &anthropic.AnthropicCountTokensResponse{}
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response := anthropicResponse.ToBifrostCountTokensResponse(request.Model)
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
vertexResponse := VertexCountTokensResponse{}
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &vertexResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
if bifrostErr != nil {
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
response := vertexResponse.ToBifrostCountTokensResponse(request.Model)
response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
response.ExtraFields.RawRequest = rawRequest
}
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
response.ExtraFields.RawResponse = rawResponse
}
return response, nil
}
// ContainerCreate is not supported by the Vertex provider.
func (provider *VertexProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
}
// ContainerList is not supported by the Vertex provider.
func (provider *VertexProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
}
// ContainerRetrieve is not supported by the Vertex provider.
func (provider *VertexProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
}
// ContainerDelete is not supported by the Vertex provider.
func (provider *VertexProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
}
// ContainerFileCreate is not supported by the Vertex provider.
func (provider *VertexProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
}
// ContainerFileList is not supported by the Vertex provider.
func (provider *VertexProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
}
// ContainerFileRetrieve is not supported by the Vertex provider.
func (provider *VertexProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
}
// ContainerFileContent is not supported by the Vertex provider.
func (provider *VertexProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
}
// ContainerFileDelete is not supported by the Vertex provider.
func (provider *VertexProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
}
func (provider *VertexProvider) Passthrough(
ctx *schemas.BifrostContext,
key schemas.Key,
req *schemas.BifrostPassthroughRequest,
) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue())
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
keyRegion := key.VertexKeyConfig.Region.GetValue()
if keyRegion == "" {
keyRegion = "global"
}
var baseURL string
if keyRegion == "global" {
baseURL = "https://aiplatform.googleapis.com/v1"
} else {
baseURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1", keyRegion)
}
// Normalize path: remove leading /v1 or /v1/ to avoid duplicate version segments (e.g. /v1/v1/...)
path := req.Path
for strings.HasPrefix(path, "/v1/") || path == "/v1" {
path = strings.TrimPrefix(path, "/v1/")
path = strings.TrimPrefix(path, "/v1")
if path != "" && !strings.HasPrefix(path, "/") {
path = "/" + path
}
}
// Replace region in path with key's configured region (client path may have different region)
if strings.Contains(path, "/locations/") {
path = vertexLocationsPathRe.ReplaceAllString(path, "/locations/"+keyRegion)
if strings.Contains(path, "/projects/") {
path = vertexProjectsPathRe.ReplaceAllString(path, "/projects/"+projectID)
}
} else {
// add projects/%s/locations/%s/publishers/google to path
path = fmt.Sprintf("/projects/%s/locations/%s%s", projectID, keyRegion, path)
}
requestURL := baseURL + path
if req.RawQuery != "" {
requestURL += "?" + req.RawQuery
}
// Only use API key for Google publisher endpoints; Anthropic/Mistral/OpenAPI-style paths require OAuth.
authQuery := ""
if key.Value.GetValue() != "" && strings.Contains(path, "publishers/google") {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
}
// Prepare fasthttp request
fasthttpReq := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(resp)
defer fasthttp.ReleaseRequest(fasthttpReq)
fasthttpReq.Header.SetMethod(req.Method)
// If auth query is set, add it to the URL; otherwise use OAuth2
if authQuery != "" {
if strings.Contains(requestURL, "?") {
requestURL += "&" + authQuery
} else {
requestURL += "?" + authQuery
}
} else {
tokenSource, err := getAuthTokenSource(key)
if err != nil {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
fasthttpReq.Header.Set("Authorization", "Bearer "+token.AccessToken)
}
fasthttpReq.SetRequestURI(requestURL)
// Set extra headers from provider network config
providerUtils.SetExtraHeaders(ctx, fasthttpReq, provider.networkConfig.ExtraHeaders, nil)
// Set safe headers from client request
for k, v := range req.SafeHeaders {
if strings.EqualFold(k, "authorization") || strings.EqualFold(k, "proxy-authorization") {
continue
}
fasthttpReq.Header.Set(k, v)
}
if len(req.Body) > 0 && strings.Contains(strings.ToLower(string(fasthttpReq.Header.ContentType())), "application/json") {
region := keyRegion
// Replace fully-qualified model paths that have placeholder project/location
// e.g. "projects/None/locations/None/publishers/..." -> "projects/real-id/locations/real-region/..."
body := req.Body
bodyStr := vertexBodyProjectsRe.ReplaceAllString(string(body), "${1}projects/"+projectID)
bodyStr = vertexLocationsPathRe.ReplaceAllString(bodyStr, "/locations/"+region)
// Expand short-form model names: "models/X" -> "projects/P/locations/L/publishers/google/models/X"
bodyStr = vertexShortModelRe.ReplaceAllString(bodyStr,
fmt.Sprintf(`"projects/%s/locations/%s/publishers/google/$1"`, projectID, keyRegion))
fasthttpReq.SetBodyString(bodyStr)
} else if len(req.Body) > 0 {
fasthttpReq.SetBody(req.Body)
}
// Execute request
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, fasthttpReq, resp)
defer wait()
if bifrostErr != nil {
return nil, bifrostErr
}
// Remove client from pool for authentication/authorization errors
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
headers := providerUtils.ExtractProviderResponseHeaders(resp)
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers)
body, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err)
}
for k := range headers {
if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") {
delete(headers, k)
}
}
bifrostResponse := &schemas.BifrostPassthroughResponse{
StatusCode: resp.StatusCode(),
Headers: headers,
Body: body,
}
bifrostResponse.ExtraFields.ProviderResponseHeaders = headers
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &bifrostResponse.ExtraFields)
}
return bifrostResponse, nil
}
func (provider *VertexProvider) PassthroughStream(
ctx *schemas.BifrostContext,
postHookRunner schemas.PostHookRunner,
postHookSpanFinalizer func(context.Context),
key schemas.Key,
req *schemas.BifrostPassthroughRequest,
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue())
if projectID == "" {
return nil, providerUtils.NewConfigurationError("project ID is not set")
}
keyRegion := key.VertexKeyConfig.Region.GetValue()
if keyRegion == "" {
keyRegion = "global"
}
var baseURL string
if keyRegion == "global" {
baseURL = "https://aiplatform.googleapis.com/v1"
} else {
baseURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1", keyRegion)
}
// Normalize path: remove leading /v1 or /v1/ to avoid duplicate version segments.
path := req.Path
for strings.HasPrefix(path, "/v1/") || path == "/v1" {
path = strings.TrimPrefix(path, "/v1/")
path = strings.TrimPrefix(path, "/v1")
if path != "" && !strings.HasPrefix(path, "/") {
path = "/" + path
}
}
// Replace region and project in path with key's configured values.
if strings.Contains(path, "/locations/") {
path = vertexLocationsPathRe.ReplaceAllString(path, "/locations/"+keyRegion)
if strings.Contains(path, "/projects/") {
path = vertexProjectsPathRe.ReplaceAllString(path, "/projects/"+projectID)
}
} else {
path = fmt.Sprintf("/projects/%s/locations/%s%s", projectID, keyRegion, path)
}
requestURL := baseURL + path
if req.RawQuery != "" {
requestURL += "?" + req.RawQuery
}
fasthttpReq := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
resp.StreamBody = true
defer fasthttp.ReleaseRequest(fasthttpReq)
fasthttpReq.Header.SetMethod(req.Method)
// Only use API key for Google publisher endpoints; Anthropic/Mistral/OpenAPI-style paths require OAuth.
authQuery := ""
if key.Value.GetValue() != "" && strings.Contains(path, "publishers/google") {
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
}
if authQuery != "" {
if strings.Contains(requestURL, "?") {
requestURL += "&" + authQuery
} else {
requestURL += "?" + authQuery
}
} else {
tokenSource, err := getAuthTokenSource(key)
if err != nil {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
providerUtils.ReleaseStreamingResponse(resp)
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
}
token, err := tokenSource.Token()
if err != nil {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
providerUtils.ReleaseStreamingResponse(resp)
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
}
fasthttpReq.Header.Set("Authorization", "Bearer "+token.AccessToken)
}
fasthttpReq.SetRequestURI(requestURL)
providerUtils.SetExtraHeaders(ctx, fasthttpReq, provider.networkConfig.ExtraHeaders, nil)
for k, v := range req.SafeHeaders {
if strings.EqualFold(k, "authorization") || strings.EqualFold(k, "proxy-authorization") {
continue
}
fasthttpReq.Header.Set(k, v)
}
if len(req.Body) > 0 && strings.Contains(strings.ToLower(string(fasthttpReq.Header.ContentType())), "application/json") {
bodyStr := vertexBodyProjectsRe.ReplaceAllString(string(req.Body), "${1}projects/"+projectID)
bodyStr = vertexLocationsPathRe.ReplaceAllString(bodyStr, "/locations/"+keyRegion)
bodyStr = vertexShortModelRe.ReplaceAllString(bodyStr,
fmt.Sprintf(`"projects/%s/locations/%s/publishers/google/$1"`, projectID, keyRegion))
fasthttpReq.SetBodyString(bodyStr)
} else if len(req.Body) > 0 {
fasthttpReq.SetBody(req.Body)
}
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.streamingClient, resp)
if err := activeClient.Do(fasthttpReq, resp); err != nil {
providerUtils.ReleaseStreamingResponse(resp)
if errors.Is(err, context.Canceled) {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Type: schemas.Ptr(schemas.RequestCancelled),
Message: schemas.ErrRequestCancelled,
Error: err,
},
}
}
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
}
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
}
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
}
headers := providerUtils.ExtractProviderResponseHeaders(resp)
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers)
bodyStream := resp.BodyStream()
if bodyStream == nil {
providerUtils.ReleaseStreamingResponse(resp)
return nil, providerUtils.NewBifrostOperationError(
"provider returned an empty stream body",
fmt.Errorf("provider returned an empty stream body"))
}
// Set stream idle timeout from provider config.
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
// Wrap body with idle timeout to detect stalled streams.
rawBodyStream := bodyStream
bodyStream, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(bodyStream, rawBodyStream, providerUtils.GetStreamIdleTimeout(ctx))
// Cancellation must close the raw stream to unblock reads.
stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger)
extraFields := schemas.BifrostResponseExtraFields{}
statusCode := resp.StatusCode()
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields)
}
ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
go func() {
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
defer func() {
if ctx.Err() == context.Canceled {
providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer)
} else if ctx.Err() == context.DeadlineExceeded {
providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer)
}
close(ch)
}()
defer providerUtils.ReleaseStreamingResponse(resp)
defer stopIdleTimeout()
defer stopCancellation()
streamStart := time.Now()
fullResponseBody := make([]byte, 0, maxStreamPassthroughCaptureBytes)
fullResponseBodyTruncated := false
terminalDetector := &providerUtils.StreamTerminalDetector{}
buf := make([]byte, 4096)
for {
n, readErr := bodyStream.Read(buf)
if n > 0 {
chunk := make([]byte, n)
copy(chunk, buf[:n])
if !fullResponseBodyTruncated {
remaining := maxStreamPassthroughCaptureBytes - len(fullResponseBody)
if remaining > 0 {
if n <= remaining {
fullResponseBody = append(fullResponseBody, chunk...)
} else {
fullResponseBody = append(fullResponseBody, chunk[:remaining]...)
fullResponseBodyTruncated = true
}
} else {
fullResponseBodyTruncated = true
}
}
select {
case ch <- &schemas.BifrostStreamChunk{
BifrostPassthroughResponse: &schemas.BifrostPassthroughResponse{
StatusCode: statusCode,
Headers: headers,
Body: chunk,
ExtraFields: extraFields,
},
}:
case <-ctx.Done():
return
}
// Vertex streamGenerateContent passthrough can emit terminal markers
// (finishReason) before the underlying HTTP body is closed.
// Finalize as success once this appears to avoid hanging clients.
if terminalDetector.ObserveChunk(chunk) {
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
extraFields.Latency = time.Since(streamStart).Milliseconds()
var capturedBody []byte
if !fullResponseBodyTruncated {
capturedBody = append([]byte(nil), fullResponseBody...)
}
finalResp := &schemas.BifrostResponse{
PassthroughResponse: &schemas.BifrostPassthroughResponse{
StatusCode: statusCode,
Headers: headers,
Body: capturedBody,
ExtraFields: extraFields,
},
}
postHookRunner(ctx, finalResp, nil)
return
}
}
if readErr == io.EOF {
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
extraFields.Latency = time.Since(streamStart).Milliseconds()
var capturedBody []byte
if !fullResponseBodyTruncated {
capturedBody = append([]byte(nil), fullResponseBody...)
}
finalResp := &schemas.BifrostResponse{
PassthroughResponse: &schemas.BifrostPassthroughResponse{
StatusCode: statusCode,
Headers: headers,
Body: capturedBody,
ExtraFields: extraFields,
},
}
postHookRunner(ctx, finalResp, nil)
return
}
if readErr != nil {
if ctx.Err() != nil {
return // let defer handle cancel/timeout
}
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
extraFields.Latency = time.Since(streamStart).Milliseconds()
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger, postHookSpanFinalizer)
return
}
}
}()
return ch, nil
}