3214 lines
133 KiB
Go
3214 lines
133 KiB
Go
package vertex
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/valyala/fasthttp"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/google"
|
|
|
|
"github.com/bytedance/sonic"
|
|
"github.com/maximhq/bifrost/core/providers/anthropic"
|
|
"github.com/maximhq/bifrost/core/providers/gemini"
|
|
"github.com/maximhq/bifrost/core/providers/openai"
|
|
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
|
schemas "github.com/maximhq/bifrost/core/schemas"
|
|
)
|
|
|
|
type VertexError struct {
|
|
Error struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Status string `json:"status"`
|
|
} `json:"error"`
|
|
}
|
|
|
|
// vertexClientPool provides a pool/cache for authenticated Vertex HTTP clients.
|
|
// This avoids creating and authenticating clients for every request.
|
|
// Uses sync.Map for atomic operations without explicit locking.
|
|
var vertexClientPool sync.Map
|
|
|
|
// vertexLocationsPathRe matches /locations/{region} in Vertex API paths for region replacement.
|
|
var vertexLocationsPathRe = regexp.MustCompile(`/locations/[^/]+`)
|
|
|
|
var vertexProjectsPathRe = regexp.MustCompile(`/projects/[^/]+`)
|
|
|
|
const maxStreamPassthroughCaptureBytes = 1024 * 1024
|
|
|
|
// vertexBodyProjectsRe matches projects/{project} in body JSON values,
|
|
// where the path may appear as "projects/X (after a JSON quote) or /projects/X (mid-path).
|
|
var vertexBodyProjectsRe = regexp.MustCompile(`(["/])projects/[^/"]+`)
|
|
|
|
// vertexShortModelRe matches short-form model names like "models/X" in JSON bodies
|
|
// that need expanding to the full Vertex resource path.
|
|
var vertexShortModelRe = regexp.MustCompile(`"(models/[^/"]+)"`)
|
|
|
|
// getClientKey generates a unique key for caching authenticated clients.
|
|
// It uses a hash of the auth credentials for security.
|
|
func getClientKey(authCredentials string) string {
|
|
hash := sha256.Sum256([]byte(authCredentials))
|
|
return hex.EncodeToString(hash[:])
|
|
}
|
|
|
|
// removeVertexClient removes a specific client from the pool.
|
|
// This should be called when:
|
|
// - API returns authentication/authorization errors (401, 403)
|
|
// - Auth client creation fails
|
|
// - Network errors that might indicate credential issues
|
|
// This ensures we don't keep using potentially invalid clients.
|
|
func removeVertexClient(authCredentials string) {
|
|
clientKey := getClientKey(authCredentials)
|
|
vertexClientPool.Delete(clientKey)
|
|
}
|
|
|
|
// VertexProvider implements the Provider interface for Google's Vertex AI API.
|
|
type VertexProvider struct {
|
|
logger schemas.Logger // Logger for provider operations
|
|
client *fasthttp.Client // HTTP client for unary API requests (ReadTimeout bounds overall response)
|
|
streamingClient *fasthttp.Client // HTTP client for streaming API requests (no ReadTimeout; idle governed by NewIdleTimeoutReader)
|
|
networkConfig schemas.NetworkConfig // Network configuration including extra headers
|
|
sendBackRawRequest bool // Whether to include raw request in BifrostResponse
|
|
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
|
|
}
|
|
|
|
// NewVertexProvider creates a new Vertex provider instance.
|
|
// It initializes the HTTP client with the provided configuration and sets up response pools.
|
|
// The client is configured with timeouts, concurrency limits, and optional proxy settings.
|
|
func NewVertexProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*VertexProvider, error) {
|
|
config.CheckAndSetDefaults()
|
|
requestTimeout := time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)
|
|
client := &fasthttp.Client{
|
|
ReadTimeout: requestTimeout,
|
|
WriteTimeout: requestTimeout,
|
|
MaxConnsPerHost: config.NetworkConfig.MaxConnsPerHost,
|
|
MaxIdleConnDuration: 30 * time.Second,
|
|
MaxConnWaitTimeout: requestTimeout,
|
|
MaxConnDuration: time.Second * time.Duration(schemas.DefaultMaxConnDurationInSeconds),
|
|
ConnPoolStrategy: fasthttp.FIFO,
|
|
}
|
|
client = providerUtils.ConfigureProxy(client, config.ProxyConfig, logger)
|
|
client = providerUtils.ConfigureDialer(client)
|
|
client = providerUtils.ConfigureTLS(client, config.NetworkConfig, logger)
|
|
streamingClient := providerUtils.BuildStreamingClient(client)
|
|
return &VertexProvider{
|
|
logger: logger,
|
|
client: client,
|
|
streamingClient: streamingClient,
|
|
networkConfig: config.NetworkConfig,
|
|
sendBackRawRequest: config.SendBackRawRequest,
|
|
sendBackRawResponse: config.SendBackRawResponse,
|
|
}, nil
|
|
}
|
|
|
|
const cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
|
|
|
|
// getAuthTokenSource returns an authenticated token source for Vertex AI API requests.
|
|
// It uses the default credentials if no auth credentials are provided.
|
|
// It uses the JWT config if auth credentials are provided.
|
|
// It returns an error if the token source creation fails.
|
|
func getAuthTokenSource(key schemas.Key) (oauth2.TokenSource, error) {
|
|
authCredentials := key.VertexKeyConfig.AuthCredentials
|
|
var tokenSource oauth2.TokenSource
|
|
if authCredentials.GetValue() == "" {
|
|
creds, err := google.FindDefaultCredentials(context.Background(), cloudPlatformScope)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to find default credentials in environment: %w", err)
|
|
}
|
|
tokenSource = creds.TokenSource
|
|
} else {
|
|
jsonData := []byte(authCredentials.GetValue())
|
|
|
|
// Peek at the JSON to detect the "type" field
|
|
var meta struct {
|
|
Type string `json:"type"`
|
|
}
|
|
if err := sonic.Unmarshal(jsonData, &meta); err != nil {
|
|
return nil, fmt.Errorf("failed to parse auth credentials JSON: %w", err)
|
|
}
|
|
|
|
// Map string to google.CredentialsType with a security whitelist
|
|
var credType google.CredentialsType
|
|
switch meta.Type {
|
|
case string(google.ServiceAccount):
|
|
credType = google.ServiceAccount
|
|
case string(google.ImpersonatedServiceAccount):
|
|
credType = google.ImpersonatedServiceAccount
|
|
case string(google.AuthorizedUser):
|
|
credType = google.AuthorizedUser
|
|
case string(google.ExternalAccount):
|
|
credType = google.ExternalAccount
|
|
case string(google.ExternalAccountAuthorizedUser):
|
|
credType = google.ExternalAccountAuthorizedUser
|
|
case "":
|
|
return nil, fmt.Errorf("invalid google auth credentials: missing 'type'")
|
|
default:
|
|
return nil, fmt.Errorf("unsupported or restricted credential type: %s", meta.Type)
|
|
}
|
|
|
|
conf, err := google.CredentialsFromJSONWithType(context.Background(), jsonData, credType, cloudPlatformScope)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create credentials from auth credentials JSON: %w", err)
|
|
}
|
|
tokenSource = conf.TokenSource
|
|
}
|
|
return tokenSource, nil
|
|
}
|
|
|
|
// GetProviderKey returns the provider identifier for Vertex.
|
|
func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider {
|
|
return schemas.Vertex
|
|
}
|
|
|
|
// listModelsByKey performs a list models request for a single key.
|
|
// Returns the response and latency, or an error if the request fails.
|
|
//
|
|
// The logic is:
|
|
// 1. If deployments or allowedModels are configured, return those (no API call needed)
|
|
// 2. Otherwise, fetch from the publishers.models.list API endpoint (Model Garden)
|
|
func (provider *VertexProvider) listModelsByKey(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
deployments := key.Aliases
|
|
allowedModels := key.Models
|
|
|
|
if !request.Unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || key.BlacklistedModels.IsBlockAll()) {
|
|
return &schemas.BifrostListModelsResponse{Data: make([]schemas.Model, 0)}, nil
|
|
}
|
|
|
|
// If deployments or allowedModels are configured, return those directly without API call
|
|
// Skip this fast path when Unfiltered is set so the full Vertex catalog can be retrieved
|
|
if !request.Unfiltered && (len(deployments) > 0 || allowedModels.IsRestricted()) {
|
|
return buildResponseFromConfig(deployments, allowedModels, key.BlacklistedModels), nil
|
|
}
|
|
|
|
// No deployments configured - fetch from Model Garden API
|
|
var host string
|
|
if region == "global" {
|
|
host = "aiplatform.googleapis.com"
|
|
} else {
|
|
host = fmt.Sprintf("%s-aiplatform.googleapis.com", region)
|
|
}
|
|
|
|
// Accumulate all publisher models from paginated requests
|
|
var allPublisherModels []VertexPublisherModel
|
|
var rawRequests []interface{}
|
|
var rawResponses []interface{}
|
|
pageToken := ""
|
|
|
|
// Getting oauth2 token
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source (api key auth not supported for list models)", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token (api key auth not supported for list models)", err)
|
|
}
|
|
|
|
// Iterate over all supported Vertex publishers to include Google, Anthropic, and Mistral models
|
|
publishers := []string{"google", "anthropic", "mistralai"}
|
|
for _, publisher := range publishers {
|
|
pageToken = ""
|
|
// Loop through all pages until no nextPageToken is returned
|
|
for {
|
|
// Build URL for publishers.models.list endpoint (Model Garden)
|
|
// Format: https://{region}-aiplatform.googleapis.com/v1beta1/publishers/{publisher}/models
|
|
requestURL := fmt.Sprintf("https://%s/v1beta1/publishers/%s/models?pageSize=%d", host, publisher, MaxPageSize)
|
|
if pageToken != "" {
|
|
requestURL = fmt.Sprintf("%s&pageToken=%s", requestURL, url.QueryEscape(pageToken))
|
|
}
|
|
|
|
// Create HTTP request for listing models
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
|
|
req.Header.SetMethod(http.MethodGet)
|
|
req.SetRequestURI(requestURL)
|
|
req.Header.SetContentType("application/json")
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
|
|
_, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
if bifrostErr != nil {
|
|
wait()
|
|
respBody := append([]byte(nil), resp.Body()...)
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
// Non-Google publishers may not be available in all regions; skip on error
|
|
if publisher != "google" {
|
|
break
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
|
|
// Non-Google publishers may not be available in all regions;
|
|
// skip only on 403/404 which indicate regional unavailability.
|
|
// Surface other errors (401, 429, 5xx) so they aren't silently swallowed.
|
|
if publisher != "google" && (resp.StatusCode() == fasthttp.StatusForbidden || resp.StatusCode() == fasthttp.StatusNotFound) {
|
|
wait()
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
break
|
|
}
|
|
|
|
respBody := append([]byte(nil), resp.Body()...)
|
|
statusCode := resp.StatusCode()
|
|
wait()
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
|
|
var errorResp VertexError
|
|
if err := sonic.Unmarshal(respBody, &errorResp); err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorResp.Error.Message, nil, statusCode, nil, nil), nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Parse Vertex's publisher models response
|
|
var vertexResponse VertexListPublisherModelsResponse
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(resp.Body(), &vertexResponse, nil, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
respBody := append([]byte(nil), resp.Body()...)
|
|
wait()
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, nil, respBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
rawRequests = append(rawRequests, rawRequest)
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
rawResponses = append(rawResponses, rawResponse)
|
|
}
|
|
|
|
// Accumulate models from this page
|
|
allPublisherModels = append(allPublisherModels, vertexResponse.PublisherModels...)
|
|
|
|
wait()
|
|
fasthttp.ReleaseRequest(req)
|
|
fasthttp.ReleaseResponse(resp)
|
|
|
|
// Check if there are more pages
|
|
if vertexResponse.NextPageToken == "" {
|
|
break
|
|
}
|
|
pageToken = vertexResponse.NextPageToken
|
|
}
|
|
}
|
|
|
|
// Create aggregated response from all pages
|
|
aggregatedResponse := &VertexListPublisherModelsResponse{
|
|
PublisherModels: allPublisherModels,
|
|
}
|
|
|
|
response := aggregatedResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.Aliases, request.Unfiltered)
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequests
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponses
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// ListModels performs a list models request to Vertex's API.
|
|
// Requests are made concurrently for improved performance.
|
|
func (provider *VertexProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) {
|
|
finalResponse, bifrostErr := providerUtils.HandleMultipleListModelsRequests(
|
|
ctx,
|
|
keys,
|
|
request,
|
|
provider.listModelsByKey,
|
|
)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
return finalResponse, nil
|
|
}
|
|
|
|
// TextCompletion is not supported by the Vertex provider.
|
|
// Returns an error indicating that text completion is not available.
|
|
func (provider *VertexProvider) TextCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostTextCompletionResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// TextCompletionStream performs a streaming text completion request to Vertex's API.
|
|
// It formats the request, sends it to Vertex, and processes the response.
|
|
// Returns a channel of BifrostStreamChunk objects or an error if the request fails.
|
|
func (provider *VertexProvider) TextCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTextCompletionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.TextCompletionStreamRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ChatCompletion performs a chat completion request to the Vertex API.
|
|
// It supports both text and image content in messages.
|
|
// Returns a BifrostResponse containing the completion results or an error if the request fails.
|
|
func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) {
|
|
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
// Format messages for Vertex API, preserving key order for prompt caching
|
|
var rawBody []byte
|
|
var extraParams map[string]interface{}
|
|
var err error
|
|
|
|
if schemas.IsAnthropicModel(request.Model) {
|
|
// Use centralized Anthropic converter
|
|
reqBody, convErr := anthropic.ToAnthropicChatRequest(ctx, request)
|
|
if convErr != nil {
|
|
return nil, convErr
|
|
}
|
|
if reqBody == nil {
|
|
return nil, fmt.Errorf("chat completion input is not provided")
|
|
}
|
|
extraParams = reqBody.GetExtraParams()
|
|
// Add provider-aware beta headers for Vertex
|
|
anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex)
|
|
// Marshal to JSON bytes, preserving struct field order
|
|
rawBody, err = providerUtils.MarshalSorted(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
|
}
|
|
// Add anthropic_version if not present (using sjson to preserve order)
|
|
if !providerUtils.JSONFieldExists(rawBody, "anthropic_version") {
|
|
rawBody, err = providerUtils.SetJSONField(rawBody, "anthropic_version", DefaultVertexAnthropicVersion)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to set anthropic_version: %w", err)
|
|
}
|
|
}
|
|
// Inject beta headers into body as anthropic_beta (Vertex uses body field, not HTTP header)
|
|
if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(provider.networkConfig.ExtraHeaders, ctx), schemas.Vertex, provider.networkConfig.BetaHeaderOverrides); len(betaHeaders) > 0 {
|
|
rawBody, err = providerUtils.SetJSONField(rawBody, "anthropic_beta", betaHeaders)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to set anthropic_beta: %w", err)
|
|
}
|
|
}
|
|
// Remove model field (it's in URL for Vertex)
|
|
rawBody, err = providerUtils.DeleteJSONField(rawBody, "model")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to delete model field: %w", err)
|
|
}
|
|
} else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
|
|
reqBody, err := gemini.ToGeminiChatCompletionRequest(request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if reqBody == nil {
|
|
return nil, fmt.Errorf("chat completion input is not provided")
|
|
}
|
|
extraParams = reqBody.GetExtraParams()
|
|
// Strip unsupported fields for Vertex Gemini
|
|
stripVertexGeminiUnsupportedFields(reqBody)
|
|
// Marshal to JSON bytes
|
|
rawBody, err = providerUtils.MarshalSorted(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
|
}
|
|
} else {
|
|
// Use centralized OpenAI converter for non-Claude models
|
|
reqBody := openai.ToOpenAIChatRequest(ctx, request)
|
|
if reqBody == nil {
|
|
return nil, fmt.Errorf("chat completion input is not provided")
|
|
}
|
|
extraParams = reqBody.GetExtraParams()
|
|
// Marshal to JSON bytes
|
|
rawBody, err = providerUtils.MarshalSorted(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
|
}
|
|
}
|
|
|
|
// Remove region field if present
|
|
rawBody, err = providerUtils.DeleteJSONField(rawBody, "region")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to delete region field: %w", err)
|
|
}
|
|
return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil
|
|
},
|
|
)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
projectID := key.VertexKeyConfig.ProjectID.GetValue()
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
// Remap unsupported tool versions for Vertex (handles raw passthrough bodies)
|
|
if schemas.IsAnthropicModel(request.Model) && jsonBody != nil {
|
|
remappedBody, remapErr := anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex)
|
|
if remapErr != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil)
|
|
}
|
|
jsonBody = remappedBody
|
|
}
|
|
|
|
// Auth query is used for fine-tuned models to pass the API key in the query string
|
|
authQuery := ""
|
|
// Determine the URL based on model type
|
|
var completeURL string
|
|
if schemas.IsAllDigitsASCII(request.Model) {
|
|
// Custom Fine-tuned models use OpenAPI endpoint
|
|
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
|
|
if projectNumber == "" {
|
|
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
|
|
}
|
|
if key.Value.GetValue() != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
|
|
}
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model)
|
|
}
|
|
} else if schemas.IsAnthropicModel(request.Model) {
|
|
// Claude models use Anthropic publisher
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model)
|
|
}
|
|
} else if schemas.IsMistralModel(request.Model) {
|
|
// Mistral models use mistralai publisher with rawPredict
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:rawPredict", projectID, request.Model)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:rawPredict", region, projectID, region, request.Model)
|
|
}
|
|
} else if schemas.IsGeminiModel(request.Model) {
|
|
// Gemini models support api key
|
|
if key.Value.GetValue() != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
|
|
}
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model)
|
|
}
|
|
} else {
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/openapi/chat/completions", projectID)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region)
|
|
}
|
|
}
|
|
|
|
// Create HTTP request for streaming
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
// Skip anthropic-beta from context headers — Anthropic models on Vertex use the
|
|
// anthropic_beta body field instead, and other model families don't use it.
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, []string{anthropic.AnthropicBetaHeader})
|
|
|
|
// If auth query is set, add it to the URL
|
|
// Otherwise, get the oauth2 token and set the Authorization header
|
|
if authQuery != "" {
|
|
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
|
|
} else {
|
|
// Getting oauth2 token
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
}
|
|
|
|
req.SetRequestURI(completeURL)
|
|
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
|
|
if !usedLargePayloadBody {
|
|
req.SetBody(jsonBody)
|
|
}
|
|
|
|
// Make the request with optional large response streaming
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if usedLargePayloadBody {
|
|
providerUtils.DrainLargePayloadRemainder(ctx)
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
// Remove client from pool for authentication/authorization errors
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
|
|
if decodeErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if isLargeResp {
|
|
respOwned = false
|
|
return &schemas.BifrostChatResponse{
|
|
Model: request.Model,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
if schemas.IsAnthropicModel(request.Model) {
|
|
// Create response object from pool
|
|
anthropicResponse := anthropic.AcquireAnthropicMessageResponse()
|
|
defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse)
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Create final response
|
|
response := anthropicResponse.ToBifrostChatResponse(ctx)
|
|
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
|
|
}
|
|
|
|
// Set raw request if enabled
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
// Set raw response if enabled
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
} else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
|
|
geminiResponse := gemini.GenerateContentResponse{}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
response := geminiResponse.ToBifrostChatResponse()
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
} else {
|
|
response := &schemas.BifrostChatResponse{}
|
|
|
|
// Use enhanced response handler with pre-allocated response
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, response, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
// Set raw request if enabled
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
// Set raw response if enabled
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
}
|
|
|
|
// ChatCompletionStream performs a streaming chat completion request to the Vertex API.
|
|
// It supports both OpenAI-style streaming (for non-Claude models) and Anthropic-style streaming (for Claude models).
|
|
// Returns a channel of BifrostStreamChunk objects for streaming results or an error if the request fails.
|
|
func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
providerName := provider.GetProviderKey()
|
|
projectID := key.VertexKeyConfig.ProjectID.GetValue()
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
if schemas.IsAnthropicModel(request.Model) {
|
|
// Use Anthropic-style streaming for Claude models
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
var extraParams map[string]interface{}
|
|
reqBody, convErr := anthropic.ToAnthropicChatRequest(ctx, request)
|
|
if convErr != nil {
|
|
return nil, convErr
|
|
}
|
|
if reqBody == nil {
|
|
return nil, fmt.Errorf("chat completion input is not provided")
|
|
}
|
|
extraParams = reqBody.GetExtraParams()
|
|
reqBody.Stream = new(true)
|
|
// Add provider-aware beta headers for Vertex
|
|
anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex)
|
|
|
|
// Marshal to JSON bytes, preserving struct field order for prompt caching
|
|
rawBody, err := providerUtils.MarshalSorted(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
|
}
|
|
|
|
// Add anthropic_version if not present (using sjson to preserve order)
|
|
if !providerUtils.JSONFieldExists(rawBody, "anthropic_version") {
|
|
rawBody, err = providerUtils.SetJSONField(rawBody, "anthropic_version", DefaultVertexAnthropicVersion)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to set anthropic_version: %w", err)
|
|
}
|
|
}
|
|
// Inject beta headers into body as anthropic_beta (Vertex uses body field, not HTTP header)
|
|
if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(provider.networkConfig.ExtraHeaders, ctx), schemas.Vertex, provider.networkConfig.BetaHeaderOverrides); len(betaHeaders) > 0 {
|
|
rawBody, err = providerUtils.SetJSONField(rawBody, "anthropic_beta", betaHeaders)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to set anthropic_beta: %w", err)
|
|
}
|
|
}
|
|
|
|
// Remove model and region fields (using sjson to preserve order)
|
|
rawBody, err = providerUtils.DeleteJSONField(rawBody, "model")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to delete model field: %w", err)
|
|
}
|
|
rawBody, err = providerUtils.DeleteJSONField(rawBody, "region")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to delete region field: %w", err)
|
|
}
|
|
return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil
|
|
},
|
|
)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Remap unsupported tool versions for Vertex streaming (handles raw passthrough bodies)
|
|
if jsonData != nil {
|
|
var remapErr error
|
|
jsonData, remapErr = anthropic.RemapRawToolVersionsForProvider(jsonData, schemas.Vertex)
|
|
if remapErr != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(remapErr.Error(), nil)
|
|
}
|
|
}
|
|
|
|
var completeURL string
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model)
|
|
}
|
|
|
|
// Prepare headers for Vertex Anthropic
|
|
headers := map[string]string{
|
|
"Content-Type": "application/json",
|
|
"Accept": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
}
|
|
|
|
// Adding authorization header
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
headers["Authorization"] = "Bearer " + token.AccessToken
|
|
|
|
// Use shared Anthropic streaming logic
|
|
return anthropic.HandleAnthropicChatCompletionStreaming(
|
|
ctx,
|
|
provider.streamingClient,
|
|
completeURL,
|
|
jsonData,
|
|
headers,
|
|
provider.networkConfig.ExtraHeaders,
|
|
provider.networkConfig.BetaHeaderOverrides,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
providerName,
|
|
postHookRunner,
|
|
nil,
|
|
provider.logger,
|
|
postHookSpanFinalizer,
|
|
)
|
|
} else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
|
|
// Use Gemini-style streaming for Gemini models
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
reqBody, err := gemini.ToGeminiChatCompletionRequest(request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if reqBody == nil {
|
|
return nil, fmt.Errorf("chat completion input is not provided")
|
|
}
|
|
// Strip unsupported fields for Vertex Gemini
|
|
stripVertexGeminiUnsupportedFields(reqBody)
|
|
return reqBody, nil
|
|
},
|
|
)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Auth query is used to pass the API key in the query string
|
|
authQuery := ""
|
|
if key.Value.GetValue() != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
|
|
}
|
|
|
|
// For custom/fine-tuned models, validate projectNumber is set
|
|
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
|
|
if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" {
|
|
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
|
|
}
|
|
|
|
// Construct the URL for Gemini streaming
|
|
completeURL := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":streamGenerateContent")
|
|
|
|
// Add alt=sse parameter
|
|
if authQuery != "" {
|
|
completeURL = fmt.Sprintf("%s?alt=sse&%s", completeURL, authQuery)
|
|
} else {
|
|
completeURL = fmt.Sprintf("%s?alt=sse", completeURL)
|
|
}
|
|
|
|
// Prepare headers for Vertex Gemini
|
|
headers := map[string]string{
|
|
"Accept": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
}
|
|
|
|
// If no auth query, use OAuth2 token
|
|
if authQuery == "" {
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
headers["Authorization"] = "Bearer " + token.AccessToken
|
|
}
|
|
|
|
// Use shared streaming logic from Gemini
|
|
return gemini.HandleGeminiChatCompletionStream(
|
|
ctx,
|
|
provider.streamingClient,
|
|
completeURL,
|
|
jsonData,
|
|
headers,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
request.Model,
|
|
postHookRunner,
|
|
nil,
|
|
provider.logger,
|
|
postHookSpanFinalizer,
|
|
)
|
|
} else {
|
|
var authHeader map[string]string
|
|
// Auth query is used for fine-tuned models to pass the API key in the query string
|
|
authQuery := ""
|
|
// Determine the URL based on model type
|
|
var completeURL string
|
|
if schemas.IsMistralModel(request.Model) {
|
|
// Mistral models use mistralai publisher with streamRawPredict
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/mistralai/models/%s:streamRawPredict", projectID, request.Model)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/mistralai/models/%s:streamRawPredict", region, projectID, region, request.Model)
|
|
}
|
|
} else {
|
|
// Other models use OpenAPI endpoint for gemini models
|
|
if key.Value.GetValue() != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
|
|
}
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/openapi/chat/completions", projectID)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region)
|
|
}
|
|
}
|
|
|
|
if authQuery != "" {
|
|
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
|
|
} else {
|
|
// Getting oauth2 token
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
authHeader = map[string]string{
|
|
"Authorization": "Bearer " + token.AccessToken,
|
|
}
|
|
}
|
|
|
|
// Use shared OpenAI streaming logic
|
|
return openai.HandleOpenAIChatCompletionStreaming(
|
|
ctx,
|
|
provider.streamingClient,
|
|
completeURL,
|
|
request,
|
|
authHeader,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
providerName,
|
|
postHookRunner,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
provider.logger,
|
|
postHookSpanFinalizer,
|
|
)
|
|
}
|
|
}
|
|
|
|
// Responses performs a responses request to the Vertex API.
|
|
func (provider *VertexProvider) Responses(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
|
|
if schemas.IsAnthropicModel(request.Model) {
|
|
jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, false, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
projectID := key.VertexKeyConfig.ProjectID.GetValue()
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
// Claude models use Anthropic publisher
|
|
var url string
|
|
if region == "global" {
|
|
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/publishers/anthropic/models/%s:rawPredict", projectID, request.Model)
|
|
} else {
|
|
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model)
|
|
}
|
|
|
|
// Create HTTP request for streaming
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, []string{anthropic.AnthropicBetaHeader})
|
|
|
|
// Getting oauth2 token
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
|
|
req.SetRequestURI(url)
|
|
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
|
|
if !usedLargePayloadBody {
|
|
req.SetBody(jsonBody)
|
|
}
|
|
|
|
// Make the request with optional large response streaming
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if usedLargePayloadBody {
|
|
providerUtils.DrainLargePayloadRemainder(ctx)
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
// Remove client from pool for authentication/authorization errors
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
|
|
if decodeErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if isLargeResp {
|
|
respOwned = false
|
|
return &schemas.BifrostResponsesResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Create response object from pool
|
|
anthropicResponse := anthropic.AcquireAnthropicMessageResponse()
|
|
defer anthropic.ReleaseAnthropicMessageResponse(anthropicResponse)
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Create final response
|
|
response := anthropicResponse.ToBifrostResponsesResponse(ctx)
|
|
|
|
response.ExtraFields = schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
}
|
|
|
|
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
// Set raw request if enabled
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
// Set raw response if enabled
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
} else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
|
|
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
reqBody, err := gemini.ToGeminiResponsesRequest(request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if reqBody == nil {
|
|
return nil, fmt.Errorf("responses input is not provided")
|
|
}
|
|
// Strip unsupported fields for Vertex Gemini
|
|
stripVertexGeminiUnsupportedFields(reqBody)
|
|
return reqBody, nil
|
|
},
|
|
)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
projectID := key.VertexKeyConfig.ProjectID.GetValue()
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
authQuery := ""
|
|
if key.Value.GetValue() != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
|
|
}
|
|
|
|
// For custom/fine-tuned models, validate projectNumber is set
|
|
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
|
|
if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" {
|
|
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
|
|
}
|
|
|
|
url := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":generateContent")
|
|
|
|
// Create HTTP request for streaming
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
// If auth query is set, add it to the URL
|
|
// Otherwise, get the oauth2 token and set the Authorization header
|
|
if authQuery != "" {
|
|
url = fmt.Sprintf("%s?%s", url, authQuery)
|
|
} else {
|
|
// Getting oauth2 token
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
}
|
|
|
|
req.SetRequestURI(url)
|
|
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
|
|
if !usedLargePayloadBody {
|
|
req.SetBody(jsonBody)
|
|
}
|
|
|
|
// Make the request with optional large response streaming
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if usedLargePayloadBody {
|
|
providerUtils.DrainLargePayloadRemainder(ctx)
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
// Remove client from pool for authentication/authorization errors
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
|
|
if decodeErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if isLargeResp {
|
|
respOwned = false
|
|
return &schemas.BifrostResponsesResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
geminiResponse := &gemini.GenerateContentResponse{}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
response := geminiResponse.ToResponsesBifrostResponsesResponse()
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
// Set raw response if enabled
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
return response, nil
|
|
} else {
|
|
chatResponse, err := provider.ChatCompletion(ctx, key, request.ToChatRequest())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
response := chatResponse.ToBifrostResponsesResponse()
|
|
return response, nil
|
|
}
|
|
}
|
|
|
|
// ResponsesStream performs a streaming responses request to the Vertex API.
|
|
func (provider *VertexProvider) ResponsesStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
if schemas.IsAnthropicModel(request.Model) {
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
projectID := key.VertexKeyConfig.ProjectID.GetValue()
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
jsonBody, bifrostErr := getRequestBodyForAnthropicResponses(ctx, request, request.Model, true, false, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
var url string
|
|
if region == "global" {
|
|
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:streamRawPredict", projectID, request.Model)
|
|
} else {
|
|
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model)
|
|
}
|
|
|
|
// Prepare headers for Vertex Anthropic
|
|
headers := map[string]string{
|
|
"Content-Type": "application/json",
|
|
"Accept": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
}
|
|
|
|
// Adding authorization header
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
headers["Authorization"] = "Bearer " + token.AccessToken
|
|
|
|
// Use shared streaming logic from Anthropic
|
|
return anthropic.HandleAnthropicResponsesStream(
|
|
ctx,
|
|
provider.streamingClient,
|
|
url,
|
|
jsonBody,
|
|
headers,
|
|
provider.networkConfig.ExtraHeaders,
|
|
provider.networkConfig.BetaHeaderOverrides,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
postHookRunner,
|
|
nil,
|
|
provider.logger,
|
|
postHookSpanFinalizer,
|
|
)
|
|
} else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
projectID := key.VertexKeyConfig.ProjectID.GetValue()
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
// Use Gemini-style streaming for Gemini models
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
reqBody, err := gemini.ToGeminiResponsesRequest(request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if reqBody == nil {
|
|
return nil, fmt.Errorf("responses input is not provided")
|
|
}
|
|
// Strip unsupported fields for Vertex Gemini
|
|
stripVertexGeminiUnsupportedFields(reqBody)
|
|
return reqBody, nil
|
|
},
|
|
)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Auth query is used to pass the API key in the query string
|
|
authQuery := ""
|
|
if key.Value.GetValue() != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
|
|
}
|
|
|
|
// For custom/fine-tuned models, validate projectNumber is set
|
|
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
|
|
if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" {
|
|
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
|
|
}
|
|
|
|
// Construct the URL for Gemini streaming
|
|
completeURL := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":streamGenerateContent")
|
|
// Add alt=sse parameter
|
|
if authQuery != "" {
|
|
completeURL = fmt.Sprintf("%s?alt=sse&%s", completeURL, authQuery)
|
|
} else {
|
|
completeURL = fmt.Sprintf("%s?alt=sse", completeURL)
|
|
}
|
|
|
|
// Prepare headers for Vertex Gemini
|
|
headers := map[string]string{
|
|
"Accept": "text/event-stream",
|
|
"Cache-Control": "no-cache",
|
|
}
|
|
|
|
// If no auth query, use OAuth2 token
|
|
if authQuery == "" {
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
headers["Authorization"] = "Bearer " + token.AccessToken
|
|
}
|
|
|
|
// Use shared streaming logic from Gemini
|
|
return gemini.HandleGeminiResponsesStream(
|
|
ctx,
|
|
provider.streamingClient,
|
|
completeURL,
|
|
jsonData,
|
|
headers,
|
|
provider.networkConfig.ExtraHeaders,
|
|
providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest),
|
|
providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse),
|
|
provider.GetProviderKey(),
|
|
request.Model,
|
|
postHookRunner,
|
|
nil,
|
|
provider.logger,
|
|
postHookSpanFinalizer,
|
|
)
|
|
} else {
|
|
ctx.SetValue(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback, true)
|
|
return provider.ChatCompletionStream(
|
|
ctx,
|
|
postHookRunner,
|
|
postHookSpanFinalizer,
|
|
key,
|
|
request.ToChatRequest(),
|
|
)
|
|
}
|
|
}
|
|
|
|
// Embedding generates embeddings for the given input text(s) using Vertex AI.
|
|
// All Vertex AI embedding models use the same response format regardless of the model type.
|
|
// Returns a BifrostResponse containing the embedding(s) and any error that occurred.
|
|
func (provider *VertexProvider) Embedding(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostEmbeddingResponse, *schemas.BifrostError) {
|
|
projectID := key.VertexKeyConfig.ProjectID.GetValue()
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return ToVertexEmbeddingRequest(request), nil
|
|
},
|
|
)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// For custom/fine-tuned models, validate projectNumber is set
|
|
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
|
|
if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" {
|
|
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
|
|
}
|
|
|
|
// Build the native Vertex embedding API endpoint
|
|
url := getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":predict")
|
|
|
|
// Create HTTP request for streaming
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.SetRequestURI(url)
|
|
req.Header.SetContentType("application/json")
|
|
|
|
// Set any extra headers from network config
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
// Getting oauth2 token
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
|
|
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
|
|
if !usedLargePayloadBody {
|
|
req.SetBody(jsonBody)
|
|
}
|
|
|
|
// Make the request with optional large response streaming
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if usedLargePayloadBody {
|
|
providerUtils.DrainLargePayloadRemainder(ctx)
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
// Remove client from pool for authentication/authorization errors
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
|
|
errBody := resp.Body()
|
|
|
|
// Extract error message from Vertex's error format
|
|
errorMessage := "Unknown error"
|
|
if len(errBody) > 0 {
|
|
// Try to parse Vertex's error format
|
|
var vertexError map[string]interface{}
|
|
if err := sonic.Unmarshal(errBody, &vertexError); err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
if errorObj, exists := vertexError["error"]; exists {
|
|
if errorMap, ok := errorObj.(map[string]interface{}); ok {
|
|
if message, exists := errorMap["message"]; exists {
|
|
if msgStr, ok := message.(string); ok {
|
|
errorMessage = msgStr
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), nil, nil), jsonBody, errBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
|
|
if decodeErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if isLargeResp {
|
|
respOwned = false
|
|
return &schemas.BifrostEmbeddingResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Parse Vertex's native embedding response using typed response
|
|
var vertexResponse VertexEmbeddingResponse
|
|
if err := sonic.Unmarshal(responseBody, &vertexResponse); err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Use centralized Vertex converter
|
|
bifrostResponse := vertexResponse.ToBifrostEmbeddingResponse()
|
|
|
|
// Set ExtraFields
|
|
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
|
|
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
// Set raw response if enabled
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
var rawResponseMap map[string]interface{}
|
|
if err := sonic.Unmarshal(resp.Body(), &rawResponseMap); err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError(schemas.ErrProviderRawResponseUnmarshal, err), jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
bifrostResponse.ExtraFields.RawResponse = rawResponseMap
|
|
}
|
|
|
|
return bifrostResponse, nil
|
|
}
|
|
|
|
// Speech is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) Speech(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostSpeechResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// Rerank performs a rerank request using Vertex Discovery Engine ranking API.
|
|
func (provider *VertexProvider) Rerank(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostRerankRequest) (*schemas.BifrostRerankResponse, *schemas.BifrostError) {
|
|
projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue())
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
options, err := getVertexRerankOptions(projectID, request.Params)
|
|
if err != nil {
|
|
return nil, providerUtils.NewConfigurationError(err.Error())
|
|
}
|
|
|
|
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return ToVertexRankRequest(request, options)
|
|
},
|
|
)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
completeURL := fmt.Sprintf("https://discoveryengine.googleapis.com/v1/%s:rank", options.RankingConfig)
|
|
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.SetRequestURI(completeURL)
|
|
req.Header.SetContentType("application/json")
|
|
req.Header.Set("X-Goog-User-Project", projectID)
|
|
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
|
|
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
|
|
if !usedLargePayloadBody {
|
|
req.SetBody(jsonBody)
|
|
}
|
|
|
|
// Make the request with optional large response streaming
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if usedLargePayloadBody {
|
|
providerUtils.DrainLargePayloadRemainder(ctx)
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
|
|
errorMessage := parseDiscoveryEngineErrorMessage(resp.Body())
|
|
parsedError := parseVertexError(resp)
|
|
|
|
if strings.TrimSpace(errorMessage) != "" {
|
|
shouldOverride := parsedError == nil ||
|
|
parsedError.Error == nil ||
|
|
strings.TrimSpace(parsedError.Error.Message) == "" ||
|
|
parsedError.Error.Message == "Unknown error" ||
|
|
parsedError.Error.Message == schemas.ErrProviderResponseUnmarshal
|
|
|
|
if shouldOverride {
|
|
parsedError = providerUtils.NewProviderAPIError(errorMessage, nil, resp.StatusCode(), nil, nil)
|
|
}
|
|
}
|
|
|
|
return nil, providerUtils.EnrichError(ctx, parsedError, jsonBody, resp.Body(), provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
|
|
if decodeErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if isLargeResp {
|
|
respOwned = false
|
|
return &schemas.BifrostRerankResponse{
|
|
Model: request.Model,
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
vertexResponse := &VertexRankResponse{}
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, vertexResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
returnDocuments := request.Params != nil && request.Params.ReturnDocuments != nil && *request.Params.ReturnDocuments
|
|
bifrostResponse, err := vertexResponse.ToBifrostRerankResponse(request.Documents, returnDocuments)
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error converting rerank response", err), jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
|
|
bifrostResponse.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
bifrostResponse.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
bifrostResponse.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return bifrostResponse, nil
|
|
}
|
|
|
|
// OCR is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) OCR(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostOCRRequest) (*schemas.BifrostOCRResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.OCRRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// SpeechStream is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) SpeechStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.SpeechStreamRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// Transcription is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) Transcription(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostTranscriptionResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// TranscriptionStream is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) TranscriptionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.TranscriptionStreamRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
func (provider *VertexProvider) ImageGeneration(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageGenerationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
// Validate model type before processing
|
|
if !schemas.IsGeminiModel(request.Model) && !schemas.IsAllDigitsASCII(request.Model) && !schemas.IsImagenModel(request.Model) {
|
|
return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image generation is only supported for Gemini and Imagen models, got: %s", request.Model))
|
|
}
|
|
|
|
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
var rawBody []byte
|
|
var extraParams map[string]interface{}
|
|
var err error
|
|
|
|
if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
|
|
reqBody := gemini.ToGeminiImageGenerationRequest(request)
|
|
if reqBody == nil {
|
|
return nil, fmt.Errorf("image generation input is not provided")
|
|
}
|
|
extraParams = reqBody.GetExtraParams()
|
|
// Strip unsupported fields for Vertex Gemini
|
|
stripVertexGeminiUnsupportedFields(reqBody)
|
|
// Marshal to JSON bytes, preserving key order
|
|
rawBody, err = providerUtils.MarshalSorted(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
|
}
|
|
} else if schemas.IsImagenModel(request.Model) {
|
|
reqBody := gemini.ToImagenImageGenerationRequest(request)
|
|
if reqBody == nil {
|
|
return nil, fmt.Errorf("image generation input is not provided")
|
|
}
|
|
extraParams = reqBody.GetExtraParams()
|
|
// Marshal to JSON bytes, preserving key order
|
|
rawBody, err = providerUtils.MarshalSorted(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
|
}
|
|
}
|
|
|
|
// Remove region field if present
|
|
rawBody, err = providerUtils.DeleteJSONField(rawBody, "region")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to delete region field: %w", err)
|
|
}
|
|
return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil
|
|
},
|
|
)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
projectID := key.VertexKeyConfig.ProjectID.GetValue()
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
// Auth query is used for fine-tuned models to pass the API key in the query string
|
|
authQuery := ""
|
|
// Determine the URL based on model type
|
|
var completeURL string
|
|
if schemas.IsAllDigitsASCII(request.Model) {
|
|
// Custom Fine-tuned models use OpenAPI endpoint
|
|
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
|
|
if projectNumber == "" {
|
|
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
|
|
}
|
|
if value := key.Value.GetValue(); value != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value))
|
|
}
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model)
|
|
}
|
|
|
|
} else if schemas.IsImagenModel(request.Model) {
|
|
// Imagen models are published models, use publishers/google/models path
|
|
if value := key.Value.GetValue(); value != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value))
|
|
}
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, request.Model)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, request.Model)
|
|
}
|
|
} else if schemas.IsGeminiModel(request.Model) {
|
|
if value := key.Value.GetValue(); value != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value))
|
|
}
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model)
|
|
}
|
|
}
|
|
|
|
// Create HTTP request for image generation
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
// If auth query is set, add it to the URL
|
|
// Otherwise, get the oauth2 token and set the Authorization header
|
|
if authQuery != "" {
|
|
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
|
|
} else {
|
|
// Getting oauth2 token
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
}
|
|
|
|
req.SetRequestURI(completeURL)
|
|
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
|
|
if !usedLargePayloadBody {
|
|
req.SetBody(jsonBody)
|
|
}
|
|
|
|
// Make the request with optional large response streaming
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if usedLargePayloadBody {
|
|
providerUtils.DrainLargePayloadRemainder(ctx)
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
// Remove client from pool for authentication/authorization errors
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
|
|
if decodeErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if isLargeResp {
|
|
respOwned = false
|
|
return &schemas.BifrostImageGenerationResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
|
|
geminiResponse := gemini.GenerateContentResponse{}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
response, err := geminiResponse.ToBifrostImageGenerationResponse()
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
} else {
|
|
// Handle Imagen responses
|
|
imagenResponse := gemini.GeminiImagenResponse{}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &imagenResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
response := imagenResponse.ToBifrostImageGenerationResponse()
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
}
|
|
|
|
// ImageGenerationStream is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) ImageGenerationStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageGenerationRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageGenerationStreamRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ImageEdit edits images for the given input text(s) using Vertex AI.
|
|
// Returns a BifrostResponse containing the images and any error that occurred.
|
|
func (provider *VertexProvider) ImageEdit(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageEditRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
// Validate model type before processing
|
|
if !schemas.IsGeminiModel(request.Model) && !schemas.IsAllDigitsASCII(request.Model) && !schemas.IsImagenModel(request.Model) {
|
|
return nil, providerUtils.NewConfigurationError(fmt.Sprintf("image edit is only supported for Gemini and Imagen models, got: %s", request.Model))
|
|
}
|
|
|
|
jsonBody, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
var rawBody []byte
|
|
var extraParams map[string]interface{}
|
|
var err error
|
|
|
|
if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
|
|
reqBody := gemini.ToGeminiImageEditRequest(request)
|
|
if reqBody == nil {
|
|
return nil, fmt.Errorf("image edit input is not provided")
|
|
}
|
|
extraParams = reqBody.GetExtraParams()
|
|
// Strip unsupported fields for Vertex Gemini
|
|
stripVertexGeminiUnsupportedFields(reqBody)
|
|
// Marshal to JSON bytes, preserving key order
|
|
rawBody, err = providerUtils.MarshalSorted(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
|
}
|
|
} else if schemas.IsImagenModel(request.Model) {
|
|
reqBody := gemini.ToImagenImageEditRequest(request)
|
|
if reqBody == nil {
|
|
return nil, fmt.Errorf("image edit input is not provided")
|
|
}
|
|
extraParams = reqBody.GetExtraParams()
|
|
// Marshal to JSON bytes, preserving key order
|
|
rawBody, err = providerUtils.MarshalSorted(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
|
}
|
|
}
|
|
|
|
// Remove region field if present
|
|
rawBody, err = providerUtils.DeleteJSONField(rawBody, "region")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to delete region field: %w", err)
|
|
}
|
|
return &VertexRawRequestBody{RawBody: rawBody, ExtraParams: extraParams}, nil
|
|
},
|
|
)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
projectID := key.VertexKeyConfig.ProjectID.GetValue()
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
authQuery := ""
|
|
if value := key.Value.GetValue(); value != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(value))
|
|
}
|
|
|
|
var completeURL string
|
|
if schemas.IsAllDigitsASCII(request.Model) {
|
|
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
|
|
if projectNumber == "" {
|
|
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
|
|
}
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s:generateContent", projectNumber, request.Model)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s:generateContent", region, projectNumber, region, request.Model)
|
|
}
|
|
} else if schemas.IsImagenModel(request.Model) {
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predict", projectID, request.Model)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", region, projectID, region, request.Model)
|
|
}
|
|
} else if schemas.IsGeminiModel(request.Model) {
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:generateContent", projectID, request.Model)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:generateContent", region, projectID, region, request.Model)
|
|
}
|
|
}
|
|
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
// If auth query is set, add it to the URL
|
|
// Otherwise, get the oauth2 token and set the Authorization header
|
|
if authQuery != "" {
|
|
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
|
|
} else {
|
|
// Getting oauth2 token
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
}
|
|
|
|
req.SetRequestURI(completeURL)
|
|
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
|
|
if !usedLargePayloadBody {
|
|
req.SetBody(jsonBody)
|
|
}
|
|
|
|
// Make the request with optional large response streaming
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if usedLargePayloadBody {
|
|
providerUtils.DrainLargePayloadRemainder(ctx)
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
|
|
if decodeErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if isLargeResp {
|
|
respOwned = false
|
|
return &schemas.BifrostImageGenerationResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
|
|
geminiResponse := gemini.GenerateContentResponse{}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &geminiResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
response, err := geminiResponse.ToBifrostImageGenerationResponse()
|
|
if err != nil {
|
|
return nil, providerUtils.EnrichError(ctx, err, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
} else {
|
|
// Handle Imagen responses
|
|
imagenResponse := gemini.GeminiImagenResponse{}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &imagenResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
response := imagenResponse.ToBifrostImageGenerationResponse()
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
}
|
|
|
|
// ImageEditStream is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) ImageEditStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostImageEditRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageEditStreamRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ImageVariation is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) ImageVariation(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostImageVariationRequest) (*schemas.BifrostImageGenerationResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ImageVariationRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// VideoGeneration generates a video using Vertex AI's Gemini models.
|
|
// Only Gemini models support video generation in Vertex AI.
|
|
// Uses the predictLongRunning endpoint for async video generation.
|
|
func (provider *VertexProvider) VideoGeneration(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoGenerationRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
|
providerName := provider.GetProviderKey()
|
|
|
|
// Only Gemini models support video generation in Vertex
|
|
if !schemas.IsVeoModel(bifrostReq.Model) && !schemas.IsAllDigitsASCII(bifrostReq.Model) {
|
|
return nil, providerUtils.NewConfigurationError(fmt.Sprintf("video generation is only supported for Veo models in Vertex, got: %s", bifrostReq.Model))
|
|
}
|
|
|
|
// Convert Bifrost request to Gemini format (reusing Gemini converters)
|
|
jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
bifrostReq,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return gemini.ToGeminiVideoGenerationRequest(bifrostReq)
|
|
},
|
|
)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
projectID := key.VertexKeyConfig.ProjectID.GetValue()
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
// Auth query is used to pass the API key in the query string
|
|
authQuery := ""
|
|
if key.Value.GetValue() != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
|
|
}
|
|
|
|
// For custom/fine-tuned models, validate projectNumber is set
|
|
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
|
|
if schemas.IsAllDigitsASCII(bifrostReq.Model) && projectNumber == "" {
|
|
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
|
|
}
|
|
|
|
// Construct the URL for Gemini video generation using predictLongRunning
|
|
completeURL := getCompleteURLForGeminiEndpoint(bifrostReq.Model, region, projectID, projectNumber, ":predictLongRunning")
|
|
|
|
// Create HTTP request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
// If auth query is set, add it to the URL
|
|
// Otherwise, get the oauth2 token and set the Authorization header
|
|
if authQuery != "" {
|
|
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
|
|
} else {
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
}
|
|
|
|
req.SetRequestURI(completeURL)
|
|
req.SetBody(jsonData)
|
|
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
// Parse response
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
|
|
var operation gemini.GenerateVideosOperation
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(body, &operation, jsonData, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
// Convert to Bifrost response using Gemini converter
|
|
bifrostResp, bifrostErr := gemini.ToBifrostVideoGenerationResponse(&operation, bifrostReq.Model)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, providerName)
|
|
|
|
bifrostResp.ExtraFields.Latency = latency.Milliseconds()
|
|
bifrostResp.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
bifrostResp.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
bifrostResp.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return bifrostResp, nil
|
|
}
|
|
|
|
// VideoRetrieve retrieves the status of a video generation operation.
|
|
// Uses the fetchPredictOperation endpoint for Vertex AI.
|
|
func (provider *VertexProvider) VideoRetrieve(ctx *schemas.BifrostContext, key schemas.Key, bifrostReq *schemas.BifrostVideoRetrieveRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
|
sendBackRawResponse := providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse)
|
|
sendBackRawRequest := providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest)
|
|
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
// Construct base URL based on region
|
|
var baseURL string
|
|
if region == "global" {
|
|
baseURL = "https://aiplatform.googleapis.com/v1"
|
|
} else {
|
|
baseURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1", region)
|
|
}
|
|
|
|
// Construct the URL for fetching the operation status
|
|
// The operation name (bifrostReq.ID) already contains the full path:
|
|
// projects/PROJECT_ID/locations/REGION/publishers/google/models/MODEL_ID/operations/OPERATION_ID
|
|
// We need to extract the model path from it to construct the fetchPredictOperation endpoint
|
|
// Extract: projects/.../models/MODEL_ID from the operation name
|
|
taskID := providerUtils.StripVideoIDProviderSuffix(bifrostReq.ID, provider.GetProviderKey())
|
|
var modelPath string
|
|
if idx := strings.Index(taskID, "/operations/"); idx != -1 {
|
|
modelPath = taskID[:idx]
|
|
} else {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid operation ID format", nil)
|
|
}
|
|
|
|
// Construct the URL: https://REGION-aiplatform.googleapis.com/v1/{modelPath}:fetchPredictOperation
|
|
completeURL := fmt.Sprintf("%s/%s:fetchPredictOperation", baseURL, modelPath)
|
|
|
|
// Auth query is used to pass the API key in the query string
|
|
authQuery := ""
|
|
if key.Value.GetValue() != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
|
|
}
|
|
|
|
// Create request body with operation name (using sjson to avoid map marshaling)
|
|
jsonBody, err := providerUtils.SetJSONField([]byte(`{}`), "operationName", taskID)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to marshal request", err)
|
|
}
|
|
|
|
// Create HTTP request
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
// If auth query is set, add it to the URL
|
|
// Otherwise, get the oauth2 token and set the Authorization header
|
|
if authQuery != "" {
|
|
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
|
|
} else {
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
}
|
|
|
|
req.SetRequestURI(completeURL)
|
|
req.SetBody(jsonBody)
|
|
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
// Handle error response
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, sendBackRawRequest, sendBackRawResponse)
|
|
}
|
|
|
|
// Parse response
|
|
var operation gemini.GenerateVideosOperation
|
|
_, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(resp.Body(), &operation, jsonBody, sendBackRawRequest, sendBackRawResponse)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
bifrostResp, bifrostErr := gemini.ToBifrostVideoGenerationResponse(&operation, "")
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
bifrostResp.ID = providerUtils.AddVideoIDProviderSuffix(bifrostResp.ID, provider.GetProviderKey())
|
|
bifrostResp.ExtraFields.Latency = latency.Milliseconds()
|
|
bifrostResp.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
if sendBackRawResponse {
|
|
bifrostResp.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return bifrostResp, nil
|
|
}
|
|
|
|
// VideoDownload downloads the generated video content.
|
|
// First retrieves the video status to get the URL, then downloads the content.
|
|
// Handles both regular URLs and data URLs (base64-encoded videos).
|
|
func (provider *VertexProvider) VideoDownload(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostVideoDownloadRequest) (*schemas.BifrostVideoDownloadResponse, *schemas.BifrostError) {
|
|
if request == nil || request.ID == "" {
|
|
return nil, providerUtils.NewBifrostOperationError("video_id is required", nil)
|
|
}
|
|
// Retrieve operation first to get the video URL
|
|
bifrostVideoRetrieveRequest := &schemas.BifrostVideoRetrieveRequest{
|
|
Provider: request.Provider,
|
|
ID: request.ID,
|
|
}
|
|
videoResp, bifrostErr := provider.VideoRetrieve(ctx, key, bifrostVideoRetrieveRequest)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
if videoResp.Status != schemas.VideoStatusCompleted {
|
|
return nil, providerUtils.NewBifrostOperationError(
|
|
fmt.Sprintf("video not ready, current status: %s", videoResp.Status),
|
|
nil)
|
|
}
|
|
if len(videoResp.Videos) == 0 {
|
|
return nil, providerUtils.NewBifrostOperationError("video URL not available", nil)
|
|
}
|
|
var content []byte
|
|
var latency time.Duration
|
|
var providerResponseHeaders map[string]string
|
|
contentType := "video/mp4"
|
|
// Check if it's a data URL (base64-encoded video)
|
|
if videoResp.Videos[0].Type == schemas.VideoOutputTypeBase64 && videoResp.Videos[0].Base64Data != nil {
|
|
// Decode base64 content
|
|
startTime := time.Now()
|
|
decoded, err := base64.StdEncoding.DecodeString(*videoResp.Videos[0].Base64Data)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to decode base64 video data", err)
|
|
}
|
|
content = decoded
|
|
contentType = videoResp.Videos[0].ContentType
|
|
latency = time.Since(startTime)
|
|
} else if videoResp.Videos[0].Type == schemas.VideoOutputTypeURL && videoResp.Videos[0].URL != nil {
|
|
// Regular URL - fetch from HTTP endpoint
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, nil)
|
|
req.SetRequestURI(*videoResp.Videos[0].URL)
|
|
req.Header.SetMethod(http.MethodGet)
|
|
// Add authentication for Vertex video downloads
|
|
authQuery := ""
|
|
if key.Value.GetValue() != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
|
|
}
|
|
if authQuery != "" {
|
|
uri := *videoResp.Videos[0].URL
|
|
if strings.Contains(uri, "?") {
|
|
uri += "&" + authQuery
|
|
} else {
|
|
uri += "?" + authQuery
|
|
}
|
|
req.SetRequestURI(uri)
|
|
} else {
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
}
|
|
var bifrostErr *schemas.BifrostError
|
|
var wait func()
|
|
latency, bifrostErr, wait = providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
return nil, providerUtils.NewBifrostOperationError(
|
|
fmt.Sprintf("failed to download video: HTTP %d", resp.StatusCode()),
|
|
nil)
|
|
}
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
|
}
|
|
contentType = string(resp.Header.ContentType())
|
|
content = append([]byte(nil), body...)
|
|
providerResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
} else {
|
|
return nil, providerUtils.NewBifrostOperationError("invalid video output type", nil)
|
|
}
|
|
|
|
bifrostResp := &schemas.BifrostVideoDownloadResponse{
|
|
VideoID: request.ID,
|
|
Content: content,
|
|
ContentType: contentType,
|
|
}
|
|
|
|
bifrostResp.ExtraFields.Latency = latency.Milliseconds()
|
|
bifrostResp.ExtraFields.ProviderResponseHeaders = providerResponseHeaders
|
|
|
|
return bifrostResp, nil
|
|
}
|
|
|
|
// VideoDelete is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) VideoDelete(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoDeleteRequest) (*schemas.BifrostVideoDeleteResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoDeleteRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// VideoList is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) VideoList(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoListRequest) (*schemas.BifrostVideoListResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoListRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// VideoRemix is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) VideoRemix(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostVideoRemixRequest) (*schemas.BifrostVideoGenerationResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.VideoRemixRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// stripVertexGeminiUnsupportedFields removes fields that are not supported by Vertex AI's Gemini API.
|
|
// Specifically, it removes the "id" field from function_call and function_response objects in contents.
|
|
func stripVertexGeminiUnsupportedFields(requestBody *gemini.GeminiGenerationRequest) {
|
|
for _, content := range requestBody.Contents {
|
|
for _, part := range content.Parts {
|
|
// Remove id from function_call
|
|
if part.FunctionCall != nil {
|
|
part.FunctionCall.ID = ""
|
|
}
|
|
// Remove id from function_response
|
|
if part.FunctionResponse != nil {
|
|
part.FunctionResponse.ID = ""
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// BatchCreate is not supported by Vertex AI provider.
|
|
func (provider *VertexProvider) BatchCreate(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostBatchCreateRequest) (*schemas.BifrostBatchCreateResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCreateRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// BatchList is not supported by Vertex AI provider.
|
|
func (provider *VertexProvider) BatchList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchListRequest) (*schemas.BifrostBatchListResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchListRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// BatchRetrieve is not supported by Vertex AI provider.
|
|
func (provider *VertexProvider) BatchRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchRetrieveRequest) (*schemas.BifrostBatchRetrieveResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchRetrieveRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// BatchCancel is not supported by Vertex AI provider.
|
|
func (provider *VertexProvider) BatchCancel(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchCancelRequest) (*schemas.BifrostBatchCancelResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchCancelRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// BatchDelete is not supported by Vertex AI provider.
|
|
func (provider *VertexProvider) BatchDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchDeleteRequest) (*schemas.BifrostBatchDeleteResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchDeleteRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// BatchResults is not supported by Vertex AI provider.
|
|
func (provider *VertexProvider) BatchResults(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostBatchResultsRequest) (*schemas.BifrostBatchResultsResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.BatchResultsRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// FileUpload is not yet implemented for Vertex AI provider.
|
|
// Vertex AI uses Google Cloud Storage (GCS) for batch input/output files.
|
|
func (provider *VertexProvider) FileUpload(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostFileUploadRequest) (*schemas.BifrostFileUploadResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileUploadRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// FileList is not yet implemented for Vertex AI provider.
|
|
func (provider *VertexProvider) FileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileListRequest) (*schemas.BifrostFileListResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileListRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// FileRetrieve is not yet implemented for Vertex AI provider.
|
|
func (provider *VertexProvider) FileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileRetrieveRequest) (*schemas.BifrostFileRetrieveResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileRetrieveRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// FileDelete is not yet implemented for Vertex AI provider.
|
|
func (provider *VertexProvider) FileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileDeleteRequest) (*schemas.BifrostFileDeleteResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileDeleteRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// FileContent is not yet implemented for Vertex AI provider.
|
|
func (provider *VertexProvider) FileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostFileContentRequest) (*schemas.BifrostFileContentResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.FileContentRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// CountTokens counts the number of tokens in the provided content using Vertex AI's countTokens endpoint.
|
|
// Supports Gemini models with both text and image content.
|
|
func (provider *VertexProvider) CountTokens(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostCountTokensResponse, *schemas.BifrostError) {
|
|
var (
|
|
jsonBody []byte
|
|
bifrostErr *schemas.BifrostError
|
|
)
|
|
|
|
if schemas.IsAnthropicModel(request.Model) {
|
|
jsonBody, bifrostErr = getRequestBodyForAnthropicResponses(ctx, request, request.Model, false, true, provider.networkConfig.BetaHeaderOverrides, provider.networkConfig.ExtraHeaders)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
} else {
|
|
jsonBody, bifrostErr = providerUtils.CheckContextAndGetRequestBody(
|
|
ctx,
|
|
request,
|
|
func() (providerUtils.RequestBodyWithExtraParams, error) {
|
|
return gemini.ToGeminiResponsesRequest(request)
|
|
},
|
|
)
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Skip field-stripping when large payload mode is active — jsonBody is nil
|
|
// and the raw body will stream directly from the ingress reader.
|
|
if jsonBody != nil {
|
|
// Use sjson to delete fields directly from JSON bytes, preserving key ordering
|
|
jsonBody, _ = providerUtils.DeleteJSONField(jsonBody, "toolConfig")
|
|
jsonBody, _ = providerUtils.DeleteJSONField(jsonBody, "generationConfig")
|
|
jsonBody, _ = providerUtils.DeleteJSONField(jsonBody, "systemInstruction")
|
|
}
|
|
}
|
|
|
|
projectID := key.VertexKeyConfig.ProjectID.GetValue()
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
region := key.VertexKeyConfig.Region.GetValue()
|
|
if region == "" {
|
|
return nil, providerUtils.NewConfigurationError("region is not set in key config")
|
|
}
|
|
|
|
authQuery := ""
|
|
var completeURL string
|
|
|
|
if schemas.IsAnthropicModel(request.Model) {
|
|
if region == "global" {
|
|
completeURL = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/count-tokens:rawPredict", projectID)
|
|
} else {
|
|
completeURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/count-tokens:rawPredict", region, projectID, region)
|
|
}
|
|
} else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) {
|
|
if key.Value.GetValue() != "" {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
|
|
}
|
|
|
|
projectNumber := key.VertexKeyConfig.ProjectNumber.GetValue()
|
|
if schemas.IsAllDigitsASCII(request.Model) && projectNumber == "" {
|
|
return nil, providerUtils.NewConfigurationError("project number is not set for fine-tuned models")
|
|
}
|
|
|
|
completeURL = getCompleteURLForGeminiEndpoint(request.Model, region, projectID, projectNumber, ":countTokens")
|
|
}
|
|
|
|
if completeURL == "" {
|
|
return nil, providerUtils.NewConfigurationError(fmt.Sprintf("count tokens is not supported for model: %s", request.Model))
|
|
}
|
|
|
|
req := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseRequest(req)
|
|
respOwned := true
|
|
defer func() {
|
|
if respOwned {
|
|
fasthttp.ReleaseResponse(resp)
|
|
}
|
|
}()
|
|
|
|
req.Header.SetMethod(http.MethodPost)
|
|
req.Header.SetContentType("application/json")
|
|
providerUtils.SetExtraHeaders(ctx, req, provider.networkConfig.ExtraHeaders, []string{anthropic.AnthropicBetaHeader})
|
|
|
|
if authQuery != "" {
|
|
completeURL = fmt.Sprintf("%s?%s", completeURL, authQuery)
|
|
} else {
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
}
|
|
|
|
req.SetRequestURI(completeURL)
|
|
usedLargePayloadBody := providerUtils.ApplyLargePayloadRequestBody(ctx, req)
|
|
if !usedLargePayloadBody {
|
|
req.SetBody(jsonBody)
|
|
}
|
|
|
|
// Make the request with optional large response streaming
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.client, resp)
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, activeClient, req, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if usedLargePayloadBody {
|
|
providerUtils.DrainLargePayloadRemainder(ctx)
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, providerUtils.ExtractProviderResponseHeaders(resp))
|
|
|
|
if resp.StatusCode() != fasthttp.StatusOK {
|
|
providerUtils.MaterializeStreamErrorBody(ctx, resp)
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
return nil, providerUtils.EnrichError(ctx, parseVertexError(resp), jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
responseBody, isLargeResp, decodeErr := providerUtils.FinalizeResponseWithLargeDetection(ctx, resp, provider.logger)
|
|
if decodeErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, decodeErr, jsonBody, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
if isLargeResp {
|
|
respOwned = false
|
|
return &schemas.BifrostCountTokensResponse{
|
|
ExtraFields: schemas.BifrostResponseExtraFields{
|
|
Latency: latency.Milliseconds(),
|
|
ProviderResponseHeaders: providerUtils.ExtractProviderResponseHeaders(resp),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
if schemas.IsAnthropicModel(request.Model) {
|
|
anthropicResponse := &anthropic.AnthropicCountTokensResponse{}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, anthropicResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
response := anthropicResponse.ToBifrostCountTokensResponse(request.Model)
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
vertexResponse := VertexCountTokensResponse{}
|
|
|
|
rawRequest, rawResponse, bifrostErr := providerUtils.HandleProviderResponse(responseBody, &vertexResponse, jsonBody, providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest), providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse))
|
|
if bifrostErr != nil {
|
|
return nil, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, responseBody, provider.sendBackRawRequest, provider.sendBackRawResponse)
|
|
}
|
|
|
|
response := vertexResponse.ToBifrostCountTokensResponse(request.Model)
|
|
response.ExtraFields.Latency = latency.Milliseconds()
|
|
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
response.ExtraFields.RawRequest = rawRequest
|
|
}
|
|
|
|
if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) {
|
|
response.ExtraFields.RawResponse = rawResponse
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// ContainerCreate is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) ContainerCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerCreateRequest) (*schemas.BifrostContainerCreateResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerCreateRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerList is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) ContainerList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerListRequest) (*schemas.BifrostContainerListResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerListRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerRetrieve is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) ContainerRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerRetrieveRequest) (*schemas.BifrostContainerRetrieveResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerRetrieveRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerDelete is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) ContainerDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerDeleteRequest) (*schemas.BifrostContainerDeleteResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerDeleteRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerFileCreate is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) ContainerFileCreate(_ *schemas.BifrostContext, _ schemas.Key, _ *schemas.BifrostContainerFileCreateRequest) (*schemas.BifrostContainerFileCreateResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileCreateRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerFileList is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) ContainerFileList(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileListRequest) (*schemas.BifrostContainerFileListResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileListRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerFileRetrieve is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) ContainerFileRetrieve(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileRetrieveRequest) (*schemas.BifrostContainerFileRetrieveResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileRetrieveRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerFileContent is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) ContainerFileContent(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileContentRequest) (*schemas.BifrostContainerFileContentResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileContentRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
// ContainerFileDelete is not supported by the Vertex provider.
|
|
func (provider *VertexProvider) ContainerFileDelete(_ *schemas.BifrostContext, _ []schemas.Key, _ *schemas.BifrostContainerFileDeleteRequest) (*schemas.BifrostContainerFileDeleteResponse, *schemas.BifrostError) {
|
|
return nil, providerUtils.NewUnsupportedOperationError(schemas.ContainerFileDeleteRequest, provider.GetProviderKey())
|
|
}
|
|
|
|
func (provider *VertexProvider) Passthrough(
|
|
ctx *schemas.BifrostContext,
|
|
key schemas.Key,
|
|
req *schemas.BifrostPassthroughRequest,
|
|
) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) {
|
|
projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue())
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
keyRegion := key.VertexKeyConfig.Region.GetValue()
|
|
if keyRegion == "" {
|
|
keyRegion = "global"
|
|
}
|
|
|
|
var baseURL string
|
|
if keyRegion == "global" {
|
|
baseURL = "https://aiplatform.googleapis.com/v1"
|
|
} else {
|
|
baseURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1", keyRegion)
|
|
}
|
|
|
|
// Normalize path: remove leading /v1 or /v1/ to avoid duplicate version segments (e.g. /v1/v1/...)
|
|
path := req.Path
|
|
for strings.HasPrefix(path, "/v1/") || path == "/v1" {
|
|
path = strings.TrimPrefix(path, "/v1/")
|
|
path = strings.TrimPrefix(path, "/v1")
|
|
if path != "" && !strings.HasPrefix(path, "/") {
|
|
path = "/" + path
|
|
}
|
|
}
|
|
|
|
// Replace region in path with key's configured region (client path may have different region)
|
|
if strings.Contains(path, "/locations/") {
|
|
path = vertexLocationsPathRe.ReplaceAllString(path, "/locations/"+keyRegion)
|
|
if strings.Contains(path, "/projects/") {
|
|
path = vertexProjectsPathRe.ReplaceAllString(path, "/projects/"+projectID)
|
|
}
|
|
} else {
|
|
// add projects/%s/locations/%s/publishers/google to path
|
|
path = fmt.Sprintf("/projects/%s/locations/%s%s", projectID, keyRegion, path)
|
|
}
|
|
|
|
requestURL := baseURL + path
|
|
if req.RawQuery != "" {
|
|
requestURL += "?" + req.RawQuery
|
|
}
|
|
|
|
// Only use API key for Google publisher endpoints; Anthropic/Mistral/OpenAPI-style paths require OAuth.
|
|
authQuery := ""
|
|
if key.Value.GetValue() != "" && strings.Contains(path, "publishers/google") {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
|
|
}
|
|
|
|
// Prepare fasthttp request
|
|
fasthttpReq := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
defer fasthttp.ReleaseResponse(resp)
|
|
defer fasthttp.ReleaseRequest(fasthttpReq)
|
|
|
|
fasthttpReq.Header.SetMethod(req.Method)
|
|
|
|
// If auth query is set, add it to the URL; otherwise use OAuth2
|
|
if authQuery != "" {
|
|
if strings.Contains(requestURL, "?") {
|
|
requestURL += "&" + authQuery
|
|
} else {
|
|
requestURL += "?" + authQuery
|
|
}
|
|
} else {
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
fasthttpReq.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
}
|
|
|
|
fasthttpReq.SetRequestURI(requestURL)
|
|
|
|
// Set extra headers from provider network config
|
|
providerUtils.SetExtraHeaders(ctx, fasthttpReq, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
// Set safe headers from client request
|
|
for k, v := range req.SafeHeaders {
|
|
if strings.EqualFold(k, "authorization") || strings.EqualFold(k, "proxy-authorization") {
|
|
continue
|
|
}
|
|
fasthttpReq.Header.Set(k, v)
|
|
}
|
|
|
|
if len(req.Body) > 0 && strings.Contains(strings.ToLower(string(fasthttpReq.Header.ContentType())), "application/json") {
|
|
region := keyRegion
|
|
// Replace fully-qualified model paths that have placeholder project/location
|
|
// e.g. "projects/None/locations/None/publishers/..." -> "projects/real-id/locations/real-region/..."
|
|
body := req.Body
|
|
bodyStr := vertexBodyProjectsRe.ReplaceAllString(string(body), "${1}projects/"+projectID)
|
|
bodyStr = vertexLocationsPathRe.ReplaceAllString(bodyStr, "/locations/"+region)
|
|
// Expand short-form model names: "models/X" -> "projects/P/locations/L/publishers/google/models/X"
|
|
bodyStr = vertexShortModelRe.ReplaceAllString(bodyStr,
|
|
fmt.Sprintf(`"projects/%s/locations/%s/publishers/google/$1"`, projectID, keyRegion))
|
|
fasthttpReq.SetBodyString(bodyStr)
|
|
} else if len(req.Body) > 0 {
|
|
fasthttpReq.SetBody(req.Body)
|
|
}
|
|
|
|
// Execute request
|
|
latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, fasthttpReq, resp)
|
|
defer wait()
|
|
if bifrostErr != nil {
|
|
return nil, bifrostErr
|
|
}
|
|
|
|
// Remove client from pool for authentication/authorization errors
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
|
|
headers := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers)
|
|
|
|
body, err := providerUtils.CheckAndDecodeBody(resp)
|
|
if err != nil {
|
|
return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err)
|
|
}
|
|
for k := range headers {
|
|
if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") {
|
|
delete(headers, k)
|
|
}
|
|
}
|
|
bifrostResponse := &schemas.BifrostPassthroughResponse{
|
|
StatusCode: resp.StatusCode(),
|
|
Headers: headers,
|
|
Body: body,
|
|
}
|
|
|
|
bifrostResponse.ExtraFields.ProviderResponseHeaders = headers
|
|
bifrostResponse.ExtraFields.Latency = latency.Milliseconds()
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &bifrostResponse.ExtraFields)
|
|
}
|
|
|
|
return bifrostResponse, nil
|
|
}
|
|
|
|
func (provider *VertexProvider) PassthroughStream(
|
|
ctx *schemas.BifrostContext,
|
|
postHookRunner schemas.PostHookRunner,
|
|
postHookSpanFinalizer func(context.Context),
|
|
key schemas.Key,
|
|
req *schemas.BifrostPassthroughRequest,
|
|
) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
|
|
projectID := strings.TrimSpace(key.VertexKeyConfig.ProjectID.GetValue())
|
|
if projectID == "" {
|
|
return nil, providerUtils.NewConfigurationError("project ID is not set")
|
|
}
|
|
|
|
keyRegion := key.VertexKeyConfig.Region.GetValue()
|
|
if keyRegion == "" {
|
|
keyRegion = "global"
|
|
}
|
|
|
|
var baseURL string
|
|
if keyRegion == "global" {
|
|
baseURL = "https://aiplatform.googleapis.com/v1"
|
|
} else {
|
|
baseURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1", keyRegion)
|
|
}
|
|
|
|
// Normalize path: remove leading /v1 or /v1/ to avoid duplicate version segments.
|
|
path := req.Path
|
|
for strings.HasPrefix(path, "/v1/") || path == "/v1" {
|
|
path = strings.TrimPrefix(path, "/v1/")
|
|
path = strings.TrimPrefix(path, "/v1")
|
|
if path != "" && !strings.HasPrefix(path, "/") {
|
|
path = "/" + path
|
|
}
|
|
}
|
|
|
|
// Replace region and project in path with key's configured values.
|
|
if strings.Contains(path, "/locations/") {
|
|
path = vertexLocationsPathRe.ReplaceAllString(path, "/locations/"+keyRegion)
|
|
if strings.Contains(path, "/projects/") {
|
|
path = vertexProjectsPathRe.ReplaceAllString(path, "/projects/"+projectID)
|
|
}
|
|
} else {
|
|
path = fmt.Sprintf("/projects/%s/locations/%s%s", projectID, keyRegion, path)
|
|
}
|
|
|
|
requestURL := baseURL + path
|
|
if req.RawQuery != "" {
|
|
requestURL += "?" + req.RawQuery
|
|
}
|
|
|
|
fasthttpReq := fasthttp.AcquireRequest()
|
|
resp := fasthttp.AcquireResponse()
|
|
resp.StreamBody = true
|
|
defer fasthttp.ReleaseRequest(fasthttpReq)
|
|
|
|
fasthttpReq.Header.SetMethod(req.Method)
|
|
|
|
// Only use API key for Google publisher endpoints; Anthropic/Mistral/OpenAPI-style paths require OAuth.
|
|
authQuery := ""
|
|
if key.Value.GetValue() != "" && strings.Contains(path, "publishers/google") {
|
|
authQuery = fmt.Sprintf("key=%s", url.QueryEscape(key.Value.GetValue()))
|
|
}
|
|
|
|
if authQuery != "" {
|
|
if strings.Contains(requestURL, "?") {
|
|
requestURL += "&" + authQuery
|
|
} else {
|
|
requestURL += "?" + authQuery
|
|
}
|
|
} else {
|
|
tokenSource, err := getAuthTokenSource(key)
|
|
if err != nil {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
providerUtils.ReleaseStreamingResponse(resp)
|
|
return nil, providerUtils.NewBifrostOperationError("error creating auth token source", err)
|
|
}
|
|
token, err := tokenSource.Token()
|
|
if err != nil {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
providerUtils.ReleaseStreamingResponse(resp)
|
|
return nil, providerUtils.NewBifrostOperationError("error getting token", err)
|
|
}
|
|
fasthttpReq.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
|
}
|
|
|
|
fasthttpReq.SetRequestURI(requestURL)
|
|
|
|
providerUtils.SetExtraHeaders(ctx, fasthttpReq, provider.networkConfig.ExtraHeaders, nil)
|
|
|
|
for k, v := range req.SafeHeaders {
|
|
if strings.EqualFold(k, "authorization") || strings.EqualFold(k, "proxy-authorization") {
|
|
continue
|
|
}
|
|
fasthttpReq.Header.Set(k, v)
|
|
}
|
|
|
|
if len(req.Body) > 0 && strings.Contains(strings.ToLower(string(fasthttpReq.Header.ContentType())), "application/json") {
|
|
bodyStr := vertexBodyProjectsRe.ReplaceAllString(string(req.Body), "${1}projects/"+projectID)
|
|
bodyStr = vertexLocationsPathRe.ReplaceAllString(bodyStr, "/locations/"+keyRegion)
|
|
bodyStr = vertexShortModelRe.ReplaceAllString(bodyStr,
|
|
fmt.Sprintf(`"projects/%s/locations/%s/publishers/google/$1"`, projectID, keyRegion))
|
|
fasthttpReq.SetBodyString(bodyStr)
|
|
} else if len(req.Body) > 0 {
|
|
fasthttpReq.SetBody(req.Body)
|
|
}
|
|
|
|
activeClient := providerUtils.PrepareResponseStreaming(ctx, provider.streamingClient, resp)
|
|
if err := activeClient.Do(fasthttpReq, resp); err != nil {
|
|
providerUtils.ReleaseStreamingResponse(resp)
|
|
if errors.Is(err, context.Canceled) {
|
|
return nil, &schemas.BifrostError{
|
|
IsBifrostError: false,
|
|
Error: &schemas.ErrorField{
|
|
Type: schemas.Ptr(schemas.RequestCancelled),
|
|
Message: schemas.ErrRequestCancelled,
|
|
Error: err,
|
|
},
|
|
}
|
|
}
|
|
if errors.Is(err, fasthttp.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) {
|
|
return nil, providerUtils.NewBifrostTimeoutError(schemas.ErrProviderRequestTimedOut, err)
|
|
}
|
|
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderDoRequest, err)
|
|
}
|
|
|
|
if resp.StatusCode() == fasthttp.StatusUnauthorized || resp.StatusCode() == fasthttp.StatusForbidden {
|
|
removeVertexClient(key.VertexKeyConfig.AuthCredentials.GetValue())
|
|
}
|
|
|
|
headers := providerUtils.ExtractProviderResponseHeaders(resp)
|
|
ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers)
|
|
|
|
bodyStream := resp.BodyStream()
|
|
if bodyStream == nil {
|
|
providerUtils.ReleaseStreamingResponse(resp)
|
|
return nil, providerUtils.NewBifrostOperationError(
|
|
"provider returned an empty stream body",
|
|
fmt.Errorf("provider returned an empty stream body"))
|
|
}
|
|
|
|
// Set stream idle timeout from provider config.
|
|
providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)
|
|
|
|
// Wrap body with idle timeout to detect stalled streams.
|
|
rawBodyStream := bodyStream
|
|
bodyStream, stopIdleTimeout := providerUtils.NewIdleTimeoutReader(bodyStream, rawBodyStream, providerUtils.GetStreamIdleTimeout(ctx))
|
|
|
|
// Cancellation must close the raw stream to unblock reads.
|
|
stopCancellation := providerUtils.SetupStreamCancellation(ctx, rawBodyStream, provider.logger)
|
|
|
|
extraFields := schemas.BifrostResponseExtraFields{}
|
|
statusCode := resp.StatusCode()
|
|
|
|
if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) {
|
|
providerUtils.ParseAndSetRawRequestIfJSON(fasthttpReq, &extraFields)
|
|
}
|
|
|
|
ch := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)
|
|
go func() {
|
|
defer providerUtils.EnsureStreamFinalizerCalled(ctx, postHookSpanFinalizer)
|
|
defer func() {
|
|
if ctx.Err() == context.Canceled {
|
|
providerUtils.HandleStreamCancellation(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer)
|
|
} else if ctx.Err() == context.DeadlineExceeded {
|
|
providerUtils.HandleStreamTimeout(ctx, postHookRunner, ch, provider.logger, postHookSpanFinalizer)
|
|
}
|
|
close(ch)
|
|
}()
|
|
defer providerUtils.ReleaseStreamingResponse(resp)
|
|
defer stopIdleTimeout()
|
|
defer stopCancellation()
|
|
streamStart := time.Now()
|
|
|
|
fullResponseBody := make([]byte, 0, maxStreamPassthroughCaptureBytes)
|
|
fullResponseBodyTruncated := false
|
|
terminalDetector := &providerUtils.StreamTerminalDetector{}
|
|
buf := make([]byte, 4096)
|
|
for {
|
|
n, readErr := bodyStream.Read(buf)
|
|
if n > 0 {
|
|
chunk := make([]byte, n)
|
|
copy(chunk, buf[:n])
|
|
if !fullResponseBodyTruncated {
|
|
remaining := maxStreamPassthroughCaptureBytes - len(fullResponseBody)
|
|
if remaining > 0 {
|
|
if n <= remaining {
|
|
fullResponseBody = append(fullResponseBody, chunk...)
|
|
} else {
|
|
fullResponseBody = append(fullResponseBody, chunk[:remaining]...)
|
|
fullResponseBodyTruncated = true
|
|
}
|
|
} else {
|
|
fullResponseBodyTruncated = true
|
|
}
|
|
}
|
|
select {
|
|
case ch <- &schemas.BifrostStreamChunk{
|
|
BifrostPassthroughResponse: &schemas.BifrostPassthroughResponse{
|
|
StatusCode: statusCode,
|
|
Headers: headers,
|
|
Body: chunk,
|
|
ExtraFields: extraFields,
|
|
},
|
|
}:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
|
|
// Vertex streamGenerateContent passthrough can emit terminal markers
|
|
// (finishReason) before the underlying HTTP body is closed.
|
|
// Finalize as success once this appears to avoid hanging clients.
|
|
if terminalDetector.ObserveChunk(chunk) {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
extraFields.Latency = time.Since(streamStart).Milliseconds()
|
|
var capturedBody []byte
|
|
if !fullResponseBodyTruncated {
|
|
capturedBody = append([]byte(nil), fullResponseBody...)
|
|
}
|
|
finalResp := &schemas.BifrostResponse{
|
|
PassthroughResponse: &schemas.BifrostPassthroughResponse{
|
|
StatusCode: statusCode,
|
|
Headers: headers,
|
|
Body: capturedBody,
|
|
ExtraFields: extraFields,
|
|
},
|
|
}
|
|
postHookRunner(ctx, finalResp, nil)
|
|
return
|
|
}
|
|
}
|
|
if readErr == io.EOF {
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
extraFields.Latency = time.Since(streamStart).Milliseconds()
|
|
var capturedBody []byte
|
|
if !fullResponseBodyTruncated {
|
|
capturedBody = append([]byte(nil), fullResponseBody...)
|
|
}
|
|
finalResp := &schemas.BifrostResponse{
|
|
PassthroughResponse: &schemas.BifrostPassthroughResponse{
|
|
StatusCode: statusCode,
|
|
Headers: headers,
|
|
Body: capturedBody,
|
|
ExtraFields: extraFields,
|
|
},
|
|
}
|
|
postHookRunner(ctx, finalResp, nil)
|
|
return
|
|
}
|
|
if readErr != nil {
|
|
if ctx.Err() != nil {
|
|
return // let defer handle cancel/timeout
|
|
}
|
|
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
|
|
extraFields.Latency = time.Since(streamStart).Milliseconds()
|
|
providerUtils.ProcessAndSendError(ctx, postHookRunner, readErr, ch, provider.logger, postHookSpanFinalizer)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
return ch, nil
|
|
}
|