first commit

This commit is contained in:
Beyhan Oğur
2026-04-26 21:52:23 +03:00
commit 880f412e2c
2662 changed files with 866266 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
package vertex
import (
"github.com/maximhq/bifrost/core/schemas"
)
func (resp *VertexCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
if resp == nil {
return nil
}
inputDetails := &schemas.ResponsesResponseInputTokens{}
inputTokens := int(resp.TotalTokens) // Vertex response typically represents prompt tokens for countTokens
total := int(resp.TotalTokens)
if resp.CachedContentTokenCount > 0 {
inputDetails.CachedReadTokens = int(resp.CachedContentTokenCount)
}
return &schemas.BifrostCountTokensResponse{
Model: model,
Object: "response.input_tokens",
InputTokens: inputTokens,
InputTokensDetails: inputDetails,
TotalTokens: &total,
}
}

View File

@@ -0,0 +1,115 @@
package vertex
import (
"github.com/maximhq/bifrost/core/schemas"
)
// ToVertexEmbeddingRequest converts a Bifrost embedding request to Vertex AI format
func ToVertexEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *VertexEmbeddingRequest {
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
return nil
}
// Create the request
vertexReq := &VertexEmbeddingRequest{}
if bifrostReq.Params != nil {
vertexReq.ExtraParams = bifrostReq.Params.ExtraParams
}
var texts []string
if bifrostReq.Input.Text != nil {
texts = []string{*bifrostReq.Input.Text}
} else {
texts = bifrostReq.Input.Texts
}
// Create instances for each text
instances := make([]VertexEmbeddingInstance, 0, len(texts))
for _, text := range texts {
instance := VertexEmbeddingInstance{
Content: text,
}
// Add optional task_type and title from params
if bifrostReq.Params != nil {
if taskTypeStr, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["task_type"]); ok {
delete(vertexReq.ExtraParams, "task_type")
instance.TaskType = taskTypeStr
}
if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok {
delete(vertexReq.ExtraParams, "title")
instance.Title = title
}
}
instances = append(instances, instance)
}
vertexReq.Instances = instances
// Add parameters if present
if bifrostReq.Params != nil {
parameters := &VertexEmbeddingParameters{}
// Set autoTruncate (defaults to true)
autoTruncate := true
if bifrostReq.Params.ExtraParams != nil {
if autoTruncateVal, ok := schemas.SafeExtractBool(bifrostReq.Params.ExtraParams["autoTruncate"]); ok {
delete(vertexReq.ExtraParams, "autoTruncate")
autoTruncate = autoTruncateVal
}
}
parameters.AutoTruncate = &autoTruncate
// Add outputDimensionality if specified
if bifrostReq.Params.Dimensions != nil {
delete(vertexReq.ExtraParams, "dimensions")
parameters.OutputDimensionality = bifrostReq.Params.Dimensions
}
vertexReq.Parameters = parameters
}
return vertexReq
}
// ToBifrostEmbeddingResponse converts a Vertex AI embedding response to Bifrost format
func (response *VertexEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse {
if response == nil || len(response.Predictions) == 0 {
return nil
}
// Convert predictions to Bifrost embeddings
embeddings := make([]schemas.EmbeddingData, 0, len(response.Predictions))
var usage *schemas.BifrostLLMUsage
for i, prediction := range response.Predictions {
if prediction.Embeddings == nil || len(prediction.Embeddings.Values) == 0 {
continue
}
// Create embedding object
embedding := schemas.EmbeddingData{
Object: "embedding",
Embedding: schemas.EmbeddingStruct{
EmbeddingArray: append([]float64(nil), prediction.Embeddings.Values...),
},
Index: i,
}
// Extract statistics if available
if prediction.Embeddings.Statistics != nil {
if usage == nil {
usage = &schemas.BifrostLLMUsage{}
}
usage.TotalTokens += prediction.Embeddings.Statistics.TokenCount
usage.PromptTokens += prediction.Embeddings.Statistics.TokenCount
}
embeddings = append(embeddings, embedding)
}
return &schemas.BifrostEmbeddingResponse{
Object: "list",
Data: embeddings,
Usage: usage,
ExtraFields: schemas.BifrostResponseExtraFields{
},
}
}

View File

@@ -0,0 +1,88 @@
package vertex
import (
"errors"
"strings"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
"github.com/valyala/fasthttp"
)
func parseVertexError(resp *fasthttp.Response) *schemas.BifrostError {
var openAIErr schemas.BifrostError
var vertexErr []VertexError
decodedBody, err := providerUtils.CheckAndDecodeBody(resp)
if err != nil {
bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
return bifrostErr
}
// Check for empty response
trimmed := strings.TrimSpace(string(decodedBody))
if len(trimmed) == 0 {
bifrostErr := &schemas.BifrostError{
IsBifrostError: false,
StatusCode: schemas.Ptr(resp.StatusCode()),
Error: &schemas.ErrorField{
Message: schemas.ErrProviderResponseEmpty,
},
}
return bifrostErr
}
// Check for HTML error response before attempting JSON parsing
if providerUtils.IsHTMLResponse(resp, decodedBody) {
bifrostErr := &schemas.BifrostError{
IsBifrostError: false,
StatusCode: schemas.Ptr(resp.StatusCode()),
Error: &schemas.ErrorField{
Message: schemas.ErrProviderResponseHTML,
Error: errors.New(string(decodedBody)),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RawResponse: string(decodedBody),
},
}
return bifrostErr
}
createError := func(message string) *schemas.BifrostError {
bifrostErr := providerUtils.NewProviderAPIError(message, nil, resp.StatusCode(), nil, nil)
var rawResponse interface{}
if err := sonic.Unmarshal(decodedBody, &rawResponse); err != nil {
rawResponse = string(decodedBody)
}
bifrostErr.ExtraFields.RawResponse = rawResponse
return bifrostErr
}
if err := sonic.Unmarshal(decodedBody, &openAIErr); err != nil || openAIErr.Error == nil {
// Try Vertex error format if OpenAI format fails or is incomplete
if err := sonic.Unmarshal(decodedBody, &vertexErr); err != nil {
//try with single Vertex error format
var vertexErr VertexError
if err := sonic.Unmarshal(decodedBody, &vertexErr); err != nil {
// Try VertexValidationError format (validation errors from Mistral endpoint)
var validationErr VertexValidationError
if err := sonic.Unmarshal(decodedBody, &validationErr); err != nil {
bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
return bifrostErr
}
if len(validationErr.Detail) > 0 {
return createError(validationErr.Detail[0].Msg)
}
return createError("Unknown error")
}
return createError(vertexErr.Error.Message)
}
if len(vertexErr) > 0 {
return createError(vertexErr[0].Error.Message)
}
return createError("Unknown error")
}
// OpenAI error format succeeded with valid Error field
return createError(openAIErr.Error.Message)
}

View File

@@ -0,0 +1,197 @@
package vertex
import (
"strings"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
// VertexRankRequest represents the Discovery Engine rank API request.
type VertexRankRequest struct {
Model *string `json:"model,omitempty"`
Query string `json:"query"`
Records []VertexRankRecord `json:"records"`
TopN *int `json:"topN,omitempty"`
IgnoreRecordDetailsInResponse *bool `json:"ignoreRecordDetailsInResponse,omitempty"`
UserLabels map[string]string `json:"userLabels,omitempty"`
}
// GetExtraParams implements providerUtils.RequestBodyWithExtraParams.
func (*VertexRankRequest) GetExtraParams() map[string]interface{} {
return nil
}
const (
vertexDefaultRankingConfigID = "default_ranking_config"
vertexDefaultRerankModel = "semantic-ranker-default@latest"
vertexMaxRerankRecordsPerQuery = 200
vertexSyntheticRecordPrefix = "idx:"
)
// VertexRankRecord represents a record for ranking.
type VertexRankRecord struct {
ID string `json:"id"`
Title *string `json:"title,omitempty"`
Content *string `json:"content,omitempty"`
}
// VertexRankResponse represents the Discovery Engine rank API response.
type VertexRankResponse struct {
Records []VertexRankedRecord `json:"records"`
}
// VertexRankedRecord represents a ranked record in response.
type VertexRankedRecord struct {
ID string `json:"id"`
Score float64 `json:"score"`
Title *string `json:"title,omitempty"`
Content *string `json:"content,omitempty"`
}
type vertexRerankOptions struct {
RankingConfig string
IgnoreRecordDetailsInResponse bool
UserLabels map[string]string
}
// ToBifrostListModelsResponse converts a Vertex AI list models response to Bifrost's format.
// It processes both custom models (from the API response) and non-custom models (from deployments and allowedModels).
//
// Custom models are those with digit-only deployment values, extracted from the API response.
// Non-custom models are those with non-digit characters in their deployment values or model names.
//
// The function performs three passes:
// 1. First pass: Process all models from the Vertex AI API response (custom models)
// 2. Second pass: Add non-custom models from deployments that aren't already in the list
// 3. Third pass: Add non-custom models from allowedModels that aren't in deployments or already added
//
// Filtering logic:
// - If allowedModels is empty, all models are allowed
// - If allowedModels is non-empty, only models/deployments with keys in allowedModels are included
// - Deployments map is used to match model IDs to aliases and filter accordingly
func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(response.Models)),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: schemas.Vertex,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
// Process all models from the Vertex AI API response (custom deployed models).
// The model ID is extracted from the endpoint URL last segment.
for _, model := range response.Models {
if len(model.DeployedModels) == 0 {
continue
}
for _, deployedModel := range model.DeployedModels {
endpoint := strings.TrimSuffix(deployedModel.Endpoint, "/")
parts := strings.Split(endpoint, "/")
if len(parts) == 0 {
continue
}
customModelID := parts[len(parts)-1]
if customModelID == "" {
continue
}
for _, result := range pipeline.FilterModel(customModelID) {
resolvedKey := strings.ToLower(result.ResolvedID)
if included[resolvedKey] {
continue
}
modelEntry := schemas.Model{
ID: string(schemas.Vertex) + "/" + result.ResolvedID,
Name: schemas.Ptr(model.DisplayName),
Description: schemas.Ptr(model.Description),
Created: schemas.Ptr(model.VersionCreateTime.Unix()),
}
if result.AliasValue != "" {
modelEntry.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, modelEntry)
included[resolvedKey] = true
}
}
}
bifrostResponse.Data = append(bifrostResponse.Data,
pipeline.BackfillModels(included)...)
bifrostResponse.NextPageToken = response.NextPageToken
return bifrostResponse
}
// ToBifrostListModelsResponse converts a Vertex AI publisher models response to Bifrost's format.
// This is for foundation models from the Model Garden (publishers.models.list endpoint).
func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
bifrostResponse := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0, len(response.PublisherModels)),
}
pipeline := &providerUtils.ListModelsPipeline{
AllowedModels: allowedModels,
BlacklistedModels: blacklistedModels,
Aliases: aliases,
Unfiltered: unfiltered,
ProviderKey: schemas.Vertex,
MatchFns: providerUtils.DefaultMatchFns(),
}
if pipeline.ShouldEarlyExit() {
return bifrostResponse
}
included := make(map[string]bool)
for _, model := range response.PublisherModels {
// Extract model ID from name (format: "publishers/google/models/gemini-1.5-pro")
modelID := extractModelIDFromName(model.Name)
if modelID == "" {
continue
}
for _, result := range pipeline.FilterModel(modelID) {
// Extract display name from supported actions if available
displayName := result.ResolvedID
if model.SupportedActions != nil && model.SupportedActions.Deploy != nil && model.SupportedActions.Deploy.ModelDisplayName != "" {
displayName = model.SupportedActions.Deploy.ModelDisplayName
}
modelEntry := schemas.Model{
ID: string(schemas.Vertex) + "/" + result.ResolvedID,
Name: schemas.Ptr(displayName),
}
if result.AliasValue != "" {
modelEntry.Alias = schemas.Ptr(result.AliasValue)
}
bifrostResponse.Data = append(bifrostResponse.Data, modelEntry)
included[strings.ToLower(result.ResolvedID)] = true
}
}
bifrostResponse.Data = append(bifrostResponse.Data,
pipeline.BackfillModels(included)...)
bifrostResponse.NextPageToken = response.NextPageToken
return bifrostResponse
}

View File

@@ -0,0 +1,290 @@
package vertex
import (
"fmt"
"sort"
"strconv"
"strings"
"github.com/bytedance/sonic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
func buildVertexRankingConfig(projectID, rankingConfigOverride string) (string, error) {
projectID = strings.TrimSpace(projectID)
if projectID == "" {
return "", fmt.Errorf("project ID is required for ranking config")
}
override := strings.TrimSpace(rankingConfigOverride)
if override == "" {
return fmt.Sprintf("projects/%s/locations/global/rankingConfigs/%s", projectID, vertexDefaultRankingConfigID), nil
}
override = strings.TrimSuffix(override, ":rank")
if strings.HasPrefix(override, "projects/") {
return override, nil
}
if strings.Contains(override, "/") {
return "", fmt.Errorf("invalid ranking_config %q: must be resource name or config ID", rankingConfigOverride)
}
return fmt.Sprintf("projects/%s/locations/global/rankingConfigs/%s", projectID, override), nil
}
func getVertexRerankOptions(projectID string, params *schemas.RerankParameters) (*vertexRerankOptions, error) {
options := &vertexRerankOptions{
IgnoreRecordDetailsInResponse: true,
}
if params == nil || params.ExtraParams == nil {
rankingConfig, err := buildVertexRankingConfig(projectID, "")
if err != nil {
return nil, err
}
options.RankingConfig = rankingConfig
return options, nil
}
extraParams := params.ExtraParams
rankingConfigOverride := ""
if rawRankingConfig, exists := extraParams["ranking_config"]; exists {
rankingConfig, ok := schemas.SafeExtractString(rawRankingConfig)
if !ok {
return nil, fmt.Errorf("invalid ranking_config: expected string")
}
rankingConfigOverride = rankingConfig
}
rankingConfig, err := buildVertexRankingConfig(projectID, rankingConfigOverride)
if err != nil {
return nil, err
}
options.RankingConfig = rankingConfig
if rawIgnoreRecordDetails, exists := extraParams["ignore_record_details_in_response"]; exists {
ignoreRecordDetailsInResponse, ok := schemas.SafeExtractBool(rawIgnoreRecordDetails)
if !ok {
return nil, fmt.Errorf("invalid ignore_record_details_in_response: expected bool")
}
options.IgnoreRecordDetailsInResponse = ignoreRecordDetailsInResponse
}
if rawUserLabels, exists := extraParams["user_labels"]; exists {
userLabels, ok := schemas.SafeExtractStringMap(rawUserLabels)
if !ok {
return nil, fmt.Errorf("invalid user_labels: expected map[string]string")
}
options.UserLabels = userLabels
}
return options, nil
}
// ToVertexRankRequest converts a Bifrost rerank request to Discovery Engine rank API format.
func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, options *vertexRerankOptions) (*VertexRankRequest, error) {
if bifrostReq == nil {
return nil, fmt.Errorf("bifrost rerank request is nil")
}
if options == nil {
return nil, fmt.Errorf("vertex rerank options are nil")
}
if len(bifrostReq.Documents) == 0 {
return nil, fmt.Errorf("documents are required for rerank request")
}
if len(bifrostReq.Documents) > vertexMaxRerankRecordsPerQuery {
return nil, fmt.Errorf("vertex rerank supports up to %d records per request", vertexMaxRerankRecordsPerQuery)
}
rankRequest := &VertexRankRequest{
Query: bifrostReq.Query,
Records: make([]VertexRankRecord, len(bifrostReq.Documents)),
}
for i, doc := range bifrostReq.Documents {
recordID := fmt.Sprintf("%s%d", vertexSyntheticRecordPrefix, i)
content := doc.Text
record := VertexRankRecord{
ID: recordID,
Content: &content,
}
if doc.Meta != nil {
if rawTitle, exists := doc.Meta["title"]; exists {
if title, ok := schemas.SafeExtractString(rawTitle); ok && strings.TrimSpace(title) != "" {
record.Title = &title
}
}
}
rankRequest.Records[i] = record
}
if bifrostReq.Params != nil && bifrostReq.Params.TopN != nil {
topN := *bifrostReq.Params.TopN
if topN < 1 {
return nil, fmt.Errorf("top_n must be at least 1")
}
if topN > len(bifrostReq.Documents) {
topN = len(bifrostReq.Documents)
}
rankRequest.TopN = &topN
}
trimmedModel := strings.TrimSpace(bifrostReq.Model)
if trimmedModel == "" {
trimmedModel = vertexDefaultRerankModel
}
rankRequest.Model = &trimmedModel
ignoreRecordDetailsInResponse := options.IgnoreRecordDetailsInResponse
rankRequest.IgnoreRecordDetailsInResponse = &ignoreRecordDetailsInResponse
if len(options.UserLabels) > 0 {
rankRequest.UserLabels = options.UserLabels
}
return rankRequest, nil
}
// ToBifrostRerankRequest converts a Discovery Engine rank request to Bifrost format.
func (req *VertexRankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
if req == nil {
return nil
}
var provider schemas.ModelProvider
var model string
if req.Model != nil {
provider, model = schemas.ParseModelString(*req.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Vertex))
} else {
provider = providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Vertex)
}
bifrostReq := &schemas.BifrostRerankRequest{
Provider: provider,
Model: model,
Query: req.Query,
Params: &schemas.RerankParameters{},
}
// Convert records to documents
for _, record := range req.Records {
doc := schemas.RerankDocument{
ID: &record.ID,
}
if record.Content != nil {
doc.Text = *record.Content
}
if record.Title != nil {
doc.Meta = map[string]interface{}{
"title": *record.Title,
}
}
bifrostReq.Documents = append(bifrostReq.Documents, doc)
}
// Extract TopN
if req.TopN != nil {
bifrostReq.Params.TopN = req.TopN
}
// Pass extra fields as ExtraParams
extraParams := make(map[string]interface{})
if req.IgnoreRecordDetailsInResponse != nil {
extraParams["ignore_record_details_in_response"] = *req.IgnoreRecordDetailsInResponse
}
if len(req.UserLabels) > 0 {
extraParams["user_labels"] = req.UserLabels
}
if len(extraParams) > 0 {
bifrostReq.Params.ExtraParams = extraParams
}
return bifrostReq
}
func parseVertexSyntheticRecordIndex(recordID string, maxDocs int) (int, error) {
if !strings.HasPrefix(recordID, vertexSyntheticRecordPrefix) {
return 0, fmt.Errorf("invalid record id %q: expected prefix %q", recordID, vertexSyntheticRecordPrefix)
}
indexStr := strings.TrimPrefix(recordID, vertexSyntheticRecordPrefix)
index, err := strconv.Atoi(indexStr)
if err != nil {
return 0, fmt.Errorf("invalid record id %q: %w", recordID, err)
}
if index < 0 || index >= maxDocs {
return 0, fmt.Errorf("record id %q maps to out-of-range index %d", recordID, index)
}
return index, nil
}
// ToBifrostRerankResponse converts a Discovery Engine rank response to Bifrost format.
func (response *VertexRankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) (*schemas.BifrostRerankResponse, error) {
if response == nil {
return nil, fmt.Errorf("vertex rerank response is nil")
}
results := make([]schemas.RerankResult, 0, len(response.Records))
seenIndices := make(map[int]struct{}, len(response.Records))
for _, record := range response.Records {
index, err := parseVertexSyntheticRecordIndex(record.ID, len(documents))
if err != nil {
return nil, err
}
if _, seen := seenIndices[index]; seen {
return nil, fmt.Errorf("duplicate record id mapping for index %d", index)
}
seenIndices[index] = struct{}{}
result := schemas.RerankResult{
Index: index,
RelevanceScore: record.Score,
}
if returnDocuments {
doc := documents[index]
result.Document = &doc
}
results = append(results, result)
}
sort.SliceStable(results, func(i, j int) bool {
if results[i].RelevanceScore == results[j].RelevanceScore {
return results[i].Index < results[j].Index
}
return results[i].RelevanceScore > results[j].RelevanceScore
})
return &schemas.BifrostRerankResponse{
Results: results,
}, nil
}
func parseDiscoveryEngineErrorMessage(responseBody []byte) string {
if len(responseBody) == 0 {
return ""
}
var errorResponse map[string]interface{}
if err := sonic.Unmarshal(responseBody, &errorResponse); err == nil {
if rawError, exists := errorResponse["error"]; exists {
if errorMap, ok := rawError.(map[string]interface{}); ok {
if message, ok := schemas.SafeExtractString(errorMap["message"]); ok && strings.TrimSpace(message) != "" {
return message
}
}
}
}
rawString := strings.TrimSpace(string(responseBody))
if rawString == "" {
return ""
}
return rawString
}

View File

@@ -0,0 +1,205 @@
package vertex
import (
"context"
"testing"
"github.com/maximhq/bifrost/core/schemas"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBuildVertexRankingConfig(t *testing.T) {
t.Parallel()
config, err := buildVertexRankingConfig("demo-project", "")
require.NoError(t, err)
assert.Equal(t, "projects/demo-project/locations/global/rankingConfigs/default_ranking_config", config)
config, err = buildVertexRankingConfig("demo-project", "custom_rank")
require.NoError(t, err)
assert.Equal(t, "projects/demo-project/locations/global/rankingConfigs/custom_rank", config)
config, err = buildVertexRankingConfig("demo-project", "projects/other/locations/global/rankingConfigs/custom_rank:rank")
require.NoError(t, err)
assert.Equal(t, "projects/other/locations/global/rankingConfigs/custom_rank", config)
_, err = buildVertexRankingConfig("demo-project", "locations/global/rankingConfigs/custom_rank")
require.Error(t, err)
}
func TestToVertexRankRequest(t *testing.T) {
t.Parallel()
req, err := ToVertexRankRequest(
&schemas.BifrostRerankRequest{
Query: "capital of france",
Documents: []schemas.RerankDocument{
{Text: "Paris is the capital of France.", Meta: map[string]interface{}{"title": "Doc A"}},
{Text: "Berlin is the capital of Germany."},
},
Params: &schemas.RerankParameters{
TopN: schemas.Ptr(10),
},
},
&vertexRerankOptions{
RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config",
IgnoreRecordDetailsInResponse: true,
},
)
require.NoError(t, err)
require.NotNil(t, req)
require.NotNil(t, req.Model)
assert.Equal(t, "semantic-ranker-default@latest", *req.Model)
require.Len(t, req.Records, 2)
assert.Equal(t, "idx:0", req.Records[0].ID)
assert.Equal(t, "idx:1", req.Records[1].ID)
require.NotNil(t, req.Records[0].Title)
assert.Equal(t, "Doc A", *req.Records[0].Title)
require.NotNil(t, req.TopN)
assert.Equal(t, 2, *req.TopN, "topN should be clamped to document count")
require.NotNil(t, req.IgnoreRecordDetailsInResponse)
assert.True(t, *req.IgnoreRecordDetailsInResponse)
}
func TestToVertexRankRequestTooManyRecords(t *testing.T) {
t.Parallel()
docs := make([]schemas.RerankDocument, 201)
for i := range docs {
docs[i] = schemas.RerankDocument{Text: "doc"}
}
_, err := ToVertexRankRequest(
&schemas.BifrostRerankRequest{
Query: "q",
Documents: docs,
},
&vertexRerankOptions{
RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config",
IgnoreRecordDetailsInResponse: true,
},
)
require.Error(t, err)
assert.Contains(t, err.Error(), "supports up to")
}
func TestGetVertexRerankOptions(t *testing.T) {
t.Parallel()
options, err := getVertexRerankOptions("project-x", &schemas.RerankParameters{
ExtraParams: map[string]interface{}{
"ranking_config": "custom_rank",
"ignore_record_details_in_response": false,
"user_labels": map[string]interface{}{
"env": "test",
},
},
})
require.NoError(t, err)
assert.Equal(t, "projects/project-x/locations/global/rankingConfigs/custom_rank", options.RankingConfig)
assert.False(t, options.IgnoreRecordDetailsInResponse)
assert.Equal(t, map[string]string{"env": "test"}, options.UserLabels)
}
func TestVertexRankResponseToBifrostRerankResponse(t *testing.T) {
t.Parallel()
docs := []schemas.RerankDocument{
{Text: "doc-0"},
{Text: "doc-1"},
{Text: "doc-2"},
}
response, err := (&VertexRankResponse{
Records: []VertexRankedRecord{
{ID: "idx:2", Score: 0.12},
{ID: "idx:1", Score: 0.91},
{ID: "idx:0", Score: 0.91},
},
}).ToBifrostRerankResponse(docs, true)
require.NoError(t, err)
require.NotNil(t, response)
require.Len(t, response.Results, 3)
assert.Equal(t, 0, response.Results[0].Index)
assert.Equal(t, 1, response.Results[1].Index)
assert.Equal(t, 2, response.Results[2].Index)
require.NotNil(t, response.Results[0].Document)
assert.Equal(t, "doc-0", response.Results[0].Document.Text)
}
func TestVertexRankRequestToBifrostRerankRequest(t *testing.T) {
t.Parallel()
topN := 5
model := "semantic-ranker-default@latest"
ignoreDetails := true
title := "Doc A"
content1 := "Paris is the capital of France."
content2 := "Berlin is the capital of Germany."
req := &VertexRankRequest{
Model: &model,
Query: "capital of france",
Records: []VertexRankRecord{
{ID: "rec-1", Content: &content1, Title: &title},
{ID: "rec-2", Content: &content2},
},
TopN: &topN,
IgnoreRecordDetailsInResponse: &ignoreDetails,
UserLabels: map[string]string{"env": "test"},
}
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
result := req.ToBifrostRerankRequest(bifrostCtx)
require.NotNil(t, result)
assert.Equal(t, schemas.Vertex, result.Provider)
assert.Equal(t, "semantic-ranker-default@latest", result.Model)
assert.Equal(t, "capital of france", result.Query)
require.Len(t, result.Documents, 2)
// First document has ID, content, and title in meta
require.NotNil(t, result.Documents[0].ID)
assert.Equal(t, "rec-1", *result.Documents[0].ID)
assert.Equal(t, "Paris is the capital of France.", result.Documents[0].Text)
require.NotNil(t, result.Documents[0].Meta)
assert.Equal(t, "Doc A", result.Documents[0].Meta["title"])
// Second document has no title
require.NotNil(t, result.Documents[1].ID)
assert.Equal(t, "rec-2", *result.Documents[1].ID)
assert.Equal(t, "Berlin is the capital of Germany.", result.Documents[1].Text)
assert.Nil(t, result.Documents[1].Meta)
// TopN
require.NotNil(t, result.Params)
require.NotNil(t, result.Params.TopN)
assert.Equal(t, 5, *result.Params.TopN)
// ExtraParams
require.NotNil(t, result.Params.ExtraParams)
assert.Equal(t, true, result.Params.ExtraParams["ignore_record_details_in_response"])
assert.Equal(t, map[string]string{"env": "test"}, result.Params.ExtraParams["user_labels"])
}
func TestVertexRankRequestToBifrostRerankRequestNil(t *testing.T) {
t.Parallel()
var req *VertexRankRequest
assert.Nil(t, req.ToBifrostRerankRequest(nil))
}
func TestVertexRankResponseToBifrostRerankResponseInvalidID(t *testing.T) {
t.Parallel()
_, err := (&VertexRankResponse{
Records: []VertexRankedRecord{
{ID: "bad-id", Score: 0.9},
},
}).ToBifrostRerankResponse([]schemas.RerankDocument{{Text: "doc"}}, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid record id")
}

View File

@@ -0,0 +1,245 @@
package vertex
import (
"time"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
)
// Vertex AI Embedding API types
const (
DefaultVertexAnthropicVersion = "vertex-2023-10-16"
)
// PhoneticEncoding represents the phonetic encoding of a phrase.
type PhoneticEncoding string
const (
PhoneticEncodingUnspecified PhoneticEncoding = "PHONETIC_ENCODING_UNSPECIFIED"
PhoneticEncodingIPA PhoneticEncoding = "PHONETIC_ENCODING_IPA"
PhoneticEncodingXSAMPA PhoneticEncoding = "PHONETIC_ENCODING_X_SAMPA"
PhoneticEncodingJapaneseYomigana PhoneticEncoding = "PHONETIC_ENCODING_JAPANESE_YOMIGANA"
PhoneticEncodingPinyin PhoneticEncoding = "PHONETIC_ENCODING_PINYIN"
)
// CustomPronunciationParams represents pronunciation customization for a phrase.
type CustomPronunciationParams struct {
Phrase string `json:"phrase,omitempty"`
PhoneticEncoding PhoneticEncoding `json:"phoneticEncoding,omitempty"`
Pronunciation string `json:"pronunciation,omitempty"`
}
// CustomPronunciations represents a collection of pronunciation customizations.
type CustomPronunciations struct {
Pronunciations []CustomPronunciationParams `json:"pronunciations,omitempty"`
}
// Turn represents a multi-speaker turn.
type Turn struct {
Speaker string `json:"speaker,omitempty"`
Text string `json:"text,omitempty"`
}
// MultiSpeakerMarkup represents a collection of turns for multi-speaker synthesis.
type MultiSpeakerMarkup struct {
Turns []Turn `json:"turns,omitempty"`
}
// VertexSynthesisInput contains text input to be synthesized.
type VertexSynthesisInput struct {
Text *string `json:"text,omitempty"`
Markup *string `json:"markup,omitempty"`
SSML *string `json:"ssml,omitempty"`
MultiSpeakerMarkup *MultiSpeakerMarkup `json:"multiSpeakerMarkup,omitempty"`
Prompt *string `json:"prompt,omitempty"`
CustomPronunciations *CustomPronunciations `json:"customPronunciations,omitempty"`
}
// VertexVoiceSelectionParams represents voice selection parameters for TTS synthesis.
type VertexVoiceSelectionParams struct {
LanguageCode string `json:"languageCode,omitempty"`
Name string `json:"name,omitempty"`
SsmlGender string `json:"ssmlGender,omitempty"`
}
// VertexAudioConfig represents audio configuration for TTS synthesis.
type VertexAudioConfig struct {
AudioEncoding string `json:"audioEncoding,omitempty"`
SpeakingRate float64 `json:"speakingRate,omitempty"`
Pitch float64 `json:"pitch,omitempty"`
VolumeGainDB float64 `json:"volumeGainDb,omitempty"`
SampleRateHertz int `json:"sampleRateHertz,omitempty"`
EffectsProfileID []string `json:"effectsProfileId,omitempty"`
}
type VertexRequestBody struct {
RequestBody map[string]interface{} `json:"-"`
ExtraParams map[string]interface{} `json:"-"`
}
func (r *VertexRequestBody) GetExtraParams() map[string]interface{} {
return r.ExtraParams
}
// MarshalJSON implements custom JSON marshalling for VertexRequestBody.
// It marshals the RequestBody field directly without wrapping.
func (r *VertexRequestBody) MarshalJSON() ([]byte, error) {
return providerUtils.MarshalSorted(r.RequestBody)
}
// VertexRawRequestBody holds pre-serialized JSON bytes to preserve key ordering
// for LLM prompt caching. This avoids the map[string]interface{} round-trip that
// destroys key order.
type VertexRawRequestBody struct {
RawBody []byte `json:"-"`
ExtraParams map[string]interface{} `json:"-"`
}
func (r *VertexRawRequestBody) GetExtraParams() map[string]interface{} {
return r.ExtraParams
}
// MarshalJSON returns the pre-serialized JSON bytes directly, preserving key order.
func (r *VertexRawRequestBody) MarshalJSON() ([]byte, error) {
return r.RawBody, nil
}
// VertexAdvancedVoiceOptions represents advanced voice options for TTS synthesis.
type VertexAdvancedVoiceOptions struct {
LowLatencyJourneySynthesis bool `json:"lowLatencyJourneySynthesis,omitempty"`
}
// VertexEmbeddingInstance represents a single embedding instance in the request
type VertexEmbeddingInstance struct {
Content string `json:"content"` // The text to generate embeddings for
TaskType *string `json:"task_type,omitempty"` // Intended downstream application (optional)
Title *string `json:"title,omitempty"` // Used to help the model produce better embeddings (optional)
}
// VertexEmbeddingParameters represents the parameters for the embedding request
type VertexEmbeddingParameters struct {
AutoTruncate *bool `json:"autoTruncate,omitempty"` // When true, input text will be truncated (defaults to true)
OutputDimensionality *int `json:"outputDimensionality,omitempty"` // Output embedding size (optional)
}
// VertexEmbeddingRequest represents the complete embedding request to Vertex AI
type VertexEmbeddingRequest struct {
Instances []VertexEmbeddingInstance `json:"instances"` // List of embedding instances
Parameters *VertexEmbeddingParameters `json:"parameters,omitempty"` // Optional parameters
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
}
func (r *VertexEmbeddingRequest) GetExtraParams() map[string]interface{} {
return r.ExtraParams
}
// VertexEmbeddingStatistics represents statistics computed from the input text
type VertexEmbeddingStatistics struct {
Truncated bool `json:"truncated"` // Whether the input text was truncated
TokenCount int `json:"token_count"` // Number of tokens in the input text
}
// VertexEmbeddingValues represents the embedding result
type VertexEmbeddingValues struct {
Values []float64 `json:"values"` // The embedding vector (list of floats)
Statistics *VertexEmbeddingStatistics `json:"statistics"` // Statistics about the input text
}
// VertexEmbeddingPrediction represents a single prediction in the response
type VertexEmbeddingPrediction struct {
Embeddings *VertexEmbeddingValues `json:"embeddings"` // The embedding result
}
// VertexEmbeddingResponse represents the complete embedding response from Vertex AI
type VertexEmbeddingResponse struct {
Predictions []VertexEmbeddingPrediction `json:"predictions"` // List of embedding predictions
}
// ================================ Model Types ================================
const MaxPageSize = 100
type VertexModel struct {
Name string `json:"name"`
VersionId string `json:"versionId"`
VersionAliases []string `json:"versionAliases"`
VersionCreateTime time.Time `json:"versionCreateTime"`
DisplayName string `json:"displayName"`
Description string `json:"description"`
DeployedModels []VertexDeployedModel `json:"deployedModels"`
Labels VertexModelLabels `json:"labels"`
}
type VertexListModelsResponse struct {
Models []VertexModel `json:"models"`
NextPageToken string `json:"nextPageToken"`
}
type VertexDeployedModel struct {
CheckpointID string `json:"checkpointId"`
DeploymentID string `json:"deploymentId"`
Endpoint string `json:"endpoint"`
}
type VertexModelLabels struct {
GoogleVertexLLMTuningBaseModelId string `json:"google-vertex-llm-tuning-base-model-id"`
GoogleVertexLLMTuningJobId string `json:"google-vertex-llm-tuning-job-id"`
TuneType string `json:"tune-type"`
}
// ================================ Publisher Model Types ================================
// These types are for the publishers.models.list endpoint (Model Garden)
type VertexPublisherModel struct {
Name string `json:"name"`
VersionID string `json:"versionId"`
OpenSourceCategory string `json:"openSourceCategory"`
LaunchStage string `json:"launchStage"`
VersionState string `json:"versionState"`
PublisherModelTemplate string `json:"publisherModelTemplate"`
SupportedActions *VertexPublisherModelActions `json:"supportedActions"`
}
type VertexPublisherModelActions struct {
OpenGenerationAIStudio *VertexPublisherModelURI `json:"openGenerationAiStudio"`
OpenGenie *VertexPublisherModelURI `json:"openGenie"`
OpenPromptTuningPipeline *VertexPublisherModelURI `json:"openPromptTuningPipeline"`
OpenNotebook *VertexPublisherModelURI `json:"openNotebook"`
OpenFineTuningPipeline *VertexPublisherModelURI `json:"openFineTuningPipeline"`
Deploy *VertexPublisherModelDeploy `json:"deploy"`
OpenEvaluationPipeline *VertexPublisherModelURI `json:"openEvaluationPipeline"`
}
type VertexPublisherModelURI struct {
URI string `json:"uri"`
}
type VertexPublisherModelDeploy struct {
ModelDisplayName string `json:"modelDisplayName"`
Title string `json:"title"`
}
type VertexListPublisherModelsResponse struct {
PublisherModels []VertexPublisherModel `json:"publisherModels"`
NextPageToken string `json:"nextPageToken"`
}
// ==================== ERROR TYPES ====================
// VertexValidationError represents validation errors
// returned by the Vertex Mistral endpoint
type VertexValidationError struct {
Detail []struct {
Input any `json:"input"` // can be number, object, or array
Loc []any `json:"loc"` // location of the error (can contain strings and numeric indices)
Msg string `json:"msg"` // error message
Type string `json:"type"` // error type (e.g., "extra_forbidden", "missing")
} `json:"detail"`
}
// VertexCountTokensResponse models the response payload for Vertex's Gemini-style countTokens.
// Vertex uses camelCase unlike other request json body.
type VertexCountTokensResponse struct {
TotalTokens int32 `json:"totalTokens,omitempty"`
CachedContentTokenCount int32 `json:"cachedContentTokenCount,omitempty"`
}

View File

@@ -0,0 +1,256 @@
package vertex
import (
"fmt"
"strings"
"github.com/maximhq/bifrost/core/providers/anthropic"
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
"github.com/maximhq/bifrost/core/schemas"
)
func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool, isCountTokens bool, betaHeaderOverrides map[string]bool, providerExtraHeaders map[string]string) ([]byte, *schemas.BifrostError) {
// Large payload mode: body streams directly from the LP reader — skip all body building
// (matches CheckContextAndGetRequestBody guard).
if providerUtils.IsLargePayloadPassthroughEnabled(ctx) {
return nil, nil
}
var jsonBody []byte
var err error
// Check if raw request body should be used
if useRawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && useRawBody {
jsonBody = request.GetRawRequestBody()
if isCountTokens {
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
} else {
// Add max_tokens if not present
if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") {
jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens))
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
// Add stream if streaming
if isStreaming {
jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
}
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
// Remap unsupported tool versions for Vertex (e.g., web_search_20260209 → web_search_20250305)
jsonBody, err = anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(err.Error(), nil)
}
// Add anthropic_version if not present
if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") {
jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
} else {
// Validate tools are supported by Vertex
if request.Params != nil && request.Params.Tools != nil {
if toolErr := anthropic.ValidateToolsForProvider(request.Params.Tools, schemas.Vertex); toolErr != nil {
return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil)
}
}
// Convert request to Anthropic format
reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request)
if convErr != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr)
}
if reqBody == nil {
return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil)
}
reqBody.Model = deployment
if isStreaming {
reqBody.Stream = schemas.Ptr(true)
}
reqBody.SetStripCacheControlScope(true)
// Add provider-aware beta headers
anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex)
// Marshal struct to JSON bytes
jsonBody, err = providerUtils.MarshalSorted(reqBody)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
// Add anthropic_version if not present (using sjson to preserve order)
if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") {
jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
if isCountTokens {
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
} else {
// Remove model field for Vertex API (it's in URL)
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region")
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(providerExtraHeaders, ctx), schemas.Vertex, betaHeaderOverrides); len(betaHeaders) > 0 {
jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_beta", betaHeaders)
if err != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
}
}
return jsonBody, nil
}
// getCompleteURLForGeminiEndpoint constructs the complete URL for the Gemini endpoint, for both streaming and non-streaming requests
// for custom/fine-tuned models, it uses the projectNumber
// for gemini models, it uses the projectID
func getCompleteURLForGeminiEndpoint(deployment string, region string, projectID string, projectNumber string, method string) string {
var url string
if schemas.IsAllDigitsASCII(deployment) {
// Custom/fine-tuned models use projectNumber
if region == "global" {
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s%s", projectNumber, deployment, method)
} else {
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s%s", region, projectNumber, region, deployment, method)
}
} else {
// Gemini models use projectID
if region == "global" {
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s%s", projectID, deployment, method)
} else {
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s%s", region, projectID, region, deployment, method)
}
}
return url
}
// buildResponseFromConfig builds a list models response from configured deployments and allowedModels.
// This is used when the user has explicitly configured which models they want to use.
func buildResponseFromConfig(deployments map[string]string, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList) *schemas.BifrostListModelsResponse {
response := &schemas.BifrostListModelsResponse{
Data: make([]schemas.Model, 0),
}
if blacklistedModels.IsBlockAll() {
return response
}
addedModelIDs := make(map[string]bool)
restrictAllowed := allowedModels.IsRestricted()
// First add models from deployments (filtered by allowedModels when set)
for alias, deploymentValue := range deployments {
if restrictAllowed && !allowedModels.Contains(alias) {
continue
}
if blacklistedModels.IsBlocked(alias) {
continue
}
modelID := string(schemas.Vertex) + "/" + alias
if addedModelIDs[modelID] {
continue
}
modelName := providerUtils.ToDisplayName(alias)
modelEntry := schemas.Model{
ID: modelID,
Name: schemas.Ptr(modelName),
Alias: schemas.Ptr(deploymentValue),
}
response.Data = append(response.Data, modelEntry)
addedModelIDs[modelID] = true
}
// Then add models from allowedModels that aren't already in deployments (only when restricted)
if !restrictAllowed {
return response
}
for _, allowedModel := range allowedModels {
modelID := string(schemas.Vertex) + "/" + allowedModel
if addedModelIDs[modelID] {
continue
}
if blacklistedModels.IsBlocked(allowedModel) {
continue
}
modelName := providerUtils.ToDisplayName(allowedModel)
modelEntry := schemas.Model{
ID: modelID,
Name: schemas.Ptr(modelName),
}
response.Data = append(response.Data, modelEntry)
addedModelIDs[modelID] = true
}
return response
}
// extractModelIDFromName extracts the model ID from a full resource name.
// Format: "publishers/google/models/gemini-1.5-pro" -> "gemini-1.5-pro"
func extractModelIDFromName(name string) string {
parts := strings.Split(name, "/")
if len(parts) >= 4 && parts[2] == "models" {
return parts[3]
}
// Fallback: return last segment
if len(parts) > 0 {
return parts[len(parts)-1]
}
return ""
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,238 @@
package vertex_test
import (
"encoding/json"
"strings"
"testing"
"github.com/maximhq/bifrost/core/providers/anthropic"
"github.com/maximhq/bifrost/core/schemas"
)
// TestVertex_AnthropicModel_CachingDeterminism verifies that Vertex's delegation
// to anthropic.ToAnthropicChatRequest() produces deterministic JSON for prompt caching.
// Two schemas with the same properties but different structural key order within
// property definitions must produce byte-identical JSON after normalization.
func TestVertex_AnthropicModel_CachingDeterminism(t *testing.T) {
makeReq := func(props *schemas.OrderedMap) *schemas.BifrostChatRequest {
return &schemas.BifrostChatRequest{
Provider: schemas.Vertex,
Model: "claude-sonnet-4-20250514",
Input: []schemas.ChatMessage{{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
}},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "test",
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: props,
},
},
}},
},
}
}
// Version A: type before description
propsA := schemas.NewOrderedMapFromPairs(
schemas.KV("chain_of_thought", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "string"),
schemas.KV("description", "Reasoning"),
)),
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
schemas.KV("type", "string"),
schemas.KV("description", "The answer"),
)),
)
// Version B: description before type (different structural order)
propsB := schemas.NewOrderedMapFromPairs(
schemas.KV("chain_of_thought", schemas.NewOrderedMapFromPairs(
schemas.KV("description", "Reasoning"),
schemas.KV("type", "string"),
)),
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
schemas.KV("description", "The answer"),
schemas.KV("type", "string"),
)),
)
// Vertex delegates Anthropic models to anthropic.ToAnthropicChatRequest()
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
resultA, err := anthropic.ToAnthropicChatRequest(ctx, makeReq(propsA))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
resultB, err := anthropic.ToAnthropicChatRequest(ctx, makeReq(propsB))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
jsonA, err := schemas.Marshal(resultA.Tools[0].InputSchema)
if err != nil {
t.Fatalf("failed to marshal params A: %v", err)
}
jsonB, err := schemas.Marshal(resultB.Tools[0].InputSchema)
if err != nil {
t.Fatalf("failed to marshal params B: %v", err)
}
// Caching: byte-identical JSON
if string(jsonA) != string(jsonB) {
t.Errorf("caching broken via Vertex→Anthropic path: same schema produced different JSON\nA: %s\nB: %s", jsonA, jsonB)
}
// CoT: property order preserved
keys := resultA.Tools[0].InputSchema.Properties.Keys()
if len(keys) != 2 || keys[0] != "chain_of_thought" || keys[1] != "answer" {
t.Errorf("expected property order [chain_of_thought, answer], got %v", keys)
}
}
// TestVertex_AnthropicModel_PreservesPropertyOrder verifies that the
// Vertex→Anthropic delegation path preserves user-defined property ordering.
func TestVertex_AnthropicModel_PreservesPropertyOrder(t *testing.T) {
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Vertex,
Model: "claude-sonnet-4-20250514",
Input: []schemas.ChatMessage{{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
}},
Params: &schemas.ChatParameters{
Tools: []schemas.ChatTool{{
Type: schemas.ChatToolTypeFunction,
Function: &schemas.ChatToolFunction{
Name: "AnswerResponseModel",
Description: schemas.Ptr("Extract answer"),
Parameters: &schemas.ToolFunctionParameters{
Type: "object",
Properties: schemas.NewOrderedMapFromPairs(
schemas.KV("chain_of_thought", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
schemas.KV("answer", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
schemas.KV("citations", schemas.NewOrderedMapFromPairs(schemas.KV("type", "array"))),
schemas.KV("is_unanswered", schemas.NewOrderedMapFromPairs(schemas.KV("type", "boolean"))),
),
Required: []string{"answer", "is_unanswered"},
},
},
}},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result, err := anthropic.ToAnthropicChatRequest(ctx, bifrostReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
keys := result.Tools[0].InputSchema.Properties.Keys()
expected := []string{"chain_of_thought", "answer", "citations", "is_unanswered"}
if len(keys) != len(expected) {
t.Fatalf("expected %d properties, got %d: %v", len(expected), len(keys), keys)
}
for i, k := range expected {
if keys[i] != k {
t.Errorf("property %d: expected %q, got %q (full order: %v)", i, k, keys[i], keys)
}
}
}
// TestVertex_ToolInputKeyOrderPreservation verifies that tool call arguments
// preserve their original key ordering through the Vertex→Anthropic delegation path.
// TestVertex_ToolInputKeyOrderPreservation verifies that Vertex→Anthropic delegation
// preserves the original key ordering of tool call arguments for prompt caching.
// Tests multiple parallel tool calls with different key orderings per block.
func TestVertex_ToolInputKeyOrderPreservation(t *testing.T) {
bifrostReq := &schemas.BifrostChatRequest{
Provider: schemas.Vertex,
Model: "claude-sonnet-4-20250514",
Input: []schemas.ChatMessage{
{
Role: schemas.ChatMessageRoleUser,
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
},
{
Role: schemas.ChatMessageRoleAssistant,
ChatAssistantMessage: &schemas.ChatAssistantMessage{
ToolCalls: []schemas.ChatAssistantMessageToolCall{
{
Index: 0,
Type: schemas.Ptr("function"),
ID: schemas.Ptr("toolu_vrtx_001"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("bash"),
Arguments: `{"description":"Find references quickly","timeout":30000,"command":"grep -r auth_injector ."}`,
},
},
{
Index: 1,
Type: schemas.Ptr("function"),
ID: schemas.Ptr("toolu_vrtx_002"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("bash"),
Arguments: `{"command":"git diff main...HEAD --stat","description":"Show diff"}`,
},
},
{
Index: 2,
Type: schemas.Ptr("function"),
ID: schemas.Ptr("toolu_vrtx_003"),
Function: schemas.ChatAssistantMessageToolCallFunction{
Name: schemas.Ptr("bash"),
Arguments: `{"command":"git log main..HEAD","description":"Show commits"}`,
},
},
},
},
},
},
}
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
defer cancel()
result, err := anthropic.ToAnthropicChatRequest(ctx, bifrostReq)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Collect all tool_use content blocks
var toolUseBlocks []struct{ jsonStr string }
for _, msg := range result.Messages {
for _, block := range msg.Content.ContentBlocks {
if block.Type == "tool_use" {
jsonBytes, _ := json.Marshal(block.Input)
toolUseBlocks = append(toolUseBlocks, struct{ jsonStr string }{string(jsonBytes)})
}
}
}
if len(toolUseBlocks) != 3 {
t.Fatalf("expected 3 tool_use blocks, got %d", len(toolUseBlocks))
}
// Block 0: keys should be description, timeout, command (NOT alphabetical)
s0 := toolUseBlocks[0].jsonStr
if !(strings.Index(s0, "description") < strings.Index(s0, "timeout") &&
strings.Index(s0, "timeout") < strings.Index(s0, "command")) {
t.Errorf("block 0: key order not preserved, expected description < timeout < command in: %s", s0)
}
// Block 1: keys should be command, description (NOT alphabetical)
s1 := toolUseBlocks[1].jsonStr
if !(strings.Index(s1, "command") < strings.Index(s1, "description")) {
t.Errorf("block 1: key order not preserved, expected command < description in: %s", s1)
}
// Block 2: keys should be command, description
s2 := toolUseBlocks[2].jsonStr
if !(strings.Index(s2, "command") < strings.Index(s2, "description")) {
t.Errorf("block 2: key order not preserved, expected command < description in: %s", s2)
}
}

View File

@@ -0,0 +1,81 @@
package vertex_test
import (
"os"
"strings"
"testing"
"github.com/maximhq/bifrost/core/internal/llmtests"
"github.com/maximhq/bifrost/core/schemas"
)
func TestVertex(t *testing.T) {
t.Parallel()
if strings.TrimSpace(os.Getenv("VERTEX_API_KEY")) == "" && (strings.TrimSpace(os.Getenv("VERTEX_PROJECT_ID")) == "" || strings.TrimSpace(os.Getenv("VERTEX_CREDENTIALS")) == "") {
t.Skip("Skipping Vertex tests because VERTEX_API_KEY is not set and VERTEX_PROJECT_ID or VERTEX_CREDENTIALS is not set")
}
client, ctx, cancel, err := llmtests.SetupTest()
if err != nil {
t.Fatalf("Error initializing test setup: %v", err)
}
defer cancel()
defer client.Shutdown()
rerankModel := strings.TrimSpace(os.Getenv("VERTEX_RERANK_MODEL"))
testConfig := llmtests.ComprehensiveTestConfig{
Provider: schemas.Vertex,
ChatModel: "gemini-2.5-pro",
PromptCachingModel: "claude-sonnet-4-5",
VisionModel: "claude-sonnet-4-5",
TextModel: "", // Vertex doesn't support text completion in newer models
EmbeddingModel: "text-multilingual-embedding-002",
RerankModel: rerankModel,
ReasoningModel: "claude-4.5-haiku",
ImageGenerationModel: "gemini-2.5-flash-image",
ImageEditModel: "imagen-3.0-capability-001",
VideoGenerationModel: "veo-3.1-generate-preview",
Scenarios: llmtests.TestScenarios{
TextCompletion: false, // Not supported
SimpleChat: true,
CompletionStream: true,
MultiTurnConversation: true,
ToolCalls: true,
ToolCallsStreaming: true,
MultipleToolCalls: true,
MultipleToolCallsStreaming: true,
End2EndToolCalling: true,
AutomaticFunctionCall: true,
ImageURL: false,
ImageBase64: true,
ImageGeneration: true,
ImageGenerationStream: false,
ImageEdit: true,
VideoGeneration: false, // disabled for now because of long running operations
VideoRetrieve: false,
VideoRemix: false,
VideoDownload: false,
VideoList: false,
VideoDelete: false,
MultipleImages: true,
CompleteEnd2End: true,
FileBase64: true,
Embedding: true,
Rerank: rerankModel != "",
Reasoning: true,
PromptCaching: true,
ListModels: false,
CountTokens: true,
StructuredOutputs: true, // Structured outputs with nullable enum support
InterleavedThinking: true,
EagerInputStreaming: true, // fine-grained-tool-streaming-2025-05-14 (GA on Vertex)
ServerToolsViaOpenAIEndpoint: true, // web_search only on Vertex per Table 20 (web_fetch/code_execution skip)
},
}
t.Run("VertexTests", func(t *testing.T) {
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
})
}