Files
bifrost/core/providers/vertex/utils.go
Beyhan Oğur 880f412e2c first commit
2026-04-26 21:52:23 +03:00

257 lines
9.4 KiB
Go

package vertex
import (
"fmt"
"strings"
"github.com/maximhq/bifrost/core/providers/anthropic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool, isCountTokens bool, betaHeaderOverrides map[string]bool, providerExtraHeaders map[string]string) ([]byte, *schemas.BifrostError) {
// Large payload mode: body streams directly from the LP reader — skip all body building
// (matches CheckContextAndGetRequestBody guard).
if providerUtils.IsLargePayloadPassthroughEnabled(ctx) {
return nil, nil
}
var jsonBody []byte
var err error
// Check if raw request body should be used
if useRawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && useRawBody {
jsonBody = request.GetRawRequestBody()
if isCountTokens {
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
} else {
// Add max_tokens if not present
if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") {
jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens))
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
// Add stream if streaming
if isStreaming {
jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
}
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
// Remap unsupported tool versions for Vertex (e.g., web_search_20260209 → web_search_20250305)
jsonBody, err = anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(err.Error(), nil)
}
// Add anthropic_version if not present
if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") {
jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
} else {
// Validate tools are supported by Vertex
if request.Params != nil && request.Params.Tools != nil {
if toolErr := anthropic.ValidateToolsForProvider(request.Params.Tools, schemas.Vertex); toolErr != nil {
return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil)
}
}
// Convert request to Anthropic format
reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request)
if convErr != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr)
}
if reqBody == nil {
return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil)
}
reqBody.Model = deployment
if isStreaming {
reqBody.Stream = schemas.Ptr(true)
}
reqBody.SetStripCacheControlScope(true)
// Add provider-aware beta headers
anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex)
// Marshal struct to JSON bytes
jsonBody, err = providerUtils.MarshalSorted(reqBody)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
// Add anthropic_version if not present (using sjson to preserve order)
if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") {
jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
if isCountTokens {
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
} else {
// Remove model field for Vertex API (it's in URL)
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(providerExtraHeaders, ctx), schemas.Vertex, betaHeaderOverrides); len(betaHeaders) > 0 {
jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_beta", betaHeaders)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
return jsonBody, nil
}
// getCompleteURLForGeminiEndpoint constructs the complete URL for the Gemini endpoint, for both streaming and non-streaming requests
// for custom/fine-tuned models, it uses the projectNumber
// for gemini models, it uses the projectID
func getCompleteURLForGeminiEndpoint(deployment string, region string, projectID string, projectNumber string, method string) string {
var url string
if schemas.IsAllDigitsASCII(deployment) {
// Custom/fine-tuned models use projectNumber
if region == "global" {
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s%s", projectNumber, deployment, method)
} else {
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s%s", region, projectNumber, region, deployment, method)
}
} else {
// Gemini models use projectID
if region == "global" {
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s%s", projectID, deployment, method)
} else {
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s%s", region, projectID, region, deployment, method)
}
}
return url
}
// buildResponseFromConfig builds a list models response from configured deployments and allowedModels.
// This is used when the user has explicitly configured which models they want to use.
func buildResponseFromConfig(deployments map[string]string, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList) *schemas.BifrostListModelsResponse {
response := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0),
}
if blacklistedModels.IsBlockAll() {
return response
}
addedModelIDs := make(map[string]bool)
restrictAllowed := allowedModels.IsRestricted()
// First add models from deployments (filtered by allowedModels when set)
for alias, deploymentValue := range deployments {
if restrictAllowed && !allowedModels.Contains(alias) {
continue
}
if blacklistedModels.IsBlocked(alias) {
continue
}
modelID := string(schemas.Vertex) + "/" + alias
if addedModelIDs[modelID] {
continue
}
modelName := providerUtils.ToDisplayName(alias)
modelEntry := schemas.Model{
ID: modelID,
Name: schemas.Ptr(modelName),
Alias: schemas.Ptr(deploymentValue),
}
response.Data = append(response.Data, modelEntry)
addedModelIDs[modelID] = true
}
// Then add models from allowedModels that aren't already in deployments (only when restricted)
if !restrictAllowed {
return response
}
for _, allowedModel := range allowedModels {
modelID := string(schemas.Vertex) + "/" + allowedModel
if addedModelIDs[modelID] {
continue
}
if blacklistedModels.IsBlocked(allowedModel) {
continue
}
modelName := providerUtils.ToDisplayName(allowedModel)
modelEntry := schemas.Model{
ID: modelID,
Name: schemas.Ptr(modelName),
}
response.Data = append(response.Data, modelEntry)
addedModelIDs[modelID] = true
}
return response
}
// extractModelIDFromName extracts the model ID from a full resource name.
// Format: "publishers/google/models/gemini-1.5-pro" -> "gemini-1.5-pro"
func extractModelIDFromName(name string) string {
parts := strings.Split(name, "/")
if len(parts) >= 4 && parts[2] == "models" {
return parts[3]
}
// Fallback: return last segment
if len(parts) > 0 {
return parts[len(parts)-1]
}
return ""
}