first commit
This commit is contained in:
27
core/providers/vertex/count_tokens.go
Normal file
27
core/providers/vertex/count_tokens.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func (resp *VertexCountTokensResponse) ToBifrostCountTokensResponse(model string) *schemas.BifrostCountTokensResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
inputDetails := &schemas.ResponsesResponseInputTokens{}
|
||||
inputTokens := int(resp.TotalTokens) // Vertex response typically represents prompt tokens for countTokens
|
||||
total := int(resp.TotalTokens)
|
||||
|
||||
if resp.CachedContentTokenCount > 0 {
|
||||
inputDetails.CachedReadTokens = int(resp.CachedContentTokenCount)
|
||||
}
|
||||
|
||||
return &schemas.BifrostCountTokensResponse{
|
||||
Model: model,
|
||||
Object: "response.input_tokens",
|
||||
InputTokens: inputTokens,
|
||||
InputTokensDetails: inputDetails,
|
||||
TotalTokens: &total,
|
||||
}
|
||||
}
|
||||
115
core/providers/vertex/embedding.go
Normal file
115
core/providers/vertex/embedding.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// ToVertexEmbeddingRequest converts a Bifrost embedding request to Vertex AI format
|
||||
func ToVertexEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *VertexEmbeddingRequest {
|
||||
if bifrostReq == nil || bifrostReq.Input == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) {
|
||||
return nil
|
||||
}
|
||||
// Create the request
|
||||
vertexReq := &VertexEmbeddingRequest{}
|
||||
if bifrostReq.Params != nil {
|
||||
vertexReq.ExtraParams = bifrostReq.Params.ExtraParams
|
||||
}
|
||||
var texts []string
|
||||
if bifrostReq.Input.Text != nil {
|
||||
texts = []string{*bifrostReq.Input.Text}
|
||||
} else {
|
||||
texts = bifrostReq.Input.Texts
|
||||
}
|
||||
|
||||
// Create instances for each text
|
||||
instances := make([]VertexEmbeddingInstance, 0, len(texts))
|
||||
for _, text := range texts {
|
||||
instance := VertexEmbeddingInstance{
|
||||
Content: text,
|
||||
}
|
||||
|
||||
// Add optional task_type and title from params
|
||||
if bifrostReq.Params != nil {
|
||||
if taskTypeStr, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["task_type"]); ok {
|
||||
delete(vertexReq.ExtraParams, "task_type")
|
||||
instance.TaskType = taskTypeStr
|
||||
}
|
||||
if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok {
|
||||
delete(vertexReq.ExtraParams, "title")
|
||||
instance.Title = title
|
||||
}
|
||||
}
|
||||
|
||||
instances = append(instances, instance)
|
||||
}
|
||||
vertexReq.Instances = instances
|
||||
// Add parameters if present
|
||||
if bifrostReq.Params != nil {
|
||||
parameters := &VertexEmbeddingParameters{}
|
||||
|
||||
// Set autoTruncate (defaults to true)
|
||||
autoTruncate := true
|
||||
if bifrostReq.Params.ExtraParams != nil {
|
||||
if autoTruncateVal, ok := schemas.SafeExtractBool(bifrostReq.Params.ExtraParams["autoTruncate"]); ok {
|
||||
delete(vertexReq.ExtraParams, "autoTruncate")
|
||||
autoTruncate = autoTruncateVal
|
||||
}
|
||||
}
|
||||
parameters.AutoTruncate = &autoTruncate
|
||||
|
||||
// Add outputDimensionality if specified
|
||||
if bifrostReq.Params.Dimensions != nil {
|
||||
delete(vertexReq.ExtraParams, "dimensions")
|
||||
parameters.OutputDimensionality = bifrostReq.Params.Dimensions
|
||||
}
|
||||
|
||||
vertexReq.Parameters = parameters
|
||||
}
|
||||
|
||||
return vertexReq
|
||||
}
|
||||
|
||||
// ToBifrostEmbeddingResponse converts a Vertex AI embedding response to Bifrost format
|
||||
func (response *VertexEmbeddingResponse) ToBifrostEmbeddingResponse() *schemas.BifrostEmbeddingResponse {
|
||||
if response == nil || len(response.Predictions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert predictions to Bifrost embeddings
|
||||
embeddings := make([]schemas.EmbeddingData, 0, len(response.Predictions))
|
||||
var usage *schemas.BifrostLLMUsage
|
||||
|
||||
for i, prediction := range response.Predictions {
|
||||
if prediction.Embeddings == nil || len(prediction.Embeddings.Values) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create embedding object
|
||||
embedding := schemas.EmbeddingData{
|
||||
Object: "embedding",
|
||||
Embedding: schemas.EmbeddingStruct{
|
||||
EmbeddingArray: append([]float64(nil), prediction.Embeddings.Values...),
|
||||
},
|
||||
Index: i,
|
||||
}
|
||||
|
||||
// Extract statistics if available
|
||||
if prediction.Embeddings.Statistics != nil {
|
||||
if usage == nil {
|
||||
usage = &schemas.BifrostLLMUsage{}
|
||||
}
|
||||
usage.TotalTokens += prediction.Embeddings.Statistics.TokenCount
|
||||
usage.PromptTokens += prediction.Embeddings.Statistics.TokenCount
|
||||
}
|
||||
|
||||
embeddings = append(embeddings, embedding)
|
||||
}
|
||||
|
||||
return &schemas.BifrostEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: embeddings,
|
||||
Usage: usage,
|
||||
ExtraFields: schemas.BifrostResponseExtraFields{
|
||||
},
|
||||
}
|
||||
}
|
||||
88
core/providers/vertex/errors.go
Normal file
88
core/providers/vertex/errors.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func parseVertexError(resp *fasthttp.Response) *schemas.BifrostError {
|
||||
var openAIErr schemas.BifrostError
|
||||
var vertexErr []VertexError
|
||||
|
||||
decodedBody, err := providerUtils.CheckAndDecodeBody(resp)
|
||||
if err != nil {
|
||||
bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, err)
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
// Check for empty response
|
||||
trimmed := strings.TrimSpace(string(decodedBody))
|
||||
if len(trimmed) == 0 {
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: schemas.Ptr(resp.StatusCode()),
|
||||
Error: &schemas.ErrorField{
|
||||
Message: schemas.ErrProviderResponseEmpty,
|
||||
},
|
||||
}
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
// Check for HTML error response before attempting JSON parsing
|
||||
if providerUtils.IsHTMLResponse(resp, decodedBody) {
|
||||
bifrostErr := &schemas.BifrostError{
|
||||
IsBifrostError: false,
|
||||
StatusCode: schemas.Ptr(resp.StatusCode()),
|
||||
Error: &schemas.ErrorField{
|
||||
Message: schemas.ErrProviderResponseHTML,
|
||||
Error: errors.New(string(decodedBody)),
|
||||
},
|
||||
ExtraFields: schemas.BifrostErrorExtraFields{
|
||||
RawResponse: string(decodedBody),
|
||||
},
|
||||
}
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
createError := func(message string) *schemas.BifrostError {
|
||||
bifrostErr := providerUtils.NewProviderAPIError(message, nil, resp.StatusCode(), nil, nil)
|
||||
var rawResponse interface{}
|
||||
if err := sonic.Unmarshal(decodedBody, &rawResponse); err != nil {
|
||||
rawResponse = string(decodedBody)
|
||||
}
|
||||
bifrostErr.ExtraFields.RawResponse = rawResponse
|
||||
return bifrostErr
|
||||
}
|
||||
|
||||
if err := sonic.Unmarshal(decodedBody, &openAIErr); err != nil || openAIErr.Error == nil {
|
||||
// Try Vertex error format if OpenAI format fails or is incomplete
|
||||
if err := sonic.Unmarshal(decodedBody, &vertexErr); err != nil {
|
||||
//try with single Vertex error format
|
||||
var vertexErr VertexError
|
||||
if err := sonic.Unmarshal(decodedBody, &vertexErr); err != nil {
|
||||
// Try VertexValidationError format (validation errors from Mistral endpoint)
|
||||
var validationErr VertexValidationError
|
||||
if err := sonic.Unmarshal(decodedBody, &validationErr); err != nil {
|
||||
bifrostErr := providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err)
|
||||
return bifrostErr
|
||||
}
|
||||
if len(validationErr.Detail) > 0 {
|
||||
return createError(validationErr.Detail[0].Msg)
|
||||
}
|
||||
return createError("Unknown error")
|
||||
}
|
||||
return createError(vertexErr.Error.Message)
|
||||
}
|
||||
if len(vertexErr) > 0 {
|
||||
return createError(vertexErr[0].Error.Message)
|
||||
}
|
||||
return createError("Unknown error")
|
||||
}
|
||||
// OpenAI error format succeeded with valid Error field
|
||||
return createError(openAIErr.Error.Message)
|
||||
}
|
||||
197
core/providers/vertex/models.go
Normal file
197
core/providers/vertex/models.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// VertexRankRequest represents the Discovery Engine rank API request.
|
||||
type VertexRankRequest struct {
|
||||
Model *string `json:"model,omitempty"`
|
||||
Query string `json:"query"`
|
||||
Records []VertexRankRecord `json:"records"`
|
||||
TopN *int `json:"topN,omitempty"`
|
||||
IgnoreRecordDetailsInResponse *bool `json:"ignoreRecordDetailsInResponse,omitempty"`
|
||||
UserLabels map[string]string `json:"userLabels,omitempty"`
|
||||
}
|
||||
|
||||
// GetExtraParams implements providerUtils.RequestBodyWithExtraParams.
|
||||
func (*VertexRankRequest) GetExtraParams() map[string]interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
vertexDefaultRankingConfigID = "default_ranking_config"
|
||||
vertexDefaultRerankModel = "semantic-ranker-default@latest"
|
||||
vertexMaxRerankRecordsPerQuery = 200
|
||||
vertexSyntheticRecordPrefix = "idx:"
|
||||
)
|
||||
|
||||
// VertexRankRecord represents a record for ranking.
|
||||
type VertexRankRecord struct {
|
||||
ID string `json:"id"`
|
||||
Title *string `json:"title,omitempty"`
|
||||
Content *string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
// VertexRankResponse represents the Discovery Engine rank API response.
|
||||
type VertexRankResponse struct {
|
||||
Records []VertexRankedRecord `json:"records"`
|
||||
}
|
||||
|
||||
// VertexRankedRecord represents a ranked record in response.
|
||||
type VertexRankedRecord struct {
|
||||
ID string `json:"id"`
|
||||
Score float64 `json:"score"`
|
||||
Title *string `json:"title,omitempty"`
|
||||
Content *string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type vertexRerankOptions struct {
|
||||
RankingConfig string
|
||||
IgnoreRecordDetailsInResponse bool
|
||||
UserLabels map[string]string
|
||||
}
|
||||
|
||||
// ToBifrostListModelsResponse converts a Vertex AI list models response to Bifrost's format.
|
||||
// It processes both custom models (from the API response) and non-custom models (from deployments and allowedModels).
|
||||
//
|
||||
// Custom models are those with digit-only deployment values, extracted from the API response.
|
||||
// Non-custom models are those with non-digit characters in their deployment values or model names.
|
||||
//
|
||||
// The function performs three passes:
|
||||
// 1. First pass: Process all models from the Vertex AI API response (custom models)
|
||||
// 2. Second pass: Add non-custom models from deployments that aren't already in the list
|
||||
// 3. Third pass: Add non-custom models from allowedModels that aren't in deployments or already added
|
||||
//
|
||||
// Filtering logic:
|
||||
// - If allowedModels is empty, all models are allowed
|
||||
// - If allowedModels is non-empty, only models/deployments with keys in allowedModels are included
|
||||
// - Deployments map is used to match model IDs to aliases and filter accordingly
|
||||
func (response *VertexListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.Models)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: schemas.Vertex,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
// Process all models from the Vertex AI API response (custom deployed models).
|
||||
// The model ID is extracted from the endpoint URL last segment.
|
||||
for _, model := range response.Models {
|
||||
if len(model.DeployedModels) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, deployedModel := range model.DeployedModels {
|
||||
endpoint := strings.TrimSuffix(deployedModel.Endpoint, "/")
|
||||
parts := strings.Split(endpoint, "/")
|
||||
if len(parts) == 0 {
|
||||
continue
|
||||
}
|
||||
customModelID := parts[len(parts)-1]
|
||||
if customModelID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, result := range pipeline.FilterModel(customModelID) {
|
||||
resolvedKey := strings.ToLower(result.ResolvedID)
|
||||
if included[resolvedKey] {
|
||||
continue
|
||||
}
|
||||
modelEntry := schemas.Model{
|
||||
ID: string(schemas.Vertex) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(model.DisplayName),
|
||||
Description: schemas.Ptr(model.Description),
|
||||
Created: schemas.Ptr(model.VersionCreateTime.Unix()),
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
modelEntry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, modelEntry)
|
||||
included[resolvedKey] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
bifrostResponse.NextPageToken = response.NextPageToken
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
// ToBifrostListModelsResponse converts a Vertex AI publisher models response to Bifrost's format.
|
||||
// This is for foundation models from the Model Garden (publishers.models.list endpoint).
|
||||
func (response *VertexListPublisherModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, aliases map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bifrostResponse := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0, len(response.PublisherModels)),
|
||||
}
|
||||
|
||||
pipeline := &providerUtils.ListModelsPipeline{
|
||||
AllowedModels: allowedModels,
|
||||
BlacklistedModels: blacklistedModels,
|
||||
Aliases: aliases,
|
||||
Unfiltered: unfiltered,
|
||||
ProviderKey: schemas.Vertex,
|
||||
MatchFns: providerUtils.DefaultMatchFns(),
|
||||
}
|
||||
if pipeline.ShouldEarlyExit() {
|
||||
return bifrostResponse
|
||||
}
|
||||
|
||||
included := make(map[string]bool)
|
||||
|
||||
for _, model := range response.PublisherModels {
|
||||
// Extract model ID from name (format: "publishers/google/models/gemini-1.5-pro")
|
||||
modelID := extractModelIDFromName(model.Name)
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, result := range pipeline.FilterModel(modelID) {
|
||||
// Extract display name from supported actions if available
|
||||
displayName := result.ResolvedID
|
||||
if model.SupportedActions != nil && model.SupportedActions.Deploy != nil && model.SupportedActions.Deploy.ModelDisplayName != "" {
|
||||
displayName = model.SupportedActions.Deploy.ModelDisplayName
|
||||
}
|
||||
modelEntry := schemas.Model{
|
||||
ID: string(schemas.Vertex) + "/" + result.ResolvedID,
|
||||
Name: schemas.Ptr(displayName),
|
||||
}
|
||||
if result.AliasValue != "" {
|
||||
modelEntry.Alias = schemas.Ptr(result.AliasValue)
|
||||
}
|
||||
bifrostResponse.Data = append(bifrostResponse.Data, modelEntry)
|
||||
included[strings.ToLower(result.ResolvedID)] = true
|
||||
}
|
||||
}
|
||||
|
||||
bifrostResponse.Data = append(bifrostResponse.Data,
|
||||
pipeline.BackfillModels(included)...)
|
||||
|
||||
bifrostResponse.NextPageToken = response.NextPageToken
|
||||
|
||||
return bifrostResponse
|
||||
}
|
||||
290
core/providers/vertex/rerank.go
Normal file
290
core/providers/vertex/rerank.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func buildVertexRankingConfig(projectID, rankingConfigOverride string) (string, error) {
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
return "", fmt.Errorf("project ID is required for ranking config")
|
||||
}
|
||||
|
||||
override := strings.TrimSpace(rankingConfigOverride)
|
||||
if override == "" {
|
||||
return fmt.Sprintf("projects/%s/locations/global/rankingConfigs/%s", projectID, vertexDefaultRankingConfigID), nil
|
||||
}
|
||||
|
||||
override = strings.TrimSuffix(override, ":rank")
|
||||
if strings.HasPrefix(override, "projects/") {
|
||||
return override, nil
|
||||
}
|
||||
if strings.Contains(override, "/") {
|
||||
return "", fmt.Errorf("invalid ranking_config %q: must be resource name or config ID", rankingConfigOverride)
|
||||
}
|
||||
return fmt.Sprintf("projects/%s/locations/global/rankingConfigs/%s", projectID, override), nil
|
||||
}
|
||||
|
||||
func getVertexRerankOptions(projectID string, params *schemas.RerankParameters) (*vertexRerankOptions, error) {
|
||||
options := &vertexRerankOptions{
|
||||
IgnoreRecordDetailsInResponse: true,
|
||||
}
|
||||
|
||||
if params == nil || params.ExtraParams == nil {
|
||||
rankingConfig, err := buildVertexRankingConfig(projectID, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
options.RankingConfig = rankingConfig
|
||||
return options, nil
|
||||
}
|
||||
|
||||
extraParams := params.ExtraParams
|
||||
|
||||
rankingConfigOverride := ""
|
||||
if rawRankingConfig, exists := extraParams["ranking_config"]; exists {
|
||||
rankingConfig, ok := schemas.SafeExtractString(rawRankingConfig)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid ranking_config: expected string")
|
||||
}
|
||||
rankingConfigOverride = rankingConfig
|
||||
}
|
||||
|
||||
rankingConfig, err := buildVertexRankingConfig(projectID, rankingConfigOverride)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
options.RankingConfig = rankingConfig
|
||||
|
||||
if rawIgnoreRecordDetails, exists := extraParams["ignore_record_details_in_response"]; exists {
|
||||
ignoreRecordDetailsInResponse, ok := schemas.SafeExtractBool(rawIgnoreRecordDetails)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid ignore_record_details_in_response: expected bool")
|
||||
}
|
||||
options.IgnoreRecordDetailsInResponse = ignoreRecordDetailsInResponse
|
||||
}
|
||||
|
||||
if rawUserLabels, exists := extraParams["user_labels"]; exists {
|
||||
userLabels, ok := schemas.SafeExtractStringMap(rawUserLabels)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid user_labels: expected map[string]string")
|
||||
}
|
||||
options.UserLabels = userLabels
|
||||
}
|
||||
|
||||
return options, nil
|
||||
}
|
||||
|
||||
// ToVertexRankRequest converts a Bifrost rerank request to Discovery Engine rank API format.
|
||||
func ToVertexRankRequest(bifrostReq *schemas.BifrostRerankRequest, options *vertexRerankOptions) (*VertexRankRequest, error) {
|
||||
if bifrostReq == nil {
|
||||
return nil, fmt.Errorf("bifrost rerank request is nil")
|
||||
}
|
||||
if options == nil {
|
||||
return nil, fmt.Errorf("vertex rerank options are nil")
|
||||
}
|
||||
if len(bifrostReq.Documents) == 0 {
|
||||
return nil, fmt.Errorf("documents are required for rerank request")
|
||||
}
|
||||
if len(bifrostReq.Documents) > vertexMaxRerankRecordsPerQuery {
|
||||
return nil, fmt.Errorf("vertex rerank supports up to %d records per request", vertexMaxRerankRecordsPerQuery)
|
||||
}
|
||||
|
||||
rankRequest := &VertexRankRequest{
|
||||
Query: bifrostReq.Query,
|
||||
Records: make([]VertexRankRecord, len(bifrostReq.Documents)),
|
||||
}
|
||||
|
||||
for i, doc := range bifrostReq.Documents {
|
||||
recordID := fmt.Sprintf("%s%d", vertexSyntheticRecordPrefix, i)
|
||||
content := doc.Text
|
||||
record := VertexRankRecord{
|
||||
ID: recordID,
|
||||
Content: &content,
|
||||
}
|
||||
|
||||
if doc.Meta != nil {
|
||||
if rawTitle, exists := doc.Meta["title"]; exists {
|
||||
if title, ok := schemas.SafeExtractString(rawTitle); ok && strings.TrimSpace(title) != "" {
|
||||
record.Title = &title
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rankRequest.Records[i] = record
|
||||
}
|
||||
|
||||
if bifrostReq.Params != nil && bifrostReq.Params.TopN != nil {
|
||||
topN := *bifrostReq.Params.TopN
|
||||
if topN < 1 {
|
||||
return nil, fmt.Errorf("top_n must be at least 1")
|
||||
}
|
||||
if topN > len(bifrostReq.Documents) {
|
||||
topN = len(bifrostReq.Documents)
|
||||
}
|
||||
rankRequest.TopN = &topN
|
||||
}
|
||||
|
||||
trimmedModel := strings.TrimSpace(bifrostReq.Model)
|
||||
if trimmedModel == "" {
|
||||
trimmedModel = vertexDefaultRerankModel
|
||||
}
|
||||
rankRequest.Model = &trimmedModel
|
||||
|
||||
ignoreRecordDetailsInResponse := options.IgnoreRecordDetailsInResponse
|
||||
rankRequest.IgnoreRecordDetailsInResponse = &ignoreRecordDetailsInResponse
|
||||
|
||||
if len(options.UserLabels) > 0 {
|
||||
rankRequest.UserLabels = options.UserLabels
|
||||
}
|
||||
|
||||
return rankRequest, nil
|
||||
}
|
||||
|
||||
// ToBifrostRerankRequest converts a Discovery Engine rank request to Bifrost format.
|
||||
func (req *VertexRankRequest) ToBifrostRerankRequest(ctx *schemas.BifrostContext) *schemas.BifrostRerankRequest {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var provider schemas.ModelProvider
|
||||
var model string
|
||||
if req.Model != nil {
|
||||
provider, model = schemas.ParseModelString(*req.Model, providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Vertex))
|
||||
} else {
|
||||
provider = providerUtils.CheckAndSetDefaultProvider(ctx, schemas.Vertex)
|
||||
}
|
||||
|
||||
bifrostReq := &schemas.BifrostRerankRequest{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Query: req.Query,
|
||||
Params: &schemas.RerankParameters{},
|
||||
}
|
||||
|
||||
// Convert records to documents
|
||||
for _, record := range req.Records {
|
||||
doc := schemas.RerankDocument{
|
||||
ID: &record.ID,
|
||||
}
|
||||
if record.Content != nil {
|
||||
doc.Text = *record.Content
|
||||
}
|
||||
if record.Title != nil {
|
||||
doc.Meta = map[string]interface{}{
|
||||
"title": *record.Title,
|
||||
}
|
||||
}
|
||||
bifrostReq.Documents = append(bifrostReq.Documents, doc)
|
||||
}
|
||||
|
||||
// Extract TopN
|
||||
if req.TopN != nil {
|
||||
bifrostReq.Params.TopN = req.TopN
|
||||
}
|
||||
|
||||
// Pass extra fields as ExtraParams
|
||||
extraParams := make(map[string]interface{})
|
||||
if req.IgnoreRecordDetailsInResponse != nil {
|
||||
extraParams["ignore_record_details_in_response"] = *req.IgnoreRecordDetailsInResponse
|
||||
}
|
||||
if len(req.UserLabels) > 0 {
|
||||
extraParams["user_labels"] = req.UserLabels
|
||||
}
|
||||
if len(extraParams) > 0 {
|
||||
bifrostReq.Params.ExtraParams = extraParams
|
||||
}
|
||||
|
||||
return bifrostReq
|
||||
}
|
||||
|
||||
func parseVertexSyntheticRecordIndex(recordID string, maxDocs int) (int, error) {
|
||||
if !strings.HasPrefix(recordID, vertexSyntheticRecordPrefix) {
|
||||
return 0, fmt.Errorf("invalid record id %q: expected prefix %q", recordID, vertexSyntheticRecordPrefix)
|
||||
}
|
||||
indexStr := strings.TrimPrefix(recordID, vertexSyntheticRecordPrefix)
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid record id %q: %w", recordID, err)
|
||||
}
|
||||
if index < 0 || index >= maxDocs {
|
||||
return 0, fmt.Errorf("record id %q maps to out-of-range index %d", recordID, index)
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// ToBifrostRerankResponse converts a Discovery Engine rank response to Bifrost format.
|
||||
func (response *VertexRankResponse) ToBifrostRerankResponse(documents []schemas.RerankDocument, returnDocuments bool) (*schemas.BifrostRerankResponse, error) {
|
||||
if response == nil {
|
||||
return nil, fmt.Errorf("vertex rerank response is nil")
|
||||
}
|
||||
|
||||
results := make([]schemas.RerankResult, 0, len(response.Records))
|
||||
seenIndices := make(map[int]struct{}, len(response.Records))
|
||||
|
||||
for _, record := range response.Records {
|
||||
index, err := parseVertexSyntheticRecordIndex(record.ID, len(documents))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, seen := seenIndices[index]; seen {
|
||||
return nil, fmt.Errorf("duplicate record id mapping for index %d", index)
|
||||
}
|
||||
seenIndices[index] = struct{}{}
|
||||
|
||||
result := schemas.RerankResult{
|
||||
Index: index,
|
||||
RelevanceScore: record.Score,
|
||||
}
|
||||
|
||||
if returnDocuments {
|
||||
doc := documents[index]
|
||||
result.Document = &doc
|
||||
}
|
||||
|
||||
results = append(results, result)
|
||||
}
|
||||
|
||||
sort.SliceStable(results, func(i, j int) bool {
|
||||
if results[i].RelevanceScore == results[j].RelevanceScore {
|
||||
return results[i].Index < results[j].Index
|
||||
}
|
||||
return results[i].RelevanceScore > results[j].RelevanceScore
|
||||
})
|
||||
|
||||
return &schemas.BifrostRerankResponse{
|
||||
Results: results,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseDiscoveryEngineErrorMessage(responseBody []byte) string {
|
||||
if len(responseBody) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var errorResponse map[string]interface{}
|
||||
if err := sonic.Unmarshal(responseBody, &errorResponse); err == nil {
|
||||
if rawError, exists := errorResponse["error"]; exists {
|
||||
if errorMap, ok := rawError.(map[string]interface{}); ok {
|
||||
if message, ok := schemas.SafeExtractString(errorMap["message"]); ok && strings.TrimSpace(message) != "" {
|
||||
return message
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rawString := strings.TrimSpace(string(responseBody))
|
||||
if rawString == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return rawString
|
||||
}
|
||||
205
core/providers/vertex/rerank_test.go
Normal file
205
core/providers/vertex/rerank_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildVertexRankingConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config, err := buildVertexRankingConfig("demo-project", "")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "projects/demo-project/locations/global/rankingConfigs/default_ranking_config", config)
|
||||
|
||||
config, err = buildVertexRankingConfig("demo-project", "custom_rank")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "projects/demo-project/locations/global/rankingConfigs/custom_rank", config)
|
||||
|
||||
config, err = buildVertexRankingConfig("demo-project", "projects/other/locations/global/rankingConfigs/custom_rank:rank")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "projects/other/locations/global/rankingConfigs/custom_rank", config)
|
||||
|
||||
_, err = buildVertexRankingConfig("demo-project", "locations/global/rankingConfigs/custom_rank")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestToVertexRankRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, err := ToVertexRankRequest(
|
||||
&schemas.BifrostRerankRequest{
|
||||
Query: "capital of france",
|
||||
Documents: []schemas.RerankDocument{
|
||||
{Text: "Paris is the capital of France.", Meta: map[string]interface{}{"title": "Doc A"}},
|
||||
{Text: "Berlin is the capital of Germany."},
|
||||
},
|
||||
Params: &schemas.RerankParameters{
|
||||
TopN: schemas.Ptr(10),
|
||||
},
|
||||
},
|
||||
&vertexRerankOptions{
|
||||
RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config",
|
||||
IgnoreRecordDetailsInResponse: true,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, req)
|
||||
|
||||
require.NotNil(t, req.Model)
|
||||
assert.Equal(t, "semantic-ranker-default@latest", *req.Model)
|
||||
require.Len(t, req.Records, 2)
|
||||
assert.Equal(t, "idx:0", req.Records[0].ID)
|
||||
assert.Equal(t, "idx:1", req.Records[1].ID)
|
||||
require.NotNil(t, req.Records[0].Title)
|
||||
assert.Equal(t, "Doc A", *req.Records[0].Title)
|
||||
require.NotNil(t, req.TopN)
|
||||
assert.Equal(t, 2, *req.TopN, "topN should be clamped to document count")
|
||||
require.NotNil(t, req.IgnoreRecordDetailsInResponse)
|
||||
assert.True(t, *req.IgnoreRecordDetailsInResponse)
|
||||
}
|
||||
|
||||
func TestToVertexRankRequestTooManyRecords(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
docs := make([]schemas.RerankDocument, 201)
|
||||
for i := range docs {
|
||||
docs[i] = schemas.RerankDocument{Text: "doc"}
|
||||
}
|
||||
|
||||
_, err := ToVertexRankRequest(
|
||||
&schemas.BifrostRerankRequest{
|
||||
Query: "q",
|
||||
Documents: docs,
|
||||
},
|
||||
&vertexRerankOptions{
|
||||
RankingConfig: "projects/p/locations/global/rankingConfigs/default_ranking_config",
|
||||
IgnoreRecordDetailsInResponse: true,
|
||||
},
|
||||
)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "supports up to")
|
||||
}
|
||||
|
||||
func TestGetVertexRerankOptions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
options, err := getVertexRerankOptions("project-x", &schemas.RerankParameters{
|
||||
ExtraParams: map[string]interface{}{
|
||||
"ranking_config": "custom_rank",
|
||||
"ignore_record_details_in_response": false,
|
||||
"user_labels": map[string]interface{}{
|
||||
"env": "test",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "projects/project-x/locations/global/rankingConfigs/custom_rank", options.RankingConfig)
|
||||
assert.False(t, options.IgnoreRecordDetailsInResponse)
|
||||
assert.Equal(t, map[string]string{"env": "test"}, options.UserLabels)
|
||||
}
|
||||
|
||||
func TestVertexRankResponseToBifrostRerankResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
docs := []schemas.RerankDocument{
|
||||
{Text: "doc-0"},
|
||||
{Text: "doc-1"},
|
||||
{Text: "doc-2"},
|
||||
}
|
||||
|
||||
response, err := (&VertexRankResponse{
|
||||
Records: []VertexRankedRecord{
|
||||
{ID: "idx:2", Score: 0.12},
|
||||
{ID: "idx:1", Score: 0.91},
|
||||
{ID: "idx:0", Score: 0.91},
|
||||
},
|
||||
}).ToBifrostRerankResponse(docs, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, response)
|
||||
require.Len(t, response.Results, 3)
|
||||
|
||||
assert.Equal(t, 0, response.Results[0].Index)
|
||||
assert.Equal(t, 1, response.Results[1].Index)
|
||||
assert.Equal(t, 2, response.Results[2].Index)
|
||||
require.NotNil(t, response.Results[0].Document)
|
||||
assert.Equal(t, "doc-0", response.Results[0].Document.Text)
|
||||
}
|
||||
|
||||
func TestVertexRankRequestToBifrostRerankRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
topN := 5
|
||||
model := "semantic-ranker-default@latest"
|
||||
ignoreDetails := true
|
||||
title := "Doc A"
|
||||
content1 := "Paris is the capital of France."
|
||||
content2 := "Berlin is the capital of Germany."
|
||||
|
||||
req := &VertexRankRequest{
|
||||
Model: &model,
|
||||
Query: "capital of france",
|
||||
Records: []VertexRankRecord{
|
||||
{ID: "rec-1", Content: &content1, Title: &title},
|
||||
{ID: "rec-2", Content: &content2},
|
||||
},
|
||||
TopN: &topN,
|
||||
IgnoreRecordDetailsInResponse: &ignoreDetails,
|
||||
UserLabels: map[string]string{"env": "test"},
|
||||
}
|
||||
|
||||
bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)
|
||||
result := req.ToBifrostRerankRequest(bifrostCtx)
|
||||
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, schemas.Vertex, result.Provider)
|
||||
assert.Equal(t, "semantic-ranker-default@latest", result.Model)
|
||||
assert.Equal(t, "capital of france", result.Query)
|
||||
require.Len(t, result.Documents, 2)
|
||||
|
||||
// First document has ID, content, and title in meta
|
||||
require.NotNil(t, result.Documents[0].ID)
|
||||
assert.Equal(t, "rec-1", *result.Documents[0].ID)
|
||||
assert.Equal(t, "Paris is the capital of France.", result.Documents[0].Text)
|
||||
require.NotNil(t, result.Documents[0].Meta)
|
||||
assert.Equal(t, "Doc A", result.Documents[0].Meta["title"])
|
||||
|
||||
// Second document has no title
|
||||
require.NotNil(t, result.Documents[1].ID)
|
||||
assert.Equal(t, "rec-2", *result.Documents[1].ID)
|
||||
assert.Equal(t, "Berlin is the capital of Germany.", result.Documents[1].Text)
|
||||
assert.Nil(t, result.Documents[1].Meta)
|
||||
|
||||
// TopN
|
||||
require.NotNil(t, result.Params)
|
||||
require.NotNil(t, result.Params.TopN)
|
||||
assert.Equal(t, 5, *result.Params.TopN)
|
||||
|
||||
// ExtraParams
|
||||
require.NotNil(t, result.Params.ExtraParams)
|
||||
assert.Equal(t, true, result.Params.ExtraParams["ignore_record_details_in_response"])
|
||||
assert.Equal(t, map[string]string{"env": "test"}, result.Params.ExtraParams["user_labels"])
|
||||
}
|
||||
|
||||
func TestVertexRankRequestToBifrostRerankRequestNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var req *VertexRankRequest
|
||||
assert.Nil(t, req.ToBifrostRerankRequest(nil))
|
||||
}
|
||||
|
||||
func TestVertexRankResponseToBifrostRerankResponseInvalidID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := (&VertexRankResponse{
|
||||
Records: []VertexRankedRecord{
|
||||
{ID: "bad-id", Score: 0.9},
|
||||
},
|
||||
}).ToBifrostRerankResponse([]schemas.RerankDocument{{Text: "doc"}}, false)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid record id")
|
||||
}
|
||||
245
core/providers/vertex/types.go
Normal file
245
core/providers/vertex/types.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
)
|
||||
|
||||
// Vertex AI Embedding API types
|
||||
|
||||
const (
|
||||
DefaultVertexAnthropicVersion = "vertex-2023-10-16"
|
||||
)
|
||||
|
||||
// PhoneticEncoding represents the phonetic encoding of a phrase.
|
||||
type PhoneticEncoding string
|
||||
|
||||
const (
|
||||
PhoneticEncodingUnspecified PhoneticEncoding = "PHONETIC_ENCODING_UNSPECIFIED"
|
||||
PhoneticEncodingIPA PhoneticEncoding = "PHONETIC_ENCODING_IPA"
|
||||
PhoneticEncodingXSAMPA PhoneticEncoding = "PHONETIC_ENCODING_X_SAMPA"
|
||||
PhoneticEncodingJapaneseYomigana PhoneticEncoding = "PHONETIC_ENCODING_JAPANESE_YOMIGANA"
|
||||
PhoneticEncodingPinyin PhoneticEncoding = "PHONETIC_ENCODING_PINYIN"
|
||||
)
|
||||
|
||||
// CustomPronunciationParams represents pronunciation customization for a phrase.
|
||||
type CustomPronunciationParams struct {
|
||||
Phrase string `json:"phrase,omitempty"`
|
||||
PhoneticEncoding PhoneticEncoding `json:"phoneticEncoding,omitempty"`
|
||||
Pronunciation string `json:"pronunciation,omitempty"`
|
||||
}
|
||||
|
||||
// CustomPronunciations represents a collection of pronunciation customizations.
|
||||
type CustomPronunciations struct {
|
||||
Pronunciations []CustomPronunciationParams `json:"pronunciations,omitempty"`
|
||||
}
|
||||
|
||||
// Turn represents a multi-speaker turn.
|
||||
type Turn struct {
|
||||
Speaker string `json:"speaker,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// MultiSpeakerMarkup represents a collection of turns for multi-speaker synthesis.
|
||||
type MultiSpeakerMarkup struct {
|
||||
Turns []Turn `json:"turns,omitempty"`
|
||||
}
|
||||
|
||||
// VertexSynthesisInput contains text input to be synthesized.
|
||||
type VertexSynthesisInput struct {
|
||||
Text *string `json:"text,omitempty"`
|
||||
Markup *string `json:"markup,omitempty"`
|
||||
SSML *string `json:"ssml,omitempty"`
|
||||
MultiSpeakerMarkup *MultiSpeakerMarkup `json:"multiSpeakerMarkup,omitempty"`
|
||||
Prompt *string `json:"prompt,omitempty"`
|
||||
CustomPronunciations *CustomPronunciations `json:"customPronunciations,omitempty"`
|
||||
}
|
||||
|
||||
// VertexVoiceSelectionParams represents voice selection parameters for TTS synthesis.
|
||||
type VertexVoiceSelectionParams struct {
|
||||
LanguageCode string `json:"languageCode,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
SsmlGender string `json:"ssmlGender,omitempty"`
|
||||
}
|
||||
|
||||
// VertexAudioConfig represents audio configuration for TTS synthesis.
|
||||
type VertexAudioConfig struct {
|
||||
AudioEncoding string `json:"audioEncoding,omitempty"`
|
||||
SpeakingRate float64 `json:"speakingRate,omitempty"`
|
||||
Pitch float64 `json:"pitch,omitempty"`
|
||||
VolumeGainDB float64 `json:"volumeGainDb,omitempty"`
|
||||
SampleRateHertz int `json:"sampleRateHertz,omitempty"`
|
||||
EffectsProfileID []string `json:"effectsProfileId,omitempty"`
|
||||
}
|
||||
|
||||
type VertexRequestBody struct {
|
||||
RequestBody map[string]interface{} `json:"-"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
func (r *VertexRequestBody) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling for VertexRequestBody.
|
||||
// It marshals the RequestBody field directly without wrapping.
|
||||
func (r *VertexRequestBody) MarshalJSON() ([]byte, error) {
|
||||
return providerUtils.MarshalSorted(r.RequestBody)
|
||||
}
|
||||
|
||||
// VertexRawRequestBody holds pre-serialized JSON bytes to preserve key ordering
|
||||
// for LLM prompt caching. This avoids the map[string]interface{} round-trip that
|
||||
// destroys key order.
|
||||
type VertexRawRequestBody struct {
|
||||
RawBody []byte `json:"-"`
|
||||
ExtraParams map[string]interface{} `json:"-"`
|
||||
}
|
||||
|
||||
func (r *VertexRawRequestBody) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
// MarshalJSON returns the pre-serialized JSON bytes directly, preserving key order.
|
||||
func (r *VertexRawRequestBody) MarshalJSON() ([]byte, error) {
|
||||
return r.RawBody, nil
|
||||
}
|
||||
|
||||
// VertexAdvancedVoiceOptions represents advanced voice options for TTS synthesis.
|
||||
type VertexAdvancedVoiceOptions struct {
|
||||
LowLatencyJourneySynthesis bool `json:"lowLatencyJourneySynthesis,omitempty"`
|
||||
}
|
||||
|
||||
// VertexEmbeddingInstance represents a single embedding instance in the request
|
||||
type VertexEmbeddingInstance struct {
|
||||
Content string `json:"content"` // The text to generate embeddings for
|
||||
TaskType *string `json:"task_type,omitempty"` // Intended downstream application (optional)
|
||||
Title *string `json:"title,omitempty"` // Used to help the model produce better embeddings (optional)
|
||||
}
|
||||
|
||||
// VertexEmbeddingParameters represents the parameters for the embedding request
|
||||
type VertexEmbeddingParameters struct {
|
||||
AutoTruncate *bool `json:"autoTruncate,omitempty"` // When true, input text will be truncated (defaults to true)
|
||||
OutputDimensionality *int `json:"outputDimensionality,omitempty"` // Output embedding size (optional)
|
||||
}
|
||||
|
||||
// VertexEmbeddingRequest represents the complete embedding request to Vertex AI
|
||||
type VertexEmbeddingRequest struct {
|
||||
Instances []VertexEmbeddingInstance `json:"instances"` // List of embedding instances
|
||||
Parameters *VertexEmbeddingParameters `json:"parameters,omitempty"` // Optional parameters
|
||||
ExtraParams map[string]interface{} `json:"-"` // Optional: Extra parameters
|
||||
}
|
||||
|
||||
func (r *VertexEmbeddingRequest) GetExtraParams() map[string]interface{} {
|
||||
return r.ExtraParams
|
||||
}
|
||||
|
||||
// VertexEmbeddingStatistics represents statistics computed from the input text
|
||||
type VertexEmbeddingStatistics struct {
|
||||
Truncated bool `json:"truncated"` // Whether the input text was truncated
|
||||
TokenCount int `json:"token_count"` // Number of tokens in the input text
|
||||
}
|
||||
|
||||
// VertexEmbeddingValues represents the embedding result
|
||||
type VertexEmbeddingValues struct {
|
||||
Values []float64 `json:"values"` // The embedding vector (list of floats)
|
||||
Statistics *VertexEmbeddingStatistics `json:"statistics"` // Statistics about the input text
|
||||
}
|
||||
|
||||
// VertexEmbeddingPrediction represents a single prediction in the response
|
||||
type VertexEmbeddingPrediction struct {
|
||||
Embeddings *VertexEmbeddingValues `json:"embeddings"` // The embedding result
|
||||
}
|
||||
|
||||
// VertexEmbeddingResponse represents the complete embedding response from Vertex AI
|
||||
type VertexEmbeddingResponse struct {
|
||||
Predictions []VertexEmbeddingPrediction `json:"predictions"` // List of embedding predictions
|
||||
}
|
||||
|
||||
// ================================ Model Types ================================
|
||||
|
||||
const MaxPageSize = 100
|
||||
|
||||
type VertexModel struct {
|
||||
Name string `json:"name"`
|
||||
VersionId string `json:"versionId"`
|
||||
VersionAliases []string `json:"versionAliases"`
|
||||
VersionCreateTime time.Time `json:"versionCreateTime"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Description string `json:"description"`
|
||||
DeployedModels []VertexDeployedModel `json:"deployedModels"`
|
||||
Labels VertexModelLabels `json:"labels"`
|
||||
}
|
||||
|
||||
type VertexListModelsResponse struct {
|
||||
Models []VertexModel `json:"models"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
type VertexDeployedModel struct {
|
||||
CheckpointID string `json:"checkpointId"`
|
||||
DeploymentID string `json:"deploymentId"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
type VertexModelLabels struct {
|
||||
GoogleVertexLLMTuningBaseModelId string `json:"google-vertex-llm-tuning-base-model-id"`
|
||||
GoogleVertexLLMTuningJobId string `json:"google-vertex-llm-tuning-job-id"`
|
||||
TuneType string `json:"tune-type"`
|
||||
}
|
||||
|
||||
// ================================ Publisher Model Types ================================
|
||||
// These types are for the publishers.models.list endpoint (Model Garden)
|
||||
|
||||
type VertexPublisherModel struct {
|
||||
Name string `json:"name"`
|
||||
VersionID string `json:"versionId"`
|
||||
OpenSourceCategory string `json:"openSourceCategory"`
|
||||
LaunchStage string `json:"launchStage"`
|
||||
VersionState string `json:"versionState"`
|
||||
PublisherModelTemplate string `json:"publisherModelTemplate"`
|
||||
SupportedActions *VertexPublisherModelActions `json:"supportedActions"`
|
||||
}
|
||||
|
||||
type VertexPublisherModelActions struct {
|
||||
OpenGenerationAIStudio *VertexPublisherModelURI `json:"openGenerationAiStudio"`
|
||||
OpenGenie *VertexPublisherModelURI `json:"openGenie"`
|
||||
OpenPromptTuningPipeline *VertexPublisherModelURI `json:"openPromptTuningPipeline"`
|
||||
OpenNotebook *VertexPublisherModelURI `json:"openNotebook"`
|
||||
OpenFineTuningPipeline *VertexPublisherModelURI `json:"openFineTuningPipeline"`
|
||||
Deploy *VertexPublisherModelDeploy `json:"deploy"`
|
||||
OpenEvaluationPipeline *VertexPublisherModelURI `json:"openEvaluationPipeline"`
|
||||
}
|
||||
|
||||
type VertexPublisherModelURI struct {
|
||||
URI string `json:"uri"`
|
||||
}
|
||||
|
||||
type VertexPublisherModelDeploy struct {
|
||||
ModelDisplayName string `json:"modelDisplayName"`
|
||||
Title string `json:"title"`
|
||||
}
|
||||
|
||||
type VertexListPublisherModelsResponse struct {
|
||||
PublisherModels []VertexPublisherModel `json:"publisherModels"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
// ==================== ERROR TYPES ====================
|
||||
// VertexValidationError represents validation errors
|
||||
// returned by the Vertex Mistral endpoint
|
||||
type VertexValidationError struct {
|
||||
Detail []struct {
|
||||
Input any `json:"input"` // can be number, object, or array
|
||||
Loc []any `json:"loc"` // location of the error (can contain strings and numeric indices)
|
||||
Msg string `json:"msg"` // error message
|
||||
Type string `json:"type"` // error type (e.g., "extra_forbidden", "missing")
|
||||
} `json:"detail"`
|
||||
}
|
||||
|
||||
// VertexCountTokensResponse models the response payload for Vertex's Gemini-style countTokens.
|
||||
// Vertex uses camelCase unlike other request json body.
|
||||
type VertexCountTokensResponse struct {
|
||||
TotalTokens int32 `json:"totalTokens,omitempty"`
|
||||
CachedContentTokenCount int32 `json:"cachedContentTokenCount,omitempty"`
|
||||
}
|
||||
256
core/providers/vertex/utils.go
Normal file
256
core/providers/vertex/utils.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/anthropic"
|
||||
providerUtils "github.com/maximhq/bifrost/core/providers/utils"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func getRequestBodyForAnthropicResponses(ctx *schemas.BifrostContext, request *schemas.BifrostResponsesRequest, deployment string, isStreaming bool, isCountTokens bool, betaHeaderOverrides map[string]bool, providerExtraHeaders map[string]string) ([]byte, *schemas.BifrostError) {
|
||||
// Large payload mode: body streams directly from the LP reader — skip all body building
|
||||
// (matches CheckContextAndGetRequestBody guard).
|
||||
if providerUtils.IsLargePayloadPassthroughEnabled(ctx) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var jsonBody []byte
|
||||
var err error
|
||||
|
||||
// Check if raw request body should be used
|
||||
if useRawBody, ok := ctx.Value(schemas.BifrostContextKeyUseRawRequestBody).(bool); ok && useRawBody {
|
||||
jsonBody = request.GetRawRequestBody()
|
||||
|
||||
if isCountTokens {
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "model", deployment)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
} else {
|
||||
// Add max_tokens if not present
|
||||
if !providerUtils.JSONFieldExists(jsonBody, "max_tokens") {
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "max_tokens", providerUtils.GetMaxOutputTokensOrDefault(deployment, anthropic.AnthropicDefaultMaxTokens))
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
// Add stream if streaming
|
||||
if isStreaming {
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "stream", true)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "fallbacks")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
|
||||
// Remap unsupported tool versions for Vertex (e.g., web_search_20260209 → web_search_20250305)
|
||||
jsonBody, err = anthropic.RemapRawToolVersionsForProvider(jsonBody, schemas.Vertex)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(err.Error(), nil)
|
||||
}
|
||||
|
||||
// Add anthropic_version if not present
|
||||
if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") {
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Validate tools are supported by Vertex
|
||||
if request.Params != nil && request.Params.Tools != nil {
|
||||
if toolErr := anthropic.ValidateToolsForProvider(request.Params.Tools, schemas.Vertex); toolErr != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(toolErr.Error(), nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert request to Anthropic format
|
||||
reqBody, convErr := anthropic.ToAnthropicResponsesRequest(ctx, request)
|
||||
if convErr != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrRequestBodyConversion, convErr)
|
||||
}
|
||||
if reqBody == nil {
|
||||
return nil, providerUtils.NewBifrostOperationError("request body is not provided", nil)
|
||||
}
|
||||
reqBody.Model = deployment
|
||||
|
||||
if isStreaming {
|
||||
reqBody.Stream = schemas.Ptr(true)
|
||||
}
|
||||
|
||||
reqBody.SetStripCacheControlScope(true)
|
||||
|
||||
// Add provider-aware beta headers
|
||||
anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex)
|
||||
|
||||
// Marshal struct to JSON bytes
|
||||
jsonBody, err = providerUtils.MarshalSorted(reqBody)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
|
||||
// Add anthropic_version if not present (using sjson to preserve order)
|
||||
if !providerUtils.JSONFieldExists(jsonBody, "anthropic_version") {
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_version", DefaultVertexAnthropicVersion)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
|
||||
if isCountTokens {
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "max_tokens")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "temperature")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
} else {
|
||||
// Remove model field for Vertex API (it's in URL)
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "model")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
|
||||
jsonBody, err = providerUtils.DeleteJSONField(jsonBody, "region")
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
|
||||
if betaHeaders := anthropic.FilterBetaHeadersForProvider(anthropic.MergeBetaHeaders(providerExtraHeaders, ctx), schemas.Vertex, betaHeaderOverrides); len(betaHeaders) > 0 {
|
||||
jsonBody, err = providerUtils.SetJSONField(jsonBody, "anthropic_beta", betaHeaders)
|
||||
if err != nil {
|
||||
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderRequestMarshal, err)
|
||||
}
|
||||
}
|
||||
|
||||
return jsonBody, nil
|
||||
}
|
||||
|
||||
// getCompleteURLForGeminiEndpoint constructs the complete URL for the Gemini endpoint, for both streaming and non-streaming requests
|
||||
// for custom/fine-tuned models, it uses the projectNumber
|
||||
// for gemini models, it uses the projectID
|
||||
func getCompleteURLForGeminiEndpoint(deployment string, region string, projectID string, projectNumber string, method string) string {
|
||||
var url string
|
||||
if schemas.IsAllDigitsASCII(deployment) {
|
||||
// Custom/fine-tuned models use projectNumber
|
||||
if region == "global" {
|
||||
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/global/endpoints/%s%s", projectNumber, deployment, method)
|
||||
} else {
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/%s%s", region, projectNumber, region, deployment, method)
|
||||
}
|
||||
} else {
|
||||
// Gemini models use projectID
|
||||
if region == "global" {
|
||||
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s%s", projectID, deployment, method)
|
||||
} else {
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s%s", region, projectID, region, deployment, method)
|
||||
}
|
||||
}
|
||||
return url
|
||||
}
|
||||
|
||||
// buildResponseFromConfig builds a list models response from configured deployments and allowedModels.
|
||||
// This is used when the user has explicitly configured which models they want to use.
|
||||
func buildResponseFromConfig(deployments map[string]string, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList) *schemas.BifrostListModelsResponse {
|
||||
response := &schemas.BifrostListModelsResponse{
|
||||
Data: make([]schemas.Model, 0),
|
||||
}
|
||||
|
||||
if blacklistedModels.IsBlockAll() {
|
||||
return response
|
||||
}
|
||||
|
||||
addedModelIDs := make(map[string]bool)
|
||||
|
||||
restrictAllowed := allowedModels.IsRestricted()
|
||||
|
||||
// First add models from deployments (filtered by allowedModels when set)
|
||||
for alias, deploymentValue := range deployments {
|
||||
if restrictAllowed && !allowedModels.Contains(alias) {
|
||||
continue
|
||||
}
|
||||
if blacklistedModels.IsBlocked(alias) {
|
||||
continue
|
||||
}
|
||||
modelID := string(schemas.Vertex) + "/" + alias
|
||||
if addedModelIDs[modelID] {
|
||||
continue
|
||||
}
|
||||
|
||||
modelName := providerUtils.ToDisplayName(alias)
|
||||
modelEntry := schemas.Model{
|
||||
ID: modelID,
|
||||
Name: schemas.Ptr(modelName),
|
||||
Alias: schemas.Ptr(deploymentValue),
|
||||
}
|
||||
|
||||
response.Data = append(response.Data, modelEntry)
|
||||
addedModelIDs[modelID] = true
|
||||
}
|
||||
|
||||
// Then add models from allowedModels that aren't already in deployments (only when restricted)
|
||||
if !restrictAllowed {
|
||||
return response
|
||||
}
|
||||
for _, allowedModel := range allowedModels {
|
||||
modelID := string(schemas.Vertex) + "/" + allowedModel
|
||||
if addedModelIDs[modelID] {
|
||||
continue
|
||||
}
|
||||
if blacklistedModels.IsBlocked(allowedModel) {
|
||||
continue
|
||||
}
|
||||
|
||||
modelName := providerUtils.ToDisplayName(allowedModel)
|
||||
modelEntry := schemas.Model{
|
||||
ID: modelID,
|
||||
Name: schemas.Ptr(modelName),
|
||||
}
|
||||
|
||||
response.Data = append(response.Data, modelEntry)
|
||||
addedModelIDs[modelID] = true
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// extractModelIDFromName extracts the model ID from a full resource name.
|
||||
// Format: "publishers/google/models/gemini-1.5-pro" -> "gemini-1.5-pro"
|
||||
func extractModelIDFromName(name string) string {
|
||||
parts := strings.Split(name, "/")
|
||||
if len(parts) >= 4 && parts[2] == "models" {
|
||||
return parts[3]
|
||||
}
|
||||
// Fallback: return last segment
|
||||
if len(parts) > 0 {
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
3213
core/providers/vertex/vertex.go
Normal file
3213
core/providers/vertex/vertex.go
Normal file
File diff suppressed because it is too large
Load Diff
238
core/providers/vertex/vertex_caching_test.go
Normal file
238
core/providers/vertex/vertex_caching_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package vertex_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/providers/anthropic"
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
// TestVertex_AnthropicModel_CachingDeterminism verifies that Vertex's delegation
|
||||
// to anthropic.ToAnthropicChatRequest() produces deterministic JSON for prompt caching.
|
||||
// Two schemas with the same properties but different structural key order within
|
||||
// property definitions must produce byte-identical JSON after normalization.
|
||||
func TestVertex_AnthropicModel_CachingDeterminism(t *testing.T) {
|
||||
makeReq := func(props *schemas.OrderedMap) *schemas.BifrostChatRequest {
|
||||
return &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Vertex,
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
Input: []schemas.ChatMessage{{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
|
||||
}},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "test",
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Version A: type before description
|
||||
propsA := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("chain_of_thought", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "string"),
|
||||
schemas.KV("description", "Reasoning"),
|
||||
)),
|
||||
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("type", "string"),
|
||||
schemas.KV("description", "The answer"),
|
||||
)),
|
||||
)
|
||||
|
||||
// Version B: description before type (different structural order)
|
||||
propsB := schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("chain_of_thought", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("description", "Reasoning"),
|
||||
schemas.KV("type", "string"),
|
||||
)),
|
||||
schemas.KV("answer", schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("description", "The answer"),
|
||||
schemas.KV("type", "string"),
|
||||
)),
|
||||
)
|
||||
|
||||
// Vertex delegates Anthropic models to anthropic.ToAnthropicChatRequest()
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
resultA, err := anthropic.ToAnthropicChatRequest(ctx, makeReq(propsA))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
resultB, err := anthropic.ToAnthropicChatRequest(ctx, makeReq(propsB))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
jsonA, err := schemas.Marshal(resultA.Tools[0].InputSchema)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal params A: %v", err)
|
||||
}
|
||||
jsonB, err := schemas.Marshal(resultB.Tools[0].InputSchema)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal params B: %v", err)
|
||||
}
|
||||
|
||||
// Caching: byte-identical JSON
|
||||
if string(jsonA) != string(jsonB) {
|
||||
t.Errorf("caching broken via Vertex→Anthropic path: same schema produced different JSON\nA: %s\nB: %s", jsonA, jsonB)
|
||||
}
|
||||
|
||||
// CoT: property order preserved
|
||||
keys := resultA.Tools[0].InputSchema.Properties.Keys()
|
||||
if len(keys) != 2 || keys[0] != "chain_of_thought" || keys[1] != "answer" {
|
||||
t.Errorf("expected property order [chain_of_thought, answer], got %v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVertex_AnthropicModel_PreservesPropertyOrder verifies that the
|
||||
// Vertex→Anthropic delegation path preserves user-defined property ordering.
|
||||
func TestVertex_AnthropicModel_PreservesPropertyOrder(t *testing.T) {
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Vertex,
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
Input: []schemas.ChatMessage{{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
|
||||
}},
|
||||
Params: &schemas.ChatParameters{
|
||||
Tools: []schemas.ChatTool{{
|
||||
Type: schemas.ChatToolTypeFunction,
|
||||
Function: &schemas.ChatToolFunction{
|
||||
Name: "AnswerResponseModel",
|
||||
Description: schemas.Ptr("Extract answer"),
|
||||
Parameters: &schemas.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: schemas.NewOrderedMapFromPairs(
|
||||
schemas.KV("chain_of_thought", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
|
||||
schemas.KV("answer", schemas.NewOrderedMapFromPairs(schemas.KV("type", "string"))),
|
||||
schemas.KV("citations", schemas.NewOrderedMapFromPairs(schemas.KV("type", "array"))),
|
||||
schemas.KV("is_unanswered", schemas.NewOrderedMapFromPairs(schemas.KV("type", "boolean"))),
|
||||
),
|
||||
Required: []string{"answer", "is_unanswered"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result, err := anthropic.ToAnthropicChatRequest(ctx, bifrostReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
keys := result.Tools[0].InputSchema.Properties.Keys()
|
||||
expected := []string{"chain_of_thought", "answer", "citations", "is_unanswered"}
|
||||
if len(keys) != len(expected) {
|
||||
t.Fatalf("expected %d properties, got %d: %v", len(expected), len(keys), keys)
|
||||
}
|
||||
for i, k := range expected {
|
||||
if keys[i] != k {
|
||||
t.Errorf("property %d: expected %q, got %q (full order: %v)", i, k, keys[i], keys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVertex_ToolInputKeyOrderPreservation verifies that tool call arguments
|
||||
// preserve their original key ordering through the Vertex→Anthropic delegation path.
|
||||
// TestVertex_ToolInputKeyOrderPreservation verifies that Vertex→Anthropic delegation
|
||||
// preserves the original key ordering of tool call arguments for prompt caching.
|
||||
// Tests multiple parallel tool calls with different key orderings per block.
|
||||
func TestVertex_ToolInputKeyOrderPreservation(t *testing.T) {
|
||||
bifrostReq := &schemas.BifrostChatRequest{
|
||||
Provider: schemas.Vertex,
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
Input: []schemas.ChatMessage{
|
||||
{
|
||||
Role: schemas.ChatMessageRoleUser,
|
||||
Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr("test")},
|
||||
},
|
||||
{
|
||||
Role: schemas.ChatMessageRoleAssistant,
|
||||
ChatAssistantMessage: &schemas.ChatAssistantMessage{
|
||||
ToolCalls: []schemas.ChatAssistantMessageToolCall{
|
||||
{
|
||||
Index: 0,
|
||||
Type: schemas.Ptr("function"),
|
||||
ID: schemas.Ptr("toolu_vrtx_001"),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("bash"),
|
||||
Arguments: `{"description":"Find references quickly","timeout":30000,"command":"grep -r auth_injector ."}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 1,
|
||||
Type: schemas.Ptr("function"),
|
||||
ID: schemas.Ptr("toolu_vrtx_002"),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("bash"),
|
||||
Arguments: `{"command":"git diff main...HEAD --stat","description":"Show diff"}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Index: 2,
|
||||
Type: schemas.Ptr("function"),
|
||||
ID: schemas.Ptr("toolu_vrtx_003"),
|
||||
Function: schemas.ChatAssistantMessageToolCallFunction{
|
||||
Name: schemas.Ptr("bash"),
|
||||
Arguments: `{"command":"git log main..HEAD","description":"Show commits"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := schemas.NewBifrostContextWithCancel(nil)
|
||||
defer cancel()
|
||||
result, err := anthropic.ToAnthropicChatRequest(ctx, bifrostReq)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Collect all tool_use content blocks
|
||||
var toolUseBlocks []struct{ jsonStr string }
|
||||
for _, msg := range result.Messages {
|
||||
for _, block := range msg.Content.ContentBlocks {
|
||||
if block.Type == "tool_use" {
|
||||
jsonBytes, _ := json.Marshal(block.Input)
|
||||
toolUseBlocks = append(toolUseBlocks, struct{ jsonStr string }{string(jsonBytes)})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolUseBlocks) != 3 {
|
||||
t.Fatalf("expected 3 tool_use blocks, got %d", len(toolUseBlocks))
|
||||
}
|
||||
|
||||
// Block 0: keys should be description, timeout, command (NOT alphabetical)
|
||||
s0 := toolUseBlocks[0].jsonStr
|
||||
if !(strings.Index(s0, "description") < strings.Index(s0, "timeout") &&
|
||||
strings.Index(s0, "timeout") < strings.Index(s0, "command")) {
|
||||
t.Errorf("block 0: key order not preserved, expected description < timeout < command in: %s", s0)
|
||||
}
|
||||
|
||||
// Block 1: keys should be command, description (NOT alphabetical)
|
||||
s1 := toolUseBlocks[1].jsonStr
|
||||
if !(strings.Index(s1, "command") < strings.Index(s1, "description")) {
|
||||
t.Errorf("block 1: key order not preserved, expected command < description in: %s", s1)
|
||||
}
|
||||
|
||||
// Block 2: keys should be command, description
|
||||
s2 := toolUseBlocks[2].jsonStr
|
||||
if !(strings.Index(s2, "command") < strings.Index(s2, "description")) {
|
||||
t.Errorf("block 2: key order not preserved, expected command < description in: %s", s2)
|
||||
}
|
||||
}
|
||||
81
core/providers/vertex/vertex_test.go
Normal file
81
core/providers/vertex/vertex_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package vertex_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/maximhq/bifrost/core/internal/llmtests"
|
||||
|
||||
"github.com/maximhq/bifrost/core/schemas"
|
||||
)
|
||||
|
||||
func TestVertex(t *testing.T) {
|
||||
t.Parallel()
|
||||
if strings.TrimSpace(os.Getenv("VERTEX_API_KEY")) == "" && (strings.TrimSpace(os.Getenv("VERTEX_PROJECT_ID")) == "" || strings.TrimSpace(os.Getenv("VERTEX_CREDENTIALS")) == "") {
|
||||
t.Skip("Skipping Vertex tests because VERTEX_API_KEY is not set and VERTEX_PROJECT_ID or VERTEX_CREDENTIALS is not set")
|
||||
}
|
||||
|
||||
client, ctx, cancel, err := llmtests.SetupTest()
|
||||
if err != nil {
|
||||
t.Fatalf("Error initializing test setup: %v", err)
|
||||
}
|
||||
defer cancel()
|
||||
defer client.Shutdown()
|
||||
|
||||
rerankModel := strings.TrimSpace(os.Getenv("VERTEX_RERANK_MODEL"))
|
||||
|
||||
testConfig := llmtests.ComprehensiveTestConfig{
|
||||
Provider: schemas.Vertex,
|
||||
ChatModel: "gemini-2.5-pro",
|
||||
PromptCachingModel: "claude-sonnet-4-5",
|
||||
VisionModel: "claude-sonnet-4-5",
|
||||
TextModel: "", // Vertex doesn't support text completion in newer models
|
||||
EmbeddingModel: "text-multilingual-embedding-002",
|
||||
RerankModel: rerankModel,
|
||||
ReasoningModel: "claude-4.5-haiku",
|
||||
ImageGenerationModel: "gemini-2.5-flash-image",
|
||||
ImageEditModel: "imagen-3.0-capability-001",
|
||||
VideoGenerationModel: "veo-3.1-generate-preview",
|
||||
Scenarios: llmtests.TestScenarios{
|
||||
TextCompletion: false, // Not supported
|
||||
SimpleChat: true,
|
||||
CompletionStream: true,
|
||||
MultiTurnConversation: true,
|
||||
ToolCalls: true,
|
||||
ToolCallsStreaming: true,
|
||||
MultipleToolCalls: true,
|
||||
MultipleToolCallsStreaming: true,
|
||||
End2EndToolCalling: true,
|
||||
AutomaticFunctionCall: true,
|
||||
ImageURL: false,
|
||||
ImageBase64: true,
|
||||
ImageGeneration: true,
|
||||
ImageGenerationStream: false,
|
||||
ImageEdit: true,
|
||||
VideoGeneration: false, // disabled for now because of long running operations
|
||||
VideoRetrieve: false,
|
||||
VideoRemix: false,
|
||||
VideoDownload: false,
|
||||
VideoList: false,
|
||||
VideoDelete: false,
|
||||
MultipleImages: true,
|
||||
CompleteEnd2End: true,
|
||||
FileBase64: true,
|
||||
Embedding: true,
|
||||
Rerank: rerankModel != "",
|
||||
Reasoning: true,
|
||||
PromptCaching: true,
|
||||
ListModels: false,
|
||||
CountTokens: true,
|
||||
StructuredOutputs: true, // Structured outputs with nullable enum support
|
||||
InterleavedThinking: true,
|
||||
EagerInputStreaming: true, // fine-grained-tool-streaming-2025-05-14 (GA on Vertex)
|
||||
ServerToolsViaOpenAIEndpoint: true, // web_search only on Vertex per Table 20 (web_fetch/code_execution skip)
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("VertexTests", func(t *testing.T) {
|
||||
llmtests.RunAllComprehensiveTests(t, client, ctx, testConfig)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user