first commit
This commit is contained in:
256
core/providers/vertex/utils.go
Normal file
256
core/providers/vertex/utils.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/anthropic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool, isCountTokens bool, betaHeaderOverrides map[string]bool, providerExtraHeaders map[string]string) ([]byte, *schemas.BifrostError) {
|
||||
// Large payload mode: body streams directly from the LP reader — skip all body building
|
||||
// (matches CheckContextAndGetRequestBody guard).
|
||||
if providerUtils.IsLargePayloadPassthroughEnabled(ctx) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var jsonBody []byte
|
||||
var err error
|
||||
|
||||
// Check if raw request body should be used
|
||||
if useRawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && useRawBody {
|
||||
jsonBody = request.GetRawRequestBody()
|
||||
|
||||
if isCountTokens {
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
} else {
|
||||
// Add max_tokens if not present
|
||||
if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") {
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens))
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
// Add stream if streaming
|
||||
if isStreaming {
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
|
||||
// Remap unsupported tool versions for Vertex (e.g., web_search_20260209 → web_search_20250305)
|
||||
jsonBody, err = anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(err.Error(), nil)
|
||||
}
|
||||
|
||||
// Add anthropic_version if not present
|
||||
if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") {
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Validate tools are supported by Vertex
|
||||
if request.Params != nil && request.Params.Tools != nil {
|
||||
if toolErr := anthropic.ValidateToolsForProvider(request.Params.Tools, schemas.Vertex); toolErr != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert request to Anthropic format
|
||||
reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request)
|
||||
if convErr != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr)
|
||||
}
|
||||
if reqBody == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil)
|
||||
}
|
||||
reqBody.Model = deployment
|
||||
|
||||
if isStreaming {
|
||||
reqBody.Stream = schemas.Ptr(true)
|
||||
}
|
||||
|
||||
reqBody.SetStripCacheControlScope(true)
|
||||
|
||||
// Add provider-aware beta headers
|
||||
anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex)
|
||||
|
||||
// Marshal struct to JSON bytes
|
||||
jsonBody, err = providerUtils.MarshalSorted(reqBody)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
|
||||
// Add anthropic_version if not present (using sjson to preserve order)
|
||||
if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") {
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
|
||||
if isCountTokens {
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
} else {
|
||||
// Remove model field for Vertex API (it's in URL)
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
|
||||
if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(providerExtraHeaders, ctx), schemas.Vertex, betaHeaderOverrides); len(betaHeaders) > 0 {
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_beta", betaHeaders)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
|
||||
return jsonBody, nil
|
||||
}
|
||||
|
||||
// getCompleteURLForGeminiEndpoint constructs the complete URL for the Gemini endpoint, for both streaming and non-streaming requests
|
||||
// for custom/fine-tuned models, it uses the projectNumber
|
||||
// for gemini models, it uses the projectID
|
||||
func getCompleteURLForGeminiEndpoint(deployment string, region string, projectID string, projectNumber string, method string) string {
|
||||
var url string
|
||||
if schemas.IsAllDigitsASCII(deployment) {
|
||||
// Custom/fine-tuned models use projectNumber
|
||||
if region == "global" {
|
||||
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s%s", projectNumber, deployment, method)
|
||||
} else {
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s%s", region, projectNumber, region, deployment, method)
|
||||
}
|
||||
} else {
|
||||
// Gemini models use projectID
|
||||
if region == "global" {
|
||||
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s%s", projectID, deployment, method)
|
||||
} else {
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s%s", region, projectID, region, deployment, method)
|
||||
}
|
||||
}
|
||||
return url
|
||||
}
|
||||
|
||||
// buildResponseFromConfig builds a list models response from configured deployments and allowedModels.
|
||||
// This is used when the user has explicitly configured which models they want to use.
|
||||
func buildResponseFromConfig(deployments map[string]string, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList) *schemas.BifrostListModelsResponse {
|
||||
response := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0),
|
||||
}
|
||||
|
||||
if blacklistedModels.IsBlockAll() {
|
||||
return response
|
||||
}
|
||||
|
||||
addedModelIDs := make(map[string]bool)
|
||||
|
||||
restrictAllowed := allowedModels.IsRestricted()
|
||||
|
||||
// First add models from deployments (filtered by allowedModels when set)
|
||||
for alias, deploymentValue := range deployments {
|
||||
if restrictAllowed && !allowedModels.Contains(alias) {
|
||||
continue
|
||||
}
|
||||
if blacklistedModels.IsBlocked(alias) {
|
||||
continue
|
||||
}
|
||||
modelID := string(schemas.Vertex) + "/" + alias
|
||||
if addedModelIDs[modelID] {
|
||||
continue
|
||||
}
|
||||
|
||||
modelName := providerUtils.ToDisplayName(alias)
|
||||
modelEntry := schemas.Model{
|
||||
ID: modelID,
|
||||
Name: schemas.Ptr(modelName),
|
||||
Alias: schemas.Ptr(deploymentValue),
|
||||
}
|
||||
|
||||
response.Data = append(response.Data, modelEntry)
|
||||
addedModelIDs[modelID] = true
|
||||
}
|
||||
|
||||
// Then add models from allowedModels that aren't already in deployments (only when restricted)
|
||||
if !restrictAllowed {
|
||||
return response
|
||||
}
|
||||
for _, allowedModel := range allowedModels {
|
||||
modelID := string(schemas.Vertex) + "/" + allowedModel
|
||||
if addedModelIDs[modelID] {
|
||||
continue
|
||||
}
|
||||
if blacklistedModels.IsBlocked(allowedModel) {
|
||||
continue
|
||||
}
|
||||
|
||||
modelName := providerUtils.ToDisplayName(allowedModel)
|
||||
modelEntry := schemas.Model{
|
||||
ID: modelID,
|
||||
Name: schemas.Ptr(modelName),
|
||||
}
|
||||
|
||||
response.Data = append(response.Data, modelEntry)
|
||||
addedModelIDs[modelID] = true
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// extractModelIDFromName extracts the model ID from a full resource name.
|
||||
// Format: "publishers/google/models/gemini-1.5-pro" -> "gemini-1.5-pro"
|
||||
func extractModelIDFromName(name string) string {
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) >= 4 && parts[2] == "models" {
|
||||
return parts[3]
|
||||
}
|
||||
// Fallback: return last segment
|
||||
if len(parts) > 0 {
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
Reference in New Issue
Block a user